diff --git a/project/api/location/schemas.py b/project/api/location/schemas.py index 13d59f1..131bcc3 100644 --- a/project/api/location/schemas.py +++ b/project/api/location/schemas.py @@ -1,22 +1,50 @@ -from marshmallow import fields, validate +from marshmallow import validate, validates_schema, ValidationError from project.api import marshmallow from project.models import Location from project.api.fields import NumericStr +from project.api.schemas import ( + SQLAlchemyBaseSchema, + PostSchema, + PatchSchema, +) -class LocationIdSchema(marshmallow.SQLAlchemySchema): +class LocationModelSchema(SQLAlchemyBaseSchema): class Meta: model = Location -class LocationSchema(LocationIdSchema): +class LocationBaseSchemaMixin(object): street = marshmallow.auto_field() - postalCode = marshmallow.auto_field() + postalCode = marshmallow.auto_field(validate=validate.Length(max=10)) city = marshmallow.auto_field() state = marshmallow.auto_field() country = marshmallow.auto_field() - longitude = NumericStr() - latitude = NumericStr() + latitude = NumericStr( + validate=validate.Range(-90, 90, min_inclusive=False, max_inclusive=False), + metadata={"description": "Latitude between (-90, 90)"}, + allow_none=True, + ) + longitude = NumericStr( + validate=validate.Range(-180, 180, min_inclusive=False, max_inclusive=False), + metadata={"description": "Longitude between (-180, 180)"}, + allow_none=True, + ) + + @validates_schema + def validate_location(self, data, **kwargs): + lat_set = "latitude" in data and data["latitude"] is not None + lon_set = "longitude" in data and data["longitude"] is not None + + if lat_set and not lon_set: + raise ValidationError("If latitude is given, longitude is required.") + + if lon_set and not lat_set: + raise ValidationError("If longitude is given, latitude is required.") + + +class LocationSchema(LocationModelSchema, LocationBaseSchemaMixin): + pass class LocationDumpSchema(LocationSchema): @@ -27,39 +55,13 @@ class LocationSearchItemSchema(LocationSchema): pass -class LocationPostRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = Location - - street = fields.Str(validate=validate.Length(max=255), missing=None) - postalCode = fields.Str(validate=validate.Length(max=10), missing=None) - city = fields.Str(validate=validate.Length(max=255), missing=None) - state = fields.Str(validate=validate.Length(max=255), missing=None) - country = fields.Str(validate=validate.Length(max=255), missing=None) - longitude = NumericStr(validate=validate.Range(-180, 180), missing=None) - latitude = NumericStr(validate=validate.Range(-90, 90), missing=None) +class LocationPostRequestSchema( + PostSchema, LocationModelSchema, LocationBaseSchemaMixin +): + pass -class LocationPostRequestLoadSchema(LocationPostRequestSchema): - class Meta: - model = Location - load_instance = True - - -class LocationPatchRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = Location - - street = fields.Str(validate=validate.Length(max=255), allow_none=True) - postalCode = fields.Str(validate=validate.Length(max=10), allow_none=True) - city = fields.Str(validate=validate.Length(max=255), allow_none=True) - state = fields.Str(validate=validate.Length(max=255), allow_none=True) - country = fields.Str(validate=validate.Length(max=255), allow_none=True) - longitude = NumericStr(validate=validate.Range(-180, 180), allow_none=True) - latitude = NumericStr(validate=validate.Range(-90, 90), allow_none=True) - - -class LocationPatchRequestLoadSchema(LocationPatchRequestSchema): - class Meta: - model = Location - load_instance = True +class LocationPatchRequestSchema( + PatchSchema, LocationModelSchema, LocationBaseSchemaMixin +): + pass diff --git a/project/api/organization/resources.py b/project/api/organization/resources.py index 8b49202..9484ca7 100644 --- a/project/api/organization/resources.py +++ b/project/api/organization/resources.py @@ -20,7 +20,6 @@ from project.api.organizer.schemas import ( OrganizerListResponseSchema, OrganizerIdSchema, OrganizerPostRequestSchema, - OrganizerPostRequestLoadSchema, ) from project.api.event_reference.schemas import ( EventReferenceListRequestSchema, @@ -35,7 +34,6 @@ from project.api.place.schemas import ( PlaceListResponseSchema, PlaceIdSchema, PlacePostRequestSchema, - PlacePostRequestLoadSchema, ) from project.services.event import get_event_dates_query, get_events_query from project.services.event_search import EventSearchParams @@ -131,7 +129,9 @@ class OrganizationOrganizerListResource(BaseResource): admin_unit = get_admin_unit_for_manage_or_404(id) access_or_401(admin_unit, "organizer:create") - organizer = OrganizerPostRequestLoadSchema().load(kwargs, session=db.session) + organizer = OrganizerPostRequestSchema(load_instance=True).load( + kwargs, session=db.session + ) organizer.admin_unit_id = admin_unit.id db.session.add(organizer) db.session.commit() @@ -163,7 +163,9 @@ class OrganizationPlaceListResource(BaseResource): admin_unit = get_admin_unit_for_manage_or_404(id) access_or_401(admin_unit, "place:create") - place = PlacePostRequestLoadSchema().load(kwargs, session=db.session) + place = PlacePostRequestSchema(load_instance=True).load( + kwargs, session=db.session + ) place.admin_unit_id = admin_unit.id db.session.add(place) db.session.commit() diff --git a/project/api/organizer/resources.py b/project/api/organizer/resources.py index ccc0b59..61edd5f 100644 --- a/project/api/organizer/resources.py +++ b/project/api/organizer/resources.py @@ -5,9 +5,7 @@ from project.api.resources import BaseResource from project.api.organizer.schemas import ( OrganizerSchema, OrganizerPostRequestSchema, - OrganizerPostRequestLoadSchema, OrganizerPatchRequestSchema, - OrganizerPatchRequestLoadSchema, ) from project.models import EventOrganizer from project.oauth2 import require_oauth @@ -35,7 +33,7 @@ class OrganizerResource(BaseResource): organizer = EventOrganizer.query.get_or_404(id) access_or_401(organizer.adminunit, "organizer:update") - organizer = OrganizerPostRequestLoadSchema().load( + organizer = OrganizerPostRequestSchema(load_instance=True).load( kwargs, session=db.session, instance=organizer ) db.session.commit() @@ -55,7 +53,7 @@ class OrganizerResource(BaseResource): organizer = EventOrganizer.query.get_or_404(id) access_or_401(organizer.adminunit, "organizer:update") - organizer = OrganizerPatchRequestLoadSchema().load( + organizer = OrganizerPatchRequestSchema(load_instance=True).load( kwargs, session=db.session, instance=organizer ) db.session.commit() diff --git a/project/api/organizer/schemas.py b/project/api/organizer/schemas.py index 0058c95..bec2112 100644 --- a/project/api/organizer/schemas.py +++ b/project/api/organizer/schemas.py @@ -4,39 +4,49 @@ from project.models import EventOrganizer from project.api.location.schemas import ( LocationSchema, LocationPostRequestSchema, - LocationPostRequestLoadSchema, LocationPatchRequestSchema, - LocationPatchRequestLoadSchema, ) from project.api.image.schemas import ImageSchema from project.api.organization.schemas import OrganizationRefSchema -from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema +from project.api.schemas import ( + SQLAlchemyBaseSchema, + IdSchemaMixin, + TrackableSchemaMixin, + PostSchema, + PatchSchema, + PaginationRequestSchema, + PaginationResponseSchema, +) -class OrganizerIdSchema(marshmallow.SQLAlchemySchema): +class OrganizerModelSchema(SQLAlchemyBaseSchema): class Meta: model = EventOrganizer - id = marshmallow.auto_field() + +class OrganizerIdSchema(OrganizerModelSchema, IdSchemaMixin): + pass -class OrganizerBaseSchema(OrganizerIdSchema): - created_at = marshmallow.auto_field() - updated_at = marshmallow.auto_field() - name = marshmallow.auto_field() - url = marshmallow.auto_field() - email = marshmallow.auto_field() +class OrganizerBaseSchemaMixin(TrackableSchemaMixin): + name = marshmallow.auto_field( + required=True, validate=validate.Length(min=3, max=255) + ) + url = marshmallow.auto_field(validate=[validate.URL(), validate.Length(max=255)]) + email = marshmallow.auto_field( + validate=[validate.Email(), validate.Length(max=255)] + ) phone = marshmallow.auto_field() fax = marshmallow.auto_field() -class OrganizerSchema(OrganizerBaseSchema): +class OrganizerSchema(OrganizerIdSchema, OrganizerBaseSchemaMixin): location = fields.Nested(LocationSchema) logo = fields.Nested(ImageSchema) organization = fields.Nested(OrganizationRefSchema, attribute="adminunit") -class OrganizerDumpSchema(OrganizerBaseSchema): +class OrganizerDumpSchema(OrganizerIdSchema, OrganizerBaseSchemaMixin): location_id = fields.Int() logo_id = fields.Int() organization_id = fields.Int(attribute="admin_unit_id") @@ -58,48 +68,13 @@ class OrganizerListResponseSchema(PaginationResponseSchema): ) -class OrganizerPostRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = EventOrganizer - - name = fields.Str(required=True, validate=validate.Length(min=3, max=255)) - url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None) - email = fields.Str( - validate=[validate.Email(), validate.Length(max=255)], missing=None - ) - phone = fields.Str(validate=validate.Length(max=255), missing=None) - fax = fields.Str(validate=validate.Length(max=255), missing=None) - +class OrganizerPostRequestSchema( + PostSchema, OrganizerModelSchema, OrganizerBaseSchemaMixin +): location = fields.Nested(LocationPostRequestSchema, missing=None) -class OrganizerPostRequestLoadSchema(OrganizerPostRequestSchema): - class Meta: - model = EventOrganizer - load_instance = True - - location = fields.Nested(LocationPostRequestLoadSchema, missing=None) - - -class OrganizerPatchRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = EventOrganizer - - name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True) - url = fields.Str( - validate=[validate.URL(), validate.Length(max=255)], allow_none=True - ) - email = fields.Str( - validate=[validate.Email(), validate.Length(max=255)], allow_none=True - ) - phone = fields.Str(validate=validate.Length(max=255), allow_none=True) - fax = fields.Str(validate=validate.Length(max=255), allow_none=True) +class OrganizerPatchRequestSchema( + PatchSchema, OrganizerModelSchema, OrganizerBaseSchemaMixin +): location = fields.Nested(LocationPatchRequestSchema, allow_none=True) - - -class OrganizerPatchRequestLoadSchema(OrganizerPatchRequestSchema): - class Meta: - model = EventOrganizer - load_instance = True - - location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True) diff --git a/project/api/place/resources.py b/project/api/place/resources.py index 1573dbc..fd6f496 100644 --- a/project/api/place/resources.py +++ b/project/api/place/resources.py @@ -5,9 +5,7 @@ from project.api.resources import BaseResource from project.api.place.schemas import ( PlaceSchema, PlacePostRequestSchema, - PlacePostRequestLoadSchema, PlacePatchRequestSchema, - PlacePatchRequestLoadSchema, ) from project.models import EventPlace from project.oauth2 import require_oauth @@ -33,7 +31,7 @@ class PlaceResource(BaseResource): place = EventPlace.query.get_or_404(id) access_or_401(place.adminunit, "place:update") - place = PlacePostRequestLoadSchema().load( + place = PlacePostRequestSchema(load_instance=True).load( kwargs, session=db.session, instance=place ) db.session.commit() @@ -49,7 +47,7 @@ class PlaceResource(BaseResource): place = EventPlace.query.get_or_404(id) access_or_401(place.adminunit, "place:update") - place = PlacePatchRequestLoadSchema().load( + place = PlacePatchRequestSchema(load_instance=True).load( kwargs, session=db.session, instance=place ) db.session.commit() diff --git a/project/api/place/schemas.py b/project/api/place/schemas.py index c8dc4b4..eba519e 100644 --- a/project/api/place/schemas.py +++ b/project/api/place/schemas.py @@ -6,36 +6,44 @@ from project.api.location.schemas import ( LocationSchema, LocationSearchItemSchema, LocationPostRequestSchema, - LocationPostRequestLoadSchema, LocationPatchRequestSchema, - LocationPatchRequestLoadSchema, ) from project.api.organization.schemas import OrganizationRefSchema -from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema +from project.api.schemas import ( + SQLAlchemyBaseSchema, + IdSchemaMixin, + TrackableSchemaMixin, + PostSchema, + PatchSchema, + PaginationRequestSchema, + PaginationResponseSchema, +) -class PlaceIdSchema(marshmallow.SQLAlchemySchema): +class PlaceModelSchema(SQLAlchemyBaseSchema): class Meta: model = EventPlace - id = marshmallow.auto_field() + +class PlaceIdSchema(PlaceModelSchema, IdSchemaMixin): + pass -class PlaceBaseSchema(PlaceIdSchema): - created_at = marshmallow.auto_field() - updated_at = marshmallow.auto_field() - name = marshmallow.auto_field() - url = marshmallow.auto_field() +class PlaceBaseSchemaMixin(TrackableSchemaMixin): + name = marshmallow.auto_field( + required=True, validate=validate.Length(min=3, max=255) + ) + url = marshmallow.auto_field(validate=[validate.URL(), validate.Length(max=255)]) description = marshmallow.auto_field() -class PlaceSchema(PlaceBaseSchema): +class PlaceSchema(PlaceIdSchema, PlaceBaseSchemaMixin): location = fields.Nested(LocationSchema) photo = fields.Nested(ImageSchema) organization = fields.Nested(OrganizationRefSchema, attribute="adminunit") -class PlaceDumpSchema(PlaceBaseSchema): +class PlaceDumpSchema(PlaceIdSchema, PlaceBaseSchemaMixin): location_id = fields.Int() photo_id = fields.Int() organization_id = fields.Int(attribute="admin_unit_id") @@ -64,39 +72,9 @@ class PlaceListResponseSchema(PaginationResponseSchema): ) -class PlacePostRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = EventPlace - - name = fields.Str(required=True, validate=validate.Length(min=3, max=255)) - url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None) - description = fields.Str(missing=None) +class PlacePostRequestSchema(PostSchema, PlaceModelSchema, PlaceBaseSchemaMixin): location = fields.Nested(LocationPostRequestSchema, missing=None) -class PlacePostRequestLoadSchema(PlacePostRequestSchema): - class Meta: - model = EventPlace - load_instance = True - - location = fields.Nested(LocationPostRequestLoadSchema, missing=None) - - -class PlacePatchRequestSchema(marshmallow.SQLAlchemySchema): - class Meta: - model = EventPlace - - name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True) - url = fields.Str( - validate=[validate.URL(), validate.Length(max=255)], allow_none=True - ) - description = fields.Str(allow_none=True) +class PlacePatchRequestSchema(PatchSchema, PlaceModelSchema, PlaceBaseSchemaMixin): location = fields.Nested(LocationPatchRequestSchema, allow_none=True) - - -class PlacePatchRequestLoadSchema(PlacePatchRequestSchema): - class Meta: - model = EventPlace - load_instance = True - - location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True) diff --git a/project/api/schemas.py b/project/api/schemas.py index 12a4a2d..5ff3f7d 100644 --- a/project/api/schemas.py +++ b/project/api/schemas.py @@ -1,5 +1,37 @@ from project.api import marshmallow -from marshmallow import fields, validate +from marshmallow import fields, validate, missing + + +class SQLAlchemyBaseSchema(marshmallow.SQLAlchemySchema): + def __init__(self, *args, **kwargs): + load_instance = kwargs.pop("load_instance", False) + super().__init__(*args, **kwargs) + self.opts.load_instance = load_instance + + +class PostSchema(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for name, field in self._declared_fields.items(): + if not field.required: + field.missing = None + + +class PatchSchema(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for name, field in self._declared_fields.items(): + field.required = False + field.allow_none = True + + +class IdSchemaMixin(object): + id = marshmallow.auto_field(dump_only=True, default=missing) + + +class TrackableSchemaMixin(object): + created_at = marshmallow.auto_field(dump_only=True) + updated_at = marshmallow.auto_field(dump_only=True) class ErrorResponseSchema(marshmallow.Schema): diff --git a/tests/api/test_fields.py b/tests/api/test_fields.py index 8456657..560863c 100644 --- a/tests/api/test_fields.py +++ b/tests/api/test_fields.py @@ -1,3 +1,6 @@ +import pytest + + def test_numeric_str_serialize(client, seeder, utils): from project.api.location.schemas import LocationSchema from project.models import Location @@ -16,32 +19,35 @@ def test_numeric_str_serialize(client, seeder, utils): assert data["longitude"] == "10.4333312" -def test_numeric_str_deserialize(client, seeder, utils): - from project.api.location.schemas import LocationPostRequestLoadSchema - - data = { - "latitude": "51.9077888", - "longitude": "10.4333312", - } - - schema = LocationPostRequestLoadSchema() - location = schema.load(data) - - assert location.latitude == 51.9077888 - assert location.longitude == 10.4333312 - - -def test_numeric_str_deserialize_invalid(client, seeder, utils): - from project.api.location.schemas import LocationPostRequestLoadSchema - import pytest +@pytest.mark.parametrize( + "latitude, longitude, valid", + [ + ("51.9077888", "10.4333312", True), + ("-89.9", "0", True), + ("-90", "0", False), + ("0", "179.9", True), + ("0", "180", False), + ("0", None, False), + (None, "0", False), + ("Quatsch", "Quatsch", False), + ], +) +def test_numeric_str_deserialize(latitude, longitude, valid): + from project.api.location.schemas import LocationPostRequestSchema from marshmallow import ValidationError data = { - "latitude": "Quatsch", - "longitude": "Quatsch", + "latitude": latitude, + "longitude": longitude, } - schema = LocationPostRequestLoadSchema() + schema = LocationPostRequestSchema(load_instance=True) + + if valid: + location = schema.load(data) + assert location.latitude == float(latitude) + assert location.longitude == float(longitude) + return with pytest.raises(ValidationError): schema.load(data)