diff --git a/opentech/apply/tests/factories.py b/opentech/apply/tests/factories.py
index cbaf9722bc7a2051998e0c05ef6b3001a3f4678b..ccd012e5ed52c8e312af47d5ba872258df653dc5 100644
--- a/opentech/apply/tests/factories.py
+++ b/opentech/apply/tests/factories.py
@@ -42,6 +42,9 @@ class PhaseFactory(factory.Factory):
 
     name = factory.Faker('word')
     actions = ListSubFactory(ActionFactory, count=factory.SelfAttribute('num_actions'))
+    stage = factory.PostGeneration(
+            lambda obj, create, extracted, **kwargs: StageFactory.build(phases=[obj])
+    )
 
     @classmethod
     def _create(cls, model_class, *args, **kwargs):
diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py
index 2186455e059fe526c3cfa69479238f0a15f83af8..c239f7cc46211f82c9cc4bc1b05b5ffe3a9ff7a2 100644
--- a/opentech/apply/tests/test_workflow.py
+++ b/opentech/apply/tests/test_workflow.py
@@ -1,7 +1,14 @@
 from django.test import SimpleTestCase
 from django.forms import Form
 
-from opentech.apply.workflow import Action, Phase, Stage, Workflow
+from opentech.apply.workflow import (
+    Action,
+    ChangePhaseAction,
+    NextPhaseAction,
+    Phase,
+    Stage,
+    Workflow,
+)
 
 from .factories import ActionFactory, PhaseFactory, StageFactory, WorkflowFactory
 
@@ -100,3 +107,15 @@ class TestActions(SimpleTestCase):
         action = ActionFactory()
         with self.assertRaises(NotImplementedError):
             action.process('')
+
+
+class TestCustomActions(SimpleTestCase):
+    def test_next_phase_action_returns_none_if_no_next(self):
+        action = NextPhaseAction('the next!')
+        phase = PhaseFactory(actions=[action])
+        self.assertEqual(phase.process(action.name), None)
+
+    def test_next_phase_action_returns_next_phase(self):
+        action = NextPhaseAction('the next!')
+        stage = StageFactory.build(num_phases=2, phases__actions=[action])
+        self.assertEqual(stage.phases[0].process(action.name), stage.phases[1])