From 65fd03007f3584f91ce4a8c410b330b1d4b0e274 Mon Sep 17 00:00:00 2001
From: Todd Dembrey <todd.dembrey@torchbox.com>
Date: Thu, 20 Dec 2018 12:01:02 +0000
Subject: [PATCH] Add test to cover accessing non lab/round pages

---
 opentech/apply/funds/tests/factories/models.py | 11 ++++++-----
 opentech/apply/funds/tests/test_views.py       | 12 ++++++++++++
 2 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/opentech/apply/funds/tests/factories/models.py b/opentech/apply/funds/tests/factories/models.py
index e678dd8d7..ecb96df5b 100644
--- a/opentech/apply/funds/tests/factories/models.py
+++ b/opentech/apply/funds/tests/factories/models.py
@@ -76,11 +76,7 @@ class AbstractApplicationFactory(wagtail_factories.PageFactory):
             if extracted_parent and parent_kwargs:
                 raise ValueError('Cant pass a parent instance and attributes')
 
-            if not extracted_parent:
-                parent = ApplyHomePageFactory(**parent_kwargs)
-            else:
-                # Assume root node if no parent passed
-                parent = extracted_parent
+            parent = extracted_parent or ApplyHomePageFactory(**parent_kwargs)
 
             parent.add_child(instance=self)
 
@@ -144,6 +140,11 @@ class RoundFactory(wagtail_factories.PageFactory):
     end_date = factory.Sequence(lambda n: datetime.date.today() + datetime.timedelta(days=7 * (n + 1)))
     lead = factory.SubFactory(StaffFactory)
 
+    @factory.post_generation
+    def parent(self, create, extracted_parent, **parent_kwargs):
+        parent = extracted_parent or FundTypeFactory(**parent_kwargs)
+        parent.add_child(instance=self)
+
     @factory.post_generation
     def forms(self, create, extracted, **kwargs):
         if create:
diff --git a/opentech/apply/funds/tests/test_views.py b/opentech/apply/funds/tests/test_views.py
index 5f6643b3b..2b1cc1c3b 100644
--- a/opentech/apply/funds/tests/test_views.py
+++ b/opentech/apply/funds/tests/test_views.py
@@ -572,6 +572,12 @@ class TestStaffSubmissionByRound(ByRoundTestCase):
         response = self.get_page(new_lab)
         self.assertContains(response, new_lab.title)
 
+    def test_cant_access_normal_page(self):
+        new_round = RoundFactory()
+        page = new_round.get_site().root_page
+        response = self.get_page(page)
+        self.assertEqual(response.status_code, 404)
+
 
 class TestApplicantSubmissionByRound(ByRoundTestCase):
     user_factory = UserFactory
@@ -585,3 +591,9 @@ class TestApplicantSubmissionByRound(ByRoundTestCase):
         new_lab = LabFactory()
         response = self.get_page(new_lab)
         self.assertEqual(response.status_code, 403)
+
+    def test_cant_access_normal_page(self):
+        new_round = RoundFactory()
+        page = new_round.get_site().root_page
+        response = self.get_page(page)
+        self.assertEqual(response.status_code, 403)
-- 
GitLab