diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index b44549f67558da5da2a089d9772b3f729dec58d6..3435937db3b0e5018a646e9629613c5f3a934ebb 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -1,4 +1,5 @@ from django.test import SimpleTestCase +from django.forms import Form from .workflow import Workflow, Stage @@ -6,7 +7,7 @@ from .workflow import Workflow, Stage class TestWorkflowCreation(SimpleTestCase): def test_can_create_workflow(self): name = 'single_stage' - stage = Stage('stage_name') + stage = Stage('stage_name', Form()) workflow = Workflow(name, stage) self.assertEqual(workflow.name, name) self.assertCountEqual(workflow.stages, [stage]) @@ -17,8 +18,8 @@ class TestWorkflowCreation(SimpleTestCase): Workflow(name) def test_can_iterate_through_workflow(self): - stage1 = Stage('stage_one') - stage2 = Stage('stage_two') + stage1 = Stage('stage_one', Form()) + stage2 = Stage('stage_two', Form()) workflow = Workflow('two_stage', stage1, stage2) for stage, check in zip(workflow, [stage1, stage2]): self.assertEqual(stage, check) @@ -27,5 +28,7 @@ class TestWorkflowCreation(SimpleTestCase): class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): name = 'the_stage' - stage = Stage(name) + form = Form() + stage = Stage(name, form) self.assertEqual(stage.name, name) + self.assertEqual(stage.form, form) diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 001ce395697976c1bc21619e11d9380835be9b4a..d24e47187a5ee8fc2d6545d71b49f57928ef00ab 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,5 +1,7 @@ from typing import Iterator, Iterable +from django.forms import Form + class Workflow(Iterable['Stage']): def __init__(self, name: str, *stages: 'Stage') -> None: @@ -13,5 +15,6 @@ class Workflow(Iterable['Stage']): class Stage: - def __init__(self, name: str) -> None: + def __init__(self, name: str, form: Form) -> None: self.name = name + self.form = form