diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py
index cc8a68758021da7c6838e696d8dfb915962d6ec0..fbbf0a09965d0ed7cee6ec14545273b1e0e0c06c 100644
--- a/opentech/apply/tests.py
+++ b/opentech/apply/tests.py
@@ -28,6 +28,10 @@ class TestWorkflowCreation(SimpleTestCase):
         workflow = WorkflowFactory(num_stages=1)
         self.assertEqual(workflow.next(), workflow.stages[0])
 
+    def test_returns_none_if_no_next(self):
+        workflow = WorkflowFactory(num_stages=1)
+        self.assertEqual(workflow.next(workflow.stages[0]), None)
+
 
 class TestStageCreation(SimpleTestCase):
     def test_can_create_stage(self):
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index 9b624227e4d40278992a9c015b46b26536a30786..fe2f35860f10b176edb813caefd6eeec95c649d6 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -1,4 +1,4 @@
-from typing import Iterator, Iterable
+from typing import Iterator, Iterable, Union
 
 from django.forms import Form
 
@@ -13,8 +13,11 @@ class Workflow(Iterable['Stage']):
     def __iter__(self) -> Iterator['Stage']:
         yield from self.stages
 
-    def next(self):
-        return self.stages[0]
+    def next(self, current_stage: Union['Stage', None]=None) -> Union['Stage', None]:
+        if not current_stage:
+            return self.stages[0]
+
+        return None
 
 
 class Stage: