diff --git a/opentech/public/mailchimp/tests.py b/opentech/public/mailchimp/tests.py index e5ca132f3f77284af7b35d8e192d027e3dc2b9af..5dd089387611991a1f8642eb17883523614e799f 100644 --- a/opentech/public/mailchimp/tests.py +++ b/opentech/public/mailchimp/tests.py @@ -1,3 +1,4 @@ +from urllib import parse from unittest import mock import re @@ -17,6 +18,15 @@ class TestNewsletterView(TestCase): self.origin = 'https://testserver/' self.client.defaults = {'HTTP_ORIGIN': self.origin} + def assertNewsletterRedirects(self, response, target_url, *args, **kwargs): + url = response.redirect_chain[0][0] + parts = parse.urlsplit(url) + self.assertTrue(parts.query.startswith('newsletter-')) + + target_url = target_url + '?' + parts.query + + return self.assertRedirects(response, target_url, *args, **kwargs) + def test_redirected_home_if_get(self): response = self.client.get(self.url, secure=True, follow=True) request = response.request @@ -31,7 +41,7 @@ class TestNewsletterView(TestCase): responses.add(responses.POST, any_url, json={'id': '1234'}, status=200) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1) @@ -39,7 +49,7 @@ class TestNewsletterView(TestCase): def test_error_in_form(self): response = self.client.post(self.url, data={'email': 'email_is_bad.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1) @@ -61,7 +71,7 @@ class TestNewsletterView(TestCase): responses.add(responses.POST, any_url, json=response_data, status=400) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1)