diff --git a/opentech/apply/tests/factories.py b/opentech/apply/tests/factories.py
index c98bd05c71ffb6461bbb4d681e4bf19e0b4afe67..406c735adb05f186553c013226f1595972454d74 100644
--- a/opentech/apply/tests/factories.py
+++ b/opentech/apply/tests/factories.py
@@ -4,6 +4,28 @@ import factory
 from opentech.apply.workflow import Phase, Stage, Workflow
 
 
+class ListSubFactory(factory.SubFactory):
+    def __init__(self, *args, count=0, **kwargs):
+        self.count = count
+        super().__init__(*args, **kwargs)
+
+    def evaluate(self, *args, **kwargs):
+        if isinstance(self.count, factory.declarations.BaseDeclaration):
+            self.evaluated_count = self.count.evaluate(*args, **kwargs)
+        else:
+            self.evaluated_count = self.count
+
+        return super().evaluate(*args, **kwargs)
+
+    def generate(self, step, params):
+        subfactory = self.get_factory()
+        force_sequence = step.sequence if self.FORCE_SEQUENCE else None
+        return [
+            step.recurse(subfactory, params, force_sequence=force_sequence)
+            for _ in range(self.evaluated_count)
+        ]
+
+
 class PhaseFactory(factory.Factory):
     class Meta:
         model = Phase
@@ -21,7 +43,7 @@ class StageFactory(factory.Factory):
 
     name = factory.Faker('word')
     form = factory.LazyFunction(Form)
-    phases = factory.LazyAttribute(lambda o: [PhaseFactory() for _ in range(o.num_phases)])
+    phases = ListSubFactory(PhaseFactory, count=factory.SelfAttribute('num_phases'))
 
 
 class WorkflowFactory(factory.Factory):
@@ -33,4 +55,4 @@ class WorkflowFactory(factory.Factory):
         num_stages = factory.Faker('random_int', min=1, max=3)
 
     name = factory.Faker('word')
-    stages = factory.LazyAttribute(lambda o: [StageFactory() for _ in range(o.num_stages)])
+    stages = ListSubFactory(StageFactory, count=factory.SelfAttribute('num_stages'))
diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py
index 213420444ec4cff22b87e566ca1ea46f38abb79c..5300cc4464ac9954f9a52a67cbad57eef40954f2 100644
--- a/opentech/apply/tests/test_workflow.py
+++ b/opentech/apply/tests/test_workflow.py
@@ -1,3 +1,4 @@
+import itertools
 from django.test import SimpleTestCase
 from django.forms import Form
 
@@ -14,23 +15,17 @@ class TestWorkflowCreation(SimpleTestCase):
         self.assertEqual(workflow.name, name)
         self.assertCountEqual(workflow.stages, [stage])
 
-    def test_can_iterate_through_workflow(self):
-        stages = StageFactory.create_batch(2)
-        workflow = Workflow('two_stage', stages)
-        for stage, check in zip(workflow, stages):
-            self.assertEqual(stage, check)
-
-    def test_returns_first_stage_if_no_arg(self):
-        workflow = WorkflowFactory(num_stages=1)
-        self.assertEqual(workflow.next(), workflow.stages[0])
+    def test_returns_first_phase_if_no_arg(self):
+        workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
+        self.assertEqual(workflow.next(), workflow.stages[0].phases[0])
 
     def test_returns_none_if_no_next(self):
-        workflow = WorkflowFactory(num_stages=1)
-        self.assertEqual(workflow.next(workflow.stages[0]), None)
+        workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
+        self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None)
 
-    def test_returns_next_stage(self):
-        workflow = WorkflowFactory(num_stages=2)
-        self.assertEqual(workflow.next(workflow.stages[0]), workflow.stages[1])
+    def test_returns_next_phase(self):
+        workflow = WorkflowFactory(num_stages=2, stages__num_phases=1)
+        self.assertEqual(workflow.next(workflow.stages[0].phases[0]), workflow.stages[1].phases[0])
 
 
 class TestStageCreation(SimpleTestCase):
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index 414978eb59b430dcc82ce704020180e2acaa8311..07aad2966c331df9f0e87ebd7633f7aec6cfd081 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -3,38 +3,55 @@ from typing import Iterator, Iterable, Sequence, Union
 from django.forms import Form
 
 
-class Workflow(Iterable['Stage']):
+class Workflow:
     def __init__(self, name: str, stages: Sequence['Stage']) -> None:
         self.name = name
         self.stages = stages
-
-    def __iter__(self) -> Iterator['Stage']:
-        yield from self.stages
-
-    def next(self, current_stage: Union['Stage', None]=None) -> Union['Stage', None]:
-        if not current_stage:
-            return self.stages[0]
-
-        for i, stage in enumerate(self):
-            if stage == current_stage:
-                try:
-                    return self.stages[i + 1]
-                except IndexError:
-                    pass
-
-        return None
-
-
-class Stage:
-    def __init__(self, name: str, form: Form, phases: Sequence['Phase']) -> None:
+        self.mapping = {
+            str(phase): (i, j)
+            for i, stage in enumerate(stages)
+            for j, phase in enumerate(stage)
+        }
+
+    def current_index(self, phase: Union['Phase', str, None]):
+        if isinstance(phase, Phase):
+            phase = str(phase)
+        try:
+            return self.mapping[phase]
+        except KeyError:
+            return 0, -1
+
+    def next(self, current_phase: Union['Phase', str, None]=None) -> Union['Phase', None]:
+        stage_idx, phase_idx = self.current_index(current_phase)
+        try:
+            return self.stages[stage_idx].phases[phase_idx + 1]
+        except IndexError:
+            try:
+                return self.stages[stage_idx + 1].phases[0]
+            except IndexError:
+                return None
+
+
+class Stage(Iterable['Phase']):
+    def __init__(self, name: str, form: Form, phases: Sequence['Phase'],
+                 current_phase: Union['Phase', None]=None
+    ) -> None:
         self.name = name
         self.form = form
+        for phase in phases:
+            phase.stage = self
         self.phases = phases
 
     def __iter__(self) -> Iterator['Phase']:
         yield from self.phases
 
+    def __str__(self):
+        return self.name
+
 
 class Phase:
     def __init__(self, name: str) -> None:
         self.name = name
+
+    def __str__(self):
+        return '__'.join([self.stage.name, self.name])