diff --git a/opentech/apply/funds/tests/factories/models.py b/opentech/apply/funds/tests/factories/models.py index 116b07094810dbda0b7bf52cc8cf1aff315337c1..6bd87ec6f614655c395f2463aec99e15db40b296 100644 --- a/opentech/apply/funds/tests/factories/models.py +++ b/opentech/apply/funds/tests/factories/models.py @@ -13,6 +13,7 @@ from opentech.apply.funds.models import ( Round, RoundForm, ) +from opentech.apply.users.tests.factories import UserFactory from . import blocks @@ -82,6 +83,7 @@ class RoundFactory(wagtail_factories.PageFactory): title = factory.Sequence('Round {}'.format) start_date = factory.LazyFunction(datetime.date.today) end_date = factory.LazyFunction(lambda: datetime.date.today() + datetime.timedelta(days=7)) + lead = factory.SubFactory(UserFactory, is_staff=True) class RoundFormFactory(AbstractRelatedFormFactory): diff --git a/opentech/apply/funds/tests/test_models.py b/opentech/apply/funds/tests/test_models.py index 86c173891275de97020ffb14b3543f80f4fb73e3..1dc83e941a21ec513f05dd7cd15ba503e0c8e0ee 100644 --- a/opentech/apply/funds/tests/test_models.py +++ b/opentech/apply/funds/tests/test_models.py @@ -147,6 +147,7 @@ class TestRoundModelWorkflowAndForms(TestCase): self.round = RoundFactory.build() self.round.parent_page = self.fund + self.round.lead = RoundFactory.lead.get_factory()(**RoundFactory.lead.defaults) self.fund.add_child(instance=self.round) @@ -235,7 +236,8 @@ class TestFormSubmission(TestCase): def test_can_submit_if_new(self): self.submit_form() - self.assertEqual(self.User.objects.count(), 1) + # Lead + applicant + self.assertEqual(self.User.objects.count(), 2) new_user = self.User.objects.get(email=self.email) self.assertEqual(new_user.get_full_name(), self.name) @@ -246,7 +248,8 @@ class TestFormSubmission(TestCase): self.submit_form() self.submit_form() - self.assertEqual(self.User.objects.count(), 1) + # Lead + applicant + self.assertEqual(self.User.objects.count(), 2) user = self.User.objects.get(email=self.email) self.assertEqual(ApplicationSubmission.objects.count(), 2) @@ -257,9 +260,10 @@ class TestFormSubmission(TestCase): # Someone else submits a form self.submit_form(email='another@email.com') - self.assertEqual(self.User.objects.count(), 2) + # Lead + 2 x applicant + self.assertEqual(self.User.objects.count(), 3) - first_user, second_user = self.User.objects.all() + _, first_user, second_user = self.User.objects.all() self.assertEqual(ApplicationSubmission.objects.count(), 2) self.assertEqual(ApplicationSubmission.objects.first().user, first_user) self.assertEqual(ApplicationSubmission.objects.last().user, second_user) @@ -267,11 +271,13 @@ class TestFormSubmission(TestCase): def test_associated_if_logged_in(self): user, _ = self.User.objects.get_or_create(email=self.email, defaults={'full_name': self.name}) - self.assertEqual(self.User.objects.count(), 1) + # Lead + Applicant + self.assertEqual(self.User.objects.count(), 2) self.submit_form(email=self.email, name=self.name, user=user) - self.assertEqual(self.User.objects.count(), 1) + # Lead + Applicant + self.assertEqual(self.User.objects.count(), 2) self.assertEqual(ApplicationSubmission.objects.count(), 1) self.assertEqual(ApplicationSubmission.objects.first().user, user) @@ -280,12 +286,14 @@ class TestFormSubmission(TestCase): def test_errors_if_blank_user_data_even_if_logged_in(self): user, _ = self.User.objects.get_or_create(email=self.email, defaults={'full_name': self.name}) - self.assertEqual(self.User.objects.count(), 1) + # Lead + applicant + self.assertEqual(self.User.objects.count(), 2) response = self.submit_form(email='', name='', user=user) self.assertContains(response, 'This field is required') - self.assertEqual(self.User.objects.count(), 1) + # Lead + applicant + self.assertEqual(self.User.objects.count(), 2) self.assertEqual(ApplicationSubmission.objects.count(), 0) diff --git a/opentech/apply/users/tests/factories.py b/opentech/apply/users/tests/factories.py new file mode 100644 index 0000000000000000000000000000000000000000..cadd5f4ec98b4dcd82b9cdc05be9e852e3b82677 --- /dev/null +++ b/opentech/apply/users/tests/factories.py @@ -0,0 +1,10 @@ +from django.contrib.auth import get_user_model + +import factory + + +class UserFactory(factory.DjangoModelFactory): + class Meta: + model = get_user_model() + + email = factory.Sequence('email{}@email.com'.format)