diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py index d529959edb2e508153f9d32ce6f009098c794b40..f313d7173c23b887829634e22e4e1e31e327c40d 100644 --- a/opentech/apply/tests/test_workflow.py +++ b/opentech/apply/tests/test_workflow.py @@ -18,6 +18,11 @@ class TestWorkflowCreation(SimpleTestCase): workflow = WorkflowFactory(num_stages=1, stages__num_phases=1) self.assertEqual(workflow.next(), workflow.stages[0].phases[0]) + def test_can_get_the_current_phase(self): + workflow = WorkflowFactory(num_stages=1, stages__num_phases=2) + phase = workflow.stages[0].phases[0] + self.assertEqual(workflow.current(str(phase)), phase) + def test_returns_none_if_no_next(self): workflow = WorkflowFactory(num_stages=1, stages__num_phases=1) self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None) @@ -53,7 +58,13 @@ class TestPhaseCreation(SimpleTestCase): name = 'the_phase' phase = Phase(name, actions) self.assertEqual(phase.name, name) - self.assertEqual(phase.actions, actions) + self.assertEqual(phase.actions, {action.name: action for action in actions}) + + def test_can_get_action_from_phase(self): + actions = ActionFactory.create_batch(3) + action = actions[1] + phase = PhaseFactory(actions=actions) + self.assertEqual(phase[action.name], action) class TestActions(SimpleTestCase): @@ -61,3 +72,8 @@ class TestActions(SimpleTestCase): name = 'action stations' action = Action(name) self.assertEqual(action.name, name) + + def test_calling_processes_the_action(self): + action = ActionFactory() + with self.assertRaises(NotImplementedError): + action() diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 4ea3bf407e3f4fee8749a8268b534f210e9d07c0..100393b30803857fa6d0282c12b8658bb0f7779d 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -26,13 +26,19 @@ class Workflow: except KeyError: return 0, -1 + def __getitem__(self, value): + return self.stages[value[0]].phases[value[1]] + + def current(self, current_phase: str): + return self[self.current_index(current_phase)] + def next(self, current_phase: Union['Phase', str]=None) -> Union['Phase', None]: stage_idx, phase_idx = self.current_index(current_phase) try: - return self.stages[stage_idx].phases[phase_idx + 1] + return self[stage_idx, phase_idx + 1] except IndexError: try: - return self.stages[stage_idx + 1].phases[0] + return self[stage_idx + 1, 0] except IndexError: return None @@ -57,17 +63,27 @@ class Phase: def __init__(self, name: str, actions: Sequence['Action']) -> None: self.name = name self.stage: Union['Stage', None] = None - self.actions = actions - self.occurance = 0 + self.actions = {action.name: action for action in actions} + self.occurance: int = 0 def __str__(self): return '__'.join([self.stage.name, self.name, str(self.occurance)]) + def __getitem__(self, value): + return self.actions[value] + class Action: def __init__(self, name: str) -> None: self.name = name + def __call__(self, *args, **kwargs): + return self.process(*args, **kwargs) + + def process(self, *args, **kwargs): + # Use this to define the behaviour of the action + raise NotImplementedError + # --- OTF Workflow ---