diff --git a/opentech/apply/tests.py b/opentech/apply/tests.py
index b44549f67558da5da2a089d9772b3f729dec58d6..3435937db3b0e5018a646e9629613c5f3a934ebb 100644
--- a/opentech/apply/tests.py
+++ b/opentech/apply/tests.py
@@ -1,4 +1,5 @@
 from django.test import SimpleTestCase
+from django.forms import Form
 
 from .workflow import Workflow, Stage
 
@@ -6,7 +7,7 @@ from .workflow import Workflow, Stage
 class TestWorkflowCreation(SimpleTestCase):
     def test_can_create_workflow(self):
         name = 'single_stage'
-        stage = Stage('stage_name')
+        stage = Stage('stage_name', Form())
         workflow = Workflow(name, stage)
         self.assertEqual(workflow.name, name)
         self.assertCountEqual(workflow.stages, [stage])
@@ -17,8 +18,8 @@ class TestWorkflowCreation(SimpleTestCase):
             Workflow(name)
 
     def test_can_iterate_through_workflow(self):
-        stage1 = Stage('stage_one')
-        stage2 = Stage('stage_two')
+        stage1 = Stage('stage_one', Form())
+        stage2 = Stage('stage_two', Form())
         workflow = Workflow('two_stage', stage1, stage2)
         for stage, check in zip(workflow, [stage1, stage2]):
             self.assertEqual(stage, check)
@@ -27,5 +28,7 @@ class TestWorkflowCreation(SimpleTestCase):
 class TestStageCreation(SimpleTestCase):
     def test_can_create_stage(self):
         name = 'the_stage'
-        stage = Stage(name)
+        form = Form()
+        stage = Stage(name, form)
         self.assertEqual(stage.name, name)
+        self.assertEqual(stage.form, form)
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index 001ce395697976c1bc21619e11d9380835be9b4a..d24e47187a5ee8fc2d6545d71b49f57928ef00ab 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -1,5 +1,7 @@
 from typing import Iterator, Iterable
 
+from django.forms import Form
+
 
 class Workflow(Iterable['Stage']):
     def __init__(self, name: str, *stages: 'Stage') -> None:
@@ -13,5 +15,6 @@ class Workflow(Iterable['Stage']):
 
 
 class Stage:
-    def __init__(self, name: str) -> None:
+    def __init__(self, name: str, form: Form) -> None:
         self.name = name
+        self.form = form