Skip to content
Snippets Groups Projects
Commit 7c2baf02 authored by Todd Dembrey's avatar Todd Dembrey
Browse files

Update tests to handle the cache busting

parent b99b3100
No related branches found
No related tags found
No related merge requests found
from urllib import parse
from unittest import mock from unittest import mock
import re import re
...@@ -17,6 +18,15 @@ class TestNewsletterView(TestCase): ...@@ -17,6 +18,15 @@ class TestNewsletterView(TestCase):
self.origin = 'https://testserver/' self.origin = 'https://testserver/'
self.client.defaults = {'HTTP_ORIGIN': self.origin} self.client.defaults = {'HTTP_ORIGIN': self.origin}
def assertNewsletterRedirects(self, response, target_url, *args, **kwargs):
url = response.redirect_chain[0][0]
parts = parse.urlsplit(url)
self.assertTrue(parts.query.startswith('newsletter-'))
target_url = target_url + '?' + parts.query
return self.assertRedirects(response, target_url, *args, **kwargs)
def test_redirected_home_if_get(self): def test_redirected_home_if_get(self):
response = self.client.get(self.url, secure=True, follow=True) response = self.client.get(self.url, secure=True, follow=True)
request = response.request request = response.request
...@@ -31,7 +41,7 @@ class TestNewsletterView(TestCase): ...@@ -31,7 +41,7 @@ class TestNewsletterView(TestCase):
responses.add(responses.POST, any_url, json={'id': '1234'}, status=200) responses.add(responses.POST, any_url, json={'id': '1234'}, status=200)
response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True)
self.assertRedirects(response, self.origin) self.assertNewsletterRedirects(response, self.origin)
messages = list(response.context['messages']) messages = list(response.context['messages'])
self.assertEqual(len(messages), 1) self.assertEqual(len(messages), 1)
...@@ -39,7 +49,7 @@ class TestNewsletterView(TestCase): ...@@ -39,7 +49,7 @@ class TestNewsletterView(TestCase):
def test_error_in_form(self): def test_error_in_form(self):
response = self.client.post(self.url, data={'email': 'email_is_bad.com'}, secure=True, follow=True) response = self.client.post(self.url, data={'email': 'email_is_bad.com'}, secure=True, follow=True)
self.assertRedirects(response, self.origin) self.assertNewsletterRedirects(response, self.origin)
messages = list(response.context['messages']) messages = list(response.context['messages'])
self.assertEqual(len(messages), 1) self.assertEqual(len(messages), 1)
...@@ -61,7 +71,7 @@ class TestNewsletterView(TestCase): ...@@ -61,7 +71,7 @@ class TestNewsletterView(TestCase):
responses.add(responses.POST, any_url, json=response_data, status=400) responses.add(responses.POST, any_url, json=response_data, status=400)
response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True) response = self.client.post(self.url, data={'email': 'email@email.com'}, secure=True, follow=True)
self.assertRedirects(response, self.origin) self.assertNewsletterRedirects(response, self.origin)
messages = list(response.context['messages']) messages = list(response.context['messages'])
self.assertEqual(len(messages), 1) self.assertEqual(len(messages), 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment