Skip to content
Snippets Groups Projects
Commit db3b84f5 authored by Todd Dembrey's avatar Todd Dembrey
Browse files

Improve how phases are accessed in the workflow and fix the tests

parent dc732308
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,28 @@ import factory
from opentech.apply.workflow import Phase, Stage, Workflow
class ListSubFactory(factory.SubFactory):
def __init__(self, *args, count=0, **kwargs):
self.count = count
super().__init__(*args, **kwargs)
def evaluate(self, *args, **kwargs):
if isinstance(self.count, factory.declarations.BaseDeclaration):
self.evaluated_count = self.count.evaluate(*args, **kwargs)
else:
self.evaluated_count = self.count
return super().evaluate(*args, **kwargs)
def generate(self, step, params):
subfactory = self.get_factory()
force_sequence = step.sequence if self.FORCE_SEQUENCE else None
return [
step.recurse(subfactory, params, force_sequence=force_sequence)
for _ in range(self.evaluated_count)
]
class PhaseFactory(factory.Factory):
class Meta:
model = Phase
......@@ -21,7 +43,7 @@ class StageFactory(factory.Factory):
name = factory.Faker('word')
form = factory.LazyFunction(Form)
phases = factory.LazyAttribute(lambda o: [PhaseFactory() for _ in range(o.num_phases)])
phases = ListSubFactory(PhaseFactory, count=factory.SelfAttribute('num_phases'))
class WorkflowFactory(factory.Factory):
......@@ -33,4 +55,4 @@ class WorkflowFactory(factory.Factory):
num_stages = factory.Faker('random_int', min=1, max=3)
name = factory.Faker('word')
stages = factory.LazyAttribute(lambda o: [StageFactory() for _ in range(o.num_stages)])
stages = ListSubFactory(StageFactory, count=factory.SelfAttribute('num_stages'))
import itertools
from django.test import SimpleTestCase
from django.forms import Form
......@@ -14,23 +15,17 @@ class TestWorkflowCreation(SimpleTestCase):
self.assertEqual(workflow.name, name)
self.assertCountEqual(workflow.stages, [stage])
def test_can_iterate_through_workflow(self):
stages = StageFactory.create_batch(2)
workflow = Workflow('two_stage', stages)
for stage, check in zip(workflow, stages):
self.assertEqual(stage, check)
def test_returns_first_stage_if_no_arg(self):
workflow = WorkflowFactory(num_stages=1)
self.assertEqual(workflow.next(), workflow.stages[0])
def test_returns_first_phase_if_no_arg(self):
workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
self.assertEqual(workflow.next(), workflow.stages[0].phases[0])
def test_returns_none_if_no_next(self):
workflow = WorkflowFactory(num_stages=1)
self.assertEqual(workflow.next(workflow.stages[0]), None)
workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None)
def test_returns_next_stage(self):
workflow = WorkflowFactory(num_stages=2)
self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1])
def test_returns_next_phase(self):
workflow = WorkflowFactory(num_stages=2, stages__num_phases=1)
self.assertEqual(workflow.next(workflow.stages[0].phases[0]), workflow.stages[1].phases[0])
class TestStageCreation(SimpleTestCase):
......
......@@ -3,38 +3,55 @@ from typing import Iterator, Iterable, Sequence, Union
from django.forms import Form
class Workflow(Iterable['Stage']):
class Workflow:
def __init__(self, name: str, stages: Sequence['Stage']) -> None:
self.name = name
self.stages = stages
def __iter__(self) -> Iterator['Stage']:
yield from self.stages
def next(self, current_stage: Union['Stage', None]=None) -> Union['Stage', None]:
if not current_stage:
return self.stages[0]
for i, stage in enumerate(self):
if stage == current_stage:
try:
return self.stages[i + 1]
except IndexError:
pass
return None
class Stage:
def __init__(self, name: str, form: Form, phases: Sequence['Phase']) -> None:
self.mapping = {
str(phase): (i, j)
for i, stage in enumerate(stages)
for j, phase in enumerate(stage)
}
def current_index(self, phase: Union['Phase', str, None]):
if isinstance(phase, Phase):
phase = str(phase)
try:
return self.mapping[phase]
except KeyError:
return 0, -1
def next(self, current_phase: Union['Phase', str, None]=None) -> Union['Phase', None]:
stage_idx, phase_idx = self.current_index(current_phase)
try:
return self.stages[stage_idx].phases[phase_idx + 1]
except IndexError:
try:
return self.stages[stage_idx + 1].phases[0]
except IndexError:
return None
class Stage(Iterable['Phase']):
def __init__(self, name: str, form: Form, phases: Sequence['Phase'],
current_phase: Union['Phase', None]=None
) -> None:
self.name = name
self.form = form
for phase in phases:
phase.stage = self
self.phases = phases
def __iter__(self) -> Iterator['Phase']:
yield from self.phases
def __str__(self):
return self.name
class Phase:
def __init__(self, name: str) -> None:
self.name = name
def __str__(self):
return '__'.join([self.stage.name, self.name])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment