diff --git a/opentech/apply/tests/factories.py b/opentech/apply/tests/factories.py
index 51aceef9f7e32ab7c26fb2eb037ad8d72de0b41a..79fa5ed29e034c1c2639a52a0245e08cf20f661a 100644
--- a/opentech/apply/tests/factories.py
+++ b/opentech/apply/tests/factories.py
@@ -43,6 +43,11 @@ class PhaseFactory(factory.Factory):
     name = factory.Faker('word')
     actions = ListSubFactory(ActionFactory, count=factory.SelfAttribute('num_actions'))
 
+    @classmethod
+    def _create(cls, model_class, *args, **kwargs):
+        actions = kwargs.pop('actions')
+        new_class = type(model_class.__name__, (model_class,), {'actions': actions})
+        return new_class(*args, **kwargs)
 
 class StageFactory(factory.Factory):
     class Meta:
diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py
index bd424b30ad375ae98324efbaa0efa8e39a06f301..aab4f8d83c27c1e057e538ba471dc76a98c290c5 100644
--- a/opentech/apply/tests/test_workflow.py
+++ b/opentech/apply/tests/test_workflow.py
@@ -23,6 +23,10 @@ class TestWorkflowCreation(SimpleTestCase):
         phase = workflow.stages[0].phases[0]
         self.assertEqual(workflow.current(str(phase)), phase)
 
+    def test_returns_next_stage(self):
+        workflow = WorkflowFactory(num_stages=2, stages__num_phases=1)
+        self.assertEqual(workflow.next_stage(workflow.stages[0]), workflow.stages[1])
+
     def test_returns_none_if_no_next(self):
         workflow = WorkflowFactory(num_stages=1, stages__num_phases=1)
         self.assertEqual(workflow.next(workflow.stages[0].phases[0]), None)
@@ -46,19 +50,20 @@ class TestStageCreation(SimpleTestCase):
         self.assertEqual(stage.form, form)
         self.assertEqual(stage.phases, phases)
 
-    def test_can_iterate_through_phases(self):
-        stage = StageFactory()
-        for phase, check in zip(stage, stage.phases):  # type: ignore # spurious error
-            self.assertEqual(phase, check)
+    def test_can_get_next_phase(self):
+        stage = StageFactory(num_phases=2)
+        self.assertEqual(stage.next(stage.phases[0]), stage.phases[1])
+
+    def test_get_none_if_no_next_phase(self):
+        stage = StageFactory(num_phases=1)
+        self.assertEqual(stage.next(stage.phases[0]), None)
 
 
 class TestPhaseCreation(SimpleTestCase):
     def test_can_create_phase(self):
-        actions = ActionFactory.create_batch(2)
         name = 'the_phase'
-        phase = Phase(name, actions)
+        phase = Phase(name)
         self.assertEqual(phase.name, name)
-        self.assertEqual(phase.actions, [action.name for action in actions])
 
     def test_can_get_action_from_phase(self):
         actions = ActionFactory.create_batch(3)
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index ff74e5bf08722125f340bf84b4c530b01e699a38..12f4777754c0b74a6ed4f72b0c048ed917f71999 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -7,43 +7,65 @@ class Workflow:
     def __init__(self, name: str, stages: Sequence['Stage']) -> None:
         self.name = name
         self.stages = stages
-        self.mapping = self.build_mapping(stages)
-
-    def build_mapping(self, stages: Sequence['Stage']) -> Dict[str, Tuple[int, int]]:
-        mapping: Dict[str, Tuple[int, int]] = {}
-        for i, stage in enumerate(stages):
-            for j, phase in enumerate(stage):
-                while str(phase) in mapping:
-                    phase.occurance += 1
-                mapping[str(phase)] = (i, j)
-        return mapping
-
-    def current_index(self, phase: Union['Phase', str, None]):
-        if isinstance(phase, Phase):
-            phase = str(phase)
-        try:
-            return self.mapping[phase]
-        except KeyError:
-            return 0, -1
 
-    def __getitem__(self, value):
-        return self.stages[value[0]].phases[value[1]]
+    def current(self, current_phase: Union[str, 'Phase']) -> Union['Phase', None]:
+        if isinstance(current_phase, Phase):
+            return current_phase
+
+        if not current_phase:
+            return self.first()
+
+        stage_name, phase_name, _ = current_phase.split('__')
+        for stage in self.stages:
+            if stage == stage_name:
+                return stage.current(phase_name)
+        return None
+
+    def first(self):
+        return self.stages[0].next()
+
+    def process(self, current_phase: str, action: str) -> Union['Phase', None]:
+        phase = self.current(current_phase)
+        new_phase = phase.process(action)
+        if not new_phase:
+            new_stage = self.next_stage(phase.stage)
+            return new_stage.first()
+        return new_phase
+
+    def next_stage(self, current_stage: 'Stage') -> 'Stage':
+        for i, stage in enumerate(self.stages):
+            if stage == current_stage:
+                try:
+                    return self.stages[i+1]
+                except IndexError:
+                    pass
+
+        return None
+
+    def next(self, current_phase: Union[str, 'Phase']=None) -> Union['Phase', None]:
+        if not current_phase:
+            return self.first()
 
-    def current(self, current_phase: str):
-        return self[self.current_index(current_phase)]
+        phase = self.current(current_phase)
 
