diff --git a/opentech/apply/factories.py b/opentech/apply/factories.py
index a36ce4560094c3a282cc98847d35263f9bfeef89..a20b3729a0e7c5f48c13d8e87d099e9b168fd066 100644
--- a/opentech/apply/factories.py
+++ b/opentech/apply/factories.py
@@ -11,10 +11,11 @@ class StageFactory(factory.Factory):
     name = factory.Faker('word')
     form = factory.LazyFunction(Form)
 
-.
+
 class WorkflowFactory(factory.Factory):
     class Meta:
         model = Workflow
+        inline_args = ('name', 'stages',)
 
     class Params:
         num_stages = factory.Faker('random_int', min=1, max=3)
diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py
index f34052fe655768d56a0da5c945233470ef674010..cc8a68758021da7c6838e696d8dfb915962d6ec0 100644
--- a/opentech/apply/tests.py
+++ b/opentech/apply/tests.py
@@ -24,8 +24,9 @@ class TestWorkflowCreation(SimpleTestCase):
         for stage, check in zip(workflow, stages):
             self.assertEqual(stage, check)
 
-    # def test_returns_none_if_no_next_stage(self):
-    #     workflow = Workflow('two_stage', stage1, stage2)
+    def test_returns_first_stage_if_no_arg(self):
+        workflow = WorkflowFactory(num_stages=1)
+        self.assertEqual(workflow.next(), workflow.stages[0])
 
 
 class TestStageCreation(SimpleTestCase):
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index d24e47187a5ee8fc2d6545d71b49f57928ef00ab..9b624227e4d40278992a9c015b46b26536a30786 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -13,6 +13,9 @@ class Workflow(Iterable['Stage']):
     def __iter__(self) -> Iterator['Stage']:
         yield from self.stages
 
+    def next(self):
+        return self.stages[0]
+
 
 class Stage:
     def __init__(self, name: str, form: Form) -> None: