diff --git a/opentech/apply/tests/test_workflow.py b/opentech/apply/tests/test_workflow.py
index 54b5caa6ec87a23df9364b4f33e59f7e005d3f7d..2186455e059fe526c3cfa69479238f0a15f83af8 100644
--- a/opentech/apply/tests/test_workflow.py
+++ b/opentech/apply/tests/test_workflow.py
@@ -9,6 +9,7 @@ from .factories import ActionFactory, PhaseFactory, StageFactory, WorkflowFactor
 class TestWorkflowCreation(SimpleTestCase):
     def test_can_create_workflow(self):
         stage = StageFactory()
+
         class NewWorkflow(Workflow):
             name = 'single_stage'
             stage_classes = [stage]
diff --git a/opentech/apply/views.py b/opentech/apply/views.py
index f5ec9072c54d46f6265064a329f64e0af568dfd7..9e4e9a3189efbfcc4323329d12df0c78269e0e5e 100644
--- a/opentech/apply/views.py
+++ b/opentech/apply/views.py
@@ -1,5 +1,4 @@
 from django import forms
-from django.shortcuts import render
 from django.template.response import TemplateResponse
 
 from .workflow import SingleStage, DoubleStage
@@ -18,7 +17,7 @@ class BasicSubmissionForm(forms.Form):
 
 def demo_workflow(request, wf_id):
     wf = int(wf_id)
-    workflow_class = workflows[wf-1]
+    workflow_class = workflows[wf - 1]
     workflow = workflow_class([BasicSubmissionForm] * wf)
 
     current_phase = request.POST.get('current')
@@ -59,5 +58,3 @@ def demo_workflow(request, wf_id):
         'form': form,
     }
     return TemplateResponse(request, 'apply/demo_workflow.html', context)
-
-
diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py
index 73655be0feab36607eb89e9279d5cd4163fe0a89..ff774fc820c1e7fd550f9d221d6b8059eeed0ef7 100644
--- a/opentech/apply/workflow.py
+++ b/opentech/apply/workflow.py
@@ -45,7 +45,7 @@ class Workflow:
         for i, stage in enumerate(self.stages):
             if stage == current_stage:
                 try:
-                    return self.stages[i+1]
+                    return self.stages[i + 1]
                 except IndexError:
                     pass
 
@@ -69,7 +69,6 @@ class Workflow:
             return stage.next()
         return None
 
-
     def __str__(self) -> str:
         return self.name
 
@@ -112,17 +111,18 @@ class Stage:
         for i, phase in enumerate(self.phases):
             if phase == current_phase:
                 try:
-                    return self.phases[i+1]
+                    return self.phases[i + 1]
                 except IndexError:
                     pass
         return None
 
+
 class Phase:
     actions: Sequence['Action'] = list()
     name: str = ''
     public_name: str = ''
 
-    def __init__(self, name: str='', public_name:str = '') -> None:
+    def __init__(self, name: str='', public_name: str ='') -> None:
         if name:
             self.name = name