diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index cc8a68758021da7c6838e696d8dfb915962d6ec0..fbbf0a09965d0ed7cee6ec14545273b1e0e0c06c 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -28,6 +28,10 @@ class TestWorkflowCreation(SimpleTestCase): workflow = WorkflowFactory(num_stages=1) self.assertEqual(workflow.next(), workflow.stages[0]) + def test_returns_none_if_no_next(self): + workflow = WorkflowFactory(num_stages=1) + self.assertEqual(workflow.next(workflow.stages[0]), None) + class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 9b624227e4d40278992a9c015b46b26536a30786..fe2f35860f10b176edb813caefd6eeec95c649d6 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,4 +1,4 @@ -from typing import Iterator, Iterable +from typing import Iterator, Iterable, Union from django.forms import Form @@ -13,8 +13,11 @@ class Workflow(Iterable['Stage']): def __iter__(self) -> Iterator['Stage']: yield from self.stages - def next(self): - return self.stages[0] + def next(self, current_stage: Union['Stage', None]=None) -> Union['Stage', None]: + if not current_stage: + return self.stages[0] + + return None class Stage: