From fee558be3331f17206a0456839cf716b70ceec77 Mon Sep 17 00:00:00 2001
From: Todd Dembrey <todd.dembrey@torchbox.com>
Date: Tue, 6 Feb 2018 12:19:31 +0000
Subject: [PATCH] Add basic tests for required fields

---
 addressfield/fields.py |  5 +++--
 addressfield/tests.py  | 40 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 43 insertions(+), 2 deletions(-)
 create mode 100644 addressfield/tests.py

diff --git a/addressfield/fields.py b/addressfield/fields.py
index 1420c2529..0f6015279 100644
--- a/addressfield/fields.py
+++ b/addressfield/fields.py
@@ -29,17 +29,18 @@ def flatten_data(data):
 
 class AddressField(forms.CharField):
     widget = AddressWidget
+    data = VALIDATION_DATA
 
     def clean(self, value, **kwargs):
         country = value['country']
         try:
-            country_data = VALIDATION_DATA[country]
+            country_data = self.data[country]
         except KeyError:
             raise ValidationError('Invalid country selected')
 
         fields = flatten_data(country_data['fields'])
 
-        missing_fields = set(country_data['required']) - set(value.keys())
+        missing_fields = set(country_data['required']) - set(field for field, value in value.items() if value)
         if missing_fields:
             missing_field_name = [fields[field]['label'] for field in missing_fields]
             raise ValidationError('Please provide data for: {}'.format(', '.join(missing_field_name)))
diff --git a/addressfield/tests.py b/addressfield/tests.py
new file mode 100644
index 000000000..adbf1dadc
--- /dev/null
+++ b/addressfield/tests.py
@@ -0,0 +1,40 @@
+from django.core.exceptions import ValidationError
+from django.test import TestCase
+
+from .fields import AddressField
+
+
+class TestRequiredFields(TestCase):
+    def build_validation_data(self, fields=list(), required=list()):
+        fields = set(fields + required)
+        return {'COUNTRY': {
+            'fields': [{field: {'label': field}} for field in fields],
+            'required': required,
+        }}
+
+    def test_non_required(self):
+        field = AddressField()
+        field.data = self.build_validation_data(fields=['postalcode'])
+        field.clean({'country': 'COUNTRY'})
+
+    def test_non_required_blank_data(self):
+        field = AddressField()
+        field.data = self.build_validation_data(fields=['postalcode'])
+        field.clean({'country': 'COUNTRY', 'postalcode': ''})
+
+    def test_one_field_required(self):
+        field = AddressField()
+        field.data = self.build_validation_data(required=['postalcode'])
+        with self.assertRaises(ValidationError):
+            field.clean({'country': 'COUNTRY'})
+
+    def test_one_field_required_blank_data(self):
+        field = AddressField()
+        field.data = self.build_validation_data(required=['postalcode'])
+        with self.assertRaises(ValidationError):
+            field.clean({'country': 'COUNTRY', 'postalcode': ''})
+
+    def test_one_field_required_supplied_data(self):
+        field = AddressField()
+        field.data = self.build_validation_data(required=['postalcode'])
+        field.clean({'country': 'COUNTRY', 'postalcode': 'BS1 2AB'})
-- 
GitLab