From c7e5201b65437ed9d7bbe3e0e98bea341ea2db23 Mon Sep 17 00:00:00 2001
From: Todd Dembrey <todd.dembrey@torchbox.com>
Date: Fri, 15 Dec 2017 15:27:23 +0000
Subject: [PATCH] Allow the action to be callable

---
 opentech/apply/tests/test_workflow.py | 18 +++++++++++++++++-
 opentech/apply/workflow.py            | 24 ++++++++++++++++++++----
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py
index d529959ed..f313d7173 100644
--- a/opentech/apply/tests/test_workflow.py
+++ b/opentech/apply/tests/test_workflow.py
@@ -18,6 +18,11 @@ class TestWorkflowCreation(SimpleTestCase):
         workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
         self.assertEqual(workflow.next(), workflow.stages[0].phases[0])
 
+    def test_can_get_the_current_phase(self):
+        workflow = WorkflowFactory(num_stages=1, stages__num_phases=2)
+        phase = workflow.stages[0].phases[0]
+        self.assertEqual(workflow.current(str(phase)), phase)
+
     def test_returns_none_if_no_next(self):
         workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
         self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None)
@@ -53,7 +58,13 @@ class TestPhaseCreation(SimpleTestCase):
         name = 'the_phase'
         phase = Phase(name, actions)
         self.assertEqual(phase.name, name)
-        self.assertEqual(phase.actions, actions)
+        self.assertEqual(phase.actions, {action.name: action for action in actions})
+
+    def test_can_get_action_from_phase(self):
+        actions = ActionFactory.create_batch(3)
+        action = actions[1]
+        phase = PhaseFactory(actions=actions)
+        self.assertEqual(phase[action.name], action)
 
 
 class TestActions(SimpleTestCase):
@@ -61,3 +72,8 @@ class TestActions(SimpleTestCase):
         name = 'action stations'
         action = Action(name)
         self.assertEqual(action.name, name)
+
+    def test_calling_processes_the_action(self):
+        action = ActionFactory()
+        with self.assertRaises(NotImplementedError):
+            action()
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index 4ea3bf407..100393b30 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -26,13 +26,19 @@ class Workflow:
         except KeyError:
             return 0, -1
 
+    def __getitem__(self, value):
+        return self.stages[value[0]].phases[value[1]]
+
+    def current(self, current_phase: str):
+        return self[self.current_index(current_phase)]
+
     def next(self, current_phase: Union['Phase', str]=None) -> Union['Phase', None]:
         stage_idx, phase_idx = self.current_index(current_phase)
         try:
-            return self.stages[stage_idx].phases[phase_idx + 1]
+            return self[stage_idx, phase_idx + 1]
         except IndexError:
             try:
-                return self.stages[stage_idx + 1].phases[0]
+                return self[stage_idx + 1, 0]
             except IndexError:
                 return None
 
@@ -57,17 +63,27 @@ class Phase:
     def __init__(self, name: str, actions: Sequence['Action']) -> None:
         self.name = name
         self.stage: Union['Stage', None] = None
-        self.actions = actions
-        self.occurance = 0
+        self.actions = {action.name: action for action in actions}
+        self.occurance: int = 0
 
     def __str__(self):
         return '__'.join([self.stage.name, self.name, str(self.occurance)])
 
+    def __getitem__(self, value):
+        return self.actions[value]
+
 
 class Action:
     def __init__(self, name: str) -> None:
         self.name = name
 
+    def __call__(self, *args, **kwargs):
+        return self.process(*args, **kwargs)
+
+    def process(self, *args, **kwargs):
+        # Use this to define the behaviour of the action
+        raise NotImplementedError
+
 
 # --- OTF Workflow ---
 
-- 
GitLab