From fab0b6c8d13dd550b2b8c169abe5ffc76a7f32e2 Mon Sep 17 00:00:00 2001 From: Todd Dembrey <todd.dembrey@torchbox.com> Date: Wed, 13 Dec 2017 16:18:42 +0000 Subject: [PATCH] Provide a sequence of stages for the workflow --- opentech/apply/tests.py | 12 +++++------- opentech/apply/workflow.py | 13 +++++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py index fbbf0a099..36cbe10a6 100644 --- a/opentech/apply/tests.py +++ b/opentech/apply/tests.py @@ -9,18 +9,13 @@ class TestWorkflowCreation(SimpleTestCase): def test_can_create_workflow(self): name = 'single_stage' stage = StageFactory() - workflow = Workflow(name, stage) + workflow = Workflow(name, [stage]) self.assertEqual(workflow.name, name) self.assertCountEqual(workflow.stages, [stage]) - def test_stages_required_for_workflow(self): - name = 'single_stage' - with self.assertRaises(ValueError): - Workflow(name) - def test_can_iterate_through_workflow(self): stages = StageFactory.create_batch(2) - workflow = Workflow('two_stage', *stages) + workflow = Workflow('two_stage', stages) for stage, check in zip(workflow, stages): self.assertEqual(stage, check) @@ -32,6 +27,9 @@ class TestWorkflowCreation(SimpleTestCase): workflow = WorkflowFactory(num_stages=1) self.assertEqual(workflow.next(workflow.stages[0]), None) + def test_returns_next_stage(self): + workflow = WorkflowFactory(num_stages=2) + self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1]) class TestStageCreation(SimpleTestCase): def test_can_create_stage(self): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index fe2f35860..7232ef447 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,13 +1,11 @@ -from typing import Iterator, Iterable, Union +from typing import Iterator, Iterable, Sequence, Union from django.forms import Form class Workflow(Iterable['Stage']): - def __init__(self, name: str, *stages: 'Stage') -> None: + def __init__(self, name: str, stages: Sequence['Stage']) -> None: self.name = name - if not stages: - raise ValueError('Stages must be supplied') self.stages = stages def __iter__(self) -> Iterator['Stage']: @@ -17,6 +15,13 @@ class Workflow(Iterable['Stage']): if not current_stage: return self.stages[0] + for i, stage in enumerate(self): + if stage == current_stage: + try: + return self.stages[i+1] + except IndexError: + pass + return None -- GitLab