API Write Access with OAuth2 #104

This commit is contained in:
Daniel Grams 2021-02-05 14:50:50 +01:00
parent bc9d2aae3c
commit 6c2384e678
66 changed files with 1945 additions and 157 deletions

View File

@ -1,2 +0,0 @@
[run]
relative_files = True

View File

@ -5,7 +5,7 @@
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.testing.pytestArgs": [
"tests"
"tests", "--capture=sys"
],
"python.testing.unittestEnabled": false,
"python.testing.nosetestsEnabled": false,

View File

@ -0,0 +1,91 @@
"""empty message
Revision ID: ddb85cb1c21e
Revises: b1a6e7630185
Create Date: 2021-02-02 15:17:21.988363
"""
from alembic import op
import sqlalchemy as sa
import sqlalchemy_utils
from project import dbtypes
# revision identifiers, used by Alembic.
revision = "ddb85cb1c21e"
down_revision = "b1a6e7630185"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"oauth2_client",
sa.Column("client_id", sa.String(length=48), nullable=True),
sa.Column("client_secret", sa.String(length=120), nullable=True),
sa.Column("client_id_issued_at", sa.Integer(), nullable=False),
sa.Column("client_secret_expires_at", sa.Integer(), nullable=False),
sa.Column("client_metadata", sa.Text(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth2_client_client_id"), "oauth2_client", ["client_id"], unique=False
)
op.create_table(
"oauth2_code",
sa.Column("code", sa.String(length=120), nullable=False),
sa.Column("client_id", sa.String(length=48), nullable=True),
sa.Column("redirect_uri", sa.Text(), nullable=True),
sa.Column("response_type", sa.Text(), nullable=True),
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("nonce", sa.Text(), nullable=True),
sa.Column("auth_time", sa.Integer(), nullable=False),
sa.Column("code_challenge", sa.Text(), nullable=True),
sa.Column("code_challenge_method", sa.String(length=48), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("code"),
)
op.create_table(
"oauth2_token",
sa.Column("client_id", sa.String(length=48), nullable=True),
sa.Column("token_type", sa.String(length=40), nullable=True),
sa.Column("access_token", sa.String(length=255), nullable=False),
sa.Column("refresh_token", sa.String(length=255), nullable=True),
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("revoked", sa.Boolean(), nullable=True),
sa.Column("issued_at", sa.Integer(), nullable=False),
sa.Column("expires_in", sa.Integer(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("access_token"),
)
op.create_index(
op.f("ix_oauth2_token_refresh_token"),
"oauth2_token",
["refresh_token"],
unique=False,
)
op.create_unique_constraint(
"eventplace_name_admin_unit_id", "eventplace", ["name", "admin_unit_id"]
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("eventplace_name_admin_unit_id", "eventplace", type_="unique")
op.drop_index(op.f("ix_oauth2_token_refresh_token"), table_name="oauth2_token")
op.drop_table("oauth2_token")
op.drop_table("oauth2_code")
op.drop_index(op.f("ix_oauth2_client_client_id"), table_name="oauth2_client")
op.drop_table("oauth2_client")
# ### end Alembic commands ###

View File

@ -1,5 +1,5 @@
import os
from flask import Flask
from flask import Flask, url_for, redirect, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from flask_security import (
Security,
@ -10,12 +10,8 @@ from flask_cors import CORS
from flask_qrcode import QRcode
from flask_mail import Mail, email_dispatched
from flask_migrate import Migrate
from flask_marshmallow import Marshmallow
from flask_restful import Api
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_apispec.extension import FlaskApiSpec
from flask_gzip import Gzip
from webargs import flaskparser
# Create app
app = Flask(__name__)
@ -59,25 +55,6 @@ babel = Babel(app)
# cors
cors = CORS(app, resources={r"/api/*", "/swagger/"})
# API
rest_api = Api(app, "/api/v1")
marshmallow = Marshmallow(app)
marshmallow_plugin = MarshmallowPlugin()
app.config.update(
{
"APISPEC_SPEC": APISpec(
title="Oveda API",
version="0.1.0",
plugins=[marshmallow_plugin],
openapi_version="2.0",
info=dict(
description="This API provides endpoints to interact with the Oveda data. At the moment, there is no authorization needed."
),
),
}
)
api_docs = FlaskApiSpec(app)
# Mail
mail_server = os.getenv("MAIL_SERVER")
@ -108,6 +85,9 @@ if app.config["MAIL_SUPPRESS_SEND"]:
db = SQLAlchemy(app)
migrate = Migrate(app, db)
# API
from project.api import RestApi
# qr code
QRcode(app)
@ -123,6 +103,13 @@ from project.forms.security import ExtendedRegisterForm
user_datastore = SQLAlchemySessionUserDatastore(db.session, User, Role)
security = Security(app, user_datastore, register_form=ExtendedRegisterForm)
# OAuth2
from project.oauth2 import config_oauth
config_oauth(app)
# Init misc modules
from project import i10n
from project import jinja_filters
from project import init_data
@ -142,6 +129,9 @@ from project.views import (
image,
manage,
organizer,
oauth,
oauth2_client,
oauth2_token,
planing,
reference,
reference_request,

View File

@ -1,4 +1,5 @@
from flask import abort
from flask_login import login_user
from flask_security import current_user
from flask_security.utils import FsPermNeed
from flask_principal import Permission
@ -12,6 +13,20 @@ def has_current_user_permission(permission):
return user_perm.can()
def has_owner_access(user_id):
return user_id == current_user.id
def owner_access_or_401(user_id):
if not has_owner_access(user_id):
abort(401)
def login_api_user_or_401(user):
if not login_user(user):
abort(401)
def has_admin_unit_member_role(admin_unit_member, role_name):
for role in admin_unit_member.roles:
if role.name == role_name:

View File

@ -1,4 +1,107 @@
from project import rest_api, api_docs
from flask_restful import Api
from sqlalchemy.exc import IntegrityError
from psycopg2.errorcodes import UNIQUE_VIOLATION
from werkzeug.exceptions import HTTPException, UnprocessableEntity
from marshmallow import ValidationError
from project.utils import get_localized_scope
from project import app
from flask_marshmallow import Marshmallow
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_apispec.extension import FlaskApiSpec
class RestApi(Api):
def handle_error(self, err):
from project.api.schemas import (
ErrorResponseSchema,
UnprocessableEntityResponseSchema,
)
schema = None
data = {}
code = 500
if (
isinstance(err, IntegrityError)
and err.orig
and err.orig.pgcode == UNIQUE_VIOLATION
):
data["name"] = "Unique Violation"
data[
"message"
] = "An entry with the entered values already exists. Duplicate entries are not allowed."
code = 400
schema = ErrorResponseSchema()
elif isinstance(err, HTTPException):
data["name"] = err.name
data["message"] = err.description
code = err.code
if (
isinstance(err, UnprocessableEntity)
and err.exc
and isinstance(err.exc, ValidationError)
):
data["name"] = err.name
data["message"] = err.description
code = err.code
schema = UnprocessableEntityResponseSchema()
if (
getattr(err.exc, "args", None)
and isinstance(err.exc.args, tuple)
and len(err.exc.args) > 0
):
arg = err.exc.args[0]
if isinstance(arg, dict):
errors = []
for field, messages in arg.items():
if isinstance(messages, list):
for message in messages:
error = {"field": field, "message": message}
errors.append(error)
if len(errors) > 0:
data["errors"] = errors
else:
schema = ErrorResponseSchema()
# Call default error handler that propagates error further
try:
super().handle_error(err)
except Exception:
if not schema:
raise
return schema.dump(data), code
scope_list = [
"organizer:write",
"place:write",
"event:write",
]
scopes = {k: get_localized_scope(k) for v, k in enumerate(scope_list)}
rest_api = RestApi(app, "/api/v1", catch_all_404s=True)
marshmallow = Marshmallow(app)
marshmallow_plugin = MarshmallowPlugin()
app.config.update(
{
"APISPEC_SPEC": APISpec(
title="Oveda API",
version="0.1.0",
plugins=[marshmallow_plugin],
openapi_version="2.0",
info=dict(
description="This API provides endpoints to interact with the Oveda data."
),
),
}
)
api_docs = FlaskApiSpec(app)
def enum_to_properties(self, field, **kwargs):
@ -17,8 +120,6 @@ def add_api_resource(resource, url, endpoint):
api_docs.register(resource, endpoint=endpoint)
from project import marshmallow_plugin
marshmallow_plugin.converter.add_attribute_function(enum_to_properties)
import project.api.event.resources
@ -26,8 +127,6 @@ import project.api.event_category.resources
import project.api.event_date.resources
import project.api.event_reference.resources
import project.api.dump.resources
import project.api.image.resources
import project.api.location.resources
import project.api.organization.resources
import project.api.organizer.resources
import project.api.place.resources

View File

@ -1,4 +1,4 @@
from project import marshmallow
from project.api import marshmallow
from marshmallow import fields
from project.api.event.schemas import EventDumpSchema
from project.api.place.schemas import PlaceDumpSchema

View File

@ -1,4 +1,4 @@
from project import marshmallow
from project.api import marshmallow
from marshmallow import fields, validate
from marshmallow_enum import EnumField
from project.models import (
@ -10,7 +10,7 @@ from project.models import (
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
from project.api.organization.schemas import OrganizationRefSchema
from project.api.organizer.schemas import OrganizerRefSchema
from project.api.image.schemas import ImageRefSchema
from project.api.image.schemas import ImageSchema
from project.api.place.schemas import PlaceRefSchema, PlaceSearchItemSchema
from project.api.event_category.schemas import (
EventCategoryRefSchema,
@ -55,7 +55,7 @@ class EventSchema(EventBaseSchema):
organization = fields.Nested(OrganizationRefSchema, attribute="admin_unit")
organizer = fields.Nested(OrganizerRefSchema)
place = fields.Nested(PlaceRefSchema, attribute="event_place")
photo = fields.Nested(ImageRefSchema)
photo = fields.Nested(ImageSchema)
categories = fields.List(fields.Nested(EventCategoryRefSchema))
@ -85,7 +85,7 @@ class EventSearchItemSchema(EventRefSchema):
start = marshmallow.auto_field()
end = marshmallow.auto_field()
recurrence_rule = marshmallow.auto_field()
photo = fields.Nested(ImageRefSchema)
photo = fields.Nested(ImageSchema)
place = fields.Nested(PlaceSearchItemSchema, attribute="event_place")
status = EnumField(EventStatus)
booked_up = marshmallow.auto_field()

View File

@ -1,5 +1,5 @@
from marshmallow import fields
from project import marshmallow
from project.api import marshmallow
from project.models import EventCategory
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema

View File

@ -1,4 +1,4 @@
from project import marshmallow
from project.api import marshmallow
from marshmallow import fields
from project.models import EventDate
from project.api.event.schemas import (

View File

@ -1,5 +1,5 @@
from marshmallow import fields
from project import marshmallow
from project.api import marshmallow
from project.models import EventReference
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
from project.api.event.schemas import EventRefSchema

15
project/api/fields.py Normal file
View File

@ -0,0 +1,15 @@
from marshmallow import fields, ValidationError
class NumericStr(fields.String):
def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return str(value)
def _deserialize(self, value, attr, data, **kwargs):
try:
return float(value)
except ValueError as error:
raise ValidationError("Must be a numeric value.") from error

View File

@ -1,15 +0,0 @@
from project.api import add_api_resource
from flask_apispec import marshal_with, doc
from project.api.resources import BaseResource
from project.api.image.schemas import ImageSchema
from project.models import Image
class ImageResource(BaseResource):
@doc(summary="Get image", tags=["Images"])
@marshal_with(ImageSchema)
def get(self, id):
return Image.query.get_or_404(id)
add_api_resource(ImageResource, "/images/<int:id>", "api_v1_image")

View File

@ -1,4 +1,4 @@
from project import marshmallow
from project.api import marshmallow
from project.models import Image
@ -10,8 +10,6 @@ class ImageIdSchema(marshmallow.SQLAlchemySchema):
class ImageBaseSchema(ImageIdSchema):
created_at = marshmallow.auto_field()
updated_at = marshmallow.auto_field()
copyright_text = marshmallow.auto_field()
@ -27,13 +25,3 @@ class ImageSchema(ImageBaseSchema):
class ImageDumpSchema(ImageBaseSchema):
pass
class ImageRefSchema(ImageIdSchema):
image_url = marshmallow.URLFor(
"image",
values=dict(id="<id>", s=500),
metadata={
"description": "Append query arguments w for width, h for height or s for size(width and height)."
},
)

View File

@ -1,15 +0,0 @@
from project.api import add_api_resource
from flask_apispec import marshal_with, doc
from project.api.resources import BaseResource
from project.api.location.schemas import LocationSchema
from project.models import Location
class LocationResource(BaseResource):
@doc(summary="Get location", tags=["Locations"])
@marshal_with(LocationSchema)
def get(self, id):
return Location.query.get_or_404(id)
add_api_resource(LocationResource, "/locations/<int:id>", "api_v1_location")

View File

@ -1,38 +1,65 @@
from marshmallow import fields
from project import marshmallow
from marshmallow import fields, validate
from project.api import marshmallow
from project.models import Location
from project.api.fields import NumericStr
class LocationIdSchema(marshmallow.SQLAlchemySchema):
class Meta:
model = Location
id = marshmallow.auto_field()
class LocationSchema(LocationIdSchema):
created_at = marshmallow.auto_field()
updated_at = marshmallow.auto_field()
street = marshmallow.auto_field()
postalCode = marshmallow.auto_field()
city = marshmallow.auto_field()
state = marshmallow.auto_field()
country = marshmallow.auto_field()
longitude = fields.Str()
latitude = fields.Str()
longitude = NumericStr()
latitude = NumericStr()
class LocationDumpSchema(LocationSchema):
pass
class LocationRefSchema(LocationIdSchema):
class LocationSearchItemSchema(LocationSchema):
pass
class LocationSearchItemSchema(LocationRefSchema):
class LocationPostRequestSchema(marshmallow.SQLAlchemySchema):
class Meta:
model = Location
longitude = fields.Str()
latitude = fields.Str()
street = fields.Str(validate=validate.Length(max=255), missing=None)
postalCode = fields.Str(validate=validate.Length(max=10), missing=None)
city = fields.Str(validate=validate.Length(max=255), missing=None)
state = fields.Str(validate=validate.Length(max=255), missing=None)
country = fields.Str(validate=validate.Length(max=255), missing=None)
longitude = NumericStr(validate=validate.Range(-180, 180), missing=None)
latitude = NumericStr(validate=validate.Range(-90, 90), missing=None)
class LocationPostRequestLoadSchema(LocationPostRequestSchema):
class Meta:
model = Location
load_instance = True
class LocationPatchRequestSchema(marshmallow.SQLAlchemySchema):
class Meta:
model = Location
street = fields.Str(validate=validate.Length(max=255), allow_none=True)
postalCode = fields.Str(validate=validate.Length(max=10), allow_none=True)
city = fields.Str(validate=validate.Length(max=255), allow_none=True)
state = fields.Str(validate=validate.Length(max=255), allow_none=True)
country = fields.Str(validate=validate.Length(max=255), allow_none=True)
longitude = NumericStr(validate=validate.Range(-180, 180), allow_none=True)
latitude = NumericStr(validate=validate.Range(-90, 90), allow_none=True)
class LocationPatchRequestLoadSchema(LocationPatchRequestSchema):
class Meta:
model = Location
load_instance = True

View File

@ -27,7 +27,13 @@ from project.services.reference import (
get_reference_incoming_query,
get_reference_outgoing_query,
)
from project.api.place.schemas import PlaceListRequestSchema, PlaceListResponseSchema
from project.api.place.schemas import (
PlaceListRequestSchema,
PlaceListResponseSchema,
PlaceIdSchema,
PlacePostRequestSchema,
PlacePostRequestLoadSchema,
)
from project.services.event import get_event_dates_query, get_events_query
from project.services.event_search import EventSearchParams
from project.services.admin_unit import (
@ -35,6 +41,14 @@ from project.services.admin_unit import (
get_organizer_query,
get_place_query,
)
from project.oauth2 import require_oauth
from authlib.integrations.flask_oauth2 import current_token
from project import db
from project.access import (
access_or_401,
get_admin_unit_for_manage_or_404,
login_api_user_or_401,
)
class OrganizationResource(BaseResource):
@ -113,6 +127,26 @@ class OrganizationPlaceListResource(BaseResource):
pagination = get_place_query(admin_unit.id, name).paginate()
return pagination
@doc(
summary="Add new place",
tags=["Organizations", "Places"],
security=[{"oauth2": ["place:write"]}],
)
@use_kwargs(PlacePostRequestSchema, location="json")
@marshal_with(PlaceIdSchema, 201)
@require_oauth("place:write")
def post(self, id, **kwargs):
login_api_user_or_401(current_token.user)
admin_unit = get_admin_unit_for_manage_or_404(id)
access_or_401(admin_unit, "place:create")
place = PlacePostRequestLoadSchema().load(kwargs, session=db.session)
place.admin_unit_id = admin_unit.id
db.session.add(place)
db.session.commit()
return place, 201
class OrganizationIncomingEventReferenceListResource(BaseResource):
@doc(

View File

@ -1,8 +1,8 @@
from marshmallow import fields
from project import marshmallow
from project.api import marshmallow
from project.models import AdminUnit
from project.api.location.schemas import LocationRefSchema
from project.api.image.schemas import ImageRefSchema
from project.api.location.schemas import LocationSchema
from project.api.image.schemas import ImageSchema
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
@ -25,8 +25,8 @@ class OrganizationBaseSchema(OrganizationIdSchema):
class OrganizationSchema(OrganizationBaseSchema):
location = fields.Nested(LocationRefSchema)
logo = fields.Nested(ImageRefSchema)
location = fields.Nested(LocationSchema)
logo = fields.Nested(ImageSchema)
class OrganizationDumpSchema(OrganizationBaseSchema):

View File

@ -1,8 +1,8 @@
from marshmallow import fields
from project import marshmallow
from project.api import marshmallow
from project.models import EventOrganizer
from project.api.location.schemas import LocationRefSchema
from project.api.image.schemas import ImageRefSchema
from project.api.location.schemas import LocationSchema
from project.api.image.schemas import ImageSchema
from project.api.organization.schemas import OrganizationRefSchema
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
@ -25,8 +25,8 @@ class OrganizerBaseSchema(OrganizerIdSchema):
class OrganizerSchema(OrganizerBaseSchema):
location = fields.Nested(LocationRefSchema)
logo = fields.Nested(ImageRefSchema)
location = fields.Nested(LocationSchema)
logo = fields.Nested(ImageSchema)
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")

View File

@ -1,8 +1,19 @@
from project.api import add_api_resource
from flask_apispec import marshal_with, doc
from flask import make_response
from flask_apispec import marshal_with, doc, use_kwargs
from project.api.resources import BaseResource
from project.api.place.schemas import PlaceSchema
from project.api.place.schemas import (
PlaceSchema,
PlacePostRequestSchema,
PlacePostRequestLoadSchema,
PlacePatchRequestSchema,
PlacePatchRequestLoadSchema,
)
from project.models import EventPlace
from project.oauth2 import require_oauth
from authlib.integrations.flask_oauth2 import current_token
from project import db
from project.access import access_or_401, login_api_user_or_401
class PlaceResource(BaseResource):
@ -11,5 +22,54 @@ class PlaceResource(BaseResource):
def get(self, id):
return EventPlace.query.get_or_404(id)
@doc(
summary="Update place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
)
@use_kwargs(PlacePostRequestSchema, location="json")
@marshal_with(None, 204)
@require_oauth("place:write")
def put(self, id, **kwargs):
login_api_user_or_401(current_token.user)
place = EventPlace.query.get_or_404(id)
access_or_401(place.adminunit, "place:update")
place = PlacePostRequestLoadSchema().load(
kwargs, session=db.session, instance=place
)
db.session.commit()
return make_response("", 204)
@doc(summary="Patch place", tags=["Places"], security=[{"oauth2": ["place:write"]}])
@use_kwargs(PlacePatchRequestSchema, location="json")
@marshal_with(None, 204)
@require_oauth("place:write")
def patch(self, id, **kwargs):
login_api_user_or_401(current_token.user)
place = EventPlace.query.get_or_404(id)
access_or_401(place.adminunit, "place:update")
place = PlacePatchRequestLoadSchema().load(
kwargs, session=db.session, instance=place
)
db.session.commit()
return make_response("", 204)
@doc(
summary="Delete place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
)
@marshal_with(None, 204)
@require_oauth("place:write")
def delete(self, id):
login_api_user_or_401(current_token.user)
place = EventPlace.query.get_or_404(id)
access_or_401(place.adminunit, "place:delete")
db.session.delete(place)
db.session.commit()
return make_response("", 204)
add_api_resource(PlaceResource, "/places/<int:id>", "api_v1_place")

View File

@ -1,8 +1,15 @@
from marshmallow import fields
from project import marshmallow
from marshmallow import fields, validate
from project.api import marshmallow
from project.models import EventPlace
from project.api.image.schemas import ImageRefSchema
from project.api.location.schemas import LocationRefSchema, LocationSearchItemSchema
from project.api.image.schemas import ImageSchema
from project.api.location.schemas import (
LocationSchema,
LocationSearchItemSchema,
LocationPostRequestSchema,
LocationPostRequestLoadSchema,
LocationPatchRequestSchema,
LocationPatchRequestLoadSchema,
)
from project.api.organization.schemas import OrganizationRefSchema
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
@ -23,8 +30,8 @@ class PlaceBaseSchema(PlaceIdSchema):
class PlaceSchema(PlaceBaseSchema):
location = fields.Nested(LocationRefSchema)
photo = fields.Nested(ImageRefSchema)
location = fields.Nested(LocationSchema)
photo = fields.Nested(ImageSchema)
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
@ -55,3 +62,41 @@ class PlaceListResponseSchema(PaginationResponseSchema):
items = fields.List(
fields.Nested(PlaceRefSchema), metadata={"description": "Places"}
)
class PlacePostRequestSchema(marshmallow.SQLAlchemySchema):
class Meta:
model = EventPlace
name = fields.Str(required=True, validate=validate.Length(min=3, max=255))
url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None)
description = fields.Str(missing=None)
location = fields.Nested(LocationPostRequestSchema, missing=None)
class PlacePostRequestLoadSchema(PlacePostRequestSchema):
class Meta:
model = EventPlace
load_instance = True
location = fields.Nested(LocationPostRequestLoadSchema, missing=None)
class PlacePatchRequestSchema(marshmallow.SQLAlchemySchema):
class Meta:
model = EventPlace
name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True)
url = fields.Str(
validate=[validate.URL(), validate.Length(max=255)], allow_none=True
)
description = fields.Str(allow_none=True)
location = fields.Nested(LocationPatchRequestSchema, allow_none=True)
class PlacePatchRequestLoadSchema(PlacePatchRequestSchema):
class Meta:
model = EventPlace
load_instance = True
location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True)

View File

@ -1,6 +1,8 @@
from flask import request
from flask_apispec import marshal_with
from flask_apispec.views import MethodResource
from functools import wraps
from project.api.schemas import ErrorResponseSchema, UnprocessableEntityResponseSchema
def etag_cache(func):
@ -13,5 +15,7 @@ def etag_cache(func):
return wrapper
@marshal_with(ErrorResponseSchema, 400, "Bad Request")
@marshal_with(UnprocessableEntityResponseSchema, 422, "Unprocessable Entity")
class BaseResource(MethodResource):
decorators = [etag_cache]

View File

@ -1,7 +1,21 @@
from project import marshmallow
from project.api import marshmallow
from marshmallow import fields, validate
class ErrorResponseSchema(marshmallow.Schema):
name = fields.Str()
message = fields.Str()
class UnprocessableEntityErrorSchema(marshmallow.Schema):
field = fields.Str()
message = fields.Str()
class UnprocessableEntityResponseSchema(ErrorResponseSchema):
errors = fields.List(fields.Nested(UnprocessableEntityErrorSchema))
class PaginationRequestSchema(marshmallow.Schema):
page = fields.Integer(
required=False,

View File

@ -0,0 +1,92 @@
from flask_wtf import FlaskForm
from flask_babelex import lazy_gettext
from wtforms import StringField, TextAreaField, SubmitField, SelectField
from wtforms.validators import Optional, DataRequired
from project.forms.widgets import MultiCheckboxField
from project.api import scopes
from project.utils import split_by_crlf
import os
class BaseOAuth2ClientForm(FlaskForm):
client_name = StringField(lazy_gettext("Client name"), validators=[DataRequired()])
redirect_uris = TextAreaField(
lazy_gettext("Redirect URIs"), validators=[Optional()]
)
grant_types = MultiCheckboxField(
lazy_gettext("Grant types"),
validators=[DataRequired()],
choices=[
("authorization_code", lazy_gettext("Authorization Code")),
("refresh_token", lazy_gettext("Refresh Token")),
],
default=["authorization_code", "refresh_token"],
)
response_types = MultiCheckboxField(
lazy_gettext("Response types"),
validators=[DataRequired()],
choices=[
("code", "code"),
],
default=["code"],
)
scope = MultiCheckboxField(
lazy_gettext("Scopes"),
validators=[DataRequired()],
choices=[(k, k) for k, v in scopes.items()],
)
token_endpoint_auth_method = SelectField(
lazy_gettext("Token endpoint auth method"),
validators=[DataRequired()],
choices=[
("client_secret_post", lazy_gettext("Client secret post")),
("client_secret_basic", lazy_gettext("Client secret basic")),
],
)
submit = SubmitField(lazy_gettext("Save"))
def populate_obj(self, obj):
meta_keys = [
"client_name",
"client_uri",
"grant_types",
"redirect_uris",
"response_types",
"scope",
"token_endpoint_auth_method",
]
metadata = dict()
for name, field in self._fields.items():
if name in meta_keys:
if name == "redirect_uris":
metadata[name] = split_by_crlf(field.data)
elif name == "scope":
metadata[name] = " ".join(field.data)
else:
metadata[name] = field.data
else:
field.populate_obj(obj, name)
obj.set_client_metadata(metadata)
def process(self, formdata=None, obj=None, data=None, **kwargs):
super().process(formdata, obj, data, **kwargs)
if not obj:
return
self.redirect_uris.data = os.linesep.join(obj.redirect_uris)
self.scope.data = obj.scope.split(" ")
class CreateOAuth2ClientForm(BaseOAuth2ClientForm):
pass
class UpdateOAuth2ClientForm(BaseOAuth2ClientForm):
pass
class DeleteOAuth2ClientForm(FlaskForm):
submit = SubmitField(lazy_gettext("Delete OAuth2 client"))
name = StringField(lazy_gettext("Name"), validators=[DataRequired()])

View File

@ -0,0 +1,7 @@
from flask_wtf import FlaskForm
from flask_babelex import lazy_gettext
from wtforms import SubmitField
class RevokeOAuth2TokenForm(FlaskForm):
submit = SubmitField(lazy_gettext("Revoke OAuth2 token"))

View File

@ -1,7 +1,9 @@
from flask_security.forms import RegisterForm, EqualTo, get_form_field_label
from wtforms import BooleanField, PasswordField
from wtforms import BooleanField, PasswordField, SubmitField
from wtforms.validators import DataRequired
from project.forms.common import get_accept_tos_markup
from flask_wtf import FlaskForm
from flask_babelex import lazy_gettext
class ExtendedRegisterForm(RegisterForm):
@ -20,3 +22,8 @@ class ExtendedRegisterForm(RegisterForm):
def __init__(self, *args, **kwargs):
super(ExtendedRegisterForm, self).__init__(*args, **kwargs)
self._fields["accept_tos"].label.text = get_accept_tos_markup()
class AuthorizeForm(FlaskForm):
allow = SubmitField(lazy_gettext("Allow"))
deny = SubmitField(lazy_gettext("Deny"))

Binary file not shown.

Binary file not shown.

View File

@ -39,3 +39,8 @@ def print_dynamic_texts():
gettext("EventReviewStatus.inbox")
gettext("EventReviewStatus.verified")
gettext("EventReviewStatus.rejected")
gettext("read")
gettext("write")
gettext("Event")
gettext("Organizer")
gettext("Place")

View File

@ -1,8 +1,27 @@
from project import app, db
from project.api import api_docs, scopes
from project.services.user import upsert_user_role
from project.services.admin_unit import upsert_admin_unit_member_role
from project.services.event import upsert_event_category
from project.models import Location
from flask import url_for
from apispec.exceptions import DuplicateComponentNameError
@app.before_first_request
def add_oauth2_scheme():
oauth2_scheme = {
"type": "oauth2",
"authorizationUrl": url_for("authorize", _external=True),
"tokenUrl": url_for("issue_token", _external=True),
"flow": "accessCode",
"scopes": scopes,
}
try:
api_docs.spec.components.security_scheme("oauth2", oauth2_scheme)
except DuplicateComponentNameError: # pragma: no cover
pass
@app.before_first_request
@ -37,12 +56,23 @@ def create_initial_data():
"reference_request:delete",
"reference_request:verify",
]
early_adopter_permissions = [
"oauth2_client:create",
"oauth2_client:read",
"oauth2_client:update",
"oauth2_client:delete",
"oauth2_token:create",
"oauth2_token:read",
"oauth2_token:update",
"oauth2_token:delete",
]
upsert_admin_unit_member_role("admin", "Administrator", admin_permissions)
upsert_admin_unit_member_role("event_verifier", "Event expert", event_permissions)
upsert_user_role("admin", "Administrator", admin_permissions)
upsert_user_role("event_verifier", "Event expert", event_permissions)
upsert_user_role("early_adopter", "Early Adopter", early_adopter_permissions)
Location.update_coordinates()

View File

@ -1,5 +1,9 @@
from project import app
from project.utils import get_event_category_name, get_localized_enum_name
from project.utils import (
get_event_category_name,
get_localized_enum_name,
get_localized_scope,
)
from urllib.parse import quote_plus
import os
@ -8,10 +12,16 @@ def env_override(value, key):
return os.getenv(key, value)
def is_list(value):
return isinstance(value, list)
app.jinja_env.filters["event_category_name"] = lambda u: get_event_category_name(u)
app.jinja_env.filters["loc_enum"] = lambda u: get_localized_enum_name(u)
app.jinja_env.filters["loc_scope"] = lambda s: get_localized_scope(s)
app.jinja_env.filters["env_override"] = env_override
app.jinja_env.filters["quote_plus"] = lambda u: quote_plus(u)
app.jinja_env.filters["is_list"] = is_list
@app.context_processor

View File

@ -1,7 +1,7 @@
from project import db
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship, backref, deferred
from sqlalchemy.orm import relationship, backref, deferred, object_session
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.event import listens_for
from sqlalchemy import (
@ -24,6 +24,12 @@ import datetime
from project.dbtypes import IntegerEnum
from geoalchemy2 import Geometry
from sqlalchemy import and_
from authlib.integrations.sqla_oauth2 import (
OAuth2ClientMixin,
OAuth2AuthorizationCodeMixin,
OAuth2TokenMixin,
)
import time
# Base
@ -156,6 +162,12 @@ class User(db.Model, UserMixin):
"Role", secondary="roles_users", backref=backref("users", lazy="dynamic")
)
def get_user_id(self):
return self.id
# OAuth Consumer: Wenn wir OAuth consumen und sich ein Nutzer per Google oder Facebook anmelden möchte
class OAuth(OAuthConsumerMixin, db.Model):
provider_user_id = Column(String(256), unique=True, nullable=False)
@ -163,6 +175,51 @@ class OAuth(OAuthConsumerMixin, db.Model):
user = db.relationship("User")
# OAuth Server: Wir bieten an, dass sich ein Nutzer per OAuth2 auf unserer Seite anmeldet
class OAuth2Client(db.Model, OAuth2ClientMixin):
__tablename__ = "oauth2_client"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
def check_redirect_uri(self, redirect_uri):
return True
class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin):
__tablename__ = "oauth2_code"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
class OAuth2Token(db.Model, OAuth2TokenMixin):
__tablename__ = "oauth2_token"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
@property
def client(self):
return (
object_session(self)
.query(OAuth2Client)
.filter(OAuth2Client.client_id == self.client_id)
.first()
)
def is_refresh_token_active(self):
if self.revoked:
return False
expires_at = self.issued_at + self.expires_in * 2
return expires_at >= time.time()
# Admin Unit
@ -298,6 +355,7 @@ def update_location_coordinate(mapper, connect, self):
# Events
class EventPlace(db.Model, TrackableMixin):
__tablename__ = "eventplace"
__table_args__ = (UniqueConstraint("name", "admin_unit_id"),)
id = Column(Integer(), primary_key=True)
name = Column(Unicode(255), nullable=False)
location_id = db.Column(db.Integer, db.ForeignKey("location.id"))

114
project/oauth2.py Normal file
View File

@ -0,0 +1,114 @@
from authlib.integrations.flask_oauth2 import (
AuthorizationServer,
ResourceProtector,
)
from authlib.integrations.sqla_oauth2 import (
create_query_client_func,
create_save_token_func,
create_bearer_token_validator,
create_query_token_func,
)
from authlib.oauth2.rfc6749 import grants
from authlib.oauth2.rfc7636 import CodeChallenge
from project import db
from project.models import User, OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
"client_secret_basic",
"client_secret_post",
"none",
]
def save_authorization_code(self, code, request):
code_challenge = request.data.get("code_challenge")
code_challenge_method = request.data.get("code_challenge_method")
auth_code = OAuth2AuthorizationCode(
code=code,
client_id=request.client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
user_id=request.user.id,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
)
db.session.add(auth_code)
db.session.commit()
return auth_code
def query_authorization_code(self, code, client):
auth_code = OAuth2AuthorizationCode.query.filter_by(
code=code, client_id=client.client_id
).first()
if auth_code and not auth_code.is_expired():
return auth_code
def delete_authorization_code(self, authorization_code):
db.session.delete(authorization_code)
db.session.commit()
def authenticate_user(self, authorization_code):
return User.query.get(authorization_code.user_id)
class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
def authenticate_refresh_token(self, refresh_token):
token = OAuth2Token.query.filter_by(refresh_token=refresh_token).first()
if token and token.is_refresh_token_active():
return token
def authenticate_user(self, credential):
return User.query.get(credential.user_id)
def revoke_old_credential(self, credential):
credential.revoked = True
db.session.add(credential)
db.session.commit()
query_client = create_query_client_func(db.session, OAuth2Client)
save_token = create_save_token_func(db.session, OAuth2Token)
authorization = AuthorizationServer(
query_client=query_client,
save_token=save_token,
)
require_oauth = ResourceProtector()
def create_revocation_endpoint(session, token_model):
from authlib.oauth2.rfc7009 import RevocationEndpoint
query_token = create_query_token_func(session, token_model)
class _RevocationEndpoint(RevocationEndpoint):
CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
def query_token(self, token, token_type_hint, client):
return query_token(token, token_type_hint, client)
def revoke_token(self, token):
token.revoked = True
session.add(token)
session.commit()
return _RevocationEndpoint
def config_oauth(app):
app.config["OAUTH2_REFRESH_TOKEN_GENERATOR"] = True
authorization.init_app(app)
# support grants
authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
authorization.register_grant(RefreshTokenGrant)
# support revocation
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token)
authorization.register_endpoint(revocation_cls)
# protect resource
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
require_oauth.register_token_validator(bearer_cls())

View File

@ -0,0 +1,12 @@
from project.models import OAuth2Client
import time
from werkzeug.security import gen_salt
def complete_oauth2_client(oauth2_client: OAuth2Client) -> None:
if not oauth2_client.id:
oauth2_client.client_id = gen_salt(24)
oauth2_client.client_id_issued_at = int(time.time())
if oauth2_client.client_secret is None:
oauth2_client.client_secret = gen_salt(48)

View File

@ -17,7 +17,7 @@ def add_roles_to_user(email, roles):
def add_admin_roles_to_user(email):
add_roles_to_user(email, ["admin", "event_verifier"])
add_roles_to_user(email, ["admin", "event_verifier", "early_adopter"])
def remove_roles_from_user(email, roles):

View File

@ -160,6 +160,21 @@
{% endif %}
{% endmacro %}
{% macro render_kv_begin() %}
<dl class="row">
{% endmacro %}
{% macro render_kv_end() %}
</dl>
{% endmacro %}
{% macro render_kv_prop(prop, label_key = None) %}
{% if prop %}
<dt class="col-sm-3">{{ _(label_key) }}</dt>
<dd class="col-sm-9">{% if prop|is_list %}{{ prop|join(', ') }}{% else %}{{ prop }}{% endif %}</dd>
{% endif %}
{% endmacro %}
{% macro render_string_prop(prop, icon = None, label_key = None) %}
{% if prop %}
<div>

View File

@ -0,0 +1,26 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_field_with_errors, render_field %}
{% block title %}
{{ _('Create OAuth2 client') }}
{% endblock %}
{% block content %}
<h1>{{ _('Create OAuth2 client') }}</h1>
<form action="" method="POST">
{{ form.hidden_tag() }}
<div class="card mb-4">
<div class="card-body">
{{ render_field_with_errors(form.client_name) }}
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
{{ render_field_with_errors(form.redirect_uris) }}
{{ render_field(form.submit) }}
</form>
{% endblock %}

View File

@ -0,0 +1,24 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_field_with_errors, render_field %}
{% block content %}
<h1>{{ _('Delete OAuth2 client') }} &quot;{{ oauth2_client.client_name }}&quot;</h1>
<form action="" method="POST">
{{ form.hidden_tag() }}
<div class="card mb-4">
<div class="card-header">
{{ _('OAuth2 client') }}
</div>
<div class="card-body">
{{ render_field_with_errors(form.name) }}
</div>
</div>
{{ render_field(form.submit) }}
</form>
{% endblock %}

View File

@ -0,0 +1,44 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_pagination %}
{% block title %}
{{ _('OAuth2 clients') }}
{% endblock %}
{% block content %}
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 clients') }}</li>
</ol>
</nav>
{% if current_user.has_permission('oauth2_client:create') %}
<div class="my-4">
<a class="btn btn-outline-secondary my-1" href="{{ url_for('oauth2_client_create') }}" role="button"><i class="fa fa-plus"></i> {{ _('Create OAuth2 client') }}</a>
</div>
{% endif %}
<div class="table-responsive">
<table class="table table-sm table-bordered table-hover table-striped">
<thead>
<tr>
<th>{{ _('Name') }}</th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
{% for oauth2_client in oauth2_clients %}
<tr>
<td><a href="{{ url_for('oauth2_client', id=oauth2_client.id) }}">{{ oauth2_client.client_name }}</a></td>
<td><a href="{{ url_for('oauth2_client_update', id=oauth2_client.id) }}">{{ _('Edit') }}</a></td>
<td><a href="{{ url_for('oauth2_client_delete', id=oauth2_client.id) }}">{{ _('Delete') }}</a></td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class="my-4">{{ render_pagination(pagination) }}</div>
{% endblock %}

View File

@ -0,0 +1,34 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_kv_begin, render_kv_end, render_kv_prop %}
{% block title %}
{{ oauth2_client.client_name }}
{% endblock %}
{% block header %}
<script type="application/ld+json">
{{ structured_data | safe }}
</script>
{% endblock %}
{% block content %}
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
<li class="breadcrumb-item"><a href="{{ url_for('oauth2_clients') }}">{{ _('OAuth2 clients') }}</a></li>
<li class="breadcrumb-item active" aria-current="page">{{ oauth2_client.client_name }}</li>
</ol>
</nav>
<div class="w-normal">
{{ render_kv_begin() }}
{{ render_kv_prop(oauth2_client.client_id, 'Client ID') }}
{{ render_kv_prop(oauth2_client.client_secret, 'Client secret') }}
{{ render_kv_prop(oauth2_client.client_uri, 'Client URI') }}
{{ render_kv_prop(oauth2_client.grant_types, 'Grant types') }}
{{ render_kv_prop(oauth2_client.redirect_uris, 'Redirect URIs') }}
{{ render_kv_prop(oauth2_client.response_types, 'Response types') }}
{{ render_kv_prop(oauth2_client.scope, 'Scope') }}
{{ render_kv_prop(oauth2_client.token_endpoint_auth_method, 'Token endpoint auth method') }}
{{ render_kv_end() }}
</div>
{% endblock %}

View File

@ -0,0 +1,27 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_field_with_errors, render_field %}
{% block title %}
{{ _('Update OAuth2 client') }}
{% endblock %}
{% block content %}
<h1>{{ _('Update OAuth2 client') }}</h1>
<form action="" method="POST">
{{ form.hidden_tag() }}
<div class="card mb-4">
<div class="card-body">
{{ render_field_with_errors(form.client_name) }}
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
{{ render_field_with_errors(form.redirect_uris) }}
</div>
</div>
{{ render_field(form.submit) }}
</form>
{% endblock %}

View File

@ -0,0 +1,40 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_pagination %}
{% block title %}
{{ _('OAuth2 tokens') }}
{% endblock %}
{% block content %}
<nav aria-label="breadcrumb">
<ol class="breadcrumb">
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 tokens') }}</li>
</ol>
</nav>
<div class="table-responsive">
<table class="table table-sm table-bordered table-hover table-striped">
<thead>
<tr>
<th>{{ _('Client') }}</th>
<th>{{ _('Scopes') }}</th>
<th>{{ _('Status') }}</th>
<th></th>
</tr>
</thead>
<tbody>
{% for oauth2_token in oauth2_tokens %}
<tr>
<td>{{ oauth2_token.client.client_name }}</td>
<td>{{ oauth2_token.client.scope }}</td>
<td>{% if oauth2_token.revoked %}{{ _('Revoked') }}{% else %}{{ _('Active') }}{% endif %}</td>
<td>{% if not oauth2_token.revoked %}<a href="{{ url_for('oauth2_token_revoke', id=oauth2_token.id) }}">{{ _('Revoke') }}</a>{% endif %}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class="my-4">{{ render_pagination(pagination) }}</div>
{% endblock %}

View File

@ -0,0 +1,16 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_field_with_errors, render_field %}
{% block content %}
<h1>{{ _('Revoke OAuth2 token') }}</h1>
<h2>{{ oauth2_token.client.client_name }}</h2>
<form action="" method="POST">
{{ form.hidden_tag() }}
{{ render_field(form.submit) }}
</form>
{% endblock %}

View File

@ -8,7 +8,23 @@
<h1>{{ current_user.email }}</h1>
<h2>{{ _('Profile') }}</h2>
<p><a href="{{ url_for_security('change_password') }}">{{ _fsdomain('Change password') }}</a></p>
<div class="list-group">
<a href="{{ url_for('security.change_password') }}" class="list-group-item">
{{ _fsdomain('Change password') }}
<i class="fa fa-caret-right"></i>
</a>
{% if current_user.has_permission('oauth2_client:read') %}
<a href="{{ url_for('oauth2_clients') }}" class="list-group-item">
{{ _('OAuth2 clients') }}
<i class="fa fa-caret-right"></i>
</a>
{% endif %}
<a href="{{ url_for('oauth2_tokens') }}" class="list-group-item">
{{ _('OAuth2 tokens') }}
<i class="fa fa-caret-right"></i>
</a>
</div>
{% if invitations %}
<h2>{{ _('Invitations') }}</h2>

View File

@ -0,0 +1,40 @@
{% extends "layout.html" %}
{% from "_macros.html" import render_field %}
{% block content %}
<div class="w-normal d-flex flex-column">
<div class="card mx-auto">
<div class="card-body">
<h5>{{ _('"%(client_name)s" wants to access your account', client_name=grant.client.client_name) }}</h5>
<p class="text-center"><strong>{{ user.email }}</strong></p>
<p class="mb-1">{{ _('This will allow "%(client_name)s" to:', client_name=grant.client.client_name) }}</p>
<ul>
{% for key, value in scopes.items() %}
<li>{{ value }}</li>
{% endfor %}
</ul>
<form action="" method="POST">
{{ form.hidden_tag() }}
<div class="d-flex flex-column mt-5">
<div class="mx-auto">
{{ render_field(form.allow, class="btn btn-success mx-auto") }}
</div>
<div class="mx-auto">
{{ render_field(form.deny, class="btn btn-light mx-auto") }}
</div>
</div>
</form>
</div>
</div>
</div>
{% endblock %}

View File

@ -4,7 +4,8 @@
{% block content %}
<h1>{{ _fsdomain('Login') }}</h1>
<form action="{{ url_for_security('login', next='manage') }}" method="POST" name="login_user_form">
{% set next = request.args['next'] if 'next' in request.args and 'authorize' in request.args['next'] else 'manage' %}
<form action="{{ url_for_security('login', next=next) }}" method="POST" name="login_user_form">
{{ login_user_form.hidden_tag() }}
{{ render_field_with_errors(login_user_form.email) }}
{{ render_field_with_errors(login_user_form.password) }}

View File

@ -11,9 +11,20 @@ def get_localized_enum_name(enum):
return lazy_gettext(enum.__class__.__name__ + "." + enum.name)
def get_localized_scope(scope: str) -> str:
type_name, action = scope.split(":")
loc_lazy_gettext = lazy_gettext(type_name.capitalize())
loc_action = lazy_gettext(action)
return f"{loc_lazy_gettext} ({loc_action})"
def make_dir(path):
try:
original_umask = os.umask(0)
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
finally:
os.umask(original_umask)
def split_by_crlf(s):
return [v for v in s.splitlines() if v]

53
project/views/oauth.py Normal file
View File

@ -0,0 +1,53 @@
from authlib.oauth2 import OAuth2Error
from project import app
from project.api import scopes
from flask_security import current_user
from flask import redirect, request, url_for, render_template
from project.forms.security import AuthorizeForm
from project.oauth2 import authorization
@app.route("/oauth/authorize", methods=["GET", "POST"])
def authorize():
user = current_user
if not user or not user.is_authenticated:
return redirect(url_for("security.login", next=request.url))
form = AuthorizeForm()
if form.validate_on_submit():
grant_user = user if form.allow.data else None
return authorization.create_authorization_response(grant_user=grant_user)
else:
try:
grant = authorization.validate_consent_request(end_user=user)
except OAuth2Error as error:
return error.error
grant_scopes = grant.request.scope.split(" ")
filtered_scopes = {k: scopes[k] for k in grant_scopes}
return render_template(
"security/authorize.html",
form=form,
scopes=filtered_scopes,
user=user,
grant=grant,
)
@app.route("/oauth/token", methods=["POST"])
def issue_token():
return authorization.create_token_response()
@app.route("/oauth/revoke", methods=["POST"])
def revoke_token():
return authorization.create_endpoint_response("revocation")
@app.route("/oauth2-redirect.html")
def swagger_oauth2_redirect():
return redirect(
url_for("flask-apispec.static", filename="oauth2-redirect.html", **request.args)
)

