diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index 423b87a20ca71505938bff4216fa13fd4574f03b..b44549f67558da5da2a089d9772b3f729dec58d6 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -16,6 +16,13 @@ class TestWorkflowCreation(SimpleTestCase): with self.assertRaises(ValueError): Workflow(name) + def test_can_iterate_through_workflow(self): + stage1 = Stage('stage_one') + stage2 = Stage('stage_two') + workflow = Workflow('two_stage', stage1, stage2) + for stage, check in zip(workflow, [stage1, stage2]): + self.assertEqual(stage, check) + class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 173156685ba691355edba24787820d60f9f72d93..001ce395697976c1bc21619e11d9380835be9b4a 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,10 +1,16 @@ -class Workflow: - def __init__(self, name: str, *stages: Stage) -> None: +from typing import Iterator, Iterable + + +class Workflow(Iterable['Stage']): + def __init__(self, name: str, *stages: 'Stage') -> None: self.name = name if not stages: raise ValueError('Stages must be supplied') self.stages = stages + def __iter__(self) -> Iterator['Stage']: + yield from self.stages + class Stage: def __init__(self, name: str) -> None: