Skip to content
Snippets Groups Projects
Commit a711b60b authored by Todd Dembrey's avatar Todd Dembrey
Browse files

Update the manager to work with the oauth tests

parent 891bf1f8
No related branches found
No related tags found
No related merge requests found
......@@ -5,17 +5,51 @@ from django.utils.translation import gettext_lazy as _
from .utils import send_activation_email
def convert_full_name_to_parts(full_name):
def convert_full_name_to_parts(defaults):
full_name = defaults.pop('full_name', ' ')
first_name, last_name = full_name.split(' ', 1)
return first_name, last_name
if first_name:
defaults.update(first_name=first_name)
if last_name:
defaults.update(last_name=last_name)
return defaults
class UserManager(BaseUserManager):
use_in_migrations = True
def _create_user(self, email, password, **extra_fields):
"""
Creates and saves a User with the given username, email and password.
"""
if not email:
raise ValueError('The given email must be set')
email = self.normalize_email(email)
extra_fields = convert_full_name_to_parts(extra_fields)
user = self.model(email=email, **extra_fields)
user.set_password(password)
user.save(using=self._db)
return user
def create_user(self, email, password=None, **extra_fields):
extra_fields.setdefault('is_staff', False)
extra_fields.setdefault('is_superuser', False)
return self._create_user(email, password, **extra_fields)
def create_superuser(self, email, password, **extra_fields):
extra_fields.setdefault('is_staff', True)
extra_fields.setdefault('is_superuser', True)
if extra_fields.get('is_staff') is not True:
raise ValueError('Superuser must have is_staff=True.')
if extra_fields.get('is_superuser') is not True:
raise ValueError('Superuser must have is_superuser=True.')
return self._create_user(email, password, **extra_fields)
def get_or_create(self, defaults, **kwargs):
# Allow passing of 'full_name' but replace it with actual database fields
first_name, last_name = convert_full_name_to_parts(defaults.pop('full_name', ''))
defaults.update(first_name=first_name, last_name=last_name)
defaults = convert_full_name_to_parts(defaults)
return super().get_or_create(defaults=defaults, **kwargs)
def get_or_create_and_notify(self, defaults=dict(), **kwargs):
......
......@@ -5,6 +5,13 @@ from django.urls import reverse
class TestOAuthAccess(TestCase):
def login(self):
email = 'test@email.com'
password = 'password'
user = get_user_model().objects.create_user(email=email, password=password)
logged_in = self.client.login(email=email, password=password)
self.assertTrue(logged_in)
return user
def test_oauth_page_requires_login(self):
"""
......@@ -48,11 +55,3 @@ class TestOAuthAccess(TestCase):
self.assertNotContains(response, 'Disconnect Google OAuth')
self.assertTemplateUsed(response, 'users/oauth.html')
def login(self):
user = get_user_model().objects.create_user(username='test', email='test@email.com', password='password')
self.assertTrue(
self.client.login(username='test', password='password')
)
return user
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment