diff --git a/opentech/apply/workflow.py b/opentech/apply/workflow.py index 12f4777754c0b74a6ed4f72b0c048ed917f71999..633ce7707e03f984091e6ec59491bb148515bae0 100644 --- a/opentech/apply/workflow.py +++ b/opentech/apply/workflow.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterator, Iterable, Sequence, Tuple, Union +from typing import Dict, Iterator, Iterable, List, Sequence, Tuple, Union from django.forms import Form @@ -17,11 +17,11 @@ class Workflow: stage_name, phase_name, _ = current_phase.split('__') for stage in self.stages: - if stage == stage_name: + if stage.name == stage_name: return stage.current(phase_name) return None - def first(self): + def first(self) -> 'Phase': return self.stages[0].next() def process(self, current_phase: str, action: str) -> Union['Phase', None]: @@ -61,7 +61,7 @@ class Workflow: return None - def __str__(self): + def __str__(self) -> str: return self.name @@ -74,17 +74,12 @@ class Stage: phase.stage = self self.phases = phases - def __eq__(self, other): - if isinstance(other, str): - return self.name == other - return super().__eq__(other) - - def __str__(self): + def __str__(self) -> str: return self.name def current(self, phase_name: str) -> 'Phase': for phase in self.phases: - if phase == phase_name: + if phase.name == phase_name: return phase return None @@ -112,22 +107,17 @@ class Phase: self._actions = {action.name: action for action in self.actions} self.occurance: int = 0 - def __eq__(self, other): - if isinstance(other, str): - return self.name == other - return super().__eq__(other) - @property - def action_names(self): + def action_names(self) -> List[str]: return list(self._actions.keys()) - def __str__(self): + def __str__(self) -> str: return '__'.join([self.stage.name, self.name, str(self.occurance)]) - def __getitem__(self, value): + def __getitem__(self, value: str) -> 'Action': return self._actions[value] - def process(self, action): + def process(self, action: str) -> Union['Phase', None]: return self[action]() @@ -135,10 +125,10 @@ class Action: def __init__(self, name: str) -> None: self.name = name - def __call__(self, *args, **kwargs): - return self.process(*args, **kwargs) + def __call__(self) -> Union['Phase', None]: + return self.process() - def process(self, *args, **kwargs) -> 'Phase': + def process(self) -> Union['Phase', None]: # Use this to define the behaviour of the action raise NotImplementedError @@ -146,11 +136,11 @@ class Action: # --- OTF Workflow --- class ChangePhaseAction(Action): - def __init__(self, phase, *args, **kwargs): + def __init__(self, phase: Union['Phase', str], *args: str, **kwargs: str) -> None: self.target_phase = phase super().__init__(*args, **kwargs) - def process(self): + def process(self) -> Union['Phase', None]: if isinstance(self.target_phase, str): phase = globals()[self.target_phase] else: diff --git a/setup.cfg b/setup.cfg index c6e0c156292251fb1e830d6de120b0fb0ad45a29..4ab5cc3deff7bec1cd567a0cbb99204f57f8bd0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,9 @@ ignore_errors = True check_untyped_defs = True ignore_errors = False +[mypy-opentech.apply.workflow*] +disallow_untyped_defs = True + [flake8] ignore=E501,F405 exclude=*/migrations/*