From fab0b6c8d13dd550b2b8c169abe5ffc76a7f32e2 Mon Sep 17 00:00:00 2001
From: Todd Dembrey <todd.dembrey@torchbox.com>
Date: Wed, 13 Dec 2017 16:18:42 +0000
Subject: [PATCH] Provide a sequence of stages for the workflow

---
 opentech/apply/tests.py    | 12 +++++-------
 opentech/apply/workflow.py | 13 +++++++++----
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py
index fbbf0a099..36cbe10a6 100644
--- a/opentech/apply/tests.py
+++ b/opentech/apply/tests.py
@@ -9,18 +9,13 @@ class TestWorkflowCreation(SimpleTestCase):
     def test_can_create_workflow(self):
         name = 'single_stage'
         stage = StageFactory()
-        workflow = Workflow(name, stage)
+        workflow = Workflow(name, [stage])
         self.assertEqual(workflow.name, name)
         self.assertCountEqual(workflow.stages, [stage])
 
-    def test_stages_required_for_workflow(self):
-        name = 'single_stage'
-        with self.assertRaises(ValueError):
-            Workflow(name)
-
     def test_can_iterate_through_workflow(self):
         stages = StageFactory.create_batch(2)
-        workflow = Workflow('two_stage', *stages)
+        workflow = Workflow('two_stage', stages)
         for stage, check in zip(workflow, stages):
             self.assertEqual(stage, check)
 
@@ -32,6 +27,9 @@ class TestWorkflowCreation(SimpleTestCase):
         workflow = WorkflowFactory(num_stages=1)
         self.assertEqual(workflow.next(workflow.stages[0]), None)
 
+    def test_returns_next_stage(self):
+        workflow = WorkflowFactory(num_stages=2)
+        self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1])
 
 class TestStageCreation(SimpleTestCase):
     def test_can_create_stage(self):
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index fe2f35860..7232ef447 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -1,13 +1,11 @@
-from typing import Iterator, Iterable, Union
+from typing import Iterator, Iterable, Sequence, Union
 
 from django.forms import Form
 
 
 class Workflow(Iterable['Stage']):
-    def __init__(self, name: str, *stages: 'Stage') -> None:
+    def __init__(self, name: str, stages: Sequence['Stage']) -> None:
         self.name = name
-        if not stages:
-            raise ValueError('Stages must be supplied')
         self.stages = stages
 
     def __iter__(self) -> Iterator['Stage']:
@@ -17,6 +15,13 @@ class Workflow(Iterable['Stage']):
         if not current_stage:
             return self.stages[0]
 
+        for i, stage in enumerate(self):
+            if stage == current_stage:
+                try:
+                    return self.stages[i+1]
+                except IndexError:
+                    pass
+
         return None
 
 
-- 
GitLab