From f8cd3d823d3d596a6e8b6ba5808f7b26033f1261 Mon Sep 17 00:00:00 2001 From: Todd Dembrey <todd.dembrey@torchbox.com> Date: Tue, 12 Dec 2017 12:04:09 +0000 Subject: [PATCH] Make the workflow iterable --- opentech/apply/tests.py | 7 +++++++ opentech/apply/workflow.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index 423b87a20..b44549f67 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -16,6 +16,13 @@ class TestWorkflowCreation(SimpleTestCase): with self.assertRaises(ValueError): Workflow(name) + def test_can_iterate_through_workflow(self): + stage1 = Stage('stage_one') + stage2 = Stage('stage_two') + workflow = Workflow('two_stage', stage1, stage2) + for stage, check in zip(workflow, [stage1, stage2]): + self.assertEqual(stage, check) + class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 173156685..001ce3956 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,10 +1,16 @@ -class Workflow: - def __init__(self, name: str, *stages: Stage) -> None: +from typing import Iterator, Iterable + + +class Workflow(Iterable['Stage']): + def __init__(self, name: str, *stages: 'Stage') -> None: self.name = name if not stages: raise ValueError('Stages must be supplied') self.stages = stages + def __iter__(self) -> Iterator['Stage']: + yield from self.stages + class Stage: def __init__(self, name: str) -> None: -- GitLab