diff --git a/opentech/public/mailchimp/tests.py b/opentech/public/mailchimp/tests.py index e5ca132f3f77284af7b35d8e192d027e3dc2b9af..5dd089387611991a1f8642eb17883523614e799f 100644 --- a/opentech/public/mailchimp/tests.py +++ b/opentech/public/mailchimp/tests.py @@ -1,3 +1,4 @@ +from urllib import parse from unittest import mock import re @@ -17,6 +18,15 @@ class TestNewsletterView(TestCase): self.origin = 'https://testserver/' self.client.defaults = {'HTTP_ORIGIN': self.origin} + def assertNewsletterRedirects(self, response, target_url, *args, **kwargs): + url = response.redirect_chain[0][0] + parts = parse.urlsplit(url) + self.assertTrue(parts.query.startswith('newsletter-')) + + target_url = target_url + '?' + parts.query + + return self.assertRedirects(response, target_url, *args, **kwargs) + def test_redirected_home_if_get(self): response = self.client.get(self.url, secure=True, follow=True) request = response.request @@ -31,7 +41,7 @@ class TestNewsletterView(TestCase): responses.add(responses.POST, any_url, json={'id': '1234'}, status=200) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1) @@ -39,7 +49,7 @@ class TestNewsletterView(TestCase): def test_error_in_form(self): response = self.client.post(self.url, data={'email': 'email_is_bad.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1) @@ -61,7 +71,7 @@ class TestNewsletterView(TestCase): responses.add(responses.POST, any_url, json=response_data, status=400) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) - self.assertRedirects(response, self.origin) + self.assertNewsletterRedirects(response, self.origin) messages = list(response.context['messages']) self.assertEqual(len(messages), 1) diff --git a/opentech/public/mailchimp/views.py b/opentech/public/mailchimp/views.py index ed0ede580472b056eed7383c59ded72c7577e5a6..0c1bdf80585333d81864c7e9ae10b4a5f28f7678 100644 --- a/opentech/public/mailchimp/views.py +++ b/opentech/public/mailchimp/views.py @@ -79,11 +79,7 @@ class MailchimpSubscribeView(FormMixin, RedirectView): def get_success_url(self): # Go back to where you came from, default to front page. - origin = '/' - if 'HTTP_ORIGIN' in self.request.META: - origin = self.request.META['HTTP_ORIGIN'] - elif 'HTTP_REFERER' in self.request.META: - origin = self.request.META['HTTP_REFERER'] + origin = self.request.META.get('HTTP_ORIGIN') or self.request.META.get('HTTP_REFERER') or '/' # Add cache busting query string. return origin + '?newsletter-' + uuid.uuid4().hex