View File

@ -0,0 +1,125 @@
from project import app, db
from flask import render_template, redirect, flash, url_for
from flask_babelex import gettext
from flask_security import permissions_required, current_user
from project.models import OAuth2Client
from project.views.utils import (
get_pagination_urls,
handleSqlError,
flash_errors,
non_match_for_deletion,
)
from project.forms.oauth2_client import (
CreateOAuth2ClientForm,
UpdateOAuth2ClientForm,
DeleteOAuth2ClientForm,
)
from project.services.oauth2_client import complete_oauth2_client
from sqlalchemy.exc import SQLAlchemyError
from project.access import owner_access_or_401
@app.route("/oauth2_client/create", methods=("GET", "POST"))
@permissions_required("oauth2_client:create")
def oauth2_client_create():
form = CreateOAuth2ClientForm()
if form.validate_on_submit():
oauth2_client = OAuth2Client()
form.populate_obj(oauth2_client)
oauth2_client.user_id = current_user.id
complete_oauth2_client(oauth2_client)
try:
db.session.add(oauth2_client)
db.session.commit()
flash(gettext("OAuth2 client successfully created"), "success")
return redirect(url_for("oauth2_client", id=oauth2_client.id))
except SQLAlchemyError as e:
db.session.rollback()
flash(handleSqlError(e), "danger")
else:
flash_errors(form)
return render_template("oauth2_client/create.html", form=form)
@app.route("/oauth2_client/<int:id>/update", methods=("GET", "POST"))
@permissions_required("oauth2_client:update")
def oauth2_client_update(id):
oauth2_client = OAuth2Client.query.get_or_404(id)
owner_access_or_401(oauth2_client.user_id)
form = UpdateOAuth2ClientForm(obj=oauth2_client)
if form.validate_on_submit():
form.populate_obj(oauth2_client)
complete_oauth2_client(oauth2_client)
try:
db.session.commit()
flash(gettext("OAuth2 client successfully updated"), "success")
return redirect(url_for("oauth2_client", id=oauth2_client.id))
except SQLAlchemyError as e:
db.session.rollback()
flash(handleSqlError(e), "danger")
else:
flash_errors(form)
return render_template(
"oauth2_client/update.html", form=form, oauth2_client=oauth2_client
)
@app.route("/oauth2_client/<int:id>/delete", methods=("GET", "POST"))
@permissions_required("oauth2_client:delete")
def oauth2_client_delete(id):
oauth2_client = OAuth2Client.query.get_or_404(id)
owner_access_or_401(oauth2_client.user_id)
form = DeleteOAuth2ClientForm()
if form.validate_on_submit():
if non_match_for_deletion(form.name.data, oauth2_client.client_name):
flash(gettext("Entered name does not match OAuth2 client name"), "danger")
else:
try:
db.session.delete(oauth2_client)
db.session.commit()
flash(gettext("OAuth2 client successfully deleted"), "success")
return redirect(url_for("oauth2_clients"))
except SQLAlchemyError as e:
db.session.rollback()
flash(handleSqlError(e), "danger")
else:
flash_errors(form)
return render_template(
"oauth2_client/delete.html", form=form, oauth2_client=oauth2_client
)
@app.route("/oauth2_client/<int:id>")
@permissions_required("oauth2_client:read")
def oauth2_client(id):
oauth2_client = OAuth2Client.query.get_or_404(id)
owner_access_or_401(oauth2_client.user_id)
return render_template(
"oauth2_client/read.html",
oauth2_client=oauth2_client,
)
@app.route("/oauth2_clients")
@permissions_required("oauth2_client:read")
def oauth2_clients():
oauth2_clients = OAuth2Client.query.filter(
OAuth2Client.user_id == current_user.id
).paginate()
return render_template(
"oauth2_client/list.html",
oauth2_clients=oauth2_clients.items,
pagination=get_pagination_urls(oauth2_clients),
)

View File

@ -0,0 +1,53 @@
from project import app, db
from flask import render_template, redirect, flash, url_for
from flask_babelex import gettext
from flask_security import current_user
from project.models import OAuth2Token
from project.views.utils import (
get_pagination_urls,
handleSqlError,
flash_errors,
)
from project.forms.oauth2_token import RevokeOAuth2TokenForm
from sqlalchemy.exc import SQLAlchemyError
from project.access import owner_access_or_401
@app.route("/oauth2_token/<int:id>/revoke", methods=("GET", "POST"))
def oauth2_token_revoke(id):
oauth2_token = OAuth2Token.query.get_or_404(id)
owner_access_or_401(oauth2_token.user_id)
if oauth2_token.revoked:
return redirect(url_for("oauth2_tokens"))
form = RevokeOAuth2TokenForm()
if form.validate_on_submit():
try:
oauth2_token.revoked = True
db.session.commit()
flash(gettext("OAuth2 token successfully revoked"), "success")
return redirect(url_for("oauth2_tokens"))
except SQLAlchemyError as e:
db.session.rollback()
flash(handleSqlError(e), "danger")
else:
flash_errors(form)
return render_template(
"oauth2_token/revoke.html", form=form, oauth2_token=oauth2_token
)
@app.route("/oauth2_tokens")
def oauth2_tokens():
oauth2_tokens = OAuth2Token.query.filter(
OAuth2Token.user_id == current_user.id
).paginate()
return render_template(
"oauth2_token/list.html",
oauth2_tokens=oauth2_tokens.items,
pagination=get_pagination_urls(oauth2_tokens),
)

View File

@ -5,6 +5,7 @@ apispec-webframeworks==0.5.2
appdirs==1.4.4
argh==0.26.2
attrs==20.3.0
Authlib==0.15.3
Babel==2.9.0
bcrypt==3.2.0
beautifulsoup4==4.9.3
@ -18,6 +19,7 @@ click==7.1.2
colour==0.1.5
coverage==5.3
coveralls==2.2.0
cryptography==3.3.1
distlib==0.3.1
dnspython==2.0.0
docopt==0.6.2

View File

@ -0,0 +1,55 @@
import pytest
from project.api import RestApi
class Psycog2Error(object):
def __init__(self, pgcode):
self.pgcode = pgcode
def test_handle_error_unique(app):
from sqlalchemy.exc import IntegrityError
from psycopg2.errorcodes import UNIQUE_VIOLATION
orig = Psycog2Error(UNIQUE_VIOLATION)
error = IntegrityError("Select", list(), orig)
api = RestApi(app)
(data, code) = api.handle_error(error)
assert code == 400
assert data["name"] == "Unique Violation"
def test_handle_error_httpException(app):
from werkzeug.exceptions import InternalServerError
error = InternalServerError()
api = RestApi(app)
(data, code) = api.handle_error(error)
assert code == 500
def test_handle_error_unprocessableEntity(app):
from werkzeug.exceptions import UnprocessableEntity
from marshmallow import ValidationError
args = {"name": ["Required"]}
validation_error = ValidationError(args)
error = UnprocessableEntity()
error.exc = validation_error
api = RestApi(app)
(data, code) = api.handle_error(error)
assert code == 422
assert data["errors"][0]["field"] == "name"
assert data["errors"][0]["message"] == "Required"
def test_handle_error_unspecificRaises(app):
error = Exception()
api = RestApi(app)
with pytest.raises(Exception):
api.handle_error(error)

