diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index fbbf0a09965d0ed7cee6ec14545273b1e0e0c06c..36cbe10a62650091d53e361920c24dcc421b6400 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -9,18 +9,13 @@ class TestWorkflowCreation(SimpleTestCase): def test_can_create_workflow(self): name = 'single_stage' stage = StageFactory() - workflow = Workflow(name, stage) + workflow = Workflow(name, [stage]) self.assertEqual(workflow.name, name) self.assertCountEqual(workflow.stages, [stage]) - def test_stages_required_for_workflow(self): - name = 'single_stage' - with self.assertRaises(ValueError): - Workflow(name) - def test_can_iterate_through_workflow(self): stages = StageFactory.create_batch(2) - workflow = Workflow('two_stage', *stages) + workflow = Workflow('two_stage', stages) for stage, check in zip(workflow, stages): self.assertEqual(stage, check) @@ -32,6 +27,9 @@ class TestWorkflowCreation(SimpleTestCase): workflow = WorkflowFactory(num_stages=1) self.assertEqual(workflow.next(workflow.stages[0]), None) + def test_returns_next_stage(self): + workflow = WorkflowFactory(num_stages=2) + self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1]) class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index fe2f35860f10b176edb813caefd6eeec95c649d6..7232ef4472e9a16bcfc4878ebedfa7ac35ed53ca 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,13 +1,11 @@ -from typing import Iterator, Iterable, Union +from typing import Iterator, Iterable, Sequence, Union from django.forms import Form class Workflow(Iterable['Stage']): - def __init__(self, name: str, *stages: 'Stage') -> None: + def __init__(self, name: str, stages: Sequence['Stage']) -> None: self.name = name - if not stages: - raise ValueError('Stages must be supplied') self.stages = stages def __iter__(self) -> Iterator['Stage']: @@ -17,6 +15,13 @@ class Workflow(Iterable['Stage']): if not current_stage: return self.stages[0] + for i, stage in enumerate(self): + if stage == current_stage: + try: + return self.stages[i+1] + except IndexError: + pass + return None