diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index ce2f455..0000000 --- a/.coveragerc +++ /dev/null @@ -1,2 +0,0 @@ -[run] -relative_files = True \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index bf81191..eb199a5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,7 +5,7 @@ "python.linting.pylintEnabled": false, "python.linting.flake8Enabled": true, "python.testing.pytestArgs": [ - "tests" + "tests", "--capture=sys" ], "python.testing.unittestEnabled": false, "python.testing.nosetestsEnabled": false, diff --git a/migrations/versions/ddb85cb1c21e_.py b/migrations/versions/ddb85cb1c21e_.py new file mode 100644 index 0000000..3aa09b9 --- /dev/null +++ b/migrations/versions/ddb85cb1c21e_.py @@ -0,0 +1,91 @@ +"""empty message + +Revision ID: ddb85cb1c21e +Revises: b1a6e7630185 +Create Date: 2021-02-02 15:17:21.988363 + +""" +from alembic import op +import sqlalchemy as sa +import sqlalchemy_utils +from project import dbtypes + + +# revision identifiers, used by Alembic. +revision = "ddb85cb1c21e" +down_revision = "b1a6e7630185" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "oauth2_client", + sa.Column("client_id", sa.String(length=48), nullable=True), + sa.Column("client_secret", sa.String(length=120), nullable=True), + sa.Column("client_id_issued_at", sa.Integer(), nullable=False), + sa.Column("client_secret_expires_at", sa.Integer(), nullable=False), + sa.Column("client_metadata", sa.Text(), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_oauth2_client_client_id"), "oauth2_client", ["client_id"], unique=False + ) + op.create_table( + "oauth2_code", + sa.Column("code", sa.String(length=120), nullable=False), + sa.Column("client_id", sa.String(length=48), nullable=True), + sa.Column("redirect_uri", sa.Text(), nullable=True), + sa.Column("response_type", sa.Text(), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("nonce", sa.Text(), nullable=True), + sa.Column("auth_time", sa.Integer(), nullable=False), + sa.Column("code_challenge", sa.Text(), nullable=True), + sa.Column("code_challenge_method", sa.String(length=48), nullable=True), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("code"), + ) + op.create_table( + "oauth2_token", + sa.Column("client_id", sa.String(length=48), nullable=True), + sa.Column("token_type", sa.String(length=40), nullable=True), + sa.Column("access_token", sa.String(length=255), nullable=False), + sa.Column("refresh_token", sa.String(length=255), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("revoked", sa.Boolean(), nullable=True), + sa.Column("issued_at", sa.Integer(), nullable=False), + sa.Column("expires_in", sa.Integer(), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("access_token"), + ) + op.create_index( + op.f("ix_oauth2_token_refresh_token"), + "oauth2_token", + ["refresh_token"], + unique=False, + ) + op.create_unique_constraint( + "eventplace_name_admin_unit_id", "eventplace", ["name", "admin_unit_id"] + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("eventplace_name_admin_unit_id", "eventplace", type_="unique") + op.drop_index(op.f("ix_oauth2_token_refresh_token"), table_name="oauth2_token") + op.drop_table("oauth2_token") + op.drop_table("oauth2_code") + op.drop_index(op.f("ix_oauth2_client_client_id"), table_name="oauth2_client") + op.drop_table("oauth2_client") + # ### end Alembic commands ### diff --git a/project/__init__.py b/project/__init__.py index 3c87974..3692115 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,5 +1,5 @@ import os -from flask import Flask +from flask import Flask, url_for, redirect, request, jsonify from flask_sqlalchemy import SQLAlchemy from flask_security import ( Security, @@ -10,12 +10,8 @@ from flask_cors import CORS from flask_qrcode import QRcode from flask_mail import Mail, email_dispatched from flask_migrate import Migrate -from flask_marshmallow import Marshmallow -from flask_restful import Api -from apispec import APISpec -from apispec.ext.marshmallow import MarshmallowPlugin -from flask_apispec.extension import FlaskApiSpec from flask_gzip import Gzip +from webargs import flaskparser # Create app app = Flask(__name__) @@ -59,25 +55,6 @@ babel = Babel(app) # cors cors = CORS(app, resources={r"/api/*", "/swagger/"}) -# API -rest_api = Api(app, "/api/v1") -marshmallow = Marshmallow(app) -marshmallow_plugin = MarshmallowPlugin() -app.config.update( - { - "APISPEC_SPEC": APISpec( - title="Oveda API", - version="0.1.0", - plugins=[marshmallow_plugin], - openapi_version="2.0", - info=dict( - description="This API provides endpoints to interact with the Oveda data. At the moment, there is no authorization needed." - ), - ), - } -) -api_docs = FlaskApiSpec(app) - # Mail mail_server = os.getenv("MAIL_SERVER") @@ -108,6 +85,9 @@ if app.config["MAIL_SUPPRESS_SEND"]: db = SQLAlchemy(app) migrate = Migrate(app, db) +# API +from project.api import RestApi + # qr code QRcode(app) @@ -123,6 +103,13 @@ from project.forms.security import ExtendedRegisterForm user_datastore = SQLAlchemySessionUserDatastore(db.session, User, Role) security = Security(app, user_datastore, register_form=ExtendedRegisterForm) +# OAuth2 +from project.oauth2 import config_oauth + +config_oauth(app) + +# Init misc modules + from project import i10n from project import jinja_filters from project import init_data @@ -142,6 +129,9 @@ from project.views import ( image, manage, organizer, + oauth, + oauth2_client, + oauth2_token, planing, reference, reference_request, diff --git a/project/access.py b/project/access.py index 6835864..a116324 100644 --- a/project/access.py +++ b/project/access.py @@ -1,4 +1,5 @@ from flask import abort +from flask_login import login_user from flask_security import current_user from flask_security.utils import FsPermNeed from flask_principal import Permission @@ -12,6 +13,20 @@ def has_current_user_permission(permission): return user_perm.can() +def has_owner_access(user_id): + return user_id == current_user.id + + +def owner_access_or_401(user_id): + if not has_owner_access(user_id): + abort(401) + + +def login_api_user_or_401(user): + if not login_user(user): + abort(401) + + def has_admin_unit_member_role(admin_unit_member, role_name): for role in admin_unit_member.roles: if role.name == role_name: diff --git a/project/api/__init__.py b/project/api/__init__.py index 6df71ae..8bd96ca 100644 --- a/project/api/__init__.py +++ b/project/api/__init__.py @@ -1,4 +1,107 @@ -from project import rest_api, api_docs +from flask_restful import Api +from sqlalchemy.exc import IntegrityError +from psycopg2.errorcodes import UNIQUE_VIOLATION +from werkzeug.exceptions import HTTPException, UnprocessableEntity +from marshmallow import ValidationError +from project.utils import get_localized_scope +from project import app +from flask_marshmallow import Marshmallow +from apispec import APISpec +from apispec.ext.marshmallow import MarshmallowPlugin +from flask_apispec.extension import FlaskApiSpec + + +class RestApi(Api): + def handle_error(self, err): + from project.api.schemas import ( + ErrorResponseSchema, + UnprocessableEntityResponseSchema, + ) + + schema = None + data = {} + code = 500 + + if ( + isinstance(err, IntegrityError) + and err.orig + and err.orig.pgcode == UNIQUE_VIOLATION + ): + data["name"] = "Unique Violation" + data[ + "message" + ] = "An entry with the entered values ​​already exists. Duplicate entries are not allowed." + code = 400 + schema = ErrorResponseSchema() + elif isinstance(err, HTTPException): + data["name"] = err.name + data["message"] = err.description + code = err.code + + if ( + isinstance(err, UnprocessableEntity) + and err.exc + and isinstance(err.exc, ValidationError) + ): + data["name"] = err.name + data["message"] = err.description + code = err.code + schema = UnprocessableEntityResponseSchema() + + if ( + getattr(err.exc, "args", None) + and isinstance(err.exc.args, tuple) + and len(err.exc.args) > 0 + ): + arg = err.exc.args[0] + if isinstance(arg, dict): + errors = [] + for field, messages in arg.items(): + if isinstance(messages, list): + for message in messages: + error = {"field": field, "message": message} + errors.append(error) + + if len(errors) > 0: + data["errors"] = errors + else: + schema = ErrorResponseSchema() + + # Call default error handler that propagates error further + try: + super().handle_error(err) + except Exception: + if not schema: + raise + + return schema.dump(data), code + + +scope_list = [ + "organizer:write", + "place:write", + "event:write", +] +scopes = {k: get_localized_scope(k) for v, k in enumerate(scope_list)} + +rest_api = RestApi(app, "/api/v1", catch_all_404s=True) +marshmallow = Marshmallow(app) +marshmallow_plugin = MarshmallowPlugin() +app.config.update( + { + "APISPEC_SPEC": APISpec( + title="Oveda API", + version="0.1.0", + plugins=[marshmallow_plugin], + openapi_version="2.0", + info=dict( + description="This API provides endpoints to interact with the Oveda data." + ), + ), + } +) + +api_docs = FlaskApiSpec(app) def enum_to_properties(self, field, **kwargs): @@ -17,8 +120,6 @@ def add_api_resource(resource, url, endpoint): api_docs.register(resource, endpoint=endpoint) -from project import marshmallow_plugin - marshmallow_plugin.converter.add_attribute_function(enum_to_properties) import project.api.event.resources @@ -26,8 +127,6 @@ import project.api.event_category.resources import project.api.event_date.resources import project.api.event_reference.resources import project.api.dump.resources -import project.api.image.resources -import project.api.location.resources import project.api.organization.resources import project.api.organizer.resources import project.api.place.resources diff --git a/project/api/dump/schemas.py b/project/api/dump/schemas.py index c78e973..649e802 100644 --- a/project/api/dump/schemas.py +++ b/project/api/dump/schemas.py @@ -1,4 +1,4 @@ -from project import marshmallow +from project.api import marshmallow from marshmallow import fields from project.api.event.schemas import EventDumpSchema from project.api.place.schemas import PlaceDumpSchema diff --git a/project/api/event/schemas.py b/project/api/event/schemas.py index c2f435b..7b904f1 100644 --- a/project/api/event/schemas.py +++ b/project/api/event/schemas.py @@ -1,4 +1,4 @@ -from project import marshmallow +from project.api import marshmallow from marshmallow import fields, validate from marshmallow_enum import EnumField from project.models import ( @@ -10,7 +10,7 @@ from project.models import ( from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema from project.api.organization.schemas import OrganizationRefSchema from project.api.organizer.schemas import OrganizerRefSchema -from project.api.image.schemas import ImageRefSchema +from project.api.image.schemas import ImageSchema from project.api.place.schemas import PlaceRefSchema, PlaceSearchItemSchema from project.api.event_category.schemas import ( EventCategoryRefSchema, @@ -55,7 +55,7 @@ class EventSchema(EventBaseSchema): organization = fields.Nested(OrganizationRefSchema, attribute="admin_unit") organizer = fields.Nested(OrganizerRefSchema) place = fields.Nested(PlaceRefSchema, attribute="event_place") - photo = fields.Nested(ImageRefSchema) + photo = fields.Nested(ImageSchema) categories = fields.List(fields.Nested(EventCategoryRefSchema)) @@ -85,7 +85,7 @@ class EventSearchItemSchema(EventRefSchema): start = marshmallow.auto_field() end = marshmallow.auto_field() recurrence_rule = marshmallow.auto_field() - photo = fields.Nested(ImageRefSchema) + photo = fields.Nested(ImageSchema) place = fields.Nested(PlaceSearchItemSchema, attribute="event_place") status = EnumField(EventStatus) booked_up = marshmallow.auto_field() diff --git a/project/api/event_category/schemas.py b/project/api/event_category/schemas.py index a354a06..0bd66bd 100644 --- a/project/api/event_category/schemas.py +++ b/project/api/event_category/schemas.py @@ -1,5 +1,5 @@ from marshmallow import fields -from project import marshmallow +from project.api import marshmallow from project.models import EventCategory from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema diff --git a/project/api/event_date/schemas.py b/project/api/event_date/schemas.py index 159972f..3efa0f5 100644 --- a/project/api/event_date/schemas.py +++ b/project/api/event_date/schemas.py @@ -1,4 +1,4 @@ -from project import marshmallow +from project.api import marshmallow from marshmallow import fields from project.models import EventDate from project.api.event.schemas import ( diff --git a/project/api/event_reference/schemas.py b/project/api/event_reference/schemas.py index 14cdf20..02808c7 100644 --- a/project/api/event_reference/schemas.py +++ b/project/api/event_reference/schemas.py @@ -1,5 +1,5 @@ from marshmallow import fields -from project import marshmallow +from project.api import marshmallow from project.models import EventReference from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema from project.api.event.schemas import EventRefSchema diff --git a/project/api/fields.py b/project/api/fields.py new file mode 100644 index 0000000..90c9bd5 --- /dev/null +++ b/project/api/fields.py @@ -0,0 +1,15 @@ +from marshmallow import fields, ValidationError + + +class NumericStr(fields.String): + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + + return str(value) + + def _deserialize(self, value, attr, data, **kwargs): + try: + return float(value) + except ValueError as error: + raise ValidationError("Must be a numeric value.") from error diff --git a/project/api/image/resources.py b/project/api/image/resources.py deleted file mode 100644 index 25cb5b6..0000000 --- a/project/api/image/resources.py +++ /dev/null @@ -1,15 +0,0 @@ -from project.api import add_api_resource -from flask_apispec import marshal_with, doc -from project.api.resources import BaseResource -from project.api.image.schemas import ImageSchema -from project.models import Image - - -class ImageResource(BaseResource): - @doc(summary="Get image", tags=["Images"]) - @marshal_with(ImageSchema) - def get(self, id): - return Image.query.get_or_404(id) - - -add_api_resource(ImageResource, "/images/", "api_v1_image") diff --git a/project/api/image/schemas.py b/project/api/image/schemas.py index 9a34c00..02f9a49 100644 --- a/project/api/image/schemas.py +++ b/project/api/image/schemas.py @@ -1,4 +1,4 @@ -from project import marshmallow +from project.api import marshmallow from project.models import Image @@ -10,8 +10,6 @@ class ImageIdSchema(marshmallow.SQLAlchemySchema): class ImageBaseSchema(ImageIdSchema): - created_at = marshmallow.auto_field() - updated_at = marshmallow.auto_field() copyright_text = marshmallow.auto_field() @@ -27,13 +25,3 @@ class ImageSchema(ImageBaseSchema): class ImageDumpSchema(ImageBaseSchema): pass - - -class ImageRefSchema(ImageIdSchema): - image_url = marshmallow.URLFor( - "image", - values=dict(id="", s=500), - metadata={ - "description": "Append query arguments w for width, h for height or s for size(width and height)." - }, - ) diff --git a/project/api/location/resources.py b/project/api/location/resources.py deleted file mode 100644 index 2e3d2e9..0000000 --- a/project/api/location/resources.py +++ /dev/null @@ -1,15 +0,0 @@ -from project.api import add_api_resource -from flask_apispec import marshal_with, doc -from project.api.resources import BaseResource -from project.api.location.schemas import LocationSchema -from project.models import Location - - -class LocationResource(BaseResource): - @doc(summary="Get location", tags=["Locations"]) - @marshal_with(LocationSchema) - def get(self, id): - return Location.query.get_or_404(id) - - -add_api_resource(LocationResource, "/locations/", "api_v1_location") diff --git a/project/api/location/schemas.py b/project/api/location/schemas.py index c16cb89..13d59f1 100644 --- a/project/api/location/schemas.py +++ b/project/api/location/schemas.py @@ -1,38 +1,65 @@ -from marshmallow import fields -from project import marshmallow +from marshmallow import fields, validate +from project.api import marshmallow from project.models import Location +from project.api.fields import NumericStr class LocationIdSchema(marshmallow.SQLAlchemySchema): class Meta: model = Location - id = marshmallow.auto_field() - class LocationSchema(LocationIdSchema): - created_at = marshmallow.auto_field() - updated_at = marshmallow.auto_field() street = marshmallow.auto_field() postalCode = marshmallow.auto_field() city = marshmallow.auto_field() state = marshmallow.auto_field() country = marshmallow.auto_field() - longitude = fields.Str() - latitude = fields.Str() + longitude = NumericStr() + latitude = NumericStr() class LocationDumpSchema(LocationSchema): pass -class LocationRefSchema(LocationIdSchema): +class LocationSearchItemSchema(LocationSchema): pass -class LocationSearchItemSchema(LocationRefSchema): +class LocationPostRequestSchema(marshmallow.SQLAlchemySchema): class Meta: model = Location - longitude = fields.Str() - latitude = fields.Str() + street = fields.Str(validate=validate.Length(max=255), missing=None) + postalCode = fields.Str(validate=validate.Length(max=10), missing=None) + city = fields.Str(validate=validate.Length(max=255), missing=None) + state = fields.Str(validate=validate.Length(max=255), missing=None) + country = fields.Str(validate=validate.Length(max=255), missing=None) + longitude = NumericStr(validate=validate.Range(-180, 180), missing=None) + latitude = NumericStr(validate=validate.Range(-90, 90), missing=None) + + +class LocationPostRequestLoadSchema(LocationPostRequestSchema): + class Meta: + model = Location + load_instance = True + + +class LocationPatchRequestSchema(marshmallow.SQLAlchemySchema): + class Meta: + model = Location + + street = fields.Str(validate=validate.Length(max=255), allow_none=True) + postalCode = fields.Str(validate=validate.Length(max=10), allow_none=True) + city = fields.Str(validate=validate.Length(max=255), allow_none=True) + state = fields.Str(validate=validate.Length(max=255), allow_none=True) + country = fields.Str(validate=validate.Length(max=255), allow_none=True) + longitude = NumericStr(validate=validate.Range(-180, 180), allow_none=True) + latitude = NumericStr(validate=validate.Range(-90, 90), allow_none=True) + + +class LocationPatchRequestLoadSchema(LocationPatchRequestSchema): + class Meta: + model = Location + load_instance = True diff --git a/project/api/organization/resources.py b/project/api/organization/resources.py index fa8dea6..ed356ab 100644 --- a/project/api/organization/resources.py +++ b/project/api/organization/resources.py @@ -27,7 +27,13 @@ from project.services.reference import ( get_reference_incoming_query, get_reference_outgoing_query, ) -from project.api.place.schemas import PlaceListRequestSchema, PlaceListResponseSchema +from project.api.place.schemas import ( + PlaceListRequestSchema, + PlaceListResponseSchema, + PlaceIdSchema, + PlacePostRequestSchema, + PlacePostRequestLoadSchema, +) from project.services.event import get_event_dates_query, get_events_query from project.services.event_search import EventSearchParams from project.services.admin_unit import ( @@ -35,6 +41,14 @@ from project.services.admin_unit import ( get_organizer_query, get_place_query, ) +from project.oauth2 import require_oauth +from authlib.integrations.flask_oauth2 import current_token +from project import db +from project.access import ( + access_or_401, + get_admin_unit_for_manage_or_404, + login_api_user_or_401, +) class OrganizationResource(BaseResource): @@ -113,6 +127,26 @@ class OrganizationPlaceListResource(BaseResource): pagination = get_place_query(admin_unit.id, name).paginate() return pagination + @doc( + summary="Add new place", + tags=["Organizations", "Places"], + security=[{"oauth2": ["place:write"]}], + ) + @use_kwargs(PlacePostRequestSchema, location="json") + @marshal_with(PlaceIdSchema, 201) + @require_oauth("place:write") + def post(self, id, **kwargs): + login_api_user_or_401(current_token.user) + admin_unit = get_admin_unit_for_manage_or_404(id) + access_or_401(admin_unit, "place:create") + + place = PlacePostRequestLoadSchema().load(kwargs, session=db.session) + place.admin_unit_id = admin_unit.id + db.session.add(place) + db.session.commit() + + return place, 201 + class OrganizationIncomingEventReferenceListResource(BaseResource): @doc( diff --git a/project/api/organization/schemas.py b/project/api/organization/schemas.py index 4cbea2d..2587432 100644 --- a/project/api/organization/schemas.py +++ b/project/api/organization/schemas.py @@ -1,8 +1,8 @@ from marshmallow import fields -from project import marshmallow +from project.api import marshmallow from project.models import AdminUnit -from project.api.location.schemas import LocationRefSchema -from project.api.image.schemas import ImageRefSchema +from project.api.location.schemas import LocationSchema +from project.api.image.schemas import ImageSchema from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema @@ -25,8 +25,8 @@ class OrganizationBaseSchema(OrganizationIdSchema): class OrganizationSchema(OrganizationBaseSchema): - location = fields.Nested(LocationRefSchema) - logo = fields.Nested(ImageRefSchema) + location = fields.Nested(LocationSchema) + logo = fields.Nested(ImageSchema) class OrganizationDumpSchema(OrganizationBaseSchema): diff --git a/project/api/organizer/schemas.py b/project/api/organizer/schemas.py index bcbf222..f732aae 100644 --- a/project/api/organizer/schemas.py +++ b/project/api/organizer/schemas.py @@ -1,8 +1,8 @@ from marshmallow import fields -from project import marshmallow +from project.api import marshmallow from project.models import EventOrganizer -from project.api.location.schemas import LocationRefSchema -from project.api.image.schemas import ImageRefSchema +from project.api.location.schemas import LocationSchema +from project.api.image.schemas import ImageSchema from project.api.organization.schemas import OrganizationRefSchema from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema @@ -25,8 +25,8 @@ class OrganizerBaseSchema(OrganizerIdSchema): class OrganizerSchema(OrganizerBaseSchema): - location = fields.Nested(LocationRefSchema) - logo = fields.Nested(ImageRefSchema) + location = fields.Nested(LocationSchema) + logo = fields.Nested(ImageSchema) organization = fields.Nested(OrganizationRefSchema, attribute="adminunit") diff --git a/project/api/place/resources.py b/project/api/place/resources.py index 912151b..1573dbc 100644 --- a/project/api/place/resources.py +++ b/project/api/place/resources.py @@ -1,8 +1,19 @@ from project.api import add_api_resource -from flask_apispec import marshal_with, doc +from flask import make_response +from flask_apispec import marshal_with, doc, use_kwargs from project.api.resources import BaseResource -from project.api.place.schemas import PlaceSchema +from project.api.place.schemas import ( + PlaceSchema, + PlacePostRequestSchema, + PlacePostRequestLoadSchema, + PlacePatchRequestSchema, + PlacePatchRequestLoadSchema, +) from project.models import EventPlace +from project.oauth2 import require_oauth +from authlib.integrations.flask_oauth2 import current_token +from project import db +from project.access import access_or_401, login_api_user_or_401 class PlaceResource(BaseResource): @@ -11,5 +22,54 @@ class PlaceResource(BaseResource): def get(self, id): return EventPlace.query.get_or_404(id) + @doc( + summary="Update place", tags=["Places"], security=[{"oauth2": ["place:write"]}] + ) + @use_kwargs(PlacePostRequestSchema, location="json") + @marshal_with(None, 204) + @require_oauth("place:write") + def put(self, id, **kwargs): + login_api_user_or_401(current_token.user) + place = EventPlace.query.get_or_404(id) + access_or_401(place.adminunit, "place:update") + + place = PlacePostRequestLoadSchema().load( + kwargs, session=db.session, instance=place + ) + db.session.commit() + + return make_response("", 204) + + @doc(summary="Patch place", tags=["Places"], security=[{"oauth2": ["place:write"]}]) + @use_kwargs(PlacePatchRequestSchema, location="json") + @marshal_with(None, 204) + @require_oauth("place:write") + def patch(self, id, **kwargs): + login_api_user_or_401(current_token.user) + place = EventPlace.query.get_or_404(id) + access_or_401(place.adminunit, "place:update") + + place = PlacePatchRequestLoadSchema().load( + kwargs, session=db.session, instance=place + ) + db.session.commit() + + return make_response("", 204) + + @doc( + summary="Delete place", tags=["Places"], security=[{"oauth2": ["place:write"]}] + ) + @marshal_with(None, 204) + @require_oauth("place:write") + def delete(self, id): + login_api_user_or_401(current_token.user) + place = EventPlace.query.get_or_404(id) + access_or_401(place.adminunit, "place:delete") + + db.session.delete(place) + db.session.commit() + + return make_response("", 204) + add_api_resource(PlaceResource, "/places/", "api_v1_place") diff --git a/project/api/place/schemas.py b/project/api/place/schemas.py index a862374..c8dc4b4 100644 --- a/project/api/place/schemas.py +++ b/project/api/place/schemas.py @@ -1,8 +1,15 @@ -from marshmallow import fields -from project import marshmallow +from marshmallow import fields, validate +from project.api import marshmallow from project.models import EventPlace -from project.api.image.schemas import ImageRefSchema -from project.api.location.schemas import LocationRefSchema, LocationSearchItemSchema +from project.api.image.schemas import ImageSchema +from project.api.location.schemas import ( + LocationSchema, + LocationSearchItemSchema, + LocationPostRequestSchema, + LocationPostRequestLoadSchema, + LocationPatchRequestSchema, + LocationPatchRequestLoadSchema, +) from project.api.organization.schemas import OrganizationRefSchema from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema @@ -23,8 +30,8 @@ class PlaceBaseSchema(PlaceIdSchema): class PlaceSchema(PlaceBaseSchema): - location = fields.Nested(LocationRefSchema) - photo = fields.Nested(ImageRefSchema) + location = fields.Nested(LocationSchema) + photo = fields.Nested(ImageSchema) organization = fields.Nested(OrganizationRefSchema, attribute="adminunit") @@ -55,3 +62,41 @@ class PlaceListResponseSchema(PaginationResponseSchema): items = fields.List( fields.Nested(PlaceRefSchema), metadata={"description": "Places"} ) + + +class PlacePostRequestSchema(marshmallow.SQLAlchemySchema): + class Meta: + model = EventPlace + + name = fields.Str(required=True, validate=validate.Length(min=3, max=255)) + url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None) + description = fields.Str(missing=None) + location = fields.Nested(LocationPostRequestSchema, missing=None) + + +class PlacePostRequestLoadSchema(PlacePostRequestSchema): + class Meta: + model = EventPlace + load_instance = True + + location = fields.Nested(LocationPostRequestLoadSchema, missing=None) + + +class PlacePatchRequestSchema(marshmallow.SQLAlchemySchema): + class Meta: + model = EventPlace + + name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True) + url = fields.Str( + validate=[validate.URL(), validate.Length(max=255)], allow_none=True + ) + description = fields.Str(allow_none=True) + location = fields.Nested(LocationPatchRequestSchema, allow_none=True) + + +class PlacePatchRequestLoadSchema(PlacePatchRequestSchema): + class Meta: + model = EventPlace + load_instance = True + + location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True) diff --git a/project/api/resources.py b/project/api/resources.py index 49cba35..add8bef 100644 --- a/project/api/resources.py +++ b/project/api/resources.py @@ -1,6 +1,8 @@ from flask import request +from flask_apispec import marshal_with from flask_apispec.views import MethodResource from functools import wraps +from project.api.schemas import ErrorResponseSchema, UnprocessableEntityResponseSchema def etag_cache(func): @@ -13,5 +15,7 @@ def etag_cache(func): return wrapper +@marshal_with(ErrorResponseSchema, 400, "Bad Request") +@marshal_with(UnprocessableEntityResponseSchema, 422, "Unprocessable Entity") class BaseResource(MethodResource): decorators = [etag_cache] diff --git a/project/api/schemas.py b/project/api/schemas.py index 7b554a2..12a4a2d 100644 --- a/project/api/schemas.py +++ b/project/api/schemas.py @@ -1,7 +1,21 @@ -from project import marshmallow +from project.api import marshmallow from marshmallow import fields, validate +class ErrorResponseSchema(marshmallow.Schema): + name = fields.Str() + message = fields.Str() + + +class UnprocessableEntityErrorSchema(marshmallow.Schema): + field = fields.Str() + message = fields.Str() + + +class UnprocessableEntityResponseSchema(ErrorResponseSchema): + errors = fields.List(fields.Nested(UnprocessableEntityErrorSchema)) + + class PaginationRequestSchema(marshmallow.Schema): page = fields.Integer( required=False, diff --git a/project/forms/oauth2_client.py b/project/forms/oauth2_client.py new file mode 100644 index 0000000..8f4cec5 --- /dev/null +++ b/project/forms/oauth2_client.py @@ -0,0 +1,92 @@ +from flask_wtf import FlaskForm +from flask_babelex import lazy_gettext +from wtforms import StringField, TextAreaField, SubmitField, SelectField +from wtforms.validators import Optional, DataRequired +from project.forms.widgets import MultiCheckboxField +from project.api import scopes +from project.utils import split_by_crlf +import os + + +class BaseOAuth2ClientForm(FlaskForm): + client_name = StringField(lazy_gettext("Client name"), validators=[DataRequired()]) + redirect_uris = TextAreaField( + lazy_gettext("Redirect URIs"), validators=[Optional()] + ) + grant_types = MultiCheckboxField( + lazy_gettext("Grant types"), + validators=[DataRequired()], + choices=[ + ("authorization_code", lazy_gettext("Authorization Code")), + ("refresh_token", lazy_gettext("Refresh Token")), + ], + default=["authorization_code", "refresh_token"], + ) + response_types = MultiCheckboxField( + lazy_gettext("Response types"), + validators=[DataRequired()], + choices=[ + ("code", "code"), + ], + default=["code"], + ) + scope = MultiCheckboxField( + lazy_gettext("Scopes"), + validators=[DataRequired()], + choices=[(k, k) for k, v in scopes.items()], + ) + token_endpoint_auth_method = SelectField( + lazy_gettext("Token endpoint auth method"), + validators=[DataRequired()], + choices=[ + ("client_secret_post", lazy_gettext("Client secret post")), + ("client_secret_basic", lazy_gettext("Client secret basic")), + ], + ) + + submit = SubmitField(lazy_gettext("Save")) + + def populate_obj(self, obj): + meta_keys = [ + "client_name", + "client_uri", + "grant_types", + "redirect_uris", + "response_types", + "scope", + "token_endpoint_auth_method", + ] + metadata = dict() + for name, field in self._fields.items(): + if name in meta_keys: + if name == "redirect_uris": + metadata[name] = split_by_crlf(field.data) + elif name == "scope": + metadata[name] = " ".join(field.data) + else: + metadata[name] = field.data + else: + field.populate_obj(obj, name) + obj.set_client_metadata(metadata) + + def process(self, formdata=None, obj=None, data=None, **kwargs): + super().process(formdata, obj, data, **kwargs) + + if not obj: + return + + self.redirect_uris.data = os.linesep.join(obj.redirect_uris) + self.scope.data = obj.scope.split(" ") + + +class CreateOAuth2ClientForm(BaseOAuth2ClientForm): + pass + + +class UpdateOAuth2ClientForm(BaseOAuth2ClientForm): + pass + + +class DeleteOAuth2ClientForm(FlaskForm): + submit = SubmitField(lazy_gettext("Delete OAuth2 client")) + name = StringField(lazy_gettext("Name"), validators=[DataRequired()]) diff --git a/project/forms/oauth2_token.py b/project/forms/oauth2_token.py new file mode 100644 index 0000000..5562294 --- /dev/null +++ b/project/forms/oauth2_token.py @@ -0,0 +1,7 @@ +from flask_wtf import FlaskForm +from flask_babelex import lazy_gettext +from wtforms import SubmitField + + +class RevokeOAuth2TokenForm(FlaskForm): + submit = SubmitField(lazy_gettext("Revoke OAuth2 token")) diff --git a/project/forms/security.py b/project/forms/security.py index 503123e..a665240 100644 --- a/project/forms/security.py +++ b/project/forms/security.py @@ -1,7 +1,9 @@ from flask_security.forms import RegisterForm, EqualTo, get_form_field_label -from wtforms import BooleanField, PasswordField +from wtforms import BooleanField, PasswordField, SubmitField from wtforms.validators import DataRequired from project.forms.common import get_accept_tos_markup +from flask_wtf import FlaskForm +from flask_babelex import lazy_gettext class ExtendedRegisterForm(RegisterForm): @@ -20,3 +22,8 @@ class ExtendedRegisterForm(RegisterForm): def __init__(self, *args, **kwargs): super(ExtendedRegisterForm, self).__init__(*args, **kwargs) self._fields["accept_tos"].label.text = get_accept_tos_markup() + + +class AuthorizeForm(FlaskForm): + allow = SubmitField(lazy_gettext("Allow")) + deny = SubmitField(lazy_gettext("Deny")) diff --git a/project/gsevpt.sqlite b/project/gsevpt.sqlite deleted file mode 100644 index 1601e2e..0000000 Binary files a/project/gsevpt.sqlite and /dev/null differ diff --git a/project/gsevpt.sqlite3 b/project/gsevpt.sqlite3 deleted file mode 100644 index d5bce6e..0000000 Binary files a/project/gsevpt.sqlite3 and /dev/null differ diff --git a/project/i10n.py b/project/i10n.py index fcf85ba..92ddd7f 100644 --- a/project/i10n.py +++ b/project/i10n.py @@ -39,3 +39,8 @@ def print_dynamic_texts(): gettext("EventReviewStatus.inbox") gettext("EventReviewStatus.verified") gettext("EventReviewStatus.rejected") + gettext("read") + gettext("write") + gettext("Event") + gettext("Organizer") + gettext("Place") diff --git a/project/init_data.py b/project/init_data.py index 3d8609c..28c6500 100644 --- a/project/init_data.py +++ b/project/init_data.py @@ -1,8 +1,27 @@ from project import app, db +from project.api import api_docs, scopes from project.services.user import upsert_user_role from project.services.admin_unit import upsert_admin_unit_member_role from project.services.event import upsert_event_category from project.models import Location +from flask import url_for +from apispec.exceptions import DuplicateComponentNameError + + +@app.before_first_request +def add_oauth2_scheme(): + oauth2_scheme = { + "type": "oauth2", + "authorizationUrl": url_for("authorize", _external=True), + "tokenUrl": url_for("issue_token", _external=True), + "flow": "accessCode", + "scopes": scopes, + } + + try: + api_docs.spec.components.security_scheme("oauth2", oauth2_scheme) + except DuplicateComponentNameError: # pragma: no cover + pass @app.before_first_request @@ -37,12 +56,23 @@ def create_initial_data(): "reference_request:delete", "reference_request:verify", ] + early_adopter_permissions = [ + "oauth2_client:create", + "oauth2_client:read", + "oauth2_client:update", + "oauth2_client:delete", + "oauth2_token:create", + "oauth2_token:read", + "oauth2_token:update", + "oauth2_token:delete", + ] upsert_admin_unit_member_role("admin", "Administrator", admin_permissions) upsert_admin_unit_member_role("event_verifier", "Event expert", event_permissions) upsert_user_role("admin", "Administrator", admin_permissions) upsert_user_role("event_verifier", "Event expert", event_permissions) + upsert_user_role("early_adopter", "Early Adopter", early_adopter_permissions) Location.update_coordinates() diff --git a/project/jinja_filters.py b/project/jinja_filters.py index c901a7f..258462a 100644 --- a/project/jinja_filters.py +++ b/project/jinja_filters.py @@ -1,5 +1,9 @@ from project import app -from project.utils import get_event_category_name, get_localized_enum_name +from project.utils import ( + get_event_category_name, + get_localized_enum_name, + get_localized_scope, +) from urllib.parse import quote_plus import os @@ -8,10 +12,16 @@ def env_override(value, key): return os.getenv(key, value) +def is_list(value): + return isinstance(value, list) + + app.jinja_env.filters["event_category_name"] = lambda u: get_event_category_name(u) app.jinja_env.filters["loc_enum"] = lambda u: get_localized_enum_name(u) +app.jinja_env.filters["loc_scope"] = lambda s: get_localized_scope(s) app.jinja_env.filters["env_override"] = env_override app.jinja_env.filters["quote_plus"] = lambda u: quote_plus(u) +app.jinja_env.filters["is_list"] = is_list @app.context_processor diff --git a/project/models.py b/project/models.py index 2eaeb46..19e9280 100644 --- a/project/models.py +++ b/project/models.py @@ -1,7 +1,7 @@ from project import db from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import relationship, backref, deferred +from sqlalchemy.orm import relationship, backref, deferred, object_session from sqlalchemy.schema import CheckConstraint from sqlalchemy.event import listens_for from sqlalchemy import ( @@ -24,6 +24,12 @@ import datetime from project.dbtypes import IntegerEnum from geoalchemy2 import Geometry from sqlalchemy import and_ +from authlib.integrations.sqla_oauth2 import ( + OAuth2ClientMixin, + OAuth2AuthorizationCodeMixin, + OAuth2TokenMixin, +) +import time # Base @@ -156,6 +162,12 @@ class User(db.Model, UserMixin): "Role", secondary="roles_users", backref=backref("users", lazy="dynamic") ) + def get_user_id(self): + return self.id + + +# OAuth Consumer: Wenn wir OAuth consumen und sich ein Nutzer per Google oder Facebook anmelden möchte + class OAuth(OAuthConsumerMixin, db.Model): provider_user_id = Column(String(256), unique=True, nullable=False) @@ -163,6 +175,51 @@ class OAuth(OAuthConsumerMixin, db.Model): user = db.relationship("User") +# OAuth Server: Wir bieten an, dass sich ein Nutzer per OAuth2 auf unserer Seite anmeldet + + +class OAuth2Client(db.Model, OAuth2ClientMixin): + __tablename__ = "oauth2_client" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + def check_redirect_uri(self, redirect_uri): + return True + + +class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): + __tablename__ = "oauth2_code" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + +class OAuth2Token(db.Model, OAuth2TokenMixin): + __tablename__ = "oauth2_token" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE")) + user = db.relationship("User") + + @property + def client(self): + return ( + object_session(self) + .query(OAuth2Client) + .filter(OAuth2Client.client_id == self.client_id) + .first() + ) + + def is_refresh_token_active(self): + if self.revoked: + return False + expires_at = self.issued_at + self.expires_in * 2 + return expires_at >= time.time() + + # Admin Unit @@ -298,6 +355,7 @@ def update_location_coordinate(mapper, connect, self): # Events class EventPlace(db.Model, TrackableMixin): __tablename__ = "eventplace" + __table_args__ = (UniqueConstraint("name", "admin_unit_id"),) id = Column(Integer(), primary_key=True) name = Column(Unicode(255), nullable=False) location_id = db.Column(db.Integer, db.ForeignKey("location.id")) diff --git a/project/oauth2.py b/project/oauth2.py new file mode 100644 index 0000000..6189cbf --- /dev/null +++ b/project/oauth2.py @@ -0,0 +1,114 @@ +from authlib.integrations.flask_oauth2 import ( + AuthorizationServer, + ResourceProtector, +) +from authlib.integrations.sqla_oauth2 import ( + create_query_client_func, + create_save_token_func, + create_bearer_token_validator, + create_query_token_func, +) +from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc7636 import CodeChallenge +from project import db +from project.models import User, OAuth2Client, OAuth2AuthorizationCode, OAuth2Token + + +class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + "client_secret_basic", + "client_secret_post", + "none", + ] + + def save_authorization_code(self, code, request): + code_challenge = request.data.get("code_challenge") + code_challenge_method = request.data.get("code_challenge_method") + auth_code = OAuth2AuthorizationCode( + code=code, + client_id=request.client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=request.user.id, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + ) + db.session.add(auth_code) + db.session.commit() + return auth_code + + def query_authorization_code(self, code, client): + auth_code = OAuth2AuthorizationCode.query.filter_by( + code=code, client_id=client.client_id + ).first() + if auth_code and not auth_code.is_expired(): + return auth_code + + def delete_authorization_code(self, authorization_code): + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return User.query.get(authorization_code.user_id) + + +class RefreshTokenGrant(grants.RefreshTokenGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] + + def authenticate_refresh_token(self, refresh_token): + token = OAuth2Token.query.filter_by(refresh_token=refresh_token).first() + if token and token.is_refresh_token_active(): + return token + + def authenticate_user(self, credential): + return User.query.get(credential.user_id) + + def revoke_old_credential(self, credential): + credential.revoked = True + db.session.add(credential) + db.session.commit() + + +query_client = create_query_client_func(db.session, OAuth2Client) +save_token = create_save_token_func(db.session, OAuth2Token) +authorization = AuthorizationServer( + query_client=query_client, + save_token=save_token, +) +require_oauth = ResourceProtector() + + +def create_revocation_endpoint(session, token_model): + from authlib.oauth2.rfc7009 import RevocationEndpoint + + query_token = create_query_token_func(session, token_model) + + class _RevocationEndpoint(RevocationEndpoint): + CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"] + + def query_token(self, token, token_type_hint, client): + return query_token(token, token_type_hint, client) + + def revoke_token(self, token): + token.revoked = True + session.add(token) + session.commit() + + return _RevocationEndpoint + + +def config_oauth(app): + app.config["OAUTH2_REFRESH_TOKEN_GENERATOR"] = True + authorization.init_app(app) + + # support grants + authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)]) + authorization.register_grant(RefreshTokenGrant) + + # support revocation + revocation_cls = create_revocation_endpoint(db.session, OAuth2Token) + authorization.register_endpoint(revocation_cls) + + # protect resource + bearer_cls = create_bearer_token_validator(db.session, OAuth2Token) + require_oauth.register_token_validator(bearer_cls()) diff --git a/project/services/oauth2_client.py b/project/services/oauth2_client.py new file mode 100644 index 0000000..539a781 --- /dev/null +++ b/project/services/oauth2_client.py @@ -0,0 +1,12 @@ +from project.models import OAuth2Client +import time +from werkzeug.security import gen_salt + + +def complete_oauth2_client(oauth2_client: OAuth2Client) -> None: + if not oauth2_client.id: + oauth2_client.client_id = gen_salt(24) + oauth2_client.client_id_issued_at = int(time.time()) + + if oauth2_client.client_secret is None: + oauth2_client.client_secret = gen_salt(48) diff --git a/project/services/user.py b/project/services/user.py index 9467329..38af892 100644 --- a/project/services/user.py +++ b/project/services/user.py @@ -17,7 +17,7 @@ def add_roles_to_user(email, roles): def add_admin_roles_to_user(email): - add_roles_to_user(email, ["admin", "event_verifier"]) + add_roles_to_user(email, ["admin", "event_verifier", "early_adopter"]) def remove_roles_from_user(email, roles): diff --git a/project/templates/_macros.html b/project/templates/_macros.html index 0c689d5..1d1c44d 100644 --- a/project/templates/_macros.html +++ b/project/templates/_macros.html @@ -160,6 +160,21 @@ {% endif %} {% endmacro %} +{% macro render_kv_begin() %} +
+{% endmacro %} + +{% macro render_kv_end() %} +
+{% endmacro %} + +{% macro render_kv_prop(prop, label_key = None) %} +{% if prop %} +
{{ _(label_key) }}
+
{% if prop|is_list %}{{ prop|join(', ') }}{% else %}{{ prop }}{% endif %}
+{% endif %} +{% endmacro %} + {% macro render_string_prop(prop, icon = None, label_key = None) %} {% if prop %}
diff --git a/project/templates/oauth2_client/create.html b/project/templates/oauth2_client/create.html new file mode 100644 index 0000000..29c946d --- /dev/null +++ b/project/templates/oauth2_client/create.html @@ -0,0 +1,26 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_field_with_errors, render_field %} +{% block title %} +{{ _('Create OAuth2 client') }} +{% endblock %} +{% block content %} + +

{{ _('Create OAuth2 client') }}

+ +
+ {{ form.hidden_tag() }} + +
+
+ {{ render_field_with_errors(form.client_name) }} + {{ render_field_with_errors(form.grant_types, ri="multicheckbox") }} + {{ render_field_with_errors(form.response_types, ri="multicheckbox") }} + {{ render_field_with_errors(form.scope, ri="multicheckbox") }} + {{ render_field_with_errors(form.token_endpoint_auth_method) }} + {{ render_field_with_errors(form.redirect_uris) }} + + + {{ render_field(form.submit) }} + + +{% endblock %} diff --git a/project/templates/oauth2_client/delete.html b/project/templates/oauth2_client/delete.html new file mode 100644 index 0000000..890b13b --- /dev/null +++ b/project/templates/oauth2_client/delete.html @@ -0,0 +1,24 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_field_with_errors, render_field %} + +{% block content %} + +

{{ _('Delete OAuth2 client') }} "{{ oauth2_client.client_name }}"

+ +
+ {{ form.hidden_tag() }} + +
+
+ {{ _('OAuth2 client') }} +
+
+ {{ render_field_with_errors(form.name) }} +
+
+ + {{ render_field(form.submit) }} + +
+ +{% endblock %} diff --git a/project/templates/oauth2_client/list.html b/project/templates/oauth2_client/list.html new file mode 100644 index 0000000..218b451 --- /dev/null +++ b/project/templates/oauth2_client/list.html @@ -0,0 +1,44 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_pagination %} +{% block title %} +{{ _('OAuth2 clients') }} +{% endblock %} +{% block content %} + + + +{% if current_user.has_permission('oauth2_client:create') %} + +{% endif %} + +
+ + + + + + + + + + {% for oauth2_client in oauth2_clients %} + + + + + + {% endfor %} + +
{{ _('Name') }}
{{ oauth2_client.client_name }}{{ _('Edit') }}{{ _('Delete') }}
+
+ +
{{ render_pagination(pagination) }}
+ +{% endblock %} \ No newline at end of file diff --git a/project/templates/oauth2_client/read.html b/project/templates/oauth2_client/read.html new file mode 100644 index 0000000..daec152 --- /dev/null +++ b/project/templates/oauth2_client/read.html @@ -0,0 +1,34 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_kv_begin, render_kv_end, render_kv_prop %} +{% block title %} +{{ oauth2_client.client_name }} +{% endblock %} +{% block header %} + +{% endblock %} +{% block content %} + + + +
+ {{ render_kv_begin() }} + {{ render_kv_prop(oauth2_client.client_id, 'Client ID') }} + {{ render_kv_prop(oauth2_client.client_secret, 'Client secret') }} + {{ render_kv_prop(oauth2_client.client_uri, 'Client URI') }} + {{ render_kv_prop(oauth2_client.grant_types, 'Grant types') }} + {{ render_kv_prop(oauth2_client.redirect_uris, 'Redirect URIs') }} + {{ render_kv_prop(oauth2_client.response_types, 'Response types') }} + {{ render_kv_prop(oauth2_client.scope, 'Scope') }} + {{ render_kv_prop(oauth2_client.token_endpoint_auth_method, 'Token endpoint auth method') }} + {{ render_kv_end() }} +
+ +{% endblock %} \ No newline at end of file diff --git a/project/templates/oauth2_client/update.html b/project/templates/oauth2_client/update.html new file mode 100644 index 0000000..f6b8b3a --- /dev/null +++ b/project/templates/oauth2_client/update.html @@ -0,0 +1,27 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_field_with_errors, render_field %} +{% block title %} +{{ _('Update OAuth2 client') }} +{% endblock %} +{% block content %} + +

{{ _('Update OAuth2 client') }}

+ +
+ {{ form.hidden_tag() }} + +
+
+ {{ render_field_with_errors(form.client_name) }} + {{ render_field_with_errors(form.grant_types, ri="multicheckbox") }} + {{ render_field_with_errors(form.response_types, ri="multicheckbox") }} + {{ render_field_with_errors(form.scope, ri="multicheckbox") }} + {{ render_field_with_errors(form.token_endpoint_auth_method) }} + {{ render_field_with_errors(form.redirect_uris) }} +
+
+ + {{ render_field(form.submit) }} +
+ +{% endblock %} diff --git a/project/templates/oauth2_token/list.html b/project/templates/oauth2_token/list.html new file mode 100644 index 0000000..9f669c2 --- /dev/null +++ b/project/templates/oauth2_token/list.html @@ -0,0 +1,40 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_pagination %} +{% block title %} +{{ _('OAuth2 tokens') }} +{% endblock %} +{% block content %} + + + +
+ + + + + + + + + + + {% for oauth2_token in oauth2_tokens %} + + + + + + + {% endfor %} + +
{{ _('Client') }}{{ _('Scopes') }}{{ _('Status') }}
{{ oauth2_token.client.client_name }}{{ oauth2_token.client.scope }}{% if oauth2_token.revoked %}{{ _('Revoked') }}{% else %}{{ _('Active') }}{% endif %}{% if not oauth2_token.revoked %}{{ _('Revoke') }}{% endif %}
+
+ +
{{ render_pagination(pagination) }}
+ +{% endblock %} \ No newline at end of file diff --git a/project/templates/oauth2_token/revoke.html b/project/templates/oauth2_token/revoke.html new file mode 100644 index 0000000..0ac45df --- /dev/null +++ b/project/templates/oauth2_token/revoke.html @@ -0,0 +1,16 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_field_with_errors, render_field %} + +{% block content %} + +

{{ _('Revoke OAuth2 token') }}

+

{{ oauth2_token.client.client_name }}

+ +
+ {{ form.hidden_tag() }} + + {{ render_field(form.submit) }} + +
+ +{% endblock %} diff --git a/project/templates/profile.html b/project/templates/profile.html index 56899bd..a9af017 100644 --- a/project/templates/profile.html +++ b/project/templates/profile.html @@ -8,7 +8,23 @@

{{ current_user.email }}

{{ _('Profile') }}

-

{{ _fsdomain('Change password') }}

+ +
+ + {{ _fsdomain('Change password') }} + + + {% if current_user.has_permission('oauth2_client:read') %} + + {{ _('OAuth2 clients') }} + + + {% endif %} + + {{ _('OAuth2 tokens') }} + + +
{% if invitations %}

{{ _('Invitations') }}

diff --git a/project/templates/security/authorize.html b/project/templates/security/authorize.html new file mode 100644 index 0000000..f01c3e1 --- /dev/null +++ b/project/templates/security/authorize.html @@ -0,0 +1,40 @@ +{% extends "layout.html" %} +{% from "_macros.html" import render_field %} + +{% block content %} + +
+ +
+
+
{{ _('"%(client_name)s" wants to access your account', client_name=grant.client.client_name) }}
+ +

{{ user.email }}

+ +

{{ _('This will allow "%(client_name)s" to:', client_name=grant.client.client_name) }}

+ +
    + {% for key, value in scopes.items() %} +
  • {{ value }}
  • + {% endfor %} +
+ +
+ {{ form.hidden_tag() }} + +
+
+ {{ render_field(form.allow, class="btn btn-success mx-auto") }} +
+
+ {{ render_field(form.deny, class="btn btn-light mx-auto") }} +
+
+
+ +
+
+ +
+ +{% endblock %} diff --git a/project/templates/security/login_user.html b/project/templates/security/login_user.html index 0b2e46c..3deaf4a 100644 --- a/project/templates/security/login_user.html +++ b/project/templates/security/login_user.html @@ -4,7 +4,8 @@ {% block content %}

{{ _fsdomain('Login') }}

-
+{% set next = request.args['next'] if 'next' in request.args and 'authorize' in request.args['next'] else 'manage' %} + {{ login_user_form.hidden_tag() }} {{ render_field_with_errors(login_user_form.email) }} {{ render_field_with_errors(login_user_form.password) }} diff --git a/project/utils.py b/project/utils.py index aea28ef..80628b6 100644 --- a/project/utils.py +++ b/project/utils.py @@ -11,9 +11,20 @@ def get_localized_enum_name(enum): return lazy_gettext(enum.__class__.__name__ + "." + enum.name) +def get_localized_scope(scope: str) -> str: + type_name, action = scope.split(":") + loc_lazy_gettext = lazy_gettext(type_name.capitalize()) + loc_action = lazy_gettext(action) + return f"{loc_lazy_gettext} ({loc_action})" + + def make_dir(path): try: original_umask = os.umask(0) pathlib.Path(path).mkdir(parents=True, exist_ok=True) finally: os.umask(original_umask) + + +def split_by_crlf(s): + return [v for v in s.splitlines() if v] diff --git a/project/views/oauth.py b/project/views/oauth.py new file mode 100644 index 0000000..17ab000 --- /dev/null +++ b/project/views/oauth.py @@ -0,0 +1,53 @@ +from authlib.oauth2 import OAuth2Error +from project import app +from project.api import scopes +from flask_security import current_user +from flask import redirect, request, url_for, render_template +from project.forms.security import AuthorizeForm +from project.oauth2 import authorization + + +@app.route("/oauth/authorize", methods=["GET", "POST"]) +def authorize(): + user = current_user + + if not user or not user.is_authenticated: + return redirect(url_for("security.login", next=request.url)) + + form = AuthorizeForm() + + if form.validate_on_submit(): + grant_user = user if form.allow.data else None + return authorization.create_authorization_response(grant_user=grant_user) + else: + try: + grant = authorization.validate_consent_request(end_user=user) + except OAuth2Error as error: + return error.error + + grant_scopes = grant.request.scope.split(" ") + filtered_scopes = {k: scopes[k] for k in grant_scopes} + return render_template( + "security/authorize.html", + form=form, + scopes=filtered_scopes, + user=user, + grant=grant, + ) + + +@app.route("/oauth/token", methods=["POST"]) +def issue_token(): + return authorization.create_token_response() + + +@app.route("/oauth/revoke", methods=["POST"]) +def revoke_token(): + return authorization.create_endpoint_response("revocation") + + +@app.route("/oauth2-redirect.html") +def swagger_oauth2_redirect(): + return redirect( + url_for("flask-apispec.static", filename="oauth2-redirect.html", **request.args) + ) diff --git a/project/views/oauth2_client.py b/project/views/oauth2_client.py new file mode 100644 index 0000000..d415566 --- /dev/null +++ b/project/views/oauth2_client.py @@ -0,0 +1,125 @@ +from project import app, db +from flask import render_template, redirect, flash, url_for +from flask_babelex import gettext +from flask_security import permissions_required, current_user +from project.models import OAuth2Client +from project.views.utils import ( + get_pagination_urls, + handleSqlError, + flash_errors, + non_match_for_deletion, +) +from project.forms.oauth2_client import ( + CreateOAuth2ClientForm, + UpdateOAuth2ClientForm, + DeleteOAuth2ClientForm, +) +from project.services.oauth2_client import complete_oauth2_client +from sqlalchemy.exc import SQLAlchemyError +from project.access import owner_access_or_401 + + +@app.route("/oauth2_client/create", methods=("GET", "POST")) +@permissions_required("oauth2_client:create") +def oauth2_client_create(): + form = CreateOAuth2ClientForm() + + if form.validate_on_submit(): + oauth2_client = OAuth2Client() + form.populate_obj(oauth2_client) + oauth2_client.user_id = current_user.id + complete_oauth2_client(oauth2_client) + + try: + db.session.add(oauth2_client) + db.session.commit() + flash(gettext("OAuth2 client successfully created"), "success") + return redirect(url_for("oauth2_client", id=oauth2_client.id)) + except SQLAlchemyError as e: + db.session.rollback() + flash(handleSqlError(e), "danger") + else: + flash_errors(form) + + return render_template("oauth2_client/create.html", form=form) + + +@app.route("/oauth2_client//update", methods=("GET", "POST")) +@permissions_required("oauth2_client:update") +def oauth2_client_update(id): + oauth2_client = OAuth2Client.query.get_or_404(id) + owner_access_or_401(oauth2_client.user_id) + + form = UpdateOAuth2ClientForm(obj=oauth2_client) + + if form.validate_on_submit(): + form.populate_obj(oauth2_client) + complete_oauth2_client(oauth2_client) + + try: + db.session.commit() + flash(gettext("OAuth2 client successfully updated"), "success") + return redirect(url_for("oauth2_client", id=oauth2_client.id)) + except SQLAlchemyError as e: + db.session.rollback() + flash(handleSqlError(e), "danger") + else: + flash_errors(form) + + return render_template( + "oauth2_client/update.html", form=form, oauth2_client=oauth2_client + ) + + +@app.route("/oauth2_client//delete", methods=("GET", "POST")) +@permissions_required("oauth2_client:delete") +def oauth2_client_delete(id): + oauth2_client = OAuth2Client.query.get_or_404(id) + owner_access_or_401(oauth2_client.user_id) + + form = DeleteOAuth2ClientForm() + + if form.validate_on_submit(): + if non_match_for_deletion(form.name.data, oauth2_client.client_name): + flash(gettext("Entered name does not match OAuth2 client name"), "danger") + else: + try: + db.session.delete(oauth2_client) + db.session.commit() + flash(gettext("OAuth2 client successfully deleted"), "success") + return redirect(url_for("oauth2_clients")) + except SQLAlchemyError as e: + db.session.rollback() + flash(handleSqlError(e), "danger") + else: + flash_errors(form) + + return render_template( + "oauth2_client/delete.html", form=form, oauth2_client=oauth2_client + ) + + +@app.route("/oauth2_client/") +@permissions_required("oauth2_client:read") +def oauth2_client(id): + oauth2_client = OAuth2Client.query.get_or_404(id) + owner_access_or_401(oauth2_client.user_id) + + return render_template( + "oauth2_client/read.html", + oauth2_client=oauth2_client, + ) + + +@app.route("/oauth2_clients") +@permissions_required("oauth2_client:read") +def oauth2_clients(): + oauth2_clients = OAuth2Client.query.filter( + OAuth2Client.user_id == current_user.id + ).paginate() + + return render_template( + "oauth2_client/list.html", + oauth2_clients=oauth2_clients.items, + pagination=get_pagination_urls(oauth2_clients), + ) diff --git a/project/views/oauth2_token.py b/project/views/oauth2_token.py new file mode 100644 index 0000000..6825c3d --- /dev/null +++ b/project/views/oauth2_token.py @@ -0,0 +1,53 @@ +from project import app, db +from flask import render_template, redirect, flash, url_for +from flask_babelex import gettext +from flask_security import current_user +from project.models import OAuth2Token +from project.views.utils import ( + get_pagination_urls, + handleSqlError, + flash_errors, +) +from project.forms.oauth2_token import RevokeOAuth2TokenForm +from sqlalchemy.exc import SQLAlchemyError +from project.access import owner_access_or_401 + + +@app.route("/oauth2_token//revoke", methods=("GET", "POST")) +def oauth2_token_revoke(id): + oauth2_token = OAuth2Token.query.get_or_404(id) + owner_access_or_401(oauth2_token.user_id) + + if oauth2_token.revoked: + return redirect(url_for("oauth2_tokens")) + + form = RevokeOAuth2TokenForm() + + if form.validate_on_submit(): + try: + oauth2_token.revoked = True + db.session.commit() + flash(gettext("OAuth2 token successfully revoked"), "success") + return redirect(url_for("oauth2_tokens")) + except SQLAlchemyError as e: + db.session.rollback() + flash(handleSqlError(e), "danger") + else: + flash_errors(form) + + return render_template( + "oauth2_token/revoke.html", form=form, oauth2_token=oauth2_token + ) + + +@app.route("/oauth2_tokens") +def oauth2_tokens(): + oauth2_tokens = OAuth2Token.query.filter( + OAuth2Token.user_id == current_user.id + ).paginate() + + return render_template( + "oauth2_token/list.html", + oauth2_tokens=oauth2_tokens.items, + pagination=get_pagination_urls(oauth2_tokens), + ) diff --git a/requirements.txt b/requirements.txt index e6f74d2..65a2a19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ apispec-webframeworks==0.5.2 appdirs==1.4.4 argh==0.26.2 attrs==20.3.0 +Authlib==0.15.3 Babel==2.9.0 bcrypt==3.2.0 beautifulsoup4==4.9.3 @@ -18,6 +19,7 @@ click==7.1.2 colour==0.1.5 coverage==5.3 coveralls==2.2.0 +cryptography==3.3.1 distlib==0.3.1 dnspython==2.0.0 docopt==0.6.2 diff --git a/tests/api/test___init__.py b/tests/api/test___init__.py new file mode 100644 index 0000000..55bd8da --- /dev/null +++ b/tests/api/test___init__.py @@ -0,0 +1,55 @@ +import pytest +from project.api import RestApi + + +class Psycog2Error(object): + def __init__(self, pgcode): + self.pgcode = pgcode + + +def test_handle_error_unique(app): + from sqlalchemy.exc import IntegrityError + from psycopg2.errorcodes import UNIQUE_VIOLATION + + orig = Psycog2Error(UNIQUE_VIOLATION) + error = IntegrityError("Select", list(), orig) + + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 400 + assert data["name"] == "Unique Violation" + + +def test_handle_error_httpException(app): + from werkzeug.exceptions import InternalServerError + + error = InternalServerError() + + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 500 + + +def test_handle_error_unprocessableEntity(app): + from werkzeug.exceptions import UnprocessableEntity + from marshmallow import ValidationError + + args = {"name": ["Required"]} + validation_error = ValidationError(args) + + error = UnprocessableEntity() + error.exc = validation_error + + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 422 + assert data["errors"][0]["field"] == "name" + assert data["errors"][0]["message"] == "Required" + + +def test_handle_error_unspecificRaises(app): + error = Exception() + api = RestApi(app) + + with pytest.raises(Exception): + api.handle_error(error) diff --git a/tests/api/test_fields.py b/tests/api/test_fields.py new file mode 100644 index 0000000..8456657 --- /dev/null +++ b/tests/api/test_fields.py @@ -0,0 +1,47 @@ +def test_numeric_str_serialize(client, seeder, utils): + from project.api.location.schemas import LocationSchema + from project.models import Location + + location = Location() + location.street = "Markt 7" + location.postalCode = "38640" + location.city = "Goslar" + location.latitude = 51.9077888 + location.longitude = 10.4333312 + + schema = LocationSchema() + data = schema.dump(location) + + assert data["latitude"] == "51.9077888" + assert data["longitude"] == "10.4333312" + + +def test_numeric_str_deserialize(client, seeder, utils): + from project.api.location.schemas import LocationPostRequestLoadSchema + + data = { + "latitude": "51.9077888", + "longitude": "10.4333312", + } + + schema = LocationPostRequestLoadSchema() + location = schema.load(data) + + assert location.latitude == 51.9077888 + assert location.longitude == 10.4333312 + + +def test_numeric_str_deserialize_invalid(client, seeder, utils): + from project.api.location.schemas import LocationPostRequestLoadSchema + import pytest + from marshmallow import ValidationError + + data = { + "latitude": "Quatsch", + "longitude": "Quatsch", + } + + schema = LocationPostRequestLoadSchema() + + with pytest.raises(ValidationError): + schema.load(data) diff --git a/tests/api/test_image.py b/tests/api/test_image.py deleted file mode 100644 index e33af36..0000000 --- a/tests/api/test_image.py +++ /dev/null @@ -1,6 +0,0 @@ -def test_read(client, seeder, utils): - user_id, admin_unit_id = seeder.setup_base() - image_id = seeder.upsert_default_image() - - url = utils.get_url("api_v1_image", id=image_id) - utils.get_ok(url) diff --git a/tests/api/test_location.py b/tests/api/test_location.py deleted file mode 100644 index 241cc9b..0000000 --- a/tests/api/test_location.py +++ /dev/null @@ -1,20 +0,0 @@ -def test_read(client, app, db, seeder, utils): - user_id, admin_unit_id = seeder.setup_base() - - with app.app_context(): - from project.models import Location - - location = Location() - location.street = "Markt 7" - location.postalCode = "38640" - location.city = "Goslar" - location.latitude = 51.9077888 - location.longitude = 10.4333312 - - db.session.add(location) - db.session.commit() - location_id = location.id - - url = utils.get_url("api_v1_location", id=location_id) - response = utils.get_ok(url) - assert response.json["latitude"] == "51.9077888000000000" diff --git a/tests/api/test_organization.py b/tests/api/test_organization.py index 55d34fc..cc072a3 100644 --- a/tests/api/test_organization.py +++ b/tests/api/test_organization.py @@ -46,6 +46,25 @@ def test_places(client, seeder, utils): utils.get_ok(url) +def test_places_post(client, seeder, utils, app): + user_id, admin_unit_id = seeder.setup_api_access() + + url = utils.get_url("api_v1_organization_place_list", id=admin_unit_id, name="crew") + response = utils.post_json(url, {"name": "Neuer Ort"}) + utils.assert_response_created(response) + assert "id" in response.json + + with app.app_context(): + from project.models import EventPlace + + place = ( + EventPlace.query.filter(EventPlace.admin_unit_id == admin_unit_id) + .filter(EventPlace.name == "Neuer Ort") + .first() + ) + assert place is not None + + def test_references_incoming(client, seeder, utils): user_id, admin_unit_id = seeder.setup_base() ( diff --git a/tests/api/test_place.py b/tests/api/test_place.py index ce33bce..72de894 100644 --- a/tests/api/test_place.py +++ b/tests/api/test_place.py @@ -4,3 +4,65 @@ def test_read(client, app, db, seeder, utils): url = utils.get_url("api_v1_place", id=place_id) utils.get_ok(url) + + +def test_put(client, seeder, utils, app): + user_id, admin_unit_id = seeder.setup_api_access() + place_id = seeder.upsert_default_event_place(admin_unit_id) + + url = utils.get_url("api_v1_place", id=place_id) + response = utils.put_json(url, {"name": "Neuer Name"}) + utils.assert_response_no_content(response) + + with app.app_context(): + from project.models import EventPlace + + place = EventPlace.query.get(place_id) + assert place.name == "Neuer Name" + + +def test_put_nonActiveReturnsUnauthorized(client, seeder, db, utils, app): + user_id, admin_unit_id = seeder.setup_api_access() + place_id = seeder.upsert_default_event_place(admin_unit_id) + + with app.app_context(): + from project.models import User + + user = User.query.get(user_id) + user.active = False + db.session.commit() + + url = utils.get_url("api_v1_place", id=place_id) + response = utils.put_json(url, {"name": "Neuer Name"}) + utils.assert_response_unauthorized(response) + + +def test_patch(client, seeder, utils, app): + user_id, admin_unit_id = seeder.setup_api_access() + place_id = seeder.upsert_default_event_place(admin_unit_id) + + url = utils.get_url("api_v1_place", id=place_id) + response = utils.patch_json(url, {"description": "Klasse"}) + utils.assert_response_no_content(response) + + with app.app_context(): + from project.models import EventPlace + + place = EventPlace.query.get(place_id) + assert place.name == "Meine Crew" + assert place.description == "Klasse" + + +def test_delete(client, seeder, utils, app): + user_id, admin_unit_id = seeder.setup_api_access() + place_id = seeder.upsert_default_event_place(admin_unit_id) + + url = utils.get_url("api_v1_place", id=place_id) + response = utils.delete(url) + utils.assert_response_no_content(response) + + with app.app_context(): + from project.models import EventPlace + + place = EventPlace.query.get(place_id) + assert place is None diff --git a/tests/conftest.py b/tests/conftest.py index fd0d923..7572e7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ def pytest_generate_tests(metafunc): os.environ["DATABASE_URL"] = os.environ.get( "TEST_DATABASE_URL", "postgresql://postgres@localhost/gsevpt_tests" ) + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "1" @pytest.fixture diff --git a/tests/seeder.py b/tests/seeder.py index 080b73b..6218ac4 100644 --- a/tests/seeder.py +++ b/tests/seeder.py @@ -4,9 +4,10 @@ class Seeder(object): self._db = db self._utils = utils - def setup_base(self, admin=False): + def setup_base(self, admin=False, log_in=True): user_id = self.create_user(admin=admin) - self._utils.login() + if log_in: + self._utils.login() admin_unit_id = self.create_admin_unit(user_id) return (user_id, admin_unit_id) @@ -126,6 +127,45 @@ class Seeder(object): return organizer_id + def insert_default_oauth2_client(self, user_id): + from project.api import scope_list + from project.models import OAuth2Client + from project.services.oauth2_client import complete_oauth2_client + + with self._app.app_context(): + client = OAuth2Client() + client.user_id = user_id + complete_oauth2_client(client) + + metadata = dict() + metadata["client_name"] = "Mein Client" + metadata["scope"] = " ".join(scope_list) + metadata["grant_types"] = ["authorization_code", "refresh_token"] + metadata["response_types"] = ["code"] + metadata["token_endpoint_auth_method"] = "client_secret_post" + client.set_client_metadata(metadata) + + self._db.session.add(client) + self._db.session.commit() + client_id = client.id + + return client_id + + def setup_api_access(self): + user_id, admin_unit_id = self.setup_base(admin=True) + oauth2_client_id = self.insert_default_oauth2_client(user_id) + + with self._app.app_context(): + from project.models import OAuth2Client + + oauth2_client = OAuth2Client.query.get(oauth2_client_id) + client_id = oauth2_client.client_id + client_secret = oauth2_client.client_secret + scope = oauth2_client.scope + + self._utils.authorize(client_id, client_secret, scope) + return (user_id, admin_unit_id) + def create_event(self, admin_unit_id, recurrence_rule=None): from project.models import Event from project.services.event import insert_event, upsert_event_category diff --git a/tests/test_models.py b/tests/test_models.py index b3751aa..4d87ed1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,3 +21,16 @@ def test_event_category(client, app, db, seeder): db.session.commit() assert event.category is None + + +def test_oauth2_token(client, app): + from project.models import OAuth2Token + + token = OAuth2Token() + token.revoked = True + assert not token.is_refresh_token_active() + + token.revoked = False + token.issued_at = 0 + token.expires_in = 0 + assert not token.is_refresh_token_active() diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py new file mode 100644 index 0000000..a9ff3e6 --- /dev/null +++ b/tests/test_oauth2.py @@ -0,0 +1,7 @@ +def test_authorization_code(seeder): + user_id, admin_unit_id = seeder.setup_api_access() + + +def test_refresh_token(seeder, utils): + user_id, admin_unit_id = seeder.setup_api_access() + utils.refresh_token() diff --git a/tests/utils.py b/tests/utils.py index e405712..c7f45e2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,12 +2,17 @@ import re from flask import g, url_for from sqlalchemy.exc import IntegrityError from bs4 import BeautifulSoup +from urllib.parse import urlsplit, parse_qs class UtilActions(object): def __init__(self, client, app): self._client = client self._app = app + self._access_token = None + self._refresh_token = None + self._client_id = None + self._client_secret = None def register(self, email="test@test.de", password="MeinPasswortIstDasBeste"): response = self._client.get("/register") @@ -77,14 +82,56 @@ class UtilActions(object): form = Form(soup.find("form")) return form.fill(values) - def post_form(self, url, response, values: dict): - data = self.create_form_data(response, values) + def post_form_data(self, url, data: dict): return self._client.post(url, data=data) + def post_form(self, url, response, values: dict): + data = self.create_form_data(response, values) + return self.post_form_data(url, data=data) + + def get_headers(self): + headers = dict() + + if self._access_token: + headers["Authorization"] = f"Bearer {self._access_token}" + + return headers + + def log_request(self, url): + print(url) + + def log_json_request(self, url, data: dict): + self.log_request(url) + print(data) + + def log_response(self, response): + print(response.status_code) + print(response.data) + print(response.json) + def post_json(self, url, data: dict): - response = self._client.post(url, json=data) - assert response.content_type == "application/json" - return response.json + self.log_json_request(url, data) + response = self._client.post(url, json=data, headers=self.get_headers()) + self.log_response(response) + return response + + def put_json(self, url, data: dict): + self.log_json_request(url, data) + response = self._client.put(url, json=data, headers=self.get_headers()) + self.log_response(response) + return response + + def patch_json(self, url, data: dict): + self.log_json_request(url, data) + response = self._client.patch(url, json=data, headers=self.get_headers()) + self.log_response(response) + return response + + def delete(self, url): + self.log_request(url) + response = self._client.delete(url, headers=self.get_headers()) + self.log_response(response) + return response def mock_db_commit(self, mocker, orig=None): mocked_commit = mocker.patch("project.db.session.commit") @@ -105,14 +152,23 @@ class UtilActions(object): url = url_for(endpoint, **values, _external=False) return url + def get(self, url): + return self._client.get(url) + def get_ok(self, url): - response = self._client.get(url) + response = self.get(url) self.assert_response_ok(response) return response def assert_response_ok(self, response): assert response.status_code == 200 + def assert_response_created(self, response): + assert response.status_code == 201 + + def assert_response_no_content(self, response): + assert response.status_code == 204 + def get_unauthorized(self, url): response = self._client.get(url) self.assert_response_unauthorized(response) @@ -146,3 +202,96 @@ class UtilActions(object): def assert_response_permission_missing(self, response, endpoint, **values): self.assert_response_redirect(response, endpoint, **values) + + def parse_query_parameters(self, url): + query = urlsplit(url).query + params = parse_qs(query) + return {k: v[0] for k, v in params.items()} + + def authorize(self, client_id, client_secret, scope): + # Authorize-Seite öffnen + redirect_uri = self.get_url("swagger_oauth2_redirect") + url = self.get_url( + "authorize", + response_type="code", + client_id=client_id, + scope=scope, + redirect_uri=redirect_uri, + ) + response = self.get_ok(url) + + # Authorisieren + response = self.post_form( + url, + response, + {}, + ) + + assert response.status_code == 302 + assert redirect_uri in response.headers["Location"] + + # Code aus der Redirect-Antwort lesen + params = self.parse_query_parameters(response.headers["Location"]) + assert "code" in params + code = params["code"] + + # Mit dem Code den Access-Token abfragen + token_url = self.get_url("issue_token") + response = self.post_form_data( + token_url, + data={ + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "authorization_code", + "scope": scope, + "code": code, + "redirect_uri": redirect_uri, + }, + ) + + self.assert_response_ok(response) + assert response.content_type == "application/json" + assert "access_token" in response.json + assert "expires_in" in response.json + assert "refresh_token" in response.json + assert response.json["scope"] == scope + assert response.json["token_type"] == "Bearer" + + self._client_id = client_id + self._client_secret = client_secret + self._access_token = response.json["access_token"] + self._refresh_token = response.json["refresh_token"] + + def refresh_token(self): + token_url = self.get_url("issue_token") + response = self.post_form_data( + token_url, + data={ + "grant_type": "refresh_token", + "refresh_token": self._refresh_token, + "client_id": self._client_id, + "client_secret": self._client_secret, + }, + ) + + self.assert_response_ok(response) + assert response.content_type == "application/json" + assert response.json["token_type"] == "Bearer" + assert "access_token" in response.json + assert "expires_in" in response.json + + self._access_token = response.json["access_token"] + + def revoke_token(self): + url = self.get_url("revoke_token") + response = self.post_form_data( + url, + data={ + "token": self._access_token, + "token_type_hint": "access_token", + "client_id": self._client_id, + "client_secret": self._client_secret, + }, + ) + + self.assert_response_ok(response) diff --git a/tests/views/test_event.py b/tests/views/test_event.py index ddaaa93..c37d2a9 100644 --- a/tests/views/test_event.py +++ b/tests/views/test_event.py @@ -430,7 +430,7 @@ def test_delete_nameDoesNotMatch(client, seeder, utils, app, mocker): def test_rrule(client, seeder, utils, app): url = utils.get_url("event_rrule") - json = utils.post_json( + response = utils.post_json( url, { "year": 2020, @@ -440,6 +440,7 @@ def test_rrule(client, seeder, utils, app): "start": 0, }, ) + json = response.json assert json["batch"]["batch_size"] == 10 diff --git a/tests/views/test_oauth.py b/tests/views/test_oauth.py new file mode 100644 index 0000000..dc41d9b --- /dev/null +++ b/tests/views/test_oauth.py @@ -0,0 +1,25 @@ +def test_authorize_unauthorizedRedirects(seeder, utils): + url = utils.get_url("authorize") + response = utils.get(url) + + assert response.status_code == 302 + assert "login" in response.headers["Location"] + + +def test_authorize_validateThrowsError(seeder, utils): + seeder.setup_base() + url = utils.get_url("authorize") + response = utils.get(url) + + utils.assert_response_error_message(response, b"invalid_grant") + + +def test_revoke_token(seeder, utils): + seeder.setup_api_access() + utils.revoke_token() + + +def test_swagger_redirect(utils): + url = utils.get_url("swagger_oauth2_redirect") + response = utils.get(url) + assert response.status_code == 302 diff --git a/tests/views/test_oauth2_client.py b/tests/views/test_oauth2_client.py new file mode 100644 index 0000000..330acb8 --- /dev/null +++ b/tests/views/test_oauth2_client.py @@ -0,0 +1,136 @@ +import pytest + + +def test_read(client, seeder, utils): + user_id, admin_unit_id = seeder.setup_base(True) + oauth2_client_id = seeder.insert_default_oauth2_client(user_id) + + url = utils.get_url("oauth2_client", id=oauth2_client_id) + utils.get_ok(url) + + +def test_read_notOwner(client, seeder, utils): + user_id = seeder.create_user(email="other@other.de", admin=True) + oauth2_client_id = seeder.insert_default_oauth2_client(user_id) + + seeder.setup_base(True) + url = utils.get_url("oauth2_client", id=oauth2_client_id) + utils.get_unauthorized(url) + + +def test_list(client, seeder, utils): + user_id, admin_unit_id = seeder.setup_base(True) + + url = utils.get_url("oauth2_clients") + utils.get_ok(url) + + +@pytest.mark.parametrize("db_error", [True, False]) +def test_create_authorization_code(client, app, utils, seeder, mocker, db_error): + from project.api import scope_list + + user_id, admin_unit_id = seeder.setup_base(True) + + url = utils.get_url("oauth2_client_create") + response = utils.get_ok(url) + + if db_error: + utils.mock_db_commit(mocker) + + response = utils.post_form( + url, + response, + { + "client_name": "Mein Client", + "scope": scope_list, + }, + ) + + if db_error: + utils.assert_response_db_error(response) + return + + with app.app_context(): + from project.models import OAuth2Client + + oauth2_client = OAuth2Client.query.filter( + OAuth2Client.user_id == user_id + ).first() + assert oauth2_client is not None + client_id = oauth2_client.id + + utils.assert_response_redirect(response, "oauth2_client", id=client_id) + + +@pytest.mark.parametrize("db_error", [True, False]) +def test_update(client, seeder, utils, app, mocker, db_error): + user_id, admin_unit_id = seeder.setup_base(True) + oauth2_client_id = seeder.insert_default_oauth2_client(user_id) + + url = utils.get_url("oauth2_client_update", id=oauth2_client_id) + response = utils.get_ok(url) + + if db_error: + utils.mock_db_commit(mocker) + + response = utils.post_form( + url, + response, + { + "client_name": "Neuer Name", + }, + ) + + if db_error: + utils.assert_response_db_error(response) + return + + utils.assert_response_redirect(response, "oauth2_client", id=oauth2_client_id) + + with app.app_context(): + from project.models import OAuth2Client + + oauth2_client = OAuth2Client.query.get(oauth2_client_id) + assert oauth2_client.client_name == "Neuer Name" + + +@pytest.mark.parametrize("db_error", [True, False]) +@pytest.mark.parametrize("non_match", [True, False]) +def test_delete(client, seeder, utils, app, mocker, db_error, non_match): + user_id, admin_unit_id = seeder.setup_base(True) + oauth2_client_id = seeder.insert_default_oauth2_client(user_id) + + url = utils.get_url("oauth2_client_delete", id=oauth2_client_id) + response = utils.get_ok(url) + + if db_error: + utils.mock_db_commit(mocker) + + form_name = "Mein Client" + + if non_match: + form_name = "Falscher Name" + + response = utils.post_form( + url, + response, + { + "name": form_name, + }, + ) + + if non_match: + utils.assert_response_error_message(response) + return + + if db_error: + utils.assert_response_db_error(response) + return + + utils.assert_response_redirect(response, "oauth2_clients") + + with app.app_context(): + from project.models import OAuth2Client + + oauth2_client = OAuth2Client.query.get(oauth2_client_id) + assert oauth2_client is None diff --git a/tests/views/test_oauth2_token.py b/tests/views/test_oauth2_token.py new file mode 100644 index 0000000..e319c80 --- /dev/null +++ b/tests/views/test_oauth2_token.py @@ -0,0 +1,47 @@ +import pytest + + +def test_list(client, seeder, utils): + user_id, admin_unit_id = seeder.setup_api_access() + + url = utils.get_url("oauth2_tokens") + utils.get_ok(url) + + +@pytest.mark.parametrize("db_error", [True, False]) +def test_revoke(client, seeder, utils, app, mocker, db_error): + user_id, admin_unit_id = seeder.setup_api_access() + + with app.app_context(): + from project.models import OAuth2Token + + oauth2_token = OAuth2Token.query.filter(OAuth2Token.user_id == user_id).first() + oauth2_token_id = oauth2_token.id + + url = utils.get_url("oauth2_token_revoke", id=oauth2_token_id) + response = utils.get_ok(url) + + if db_error: + utils.mock_db_commit(mocker) + + response = utils.post_form( + url, + response, + {}, + ) + + if db_error: + utils.assert_response_db_error(response) + return + + utils.assert_response_redirect(response, "oauth2_tokens") + + with app.app_context(): + from project.models import OAuth2Token + + oauth2_token = OAuth2Token.query.get(oauth2_token_id) + assert oauth2_token.revoked + + # Kann nicht zweimal revoked werden + response = utils.get(url) + utils.assert_response_redirect(response, "oauth2_tokens")