47
tests/api/test_fields.py Normal file
View File

@ -0,0 +1,47 @@
def test_numeric_str_serialize(client, seeder, utils):
from project.api.location.schemas import LocationSchema
from project.models import Location
location = Location()
location.street = "Markt 7"
location.postalCode = "38640"
location.city = "Goslar"
location.latitude = 51.9077888
location.longitude = 10.4333312
schema = LocationSchema()
data = schema.dump(location)
assert data["latitude"] == "51.9077888"
assert data["longitude"] == "10.4333312"
def test_numeric_str_deserialize(client, seeder, utils):
from project.api.location.schemas import LocationPostRequestLoadSchema
data = {
"latitude": "51.9077888",
"longitude": "10.4333312",
}
schema = LocationPostRequestLoadSchema()
location = schema.load(data)
assert location.latitude == 51.9077888
assert location.longitude == 10.4333312
def test_numeric_str_deserialize_invalid(client, seeder, utils):
from project.api.location.schemas import LocationPostRequestLoadSchema
import pytest
from marshmallow import ValidationError
data = {
"latitude": "Quatsch",
"longitude": "Quatsch",
}
schema = LocationPostRequestLoadSchema()
with pytest.raises(ValidationError):
schema.load(data)

View File

@ -1,6 +0,0 @@
def test_read(client, seeder, utils):
user_id, admin_unit_id = seeder.setup_base()
image_id = seeder.upsert_default_image()
url = utils.get_url("api_v1_image", id=image_id)
utils.get_ok(url)

View File

@ -1,20 +0,0 @@
def test_read(client, app, db, seeder, utils):
user_id, admin_unit_id = seeder.setup_base()
with app.app_context():
from project.models import Location
location = Location()
location.street = "Markt 7"
location.postalCode = "38640"
location.city = "Goslar"
location.latitude = 51.9077888
location.longitude = 10.4333312
db.session.add(location)
db.session.commit()
location_id = location.id
url = utils.get_url("api_v1_location", id=location_id)
response = utils.get_ok(url)
assert response.json["latitude"] == "51.9077888000000000"

View File

