diff --git a/opentech/apply/tests/factories.py b/opentech/apply/tests/factories.py index c98bd05c71ffb6461bbb4d681e4bf19e0b4afe67..406c735adb05f186553c013226f1595972454d74 100644 --- a/opentech/apply/tests/factories.py +++ b/opentech/apply/tests/factories.py @@ -4,6 +4,28 @@ import factory from opentech.apply.workflow import Phase, Stage, Workflow +class ListSubFactory(factory.SubFactory): + def __init__(self, *args, count=0, **kwargs): + self.count = count + super().__init__(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + if isinstance(self.count, factory.declarations.BaseDeclaration): + self.evaluated_count = self.count.evaluate(*args, **kwargs) + else: + self.evaluated_count = self.count + + return super().evaluate(*args, **kwargs) + + def generate(self, step, params): + subfactory = self.get_factory() + force_sequence = step.sequence if self.FORCE_SEQUENCE else None + return [ + step.recurse(subfactory, params, force_sequence=force_sequence) + for _ in range(self.evaluated_count) + ] + + class PhaseFactory(factory.Factory): class Meta: model = Phase @@ -21,7 +43,7 @@ class StageFactory(factory.Factory): name = factory.Faker('word') form = factory.LazyFunction(Form) - phases = factory.LazyAttribute(lambda o: [PhaseFactory() for _ in range(o.num_phases)]) + phases = ListSubFactory(PhaseFactory, count=factory.SelfAttribute('num_phases')) class WorkflowFactory(factory.Factory): @@ -33,4 +55,4 @@ class WorkflowFactory(factory.Factory): num_stages = factory.Faker('random_int', min=1, max=3) name = factory.Faker('word') - stages = factory.LazyAttribute(lambda o: [StageFactory() for _ in range(o.num_stages)]) + stages = ListSubFactory(StageFactory, count=factory.SelfAttribute('num_stages')) diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py index 213420444ec4cff22b87e566ca1ea46f38abb79c..5300cc4464ac9954f9a52a67cbad57eef40954f2 100644 --- a/opentech/apply/tests/test_workflow.py +++ b/opentech/apply/tests/test_workflow.py @@ -1,3 +1,4 @@ +import itertools from django.test import SimpleTestCase from django.forms import Form @@ -14,23 +15,17 @@ class TestWorkflowCreation(SimpleTestCase): self.assertEqual(workflow.name, name) self.assertCountEqual(workflow.stages, [stage]) - def test_can_iterate_through_workflow(self): - stages = StageFactory.create_batch(2) - workflow = Workflow('two_stage', stages) - for stage, check in zip(workflow, stages): - self.assertEqual(stage, check) - - def test_returns_first_stage_if_no_arg(self): - workflow = WorkflowFactory(num_stages=1) - self.assertEqual(workflow.next(), workflow.stages[0]) + def test_returns_first_phase_if_no_arg(self): + workflow = WorkflowFactory(num_stages=1, stages__num_phases=1) + self.assertEqual(workflow.next(), workflow.stages[0].phases[0]) def test_returns_none_if_no_next(self): - workflow = WorkflowFactory(num_stages=1) - self.assertEqual(workflow.next(workflow.stages[0]), None) + workflow = WorkflowFactory(num_stages=1, stages__num_phases=1) + self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None) - def test_returns_next_stage(self): - workflow = WorkflowFactory(num_stages=2) - self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1]) + def test_returns_next_phase(self): + workflow = WorkflowFactory(num_stages=2, stages__num_phases=1) + self.assertEqual(workflow.next(workflow.stages[0].phases[0]), workflow.stages[1].phases[0]) class TestStageCreation(SimpleTestCase): diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 414978eb59b430dcc82ce704020180e2acaa8311..07aad2966c331df9f0e87ebd7633f7aec6cfd081 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -3,38 +3,55 @@ from typing import Iterator, Iterable, Sequence, Union from django.forms import Form -class Workflow(Iterable['Stage']): +class Workflow: def __init__(self, name: str, stages: Sequence['Stage']) -> None: self.name = name self.stages = stages - - def __iter__(self) -> Iterator['Stage']: - yield from self.stages - - def next(self, current_stage: Union['Stage', None]=None) -> Union['Stage', None]: - if not current_stage: - return self.stages[0] - - for i, stage in enumerate(self): - if stage == current_stage: - try: - return self.stages[i + 1] - except IndexError: - pass - - return None - - -class Stage: - def __init__(self, name: str, form: Form, phases: Sequence['Phase']) -> None: + self.mapping = { + str(phase): (i, j) + for i, stage in enumerate(stages) + for j, phase in enumerate(stage) + } + + def current_index(self, phase: Union['Phase', str, None]): + if isinstance(phase, Phase): + phase = str(phase) + try: + return self.mapping[phase] + except KeyError: + return 0, -1 + + def next(self, current_phase: Union['Phase', str, None]=None) -> Union['Phase', None]: + stage_idx, phase_idx = self.current_index(current_phase) + try: + return self.stages[stage_idx].phases[phase_idx + 1] + except IndexError: + try: + return self.stages[stage_idx + 1].phases[0] + except IndexError: + return None + + +class Stage(Iterable['Phase']): + def __init__(self, name: str, form: Form, phases: Sequence['Phase'], + current_phase: Union['Phase', None]=None + ) -> None: self.name = name self.form = form + for phase in phases: + phase.stage = self self.phases = phases def __iter__(self) -> Iterator['Phase']: yield from self.phases + def __str__(self): + return self.name + class Phase: def __init__(self, name: str) -> None: self.name = name + + def __str__(self): + return '__'.join([self.stage.name, self.name])