-    def next(self, current_phase: Union['Phase', str]=None) -> Union['Phase', None]:
-        stage_idx, phase_idx = self.current_index(current_phase)
-        try:
-            return self[stage_idx, phase_idx + 1]
-        except IndexError:
-            try:
-                return self[stage_idx + 1, 0]
-            except IndexError:
-                return None
+        for stage in self.stages:
+            if stage == phase.stage:
+                next_phase = stage.next(phase)
+                if not next_phase:
+                    continue
+                return next_phase
 
+        next_stage = self.next_stage(phase.stage)
+        if next_stage:
+            return stage.next()
+        return None
 
-class Stage(Iterable['Phase']):
+
+    def __str__(self):
+        return self.name
+
+
+class Stage:
     def __init__(self, name: str, form: Form, phases: Sequence['Phase'],
                  current_phase: Union['Phase', None]=None) -> None:
         self.name = name
@@ -52,22 +74,51 @@ class Stage(Iterable['Phase']):
             phase.stage = self
         self.phases = phases
 
-    def __iter__(self) -> Iterator['Phase']:
-        yield from self.phases
+    def __eq__(self, other):
+        if isinstance(other, str):
+            return self.name == other
+        return super().__eq__(other)
 
     def __str__(self):
         return self.name
 
+    def current(self, phase_name: str) -> 'Phase':
+        for phase in self.phases:
+            if phase == phase_name:
+                return phase
+        return None
+
+    def first(self) -> 'Phase':
+        return self.phases[0]
+
+    def next(self, current_phase: 'Phase'=None) -> 'Phase':
+        if not current_phase:
+            return self.first()
+
+        for i, phase in enumerate(self.phases):
+            if phase == current_phase:
+                try:
+                    return self.phases[i+1]
+                except IndexError:
+                    pass
+        return None
 
 class Phase:
-    def __init__(self, name: str, actions: Sequence['Action']) -> None:
+    actions: Sequence['Action'] = list()
+
+    def __init__(self, name: str) -> None:
         self.name = name
         self.stage: Union['Stage', None] = None
-        self._actions = {action.name: action for action in actions}
+        self._actions = {action.name: action for action in self.actions}
         self.occurance: int = 0
 
+    def __eq__(self, other):
+        if isinstance(other, str):
+            return self.name == other
+        return super().__eq__(other)
+
     @property
-    def actions(self):
+    def action_names(self):
         return list(self._actions.keys())
 
     def __str__(self):
@@ -76,6 +127,9 @@ class Phase:
     def __getitem__(self, value):
         return self._actions[value]
 
+    def process(self, action):
+        return self[action]()
+
 
 class Action:
     def __init__(self, name: str) -> None:
@@ -84,35 +138,61 @@ class Action:
     def __call__(self, *args, **kwargs):
         return self.process(*args, **kwargs)
 
-    def process(self, *args, **kwargs):
+    def process(self, *args, **kwargs) -> 'Phase':
         # Use this to define the behaviour of the action
         raise NotImplementedError
 
 
 # --- OTF Workflow ---
 
+class ChangePhaseAction(Action):
+    def __init__(self, phase, *args, **kwargs):
+        self.target_phase = phase
+        super().__init__(*args, **kwargs)
+
+    def process(self):
+        if isinstance(self.target_phase, str):
+            phase = globals()[self.target_phase]
+        else:
+            phase = self.target_phase
+        return phase
+
+
+reject_action = ChangePhaseAction('rejected', 'Reject')
+
+accept_action = ChangePhaseAction('accepted', 'Accept')
+
+progress_external = ChangePhaseAction('external_review', 'Progress')
+
+progress_stage = ChangePhaseAction(None, 'Progress Stage')
+
+
 class ReviewPhase(Phase):
-    pass
+    actions = [progress_stage, reject_action]
+
+
+class ProposalReviewPhase(Phase):
+    actions = [progress_external, reject_action]
 
 
-next_phase = Action('next')
+class FinalReviewPhase(Phase):
+    actions = [accept_action, reject_action]
 
-internal_review = ReviewPhase('Under Review', [next_phase])
 
-ac_review = ReviewPhase('Under Review', [next_phase])
+concept_review = ReviewPhase('Internal Review')
 
-response = Phase('Ready to Respond', [next_phase])
+proposal_review = ProposalReviewPhase('Internal Review')
 
-rejected = Phase('Rejected', [])
+external_review = FinalReviewPhase('AC Review')
 
-accepted = Phase('Accepted', [next_phase])
+rejected = Phase('Rejected')
 
-progress = Phase('Progress', [next_phase])
+accepted = Phase('Accepted')
 
-standard_stage = Stage('Standard', Form(), [internal_review, response, ac_review, response, accepted, rejected])
+concept_note = Stage('Concept', Form(), [concept_review, accepted, rejected])
 
-first_stage = Stage('Standard', Form(), [internal_review, response, progress, rejected])
+proposal = Stage('Proposal', Form(), [proposal_review, external_review, accepted, rejected])
 
-single_stage = Workflow('Single Stage', [standard_stage])
+single_stage = Workflow('Single Stage', [proposal])
 
-two_stage = Workflow('Two Stage', [first_stage, standard_stage])
+two_stage = Workflow('Two Stage', [concept_note, proposal])