From f51fc9303593e84ee6b002488bf2d0889056db60 Mon Sep 17 00:00:00 2001
From: Todd Dembrey <todd.dembrey@torchbox.com>
Date: Thu, 15 Mar 2018 12:29:38 +0000
Subject: [PATCH] Copy the new form fields onto the submission

---
 opentech/apply/funds/models.py                | 10 +++++--
 .../apply/funds/tests/factories/models.py     | 28 ++++++++++++++++++-
 .../apply/funds/tests/factories/workflows.py  |  3 ++
 opentech/apply/funds/tests/test_models.py     | 10 ++++---
 opentech/apply/funds/tests/test_workflow.py   |  4 +--
 5 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/opentech/apply/funds/models.py b/opentech/apply/funds/models.py
index 556e0bb66..34d9dc523 100644
--- a/opentech/apply/funds/models.py
+++ b/opentech/apply/funds/models.py
@@ -111,9 +111,12 @@ class WorkflowStreamForm(WorkflowHelpers, AbstractStreamForm):  # type: ignore
     class Meta:
         abstract = True
 
-    def get_defined_fields(self):
-        # Only return the first form, will need updating for when working with 2 stage WF
-        return self.forms.all()[0].fields
+    def get_defined_fields(self, stage=None):
+        if not stage:
+            form_index = 0
+        else:
+            form_index = self.workflow.stages.index(stage)
+        return self.forms.all()[form_index].fields
 
     content_panels = AbstractStreamForm.content_panels + [
         FieldPanel('workflow_name'),
@@ -601,6 +604,7 @@ class ApplicationSubmission(WorkflowHelpers, AbstractFormSubmission):
 
             self.id = None
             self.status = str(self.workflow.next(self.status))
+            self.form_fields = self.round.get_defined_fields(self.stage)
 
             super().save(*args, **kwargs)
 
diff --git a/opentech/apply/funds/tests/factories/models.py b/opentech/apply/funds/tests/factories/models.py
index 87d4ffcf5..edceb91f9 100644
--- a/opentech/apply/funds/tests/factories/models.py
+++ b/opentech/apply/funds/tests/factories/models.py
@@ -104,6 +104,17 @@ class RoundFactory(wagtail_factories.PageFactory):
     end_date = factory.LazyFunction(lambda: datetime.date.today() + datetime.timedelta(days=7))
     lead = factory.SubFactory(UserFactory, groups__name=STAFF_GROUP_NAME)
 
+    @factory.post_generation
+    def forms(self, create, extracted, **kwargs):
+        if create:
+            fields = build_form(kwargs, prefix='form')
+            for _ in range(len(self.workflow_class.stage_classes)):
+                # Generate a form based on all defined fields on the model
+                RoundFormFactory(
+                    round=self,
+                    **fields,
+                )
+
 
 class RoundFormFactory(AbstractRelatedFormFactory):
     class Meta:
@@ -123,6 +134,17 @@ class LabFactory(wagtail_factories.PageFactory):
     workflow_name = factory.LazyAttribute(lambda o: list(FundType.WORKFLOWS.keys())[o.workflow_stages - 1])
     lead = factory.SubFactory(UserFactory, groups__name=STAFF_GROUP_NAME)
 
+    @factory.post_generation
+    def forms(self, create, extracted, **kwargs):
+        if create:
+            fields = build_form(kwargs, prefix='form')
+            for _ in range(len(self.workflow_class.stage_classes)):
+                # Generate a form based on all defined fields on the model
+                LabFormFactory(
+                    lab=self,
+                    **fields,
+                )
+
 
 class LabFormFactory(AbstractRelatedFormFactory):
     class Meta:
@@ -163,10 +185,14 @@ class ApplicationSubmissionFactory(factory.DjangoModelFactory):
     class Meta:
         model = ApplicationSubmission
 
+    class Params:
+        workflow_stages = 1
+
     form_fields = blocks.CustomFormFieldsFactory
     form_data = factory.SubFactory(FormDataFactory, form_fields=factory.SelfAttribute('..form_fields'))
     page = factory.SubFactory(FundTypeFactory)
-    round = factory.SubFactory(RoundFactory)
+    workflow_name = factory.LazyAttribute(lambda o: list(FundType.WORKFLOWS.keys())[o.workflow_stages - 1])
+    round = factory.SubFactory(RoundFactory, workflow_name=factory.SelfAttribute('..workflow_name'))
     user = factory.SubFactory(UserFactory)
 
     @classmethod
diff --git a/opentech/apply/funds/tests/factories/workflows.py b/opentech/apply/funds/tests/factories/workflows.py
index 009922275..31815c1b7 100644
--- a/opentech/apply/funds/tests/factories/workflows.py
+++ b/opentech/apply/funds/tests/factories/workflows.py
@@ -90,6 +90,9 @@ class StageFactory(factory.Factory):
         phases = kwargs.pop('phases')
         name = kwargs.pop('name')
         new_class = type(model_class.__name__, (model_class,), {'phases': phases, 'name': name})
+
+        # Pretend we have a workflow object, only used for __le__
+        kwargs['workflow'] = None
         return new_class(*args, **kwargs)
 
 
diff --git a/opentech/apply/funds/tests/test_models.py b/opentech/apply/funds/tests/test_models.py
index c429bb6b9..644632f2b 100644
--- a/opentech/apply/funds/tests/test_models.py
+++ b/opentech/apply/funds/tests/test_models.py
@@ -201,9 +201,7 @@ class TestFormSubmission(TestCase):
         self.site.save()
 
         self.round_page = RoundFactory(parent=fund)
-        RoundFormFactory(round=self.round_page, form=form)
         self.lab_page = LabFactory(lead=self.round_page.lead)
-        LabFormFactory(lab=self.lab_page, form=form)
 
     def submit_form(self, page=None, email=None, name=None, user=AnonymousUser()):
         if email is None:
@@ -213,7 +211,7 @@ class TestFormSubmission(TestCase):
 
         page = page or self.round_page
         fields = page.get_form_fields()
-        data = {k: v for k, v in zip(fields, [email, name, 'project'])}
+        data = {k: v for k, v in zip(fields, ['project', email, name])}
 
         request = self.request_factory.post('', data)
         request.user = user
@@ -378,7 +376,7 @@ class TestApplicationSubmission(TestCase):
 
 class TestApplicationProgression(TestCase):
     def test_new_submission_created(self):
-        submission = ApplicationSubmissionFactory(round__workflow_name='double')
+        submission = ApplicationSubmissionFactory(workflow_name='double')
         self.assertEqual(ApplicationSubmission.objects.count(), 1)
         old_id = submission.id
 
@@ -391,3 +389,7 @@ class TestApplicationProgression(TestCase):
         self.assertEqual(ApplicationSubmission.objects.count(), 2)
         self.assertEqual(submission.previous, old_submission)
         self.assertEqual(old_submission.next, submission)
+
+        form_fields = submission.round.forms.all()[1].fields
+
+        self.assertEqual(submission.form_fields, form_fields)
diff --git a/opentech/apply/funds/tests/test_workflow.py b/opentech/apply/funds/tests/test_workflow.py
index 7fd8a08ed..30a0226ea 100644
--- a/opentech/apply/funds/tests/test_workflow.py
+++ b/opentech/apply/funds/tests/test_workflow.py
@@ -54,7 +54,7 @@ class TestStageCreation(SimpleTestCase):
     def test_can_create_stage(self):
         name = 'the_stage'
         form = Form()
-        stage = Stage(form, name=name)
+        stage = Stage(form, None, name=name)
         self.assertEqual(stage.name, name)
         self.assertEqual(stage.form, form)
 
@@ -73,7 +73,7 @@ class TestStageCreation(SimpleTestCase):
                 [first_phase, second_phase],
             ]
 
-        stage = MultiPhaseStep(None)
+        stage = MultiPhaseStep(None, None)
         self.assertEqual(stage.steps, 2)
 
         current_phase = stage.phases[0]
-- 
GitLab