@ -46,6 +46,25 @@ def test_places(client, seeder, utils):
utils.get_ok(url)
def test_places_post(client, seeder, utils, app):
user_id, admin_unit_id = seeder.setup_api_access()
url = utils.get_url("api_v1_organization_place_list", id=admin_unit_id, name="crew")
response = utils.post_json(url, {"name": "Neuer Ort"})
utils.assert_response_created(response)
assert "id" in response.json
with app.app_context():
from project.models import EventPlace
place = (
EventPlace.query.filter(EventPlace.admin_unit_id == admin_unit_id)
.filter(EventPlace.name == "Neuer Ort")
.first()
)
assert place is not None
def test_references_incoming(client, seeder, utils):
user_id, admin_unit_id = seeder.setup_base()
(

View File

@ -4,3 +4,65 @@ def test_read(client, app, db, seeder, utils):
url = utils.get_url("api_v1_place", id=place_id)
utils.get_ok(url)
def test_put(client, seeder, utils, app):
user_id, admin_unit_id = seeder.setup_api_access()
place_id = seeder.upsert_default_event_place(admin_unit_id)
url = utils.get_url("api_v1_place", id=place_id)
response = utils.put_json(url, {"name": "Neuer Name"})
utils.assert_response_no_content(response)
with app.app_context():
from project.models import EventPlace
place = EventPlace.query.get(place_id)
assert place.name == "Neuer Name"
def test_put_nonActiveReturnsUnauthorized(client, seeder, db, utils, app):
user_id, admin_unit_id = seeder.setup_api_access()
place_id = seeder.upsert_default_event_place(admin_unit_id)
with app.app_context():
from project.models import User
user = User.query.get(user_id)
user.active = False
db.session.commit()
url = utils.get_url("api_v1_place", id=place_id)
response = utils.put_json(url, {"name": "Neuer Name"})
utils.assert_response_unauthorized(response)
def test_patch(client, seeder, utils, app):
user_id, admin_unit_id = seeder.setup_api_access()
place_id = seeder.upsert_default_event_place(admin_unit_id)
url = utils.get_url("api_v1_place", id=place_id)
response = utils.patch_json(url, {"description": "Klasse"})
utils.assert_response_no_content(response)
with app.app_context():
from project.models import EventPlace
place = EventPlace.query.get(place_id)
assert place.name == "Meine Crew"
assert place.description == "Klasse"
def test_delete(client, seeder, utils, app):
user_id, admin_unit_id = seeder.setup_api_access()
place_id = seeder.upsert_default_event_place(admin_unit_id)
url = utils.get_url("api_v1_place", id=place_id)
response = utils.delete(url)
utils.assert_response_no_content(response)
with app.app_context():
from project.models import EventPlace
place = EventPlace.query.get(place_id)
assert place is None

View File

@ -8,6 +8,7 @@ def pytest_generate_tests(metafunc):
os.environ["DATABASE_URL"] = os.environ.get(
"TEST_DATABASE_URL", "postgresql://postgres@localhost/gsevpt_tests"
)
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "1"
@pytest.fixture

View File

@ -4,8 +4,9 @@ class Seeder(object):
self._db = db
self._utils = utils
def setup_base(self, admin=False):
def setup_base(self, admin=False, log_in=True):
user_id = self.create_user(admin=admin)
if log_in:
self._utils.login()
admin_unit_id = self.create_admin_unit(user_id)
return (user_id, admin_unit_id)
@ -126,6 +127,45 @@ class Seeder(object):
return organizer_id
def insert_default_oauth2_client(self, user_id):
from project.api import scope_list
from project.models import OAuth2Client
from project.services.oauth2_client import complete_oauth2_client
with self._app.app_context():
client = OAuth2Client()
client.user_id = user_id
complete_oauth2_client(client)
metadata = dict()
metadata["client_name"] = "Mein Client"
metadata["scope"] = " ".join(scope_list)
metadata["grant_types"] = ["authorization_code", "refresh_token"]
metadata["response_types"] = ["code"]
metadata["token_endpoint_auth_method"] = "client_secret_post"
client.set_client_metadata(metadata)
self._db.session.add(client)
self._db.session.commit()
client_id = client.id
return client_id
def setup_api_access(self):
user_id, admin_unit_id = self.setup_base(admin=True)
oauth2_client_id = self.insert_default_oauth2_client(user_id)
with self._app.app_context():
from project.models import OAuth2Client
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
client_id = oauth2_client.client_id
client_secret = oauth2_client.client_secret
scope = oauth2_client.scope
self._utils.authorize(client_id, client_secret, scope)
return (user_id, admin_unit_id)
def create_event(self, admin_unit_id, recurrence_rule=None):
from project.models import Event
from project.services.event import insert_event, upsert_event_category

View File

@ -21,3 +21,16 @@ def test_event_category(client, app, db, seeder):
db.session.commit()
assert event.category is None
def test_oauth2_token(client, app):
from project.models import OAuth2Token
token = OAuth2Token()
token.revoked = True
assert not token.is_refresh_token_active()
token.revoked = False
token.issued_at = 0
token.expires_in = 0
assert not token.is_refresh_token_active()

7
tests/test_oauth2.py Normal file
View File

@ -0,0 +1,7 @@
def test_authorization_code(seeder):
user_id, admin_unit_id = seeder.setup_api_access()
def test_refresh_token(seeder, utils):
user_id, admin_unit_id = seeder.setup_api_access()
utils.refresh_token()

View File

@ -2,12 +2,17 @@ import re
from flask import g, url_for
from sqlalchemy.exc import IntegrityError
from bs4 import BeautifulSoup
from urllib.parse import urlsplit, parse_qs
class UtilActions(object):
def __init__(self, client, app):
self._client = client
self._app = app
self._access_token = None
self._refresh_token = None
self._client_id = None
self._client_secret = None
def register(self, email="test@test.de", password="MeinPasswortIstDasBeste"):
response = self._client.get("/register")
@ -77,14 +82,56 @@ class UtilActions(object):
form = Form(soup.find("form"))
return form.fill(values)
def post_form(self, url, response, values: dict):
data = self.create_form_data(response, values)
def post_form_data(self, url, data: dict):
return self._client.post(url, data=data)
def post_form(self, url, response, values: dict):
data = self.create_form_data(response, values)
return self.post_form_data(url, data=data)
def get_headers(self):
headers = dict()
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
def log_request(self, url):
print(url)
def log_json_request(self, url, data: dict):
self.log_request(url)
print(data)
def log_response(self, response):
print(response.status_code)
print(response.data)
print(response.json)
def post_json(self, url, data: dict):
response = self._client.post(url, json=data)
assert response.content_type == "application/json"
return response.json
self.log_json_request(url, data)
response = self._client.post(url, json=data, headers=self.get_headers())
self.log_response(response)
return response
def put_json(self, url, data: dict):
self.log_json_request(url, data)
response = self._client.put(url, json=data, headers=self.get_headers())
self.log_response(response)
return response
def patch_json(self, url, data: dict):
self.log_json_request(url, data)
response = self._client.patch(url, json=data, headers=self.get_headers())
self.log_response(response)
return response
def delete(self, url):
self.log_request(url)
response = self._client.delete(url, headers=self.get_headers())
self.log_response(response)
return response
def mock_db_commit(self, mocker, orig=None):
mocked_commit = mocker.patch("project.db.session.commit")
@ -105,14 +152,23 @@ class UtilActions(object):
url = url_for(endpoint, **values, _external=False)
return url
def get(self, url):
return self._client.get(url)
def get_ok(self, url):
response = self._client.get(url)
response = self.get(url)
self.assert_response_ok(response)
return response
def assert_response_ok(self, response):
assert response.status_code == 200
def assert_response_created(self, response):
assert response.status_code == 201
def assert_response_no_content(self, response):
assert response.status_code == 204
def get_unauthorized(self, url):
response = self._client.get(url)
self.assert_response_unauthorized(response)
@ -146,3 +202,96 @@ class UtilActions(object):
def assert_response_permission_missing(self, response, endpoint, **values):
self.assert_response_redirect(response, endpoint, **values)
def parse_query_parameters(self, url):
query = urlsplit(url).query
params = parse_qs(query)
return {k: v[0] for k, v in params.items()}
def authorize(self, client_id, client_secret, scope):
# Authorize-Seite öffnen
redirect_uri = self.get_url("swagger_oauth2_redirect")
url = self.get_url(
"authorize",
response_type="code",
client_id=client_id,
scope=scope,
redirect_uri=redirect_uri,
)
response = self.get_ok(url)
# Authorisieren
response = self.post_form(
url,
response,
{},
)
assert response.status_code == 302
assert redirect_uri in response.headers["Location"]
# Code aus der Redirect-Antwort lesen
params = self.parse_query_parameters(response.headers["Location"])
assert "code" in params
code = params["code"]
# Mit dem Code den Access-Token abfragen
token_url = self.get_url("issue_token")
response = self.post_form_data(
token_url,
data={
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "authorization_code",
"scope": scope,
"code": code,
"redirect_uri": redirect_uri,
},
)
self.assert_response_ok(response)
assert response.content_type == "application/json"
assert "access_token" in response.json
assert "expires_in" in response.json
assert "refresh_token" in response.json
assert response.json["scope"] == scope
assert response.json["token_type"] == "Bearer"
self._client_id = client_id
self._client_secret = client_secret
self._access_token = response.json["access_token"]
self._refresh_token = response.json["refresh_token"]
def refresh_token(self):
token_url = self.get_url("issue_token")
response = self.post_form_data(
token_url,
data={
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self._client_id,
"client_secret": self._client_secret,
},
)
self.assert_response_ok(response)
assert response.content_type == "application/json"
assert response.json["token_type"] == "Bearer"
assert "access_token" in response.json
assert "expires_in" in response.json
self._access_token = response.json["access_token"]
def revoke_token(self):
url = self.get_url("revoke_token")
response = self.post_form_data(
url,
data={
"token": self._access_token,
"token_type_hint": "access_token",
"client_id": self._client_id,
"client_secret": self._client_secret,
},
)
self.assert_response_ok(response)

View File

@ -430,7 +430,7 @@ def test_delete_nameDoesNotMatch(client, seeder, utils, app, mocker):
def test_rrule(client, seeder, utils, app):
url = utils.get_url("event_rrule")
json = utils.post_json(
response = utils.post_json(
url,
{
"year": 2020,
@ -440,6 +440,7 @@ def test_rrule(client, seeder, utils, app):
"start": 0,
},
)
json = response.json
assert json["batch"]["batch_size"] == 10

25
tests/views/test_oauth.py Normal file
View File

@ -0,0 +1,25 @@
def test_authorize_unauthorizedRedirects(seeder, utils):
url = utils.get_url("authorize")
response = utils.get(url)
assert response.status_code == 302
assert "login" in response.headers["Location"]
def test_authorize_validateThrowsError(seeder, utils):
seeder.setup_base()
url = utils.get_url("authorize")
response = utils.get(url)
utils.assert_response_error_message(response, b"invalid_grant")
def test_revoke_token(seeder, utils):
seeder.setup_api_access()
utils.revoke_token()
def test_swagger_redirect(utils):
url = utils.get_url("swagger_oauth2_redirect")
response = utils.get(url)
assert response.status_code == 302

View File

@ -0,0 +1,136 @@
import pytest
def test_read(client, seeder, utils):
user_id, admin_unit_id = seeder.setup_base(True)
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
url = utils.get_url("oauth2_client", id=oauth2_client_id)
utils.get_ok(url)
def test_read_notOwner(client, seeder, utils):
user_id = seeder.create_user(email="other@other.de", admin=True)
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
seeder.setup_base(True)
url = utils.get_url("oauth2_client", id=oauth2_client_id)
utils.get_unauthorized(url)
def test_list(client, seeder, utils):
user_id, admin_unit_id = seeder.setup_base(True)
url = utils.get_url("oauth2_clients")
utils.get_ok(url)
@pytest.mark.parametrize("db_error", [True, False])
def test_create_authorization_code(client, app, utils, seeder, mocker, db_error):
from project.api import scope_list
user_id, admin_unit_id = seeder.setup_base(True)
url = utils.get_url("oauth2_client_create")
response = utils.get_ok(url)
if db_error:
utils.mock_db_commit(mocker)
response = utils.post_form(
url,
response,
{
"client_name": "Mein Client",
"scope": scope_list,
},
)
if db_error:
utils.assert_response_db_error(response)
return
with app.app_context():
from project.models import OAuth2Client
oauth2_client = OAuth2Client.query.filter(
OAuth2Client.user_id == user_id
).first()
assert oauth2_client is not None
client_id = oauth2_client.id
utils.assert_response_redirect(response, "oauth2_client", id=client_id)
@pytest.mark.parametrize("db_error", [True, False])
def test_update(client, seeder, utils, app, mocker, db_error):
user_id, admin_unit_id = seeder.setup_base(True)
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
url = utils.get_url("oauth2_client_update", id=oauth2_client_id)
response = utils.get_ok(url)
if db_error:
utils.mock_db_commit(mocker)
response = utils.post_form(
url,
response,
{
"client_name": "Neuer Name",
},
)
if db_error:
utils.assert_response_db_error(response)
return
utils.assert_response_redirect(response, "oauth2_client", id=oauth2_client_id)
with app.app_context():
from project.models import OAuth2Client
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
assert oauth2_client.client_name == "Neuer Name"
@pytest.mark.parametrize("db_error", [True, False])
@pytest.mark.parametrize("non_match", [True, False])
def test_delete(client, seeder, utils, app, mocker, db_error, non_match):
user_id, admin_unit_id = seeder.setup_base(True)
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
url = utils.get_url("oauth2_client_delete", id=oauth2_client_id)
response = utils.get_ok(url)
if db_error:
utils.mock_db_commit(mocker)
form_name = "Mein Client"
if non_match:
form_name = "Falscher Name"
response = utils.post_form(
url,
response,
{
"name": form_name,
},
)
if non_match:
utils.assert_response_error_message(response)
return
if db_error:
utils.assert_response_db_error(response)
return
utils.assert_response_redirect(response, "oauth2_clients")
with app.app_context():
from project.models import OAuth2Client
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
assert oauth2_client is None

View File

@ -0,0 +1,47 @@
import pytest
def test_list(client, seeder, utils):
user_id, admin_unit_id = seeder.setup_api_access()
url = utils.get_url("oauth2_tokens")
utils.get_ok(url)
@pytest.mark.parametrize("db_error", [True, False])
def test_revoke(client, seeder, utils, app, mocker, db_error):
user_id, admin_unit_id = seeder.setup_api_access()
with app.app_context():
from project.models import OAuth2Token
oauth2_token = OAuth2Token.query.filter(OAuth2Token.user_id == user_id).first()
oauth2_token_id = oauth2_token.id
url = utils.get_url("oauth2_token_revoke", id=oauth2_token_id)
response = utils.get_ok(url)
if db_error:
utils.mock_db_commit(mocker)
response = utils.post_form(
url,
response,
{},
)
if db_error:
utils.assert_response_db_error(response)
return
utils.assert_response_redirect(response, "oauth2_tokens")
with app.app_context():
from project.models import OAuth2Token
oauth2_token = OAuth2Token.query.get(oauth2_token_id)
assert oauth2_token.revoked
# Kann nicht zweimal revoked werden
response = utils.get(url)
utils.assert_response_redirect(response, "oauth2_tokens")