mirror of
https://github.com/lucaspalomodevelop/eventcally.git
synced 2026-03-13 00:07:22 +00:00
API Write Access with OAuth2 #104
This commit is contained in:
parent
bc9d2aae3c
commit
6c2384e678
@ -1,2 +0,0 @@
|
||||
[run]
|
||||
relative_files = True
|
||||
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@ -5,7 +5,7 @@
|
||||
"python.linting.pylintEnabled": false,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.testing.pytestArgs": [
|
||||
"tests"
|
||||
"tests", "--capture=sys"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.nosetestsEnabled": false,
|
||||
|
||||
91
migrations/versions/ddb85cb1c21e_.py
Normal file
91
migrations/versions/ddb85cb1c21e_.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: ddb85cb1c21e
|
||||
Revises: b1a6e7630185
|
||||
Create Date: 2021-02-02 15:17:21.988363
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy_utils
|
||||
from project import dbtypes
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ddb85cb1c21e"
|
||||
down_revision = "b1a6e7630185"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"oauth2_client",
|
||||
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||
sa.Column("client_secret", sa.String(length=120), nullable=True),
|
||||
sa.Column("client_id_issued_at", sa.Integer(), nullable=False),
|
||||
sa.Column("client_secret_expires_at", sa.Integer(), nullable=False),
|
||||
sa.Column("client_metadata", sa.Text(), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth2_client_client_id"), "oauth2_client", ["client_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth2_code",
|
||||
sa.Column("code", sa.String(length=120), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||
sa.Column("redirect_uri", sa.Text(), nullable=True),
|
||||
sa.Column("response_type", sa.Text(), nullable=True),
|
||||
sa.Column("scope", sa.Text(), nullable=True),
|
||||
sa.Column("nonce", sa.Text(), nullable=True),
|
||||
sa.Column("auth_time", sa.Integer(), nullable=False),
|
||||
sa.Column("code_challenge", sa.Text(), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=48), nullable=True),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("code"),
|
||||
)
|
||||
op.create_table(
|
||||
"oauth2_token",
|
||||
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||
sa.Column("token_type", sa.String(length=40), nullable=True),
|
||||
sa.Column("access_token", sa.String(length=255), nullable=False),
|
||||
sa.Column("refresh_token", sa.String(length=255), nullable=True),
|
||||
sa.Column("scope", sa.Text(), nullable=True),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=True),
|
||||
sa.Column("issued_at", sa.Integer(), nullable=False),
|
||||
sa.Column("expires_in", sa.Integer(), nullable=False),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("access_token"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth2_token_refresh_token"),
|
||||
"oauth2_token",
|
||||
["refresh_token"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
"eventplace_name_admin_unit_id", "eventplace", ["name", "admin_unit_id"]
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("eventplace_name_admin_unit_id", "eventplace", type_="unique")
|
||||
op.drop_index(op.f("ix_oauth2_token_refresh_token"), table_name="oauth2_token")
|
||||
op.drop_table("oauth2_token")
|
||||
op.drop_table("oauth2_code")
|
||||
op.drop_index(op.f("ix_oauth2_client_client_id"), table_name="oauth2_client")
|
||||
op.drop_table("oauth2_client")
|
||||
# ### end Alembic commands ###
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from flask import Flask
|
||||
from flask import Flask, url_for, redirect, request, jsonify
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_security import (
|
||||
Security,
|
||||
@ -10,12 +10,8 @@ from flask_cors import CORS
|
||||
from flask_qrcode import QRcode
|
||||
from flask_mail import Mail, email_dispatched
|
||||
from flask_migrate import Migrate
|
||||
from flask_marshmallow import Marshmallow
|
||||
from flask_restful import Api
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask_apispec.extension import FlaskApiSpec
|
||||
from flask_gzip import Gzip
|
||||
from webargs import flaskparser
|
||||
|
||||
# Create app
|
||||
app = Flask(__name__)
|
||||
@ -59,25 +55,6 @@ babel = Babel(app)
|
||||
# cors
|
||||
cors = CORS(app, resources={r"/api/*", "/swagger/"})
|
||||
|
||||
# API
|
||||
rest_api = Api(app, "/api/v1")
|
||||
marshmallow = Marshmallow(app)
|
||||
marshmallow_plugin = MarshmallowPlugin()
|
||||
app.config.update(
|
||||
{
|
||||
"APISPEC_SPEC": APISpec(
|
||||
title="Oveda API",
|
||||
version="0.1.0",
|
||||
plugins=[marshmallow_plugin],
|
||||
openapi_version="2.0",
|
||||
info=dict(
|
||||
description="This API provides endpoints to interact with the Oveda data. At the moment, there is no authorization needed."
|
||||
),
|
||||
),
|
||||
}
|
||||
)
|
||||
api_docs = FlaskApiSpec(app)
|
||||
|
||||
# Mail
|
||||
mail_server = os.getenv("MAIL_SERVER")
|
||||
|
||||
@ -108,6 +85,9 @@ if app.config["MAIL_SUPPRESS_SEND"]:
|
||||
db = SQLAlchemy(app)
|
||||
migrate = Migrate(app, db)
|
||||
|
||||
# API
|
||||
from project.api import RestApi
|
||||
|
||||
# qr code
|
||||
QRcode(app)
|
||||
|
||||
@ -123,6 +103,13 @@ from project.forms.security import ExtendedRegisterForm
|
||||
user_datastore = SQLAlchemySessionUserDatastore(db.session, User, Role)
|
||||
security = Security(app, user_datastore, register_form=ExtendedRegisterForm)
|
||||
|
||||
# OAuth2
|
||||
from project.oauth2 import config_oauth
|
||||
|
||||
config_oauth(app)
|
||||
|
||||
# Init misc modules
|
||||
|
||||
from project import i10n
|
||||
from project import jinja_filters
|
||||
from project import init_data
|
||||
@ -142,6 +129,9 @@ from project.views import (
|
||||
image,
|
||||
manage,
|
||||
organizer,
|
||||
oauth,
|
||||
oauth2_client,
|
||||
oauth2_token,
|
||||
planing,
|
||||
reference,
|
||||
reference_request,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from flask import abort
|
||||
from flask_login import login_user
|
||||
from flask_security import current_user
|
||||
from flask_security.utils import FsPermNeed
|
||||
from flask_principal import Permission
|
||||
@ -12,6 +13,20 @@ def has_current_user_permission(permission):
|
||||
return user_perm.can()
|
||||
|
||||
|
||||
def has_owner_access(user_id):
|
||||
return user_id == current_user.id
|
||||
|
||||
|
||||
def owner_access_or_401(user_id):
|
||||
if not has_owner_access(user_id):
|
||||
abort(401)
|
||||
|
||||
|
||||
def login_api_user_or_401(user):
|
||||
if not login_user(user):
|
||||
abort(401)
|
||||
|
||||
|
||||
def has_admin_unit_member_role(admin_unit_member, role_name):
|
||||
for role in admin_unit_member.roles:
|
||||
if role.name == role_name:
|
||||
|
||||
@ -1,4 +1,107 @@
|
||||
from project import rest_api, api_docs
|
||||
from flask_restful import Api
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||
from werkzeug.exceptions import HTTPException, UnprocessableEntity
|
||||
from marshmallow import ValidationError
|
||||
from project.utils import get_localized_scope
|
||||
from project import app
|
||||
from flask_marshmallow import Marshmallow
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask_apispec.extension import FlaskApiSpec
|
||||
|
||||
|
||||
class RestApi(Api):
|
||||
def handle_error(self, err):
|
||||
from project.api.schemas import (
|
||||
ErrorResponseSchema,
|
||||
UnprocessableEntityResponseSchema,
|
||||
)
|
||||
|
||||
schema = None
|
||||
data = {}
|
||||
code = 500
|
||||
|
||||
if (
|
||||
isinstance(err, IntegrityError)
|
||||
and err.orig
|
||||
and err.orig.pgcode == UNIQUE_VIOLATION
|
||||
):
|
||||
data["name"] = "Unique Violation"
|
||||
data[
|
||||
"message"
|
||||
] = "An entry with the entered values already exists. Duplicate entries are not allowed."
|
||||
code = 400
|
||||
schema = ErrorResponseSchema()
|
||||
elif isinstance(err, HTTPException):
|
||||
data["name"] = err.name
|
||||
data["message"] = err.description
|
||||
code = err.code
|
||||
|
||||
if (
|
||||
isinstance(err, UnprocessableEntity)
|
||||
and err.exc
|
||||
and isinstance(err.exc, ValidationError)
|
||||
):
|
||||
data["name"] = err.name
|
||||
data["message"] = err.description
|
||||
code = err.code
|
||||
schema = UnprocessableEntityResponseSchema()
|
||||
|
||||
if (
|
||||
getattr(err.exc, "args", None)
|
||||
and isinstance(err.exc.args, tuple)
|
||||
and len(err.exc.args) > 0
|
||||
):
|
||||
arg = err.exc.args[0]
|
||||
if isinstance(arg, dict):
|
||||
errors = []
|
||||
for field, messages in arg.items():
|
||||
if isinstance(messages, list):
|
||||
for message in messages:
|
||||
error = {"field": field, "message": message}
|
||||
errors.append(error)
|
||||
|
||||
if len(errors) > 0:
|
||||
data["errors"] = errors
|
||||
else:
|
||||
schema = ErrorResponseSchema()
|
||||
|
||||
# Call default error handler that propagates error further
|
||||
try:
|
||||
super().handle_error(err)
|
||||
except Exception:
|
||||
if not schema:
|
||||
raise
|
||||
|
||||
return schema.dump(data), code
|
||||
|
||||
|
||||
scope_list = [
|
||||
"organizer:write",
|
||||
"place:write",
|
||||
"event:write",
|
||||
]
|
||||
scopes = {k: get_localized_scope(k) for v, k in enumerate(scope_list)}
|
||||
|
||||
rest_api = RestApi(app, "/api/v1", catch_all_404s=True)
|
||||
marshmallow = Marshmallow(app)
|
||||
marshmallow_plugin = MarshmallowPlugin()
|
||||
app.config.update(
|
||||
{
|
||||
"APISPEC_SPEC": APISpec(
|
||||
title="Oveda API",
|
||||
version="0.1.0",
|
||||
plugins=[marshmallow_plugin],
|
||||
openapi_version="2.0",
|
||||
info=dict(
|
||||
description="This API provides endpoints to interact with the Oveda data."
|
||||
),
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
api_docs = FlaskApiSpec(app)
|
||||
|
||||
|
||||
def enum_to_properties(self, field, **kwargs):
|
||||
@ -17,8 +120,6 @@ def add_api_resource(resource, url, endpoint):
|
||||
api_docs.register(resource, endpoint=endpoint)
|
||||
|
||||
|
||||
from project import marshmallow_plugin
|
||||
|
||||
marshmallow_plugin.converter.add_attribute_function(enum_to_properties)
|
||||
|
||||
import project.api.event.resources
|
||||
@ -26,8 +127,6 @@ import project.api.event_category.resources
|
||||
import project.api.event_date.resources
|
||||
import project.api.event_reference.resources
|
||||
import project.api.dump.resources
|
||||
import project.api.image.resources
|
||||
import project.api.location.resources
|
||||
import project.api.organization.resources
|
||||
import project.api.organizer.resources
|
||||
import project.api.place.resources
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from marshmallow import fields
|
||||
from project.api.event.schemas import EventDumpSchema
|
||||
from project.api.place.schemas import PlaceDumpSchema
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from marshmallow import fields, validate
|
||||
from marshmallow_enum import EnumField
|
||||
from project.models import (
|
||||
@ -10,7 +10,7 @@ from project.models import (
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
from project.api.organization.schemas import OrganizationRefSchema
|
||||
from project.api.organizer.schemas import OrganizerRefSchema
|
||||
from project.api.image.schemas import ImageRefSchema
|
||||
from project.api.image.schemas import ImageSchema
|
||||
from project.api.place.schemas import PlaceRefSchema, PlaceSearchItemSchema
|
||||
from project.api.event_category.schemas import (
|
||||
EventCategoryRefSchema,
|
||||
@ -55,7 +55,7 @@ class EventSchema(EventBaseSchema):
|
||||
organization = fields.Nested(OrganizationRefSchema, attribute="admin_unit")
|
||||
organizer = fields.Nested(OrganizerRefSchema)
|
||||
place = fields.Nested(PlaceRefSchema, attribute="event_place")
|
||||
photo = fields.Nested(ImageRefSchema)
|
||||
photo = fields.Nested(ImageSchema)
|
||||
categories = fields.List(fields.Nested(EventCategoryRefSchema))
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ class EventSearchItemSchema(EventRefSchema):
|
||||
start = marshmallow.auto_field()
|
||||
end = marshmallow.auto_field()
|
||||
recurrence_rule = marshmallow.auto_field()
|
||||
photo = fields.Nested(ImageRefSchema)
|
||||
photo = fields.Nested(ImageSchema)
|
||||
place = fields.Nested(PlaceSearchItemSchema, attribute="event_place")
|
||||
status = EnumField(EventStatus)
|
||||
booked_up = marshmallow.auto_field()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from project.models import EventCategory
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from marshmallow import fields
|
||||
from project.models import EventDate
|
||||
from project.api.event.schemas import (
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from project.models import EventReference
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
from project.api.event.schemas import EventRefSchema
|
||||
|
||||
15
project/api/fields.py
Normal file
15
project/api/fields.py
Normal file
@ -0,0 +1,15 @@
|
||||
from marshmallow import fields, ValidationError
|
||||
|
||||
|
||||
class NumericStr(fields.String):
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
return str(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError as error:
|
||||
raise ValidationError("Must be a numeric value.") from error
|
||||
@ -1,15 +0,0 @@
|
||||
from project.api import add_api_resource
|
||||
from flask_apispec import marshal_with, doc
|
||||
from project.api.resources import BaseResource
|
||||
from project.api.image.schemas import ImageSchema
|
||||
from project.models import Image
|
||||
|
||||
|
||||
class ImageResource(BaseResource):
|
||||
@doc(summary="Get image", tags=["Images"])
|
||||
@marshal_with(ImageSchema)
|
||||
def get(self, id):
|
||||
return Image.query.get_or_404(id)
|
||||
|
||||
|
||||
add_api_resource(ImageResource, "/images/<int:id>", "api_v1_image")
|
||||
@ -1,4 +1,4 @@
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from project.models import Image
|
||||
|
||||
|
||||
@ -10,8 +10,6 @@ class ImageIdSchema(marshmallow.SQLAlchemySchema):
|
||||
|
||||
|
||||
class ImageBaseSchema(ImageIdSchema):
|
||||
created_at = marshmallow.auto_field()
|
||||
updated_at = marshmallow.auto_field()
|
||||
copyright_text = marshmallow.auto_field()
|
||||
|
||||
|
||||
@ -27,13 +25,3 @@ class ImageSchema(ImageBaseSchema):
|
||||
|
||||
class ImageDumpSchema(ImageBaseSchema):
|
||||
pass
|
||||
|
||||
|
||||
class ImageRefSchema(ImageIdSchema):
|
||||
image_url = marshmallow.URLFor(
|
||||
"image",
|
||||
values=dict(id="<id>", s=500),
|
||||
metadata={
|
||||
"description": "Append query arguments w for width, h for height or s for size(width and height)."
|
||||
},
|
||||
)
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
from project.api import add_api_resource
|
||||
from flask_apispec import marshal_with, doc
|
||||
from project.api.resources import BaseResource
|
||||
from project.api.location.schemas import LocationSchema
|
||||
from project.models import Location
|
||||
|
||||
|
||||
class LocationResource(BaseResource):
|
||||
@doc(summary="Get location", tags=["Locations"])
|
||||
@marshal_with(LocationSchema)
|
||||
def get(self, id):
|
||||
return Location.query.get_or_404(id)
|
||||
|
||||
|
||||
add_api_resource(LocationResource, "/locations/<int:id>", "api_v1_location")
|
||||
@ -1,38 +1,65 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from marshmallow import fields, validate
|
||||
from project.api import marshmallow
|
||||
from project.models import Location
|
||||
from project.api.fields import NumericStr
|
||||
|
||||
|
||||
class LocationIdSchema(marshmallow.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = Location
|
||||
|
||||
id = marshmallow.auto_field()
|
||||
|
||||
|
||||
class LocationSchema(LocationIdSchema):
|
||||
created_at = marshmallow.auto_field()
|
||||
updated_at = marshmallow.auto_field()
|
||||
street = marshmallow.auto_field()
|
||||
postalCode = marshmallow.auto_field()
|
||||
city = marshmallow.auto_field()
|
||||
state = marshmallow.auto_field()
|
||||
country = marshmallow.auto_field()
|
||||
longitude = fields.Str()
|
||||
latitude = fields.Str()
|
||||
longitude = NumericStr()
|
||||
latitude = NumericStr()
|
||||
|
||||
|
||||
class LocationDumpSchema(LocationSchema):
|
||||
pass
|
||||
|
||||
|
||||
class LocationRefSchema(LocationIdSchema):
|
||||
class LocationSearchItemSchema(LocationSchema):
|
||||
pass
|
||||
|
||||
|
||||
class LocationSearchItemSchema(LocationRefSchema):
|
||||
class LocationPostRequestSchema(marshmallow.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = Location
|
||||
|
||||
longitude = fields.Str()
|
||||
latitude = fields.Str()
|
||||
street = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||
postalCode = fields.Str(validate=validate.Length(max=10), missing=None)
|
||||
city = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||
state = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||
country = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||
longitude = NumericStr(validate=validate.Range(-180, 180), missing=None)
|
||||
latitude = NumericStr(validate=validate.Range(-90, 90), missing=None)
|
||||
|
||||
|
||||
class LocationPostRequestLoadSchema(LocationPostRequestSchema):
|
||||
class Meta:
|
||||
model = Location
|
||||
load_instance = True
|
||||
|
||||
|
||||
class LocationPatchRequestSchema(marshmallow.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = Location
|
||||
|
||||
street = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||
postalCode = fields.Str(validate=validate.Length(max=10), allow_none=True)
|
||||
city = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||
state = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||
country = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||
longitude = NumericStr(validate=validate.Range(-180, 180), allow_none=True)
|
||||
latitude = NumericStr(validate=validate.Range(-90, 90), allow_none=True)
|
||||
|
||||
|
||||
class LocationPatchRequestLoadSchema(LocationPatchRequestSchema):
|
||||
class Meta:
|
||||
model = Location
|
||||
load_instance = True
|
||||
|
||||
@ -27,7 +27,13 @@ from project.services.reference import (
|
||||
get_reference_incoming_query,
|
||||
get_reference_outgoing_query,
|
||||
)
|
||||
from project.api.place.schemas import PlaceListRequestSchema, PlaceListResponseSchema
|
||||
from project.api.place.schemas import (
|
||||
PlaceListRequestSchema,
|
||||
PlaceListResponseSchema,
|
||||
PlaceIdSchema,
|
||||
PlacePostRequestSchema,
|
||||
PlacePostRequestLoadSchema,
|
||||
)
|
||||
from project.services.event import get_event_dates_query, get_events_query
|
||||
from project.services.event_search import EventSearchParams
|
||||
from project.services.admin_unit import (
|
||||
@ -35,6 +41,14 @@ from project.services.admin_unit import (
|
||||
get_organizer_query,
|
||||
get_place_query,
|
||||
)
|
||||
from project.oauth2 import require_oauth
|
||||
from authlib.integrations.flask_oauth2 import current_token
|
||||
from project import db
|
||||
from project.access import (
|
||||
access_or_401,
|
||||
get_admin_unit_for_manage_or_404,
|
||||
login_api_user_or_401,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationResource(BaseResource):
|
||||
@ -113,6 +127,26 @@ class OrganizationPlaceListResource(BaseResource):
|
||||
pagination = get_place_query(admin_unit.id, name).paginate()
|
||||
return pagination
|
||||
|
||||
@doc(
|
||||
summary="Add new place",
|
||||
tags=["Organizations", "Places"],
|
||||
security=[{"oauth2": ["place:write"]}],
|
||||
)
|
||||
@use_kwargs(PlacePostRequestSchema, location="json")
|
||||
@marshal_with(PlaceIdSchema, 201)
|
||||
@require_oauth("place:write")
|
||||
def post(self, id, **kwargs):
|
||||
login_api_user_or_401(current_token.user)
|
||||
admin_unit = get_admin_unit_for_manage_or_404(id)
|
||||
access_or_401(admin_unit, "place:create")
|
||||
|
||||
place = PlacePostRequestLoadSchema().load(kwargs, session=db.session)
|
||||
place.admin_unit_id = admin_unit.id
|
||||
db.session.add(place)
|
||||
db.session.commit()
|
||||
|
||||
return place, 201
|
||||
|
||||
|
||||
class OrganizationIncomingEventReferenceListResource(BaseResource):
|
||||
@doc(
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from project.models import AdminUnit
|
||||
from project.api.location.schemas import LocationRefSchema
|
||||
from project.api.image.schemas import ImageRefSchema
|
||||
from project.api.location.schemas import LocationSchema
|
||||
from project.api.image.schemas import ImageSchema
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
|
||||
|
||||
@ -25,8 +25,8 @@ class OrganizationBaseSchema(OrganizationIdSchema):
|
||||
|
||||
|
||||
class OrganizationSchema(OrganizationBaseSchema):
|
||||
location = fields.Nested(LocationRefSchema)
|
||||
logo = fields.Nested(ImageRefSchema)
|
||||
location = fields.Nested(LocationSchema)
|
||||
logo = fields.Nested(ImageSchema)
|
||||
|
||||
|
||||
class OrganizationDumpSchema(OrganizationBaseSchema):
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from project.models import EventOrganizer
|
||||
from project.api.location.schemas import LocationRefSchema
|
||||
from project.api.image.schemas import ImageRefSchema
|
||||
from project.api.location.schemas import LocationSchema
|
||||
from project.api.image.schemas import ImageSchema
|
||||
from project.api.organization.schemas import OrganizationRefSchema
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
|
||||
@ -25,8 +25,8 @@ class OrganizerBaseSchema(OrganizerIdSchema):
|
||||
|
||||
|
||||
class OrganizerSchema(OrganizerBaseSchema):
|
||||
location = fields.Nested(LocationRefSchema)
|
||||
logo = fields.Nested(ImageRefSchema)
|
||||
location = fields.Nested(LocationSchema)
|
||||
logo = fields.Nested(ImageSchema)
|
||||
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
from project.api import add_api_resource
|
||||
from flask_apispec import marshal_with, doc
|
||||
from flask import make_response
|
||||
from flask_apispec import marshal_with, doc, use_kwargs
|
||||
from project.api.resources import BaseResource
|
||||
from project.api.place.schemas import PlaceSchema
|
||||
from project.api.place.schemas import (
|
||||
PlaceSchema,
|
||||
PlacePostRequestSchema,
|
||||
PlacePostRequestLoadSchema,
|
||||
PlacePatchRequestSchema,
|
||||
PlacePatchRequestLoadSchema,
|
||||
)
|
||||
from project.models import EventPlace
|
||||
from project.oauth2 import require_oauth
|
||||
from authlib.integrations.flask_oauth2 import current_token
|
||||
from project import db
|
||||
from project.access import access_or_401, login_api_user_or_401
|
||||
|
||||
|
||||
class PlaceResource(BaseResource):
|
||||
@ -11,5 +22,54 @@ class PlaceResource(BaseResource):
|
||||
def get(self, id):
|
||||
return EventPlace.query.get_or_404(id)
|
||||
|
||||
@doc(
|
||||
summary="Update place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
|
||||
)
|
||||
@use_kwargs(PlacePostRequestSchema, location="json")
|
||||
@marshal_with(None, 204)
|
||||
@require_oauth("place:write")
|
||||
def put(self, id, **kwargs):
|
||||
login_api_user_or_401(current_token.user)
|
||||
place = EventPlace.query.get_or_404(id)
|
||||
access_or_401(place.adminunit, "place:update")
|
||||
|
||||
place = PlacePostRequestLoadSchema().load(
|
||||
kwargs, session=db.session, instance=place
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
return make_response("", 204)
|
||||
|
||||
@doc(summary="Patch place", tags=["Places"], security=[{"oauth2": ["place:write"]}])
|
||||
@use_kwargs(PlacePatchRequestSchema, location="json")
|
||||
@marshal_with(None, 204)
|
||||
@require_oauth("place:write")
|
||||
def patch(self, id, **kwargs):
|
||||
login_api_user_or_401(current_token.user)
|
||||
place = EventPlace.query.get_or_404(id)
|
||||
access_or_401(place.adminunit, "place:update")
|
||||
|
||||
place = PlacePatchRequestLoadSchema().load(
|
||||
kwargs, session=db.session, instance=place
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
return make_response("", 204)
|
||||
|
||||
@doc(
|
||||
summary="Delete place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
|
||||
)
|
||||
@marshal_with(None, 204)
|
||||
@require_oauth("place:write")
|
||||
def delete(self, id):
|
||||
login_api_user_or_401(current_token.user)
|
||||
place = EventPlace.query.get_or_404(id)
|
||||
access_or_401(place.adminunit, "place:delete")
|
||||
|
||||
db.session.delete(place)
|
||||
db.session.commit()
|
||||
|
||||
return make_response("", 204)
|
||||
|
||||
|
||||
add_api_resource(PlaceResource, "/places/<int:id>", "api_v1_place")
|
||||
|
||||
@ -1,8 +1,15 @@
|
||||
from marshmallow import fields
|
||||
from project import marshmallow
|
||||
from marshmallow import fields, validate
|
||||
from project.api import marshmallow
|
||||
from project.models import EventPlace
|
||||
from project.api.image.schemas import ImageRefSchema
|
||||
from project.api.location.schemas import LocationRefSchema, LocationSearchItemSchema
|
||||
from project.api.image.schemas import ImageSchema
|
||||
from project.api.location.schemas import (
|
||||
LocationSchema,
|
||||
LocationSearchItemSchema,
|
||||
LocationPostRequestSchema,
|
||||
LocationPostRequestLoadSchema,
|
||||
LocationPatchRequestSchema,
|
||||
LocationPatchRequestLoadSchema,
|
||||
)
|
||||
from project.api.organization.schemas import OrganizationRefSchema
|
||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||
|
||||
@ -23,8 +30,8 @@ class PlaceBaseSchema(PlaceIdSchema):
|
||||
|
||||
|
||||
class PlaceSchema(PlaceBaseSchema):
|
||||
location = fields.Nested(LocationRefSchema)
|
||||
photo = fields.Nested(ImageRefSchema)
|
||||
location = fields.Nested(LocationSchema)
|
||||
photo = fields.Nested(ImageSchema)
|
||||
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
||||
|
||||
|
||||
@ -55,3 +62,41 @@ class PlaceListResponseSchema(PaginationResponseSchema):
|
||||
items = fields.List(
|
||||
fields.Nested(PlaceRefSchema), metadata={"description": "Places"}
|
||||
)
|
||||
|
||||
|
||||
class PlacePostRequestSchema(marshmallow.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = EventPlace
|
||||
|
||||
name = fields.Str(required=True, validate=validate.Length(min=3, max=255))
|
||||
url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None)
|
||||
description = fields.Str(missing=None)
|
||||
location = fields.Nested(LocationPostRequestSchema, missing=None)
|
||||
|
||||
|
||||
class PlacePostRequestLoadSchema(PlacePostRequestSchema):
|
||||
class Meta:
|
||||
model = EventPlace
|
||||
load_instance = True
|
||||
|
||||
location = fields.Nested(LocationPostRequestLoadSchema, missing=None)
|
||||
|
||||
|
||||
class PlacePatchRequestSchema(marshmallow.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = EventPlace
|
||||
|
||||
name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True)
|
||||
url = fields.Str(
|
||||
validate=[validate.URL(), validate.Length(max=255)], allow_none=True
|
||||
)
|
||||
description = fields.Str(allow_none=True)
|
||||
location = fields.Nested(LocationPatchRequestSchema, allow_none=True)
|
||||
|
||||
|
||||
class PlacePatchRequestLoadSchema(PlacePatchRequestSchema):
|
||||
class Meta:
|
||||
model = EventPlace
|
||||
load_instance = True
|
||||
|
||||
location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from flask import request
|
||||
from flask_apispec import marshal_with
|
||||
from flask_apispec.views import MethodResource
|
||||
from functools import wraps
|
||||
from project.api.schemas import ErrorResponseSchema, UnprocessableEntityResponseSchema
|
||||
|
||||
|
||||
def etag_cache(func):
|
||||
@ -13,5 +15,7 @@ def etag_cache(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
@marshal_with(ErrorResponseSchema, 400, "Bad Request")
|
||||
@marshal_with(UnprocessableEntityResponseSchema, 422, "Unprocessable Entity")
|
||||
class BaseResource(MethodResource):
|
||||
decorators = [etag_cache]
|
||||
|
||||
@ -1,7 +1,21 @@
|
||||
from project import marshmallow
|
||||
from project.api import marshmallow
|
||||
from marshmallow import fields, validate
|
||||
|
||||
|
||||
class ErrorResponseSchema(marshmallow.Schema):
|
||||
name = fields.Str()
|
||||
message = fields.Str()
|
||||
|
||||
|
||||
class UnprocessableEntityErrorSchema(marshmallow.Schema):
|
||||
field = fields.Str()
|
||||
message = fields.Str()
|
||||
|
||||
|
||||
class UnprocessableEntityResponseSchema(ErrorResponseSchema):
|
||||
errors = fields.List(fields.Nested(UnprocessableEntityErrorSchema))
|
||||
|
||||
|
||||
class PaginationRequestSchema(marshmallow.Schema):
|
||||
page = fields.Integer(
|
||||
required=False,
|
||||
|
||||
92
project/forms/oauth2_client.py
Normal file
92
project/forms/oauth2_client.py
Normal file
@ -0,0 +1,92 @@
|
||||
from flask_wtf import FlaskForm
|
||||
from flask_babelex import lazy_gettext
|
||||
from wtforms import StringField, TextAreaField, SubmitField, SelectField
|
||||
from wtforms.validators import Optional, DataRequired
|
||||
from project.forms.widgets import MultiCheckboxField
|
||||
from project.api import scopes
|
||||
from project.utils import split_by_crlf
|
||||
import os
|
||||
|
||||
|
||||
class BaseOAuth2ClientForm(FlaskForm):
|
||||
client_name = StringField(lazy_gettext("Client name"), validators=[DataRequired()])
|
||||
redirect_uris = TextAreaField(
|
||||
lazy_gettext("Redirect URIs"), validators=[Optional()]
|
||||
)
|
||||
grant_types = MultiCheckboxField(
|
||||
lazy_gettext("Grant types"),
|
||||
validators=[DataRequired()],
|
||||
choices=[
|
||||
("authorization_code", lazy_gettext("Authorization Code")),
|
||||
("refresh_token", lazy_gettext("Refresh Token")),
|
||||
],
|
||||
default=["authorization_code", "refresh_token"],
|
||||
)
|
||||
response_types = MultiCheckboxField(
|
||||
lazy_gettext("Response types"),
|
||||
validators=[DataRequired()],
|
||||
choices=[
|
||||
("code", "code"),
|
||||
],
|
||||
default=["code"],
|
||||
)
|
||||
scope = MultiCheckboxField(
|
||||
lazy_gettext("Scopes"),
|
||||
validators=[DataRequired()],
|
||||
choices=[(k, k) for k, v in scopes.items()],
|
||||
)
|
||||
token_endpoint_auth_method = SelectField(
|
||||
lazy_gettext("Token endpoint auth method"),
|
||||
validators=[DataRequired()],
|
||||
choices=[
|
||||
("client_secret_post", lazy_gettext("Client secret post")),
|
||||
("client_secret_basic", lazy_gettext("Client secret basic")),
|
||||
],
|
||||
)
|
||||
|
||||
submit = SubmitField(lazy_gettext("Save"))
|
||||
|
||||
def populate_obj(self, obj):
|
||||
meta_keys = [
|
||||
"client_name",
|
||||
"client_uri",
|
||||
"grant_types",
|
||||
"redirect_uris",
|
||||
"response_types",
|
||||
"scope",
|
||||
"token_endpoint_auth_method",
|
||||
]
|
||||
metadata = dict()
|
||||
for name, field in self._fields.items():
|
||||
if name in meta_keys:
|
||||
if name == "redirect_uris":
|
||||
metadata[name] = split_by_crlf(field.data)
|
||||
elif name == "scope":
|
||||
metadata[name] = " ".join(field.data)
|
||||
else:
|
||||
metadata[name] = field.data
|
||||
else:
|
||||
field.populate_obj(obj, name)
|
||||
obj.set_client_metadata(metadata)
|
||||
|
||||
def process(self, formdata=None, obj=None, data=None, **kwargs):
|
||||
super().process(formdata, obj, data, **kwargs)
|
||||
|
||||
if not obj:
|
||||
return
|
||||
|
||||
self.redirect_uris.data = os.linesep.join(obj.redirect_uris)
|
||||
self.scope.data = obj.scope.split(" ")
|
||||
|
||||
|
||||
class CreateOAuth2ClientForm(BaseOAuth2ClientForm):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateOAuth2ClientForm(BaseOAuth2ClientForm):
|
||||
pass
|
||||
|
||||
|
||||
class DeleteOAuth2ClientForm(FlaskForm):
|
||||
submit = SubmitField(lazy_gettext("Delete OAuth2 client"))
|
||||
name = StringField(lazy_gettext("Name"), validators=[DataRequired()])
|
||||
7
project/forms/oauth2_token.py
Normal file
7
project/forms/oauth2_token.py
Normal file
@ -0,0 +1,7 @@
|
||||
from flask_wtf import FlaskForm
|
||||
from flask_babelex import lazy_gettext
|
||||
from wtforms import SubmitField
|
||||
|
||||
|
||||
class RevokeOAuth2TokenForm(FlaskForm):
|
||||
submit = SubmitField(lazy_gettext("Revoke OAuth2 token"))
|
||||
@ -1,7 +1,9 @@
|
||||
from flask_security.forms import RegisterForm, EqualTo, get_form_field_label
|
||||
from wtforms import BooleanField, PasswordField
|
||||
from wtforms import BooleanField, PasswordField, SubmitField
|
||||
from wtforms.validators import DataRequired
|
||||
from project.forms.common import get_accept_tos_markup
|
||||
from flask_wtf import FlaskForm
|
||||
from flask_babelex import lazy_gettext
|
||||
|
||||
|
||||
class ExtendedRegisterForm(RegisterForm):
|
||||
@ -20,3 +22,8 @@ class ExtendedRegisterForm(RegisterForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ExtendedRegisterForm, self).__init__(*args, **kwargs)
|
||||
self._fields["accept_tos"].label.text = get_accept_tos_markup()
|
||||
|
||||
|
||||
class AuthorizeForm(FlaskForm):
|
||||
allow = SubmitField(lazy_gettext("Allow"))
|
||||
deny = SubmitField(lazy_gettext("Deny"))
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@ -39,3 +39,8 @@ def print_dynamic_texts():
|
||||
gettext("EventReviewStatus.inbox")
|
||||
gettext("EventReviewStatus.verified")
|
||||
gettext("EventReviewStatus.rejected")
|
||||
gettext("read")
|
||||
gettext("write")
|
||||
gettext("Event")
|
||||
gettext("Organizer")
|
||||
gettext("Place")
|
||||
|
||||
@ -1,8 +1,27 @@
|
||||
from project import app, db
|
||||
from project.api import api_docs, scopes
|
||||
from project.services.user import upsert_user_role
|
||||
from project.services.admin_unit import upsert_admin_unit_member_role
|
||||
from project.services.event import upsert_event_category
|
||||
from project.models import Location
|
||||
from flask import url_for
|
||||
from apispec.exceptions import DuplicateComponentNameError
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
def add_oauth2_scheme():
|
||||
oauth2_scheme = {
|
||||
"type": "oauth2",
|
||||
"authorizationUrl": url_for("authorize", _external=True),
|
||||
"tokenUrl": url_for("issue_token", _external=True),
|
||||
"flow": "accessCode",
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
try:
|
||||
api_docs.spec.components.security_scheme("oauth2", oauth2_scheme)
|
||||
except DuplicateComponentNameError: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
@ -37,12 +56,23 @@ def create_initial_data():
|
||||
"reference_request:delete",
|
||||
"reference_request:verify",
|
||||
]
|
||||
early_adopter_permissions = [
|
||||
"oauth2_client:create",
|
||||
"oauth2_client:read",
|
||||
"oauth2_client:update",
|
||||
"oauth2_client:delete",
|
||||
"oauth2_token:create",
|
||||
"oauth2_token:read",
|
||||
"oauth2_token:update",
|
||||
"oauth2_token:delete",
|
||||
]
|
||||
|
||||
upsert_admin_unit_member_role("admin", "Administrator", admin_permissions)
|
||||
upsert_admin_unit_member_role("event_verifier", "Event expert", event_permissions)
|
||||
|
||||
upsert_user_role("admin", "Administrator", admin_permissions)
|
||||
upsert_user_role("event_verifier", "Event expert", event_permissions)
|
||||
upsert_user_role("early_adopter", "Early Adopter", early_adopter_permissions)
|
||||
|
||||
Location.update_coordinates()
|
||||
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
from project import app
|
||||
from project.utils import get_event_category_name, get_localized_enum_name
|
||||
from project.utils import (
|
||||
get_event_category_name,
|
||||
get_localized_enum_name,
|
||||
get_localized_scope,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
import os
|
||||
|
||||
@ -8,10 +12,16 @@ def env_override(value, key):
|
||||
return os.getenv(key, value)
|
||||
|
||||
|
||||
def is_list(value):
|
||||
return isinstance(value, list)
|
||||
|
||||
|
||||
app.jinja_env.filters["event_category_name"] = lambda u: get_event_category_name(u)
|
||||
app.jinja_env.filters["loc_enum"] = lambda u: get_localized_enum_name(u)
|
||||
app.jinja_env.filters["loc_scope"] = lambda s: get_localized_scope(s)
|
||||
app.jinja_env.filters["env_override"] = env_override
|
||||
app.jinja_env.filters["quote_plus"] = lambda u: quote_plus(u)
|
||||
app.jinja_env.filters["is_list"] = is_list
|
||||
|
||||
|
||||
@app.context_processor
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from project import db
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import relationship, backref, deferred
|
||||
from sqlalchemy.orm import relationship, backref, deferred, object_session
|
||||
from sqlalchemy.schema import CheckConstraint
|
||||
from sqlalchemy.event import listens_for
|
||||
from sqlalchemy import (
|
||||
@ -24,6 +24,12 @@ import datetime
|
||||
from project.dbtypes import IntegerEnum
|
||||
from geoalchemy2 import Geometry
|
||||
from sqlalchemy import and_
|
||||
from authlib.integrations.sqla_oauth2 import (
|
||||
OAuth2ClientMixin,
|
||||
OAuth2AuthorizationCodeMixin,
|
||||
OAuth2TokenMixin,
|
||||
)
|
||||
import time
|
||||
|
||||
# Base
|
||||
|
||||
@ -156,6 +162,12 @@ class User(db.Model, UserMixin):
|
||||
"Role", secondary="roles_users", backref=backref("users", lazy="dynamic")
|
||||
)
|
||||
|
||||
def get_user_id(self):
|
||||
return self.id
|
||||
|
||||
|
||||
# OAuth Consumer: Wenn wir OAuth consumen und sich ein Nutzer per Google oder Facebook anmelden möchte
|
||||
|
||||
|
||||
class OAuth(OAuthConsumerMixin, db.Model):
|
||||
provider_user_id = Column(String(256), unique=True, nullable=False)
|
||||
@ -163,6 +175,51 @@ class OAuth(OAuthConsumerMixin, db.Model):
|
||||
user = db.relationship("User")
|
||||
|
||||
|
||||
# OAuth Server: Wir bieten an, dass sich ein Nutzer per OAuth2 auf unserer Seite anmeldet
|
||||
|
||||
|
||||
class OAuth2Client(db.Model, OAuth2ClientMixin):
|
||||
__tablename__ = "oauth2_client"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||
user = db.relationship("User")
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return True
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin):
|
||||
__tablename__ = "oauth2_code"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||
user = db.relationship("User")
|
||||
|
||||
|
||||
class OAuth2Token(db.Model, OAuth2TokenMixin):
|
||||
__tablename__ = "oauth2_token"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||
user = db.relationship("User")
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return (
|
||||
object_session(self)
|
||||
.query(OAuth2Client)
|
||||
.filter(OAuth2Client.client_id == self.client_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def is_refresh_token_active(self):
|
||||
if self.revoked:
|
||||
return False
|
||||
expires_at = self.issued_at + self.expires_in * 2
|
||||
return expires_at >= time.time()
|
||||
|
||||
|
||||
# Admin Unit
|
||||
|
||||
|
||||
@ -298,6 +355,7 @@ def update_location_coordinate(mapper, connect, self):
|
||||
# Events
|
||||
class EventPlace(db.Model, TrackableMixin):
|
||||
__tablename__ = "eventplace"
|
||||
__table_args__ = (UniqueConstraint("name", "admin_unit_id"),)
|
||||
id = Column(Integer(), primary_key=True)
|
||||
name = Column(Unicode(255), nullable=False)
|
||||
location_id = db.Column(db.Integer, db.ForeignKey("location.id"))
|
||||
|
||||
114
project/oauth2.py
Normal file
114
project/oauth2.py
Normal file
@ -0,0 +1,114 @@
|
||||
from authlib.integrations.flask_oauth2 import (
|
||||
AuthorizationServer,
|
||||
ResourceProtector,
|
||||
)
|
||||
from authlib.integrations.sqla_oauth2 import (
|
||||
create_query_client_func,
|
||||
create_save_token_func,
|
||||
create_bearer_token_validator,
|
||||
create_query_token_func,
|
||||
)
|
||||
from authlib.oauth2.rfc6749 import grants
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from project import db
|
||||
from project.models import User, OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
|
||||
|
||||
|
||||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none",
|
||||
]
|
||||
|
||||
def save_authorization_code(self, code, request):
|
||||
code_challenge = request.data.get("code_challenge")
|
||||
code_challenge_method = request.data.get("code_challenge_method")
|
||||
auth_code = OAuth2AuthorizationCode(
|
||||
code=code,
|
||||
client_id=request.client.client_id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
scope=request.scope,
|
||||
user_id=request.user.id,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
)
|
||||
db.session.add(auth_code)
|
||||
db.session.commit()
|
||||
return auth_code
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
||||
code=code, client_id=client.client_id
|
||||
).first()
|
||||
if auth_code and not auth_code.is_expired():
|
||||
return auth_code
|
||||
|
||||
def delete_authorization_code(self, authorization_code):
|
||||
db.session.delete(authorization_code)
|
||||
db.session.commit()
|
||||
|
||||
def authenticate_user(self, authorization_code):
|
||||
return User.query.get(authorization_code.user_id)
|
||||
|
||||
|
||||
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
|
||||
|
||||
def authenticate_refresh_token(self, refresh_token):
|
||||
token = OAuth2Token.query.filter_by(refresh_token=refresh_token).first()
|
||||
if token and token.is_refresh_token_active():
|
||||
return token
|
||||
|
||||
def authenticate_user(self, credential):
|
||||
return User.query.get(credential.user_id)
|
||||
|
||||
def revoke_old_credential(self, credential):
|
||||
credential.revoked = True
|
||||
db.session.add(credential)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
query_client = create_query_client_func(db.session, OAuth2Client)
|
||||
save_token = create_save_token_func(db.session, OAuth2Token)
|
||||
authorization = AuthorizationServer(
|
||||
query_client=query_client,
|
||||
save_token=save_token,
|
||||
)
|
||||
require_oauth = ResourceProtector()
|
||||
|
||||
|
||||
def create_revocation_endpoint(session, token_model):
|
||||
from authlib.oauth2.rfc7009 import RevocationEndpoint
|
||||
|
||||
query_token = create_query_token_func(session, token_model)
|
||||
|
||||
class _RevocationEndpoint(RevocationEndpoint):
|
||||
CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
|
||||
|
||||
def query_token(self, token, token_type_hint, client):
|
||||
return query_token(token, token_type_hint, client)
|
||||
|
||||
def revoke_token(self, token):
|
||||
token.revoked = True
|
||||
session.add(token)
|
||||
session.commit()
|
||||
|
||||
return _RevocationEndpoint
|
||||
|
||||
|
||||
def config_oauth(app):
|
||||
app.config["OAUTH2_REFRESH_TOKEN_GENERATOR"] = True
|
||||
authorization.init_app(app)
|
||||
|
||||
# support grants
|
||||
authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
|
||||
authorization.register_grant(RefreshTokenGrant)
|
||||
|
||||
# support revocation
|
||||
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token)
|
||||
authorization.register_endpoint(revocation_cls)
|
||||
|
||||
# protect resource
|
||||
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
|
||||
require_oauth.register_token_validator(bearer_cls())
|
||||
12
project/services/oauth2_client.py
Normal file
12
project/services/oauth2_client.py
Normal file
@ -0,0 +1,12 @@
|
||||
from project.models import OAuth2Client
|
||||
import time
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
|
||||
def complete_oauth2_client(oauth2_client: OAuth2Client) -> None:
|
||||
if not oauth2_client.id:
|
||||
oauth2_client.client_id = gen_salt(24)
|
||||
oauth2_client.client_id_issued_at = int(time.time())
|
||||
|
||||
if oauth2_client.client_secret is None:
|
||||
oauth2_client.client_secret = gen_salt(48)
|
||||
@ -17,7 +17,7 @@ def add_roles_to_user(email, roles):
|
||||
|
||||
|
||||
def add_admin_roles_to_user(email):
|
||||
add_roles_to_user(email, ["admin", "event_verifier"])
|
||||
add_roles_to_user(email, ["admin", "event_verifier", "early_adopter"])
|
||||
|
||||
|
||||
def remove_roles_from_user(email, roles):
|
||||
|
||||
@ -160,6 +160,21 @@
|
||||
{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{% macro render_kv_begin() %}
|
||||
<dl class="row">
|
||||
{% endmacro %}
|
||||
|
||||
{% macro render_kv_end() %}
|
||||
</dl>
|
||||
{% endmacro %}
|
||||
|
||||
{% macro render_kv_prop(prop, label_key = None) %}
|
||||
{% if prop %}
|
||||
<dt class="col-sm-3">{{ _(label_key) }}</dt>
|
||||
<dd class="col-sm-9">{% if prop|is_list %}{{ prop|join(', ') }}{% else %}{{ prop }}{% endif %}</dd>
|
||||
{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{% macro render_string_prop(prop, icon = None, label_key = None) %}
|
||||
{% if prop %}
|
||||
<div>
|
||||
|
||||
26
project/templates/oauth2_client/create.html
Normal file
26
project/templates/oauth2_client/create.html
Normal file
@ -0,0 +1,26 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||
{% block title %}
|
||||
{{ _('Create OAuth2 client') }}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
|
||||
<h1>{{ _('Create OAuth2 client') }}</h1>
|
||||
|
||||
<form action="" method="POST">
|
||||
{{ form.hidden_tag() }}
|
||||
|
||||
<div class="card mb-4">
|
||||
<div class="card-body">
|
||||
{{ render_field_with_errors(form.client_name) }}
|
||||
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
|
||||
{{ render_field_with_errors(form.redirect_uris) }}
|
||||
|
||||
|
||||
{{ render_field(form.submit) }}
|
||||
</form>
|
||||
|
||||
{% endblock %}
|
||||
24
project/templates/oauth2_client/delete.html
Normal file
24
project/templates/oauth2_client/delete.html
Normal file
@ -0,0 +1,24 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||
|
||||
{% block content %}
|
||||
|
||||
<h1>{{ _('Delete OAuth2 client') }} "{{ oauth2_client.client_name }}"</h1>
|
||||
|
||||
<form action="" method="POST">
|
||||
{{ form.hidden_tag() }}
|
||||
|
||||
<div class="card mb-4">
|
||||
<div class="card-header">
|
||||
{{ _('OAuth2 client') }}
|
||||
</div>
|
||||
<div class="card-body">
|
||||
{{ render_field_with_errors(form.name) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{{ render_field(form.submit) }}
|
||||
|
||||
</form>
|
||||
|
||||
{% endblock %}
|
||||
44
project/templates/oauth2_client/list.html
Normal file
44
project/templates/oauth2_client/list.html
Normal file
@ -0,0 +1,44 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_pagination %}
|
||||
{% block title %}
|
||||
{{ _('OAuth2 clients') }}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
|
||||
<nav aria-label="breadcrumb">
|
||||
<ol class="breadcrumb">
|
||||
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 clients') }}</li>
|
||||
</ol>
|
||||
</nav>
|
||||
|
||||
{% if current_user.has_permission('oauth2_client:create') %}
|
||||
<div class="my-4">
|
||||
<a class="btn btn-outline-secondary my-1" href="{{ url_for('oauth2_client_create') }}" role="button"><i class="fa fa-plus"></i> {{ _('Create OAuth2 client') }}</a>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<div class="table-responsive">
|
||||
<table class="table table-sm table-bordered table-hover table-striped">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{{ _('Name') }}</th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for oauth2_client in oauth2_clients %}
|
||||
<tr>
|
||||
<td><a href="{{ url_for('oauth2_client', id=oauth2_client.id) }}">{{ oauth2_client.client_name }}</a></td>
|
||||
<td><a href="{{ url_for('oauth2_client_update', id=oauth2_client.id) }}">{{ _('Edit') }}</a></td>
|
||||
<td><a href="{{ url_for('oauth2_client_delete', id=oauth2_client.id) }}">{{ _('Delete') }}</a></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<div class="my-4">{{ render_pagination(pagination) }}</div>
|
||||
|
||||
{% endblock %}
|
||||
34
project/templates/oauth2_client/read.html
Normal file
34
project/templates/oauth2_client/read.html
Normal file
@ -0,0 +1,34 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_kv_begin, render_kv_end, render_kv_prop %}
|
||||
{% block title %}
|
||||
{{ oauth2_client.client_name }}
|
||||
{% endblock %}
|
||||
{% block header %}
|
||||
<script type="application/ld+json">
|
||||
{{ structured_data | safe }}
|
||||
</script>
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
|
||||
<nav aria-label="breadcrumb">
|
||||
<ol class="breadcrumb">
|
||||
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||
<li class="breadcrumb-item"><a href="{{ url_for('oauth2_clients') }}">{{ _('OAuth2 clients') }}</a></li>
|
||||
<li class="breadcrumb-item active" aria-current="page">{{ oauth2_client.client_name }}</li>
|
||||
</ol>
|
||||
</nav>
|
||||
|
||||
<div class="w-normal">
|
||||
{{ render_kv_begin() }}
|
||||
{{ render_kv_prop(oauth2_client.client_id, 'Client ID') }}
|
||||
{{ render_kv_prop(oauth2_client.client_secret, 'Client secret') }}
|
||||
{{ render_kv_prop(oauth2_client.client_uri, 'Client URI') }}
|
||||
{{ render_kv_prop(oauth2_client.grant_types, 'Grant types') }}
|
||||
{{ render_kv_prop(oauth2_client.redirect_uris, 'Redirect URIs') }}
|
||||
{{ render_kv_prop(oauth2_client.response_types, 'Response types') }}
|
||||
{{ render_kv_prop(oauth2_client.scope, 'Scope') }}
|
||||
{{ render_kv_prop(oauth2_client.token_endpoint_auth_method, 'Token endpoint auth method') }}
|
||||
{{ render_kv_end() }}
|
||||
</div>
|
||||
|
||||
{% endblock %}
|
||||
27
project/templates/oauth2_client/update.html
Normal file
27
project/templates/oauth2_client/update.html
Normal file
@ -0,0 +1,27 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||
{% block title %}
|
||||
{{ _('Update OAuth2 client') }}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
|
||||
<h1>{{ _('Update OAuth2 client') }}</h1>
|
||||
|
||||
<form action="" method="POST">
|
||||
{{ form.hidden_tag() }}
|
||||
|
||||
<div class="card mb-4">
|
||||
<div class="card-body">
|
||||
{{ render_field_with_errors(form.client_name) }}
|
||||
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
|
||||
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
|
||||
{{ render_field_with_errors(form.redirect_uris) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{{ render_field(form.submit) }}
|
||||
</form>
|
||||
|
||||
{% endblock %}
|
||||
40
project/templates/oauth2_token/list.html
Normal file
40
project/templates/oauth2_token/list.html
Normal file
@ -0,0 +1,40 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_pagination %}
|
||||
{% block title %}
|
||||
{{ _('OAuth2 tokens') }}
|
||||
{% endblock %}
|
||||
{% block content %}
|
||||
|
||||
<nav aria-label="breadcrumb">
|
||||
<ol class="breadcrumb">
|
||||
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 tokens') }}</li>
|
||||
</ol>
|
||||
</nav>
|
||||
|
||||
<div class="table-responsive">
|
||||
<table class="table table-sm table-bordered table-hover table-striped">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>{{ _('Client') }}</th>
|
||||
<th>{{ _('Scopes') }}</th>
|
||||
<th>{{ _('Status') }}</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for oauth2_token in oauth2_tokens %}
|
||||
<tr>
|
||||
<td>{{ oauth2_token.client.client_name }}</td>
|
||||
<td>{{ oauth2_token.client.scope }}</td>
|
||||
<td>{% if oauth2_token.revoked %}{{ _('Revoked') }}{% else %}{{ _('Active') }}{% endif %}</td>
|
||||
<td>{% if not oauth2_token.revoked %}<a href="{{ url_for('oauth2_token_revoke', id=oauth2_token.id) }}">{{ _('Revoke') }}</a>{% endif %}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<div class="my-4">{{ render_pagination(pagination) }}</div>
|
||||
|
||||
{% endblock %}
|
||||
16
project/templates/oauth2_token/revoke.html
Normal file
16
project/templates/oauth2_token/revoke.html
Normal file
@ -0,0 +1,16 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||
|
||||
{% block content %}
|
||||
|
||||
<h1>{{ _('Revoke OAuth2 token') }}</h1>
|
||||
<h2>{{ oauth2_token.client.client_name }}</h2>
|
||||
|
||||
<form action="" method="POST">
|
||||
{{ form.hidden_tag() }}
|
||||
|
||||
{{ render_field(form.submit) }}
|
||||
|
||||
</form>
|
||||
|
||||
{% endblock %}
|
||||
@ -8,7 +8,23 @@
|
||||
<h1>{{ current_user.email }}</h1>
|
||||
|
||||
<h2>{{ _('Profile') }}</h2>
|
||||
<p><a href="{{ url_for_security('change_password') }}">{{ _fsdomain('Change password') }}</a></p>
|
||||
|
||||
<div class="list-group">
|
||||
<a href="{{ url_for('security.change_password') }}" class="list-group-item">
|
||||
{{ _fsdomain('Change password') }}
|
||||
<i class="fa fa-caret-right"></i>
|
||||
</a>
|
||||
{% if current_user.has_permission('oauth2_client:read') %}
|
||||
<a href="{{ url_for('oauth2_clients') }}" class="list-group-item">
|
||||
{{ _('OAuth2 clients') }}
|
||||
<i class="fa fa-caret-right"></i>
|
||||
</a>
|
||||
{% endif %}
|
||||
<a href="{{ url_for('oauth2_tokens') }}" class="list-group-item">
|
||||
{{ _('OAuth2 tokens') }}
|
||||
<i class="fa fa-caret-right"></i>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{% if invitations %}
|
||||
<h2>{{ _('Invitations') }}</h2>
|
||||
|
||||
40
project/templates/security/authorize.html
Normal file
40
project/templates/security/authorize.html
Normal file
@ -0,0 +1,40 @@
|
||||
{% extends "layout.html" %}
|
||||
{% from "_macros.html" import render_field %}
|
||||
|
||||
{% block content %}
|
||||
|
||||
<div class="w-normal d-flex flex-column">
|
||||
|
||||
<div class="card mx-auto">
|
||||
<div class="card-body">
|
||||
<h5>{{ _('"%(client_name)s" wants to access your account', client_name=grant.client.client_name) }}</h5>
|
||||
|
||||
<p class="text-center"><strong>{{ user.email }}</strong></p>
|
||||
|
||||
<p class="mb-1">{{ _('This will allow "%(client_name)s" to:', client_name=grant.client.client_name) }}</p>
|
||||
|
||||
<ul>
|
||||
{% for key, value in scopes.items() %}
|
||||
<li>{{ value }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
||||
<form action="" method="POST">
|
||||
{{ form.hidden_tag() }}
|
||||
|
||||
<div class="d-flex flex-column mt-5">
|
||||
<div class="mx-auto">
|
||||
{{ render_field(form.allow, class="btn btn-success mx-auto") }}
|
||||
</div>
|
||||
<div class="mx-auto">
|
||||
{{ render_field(form.deny, class="btn btn-light mx-auto") }}
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
{% endblock %}
|
||||
@ -4,7 +4,8 @@
|
||||
{% block content %}
|
||||
|
||||
<h1>{{ _fsdomain('Login') }}</h1>
|
||||
<form action="{{ url_for_security('login', next='manage') }}" method="POST" name="login_user_form">
|
||||
{% set next = request.args['next'] if 'next' in request.args and 'authorize' in request.args['next'] else 'manage' %}
|
||||
<form action="{{ url_for_security('login', next=next) }}" method="POST" name="login_user_form">
|
||||
{{ login_user_form.hidden_tag() }}
|
||||
{{ render_field_with_errors(login_user_form.email) }}
|
||||
{{ render_field_with_errors(login_user_form.password) }}
|
||||
|
||||
@ -11,9 +11,20 @@ def get_localized_enum_name(enum):
|
||||
return lazy_gettext(enum.__class__.__name__ + "." + enum.name)
|
||||
|
||||
|
||||
def get_localized_scope(scope: str) -> str:
|
||||
type_name, action = scope.split(":")
|
||||
loc_lazy_gettext = lazy_gettext(type_name.capitalize())
|
||||
loc_action = lazy_gettext(action)
|
||||
return f"{loc_lazy_gettext} ({loc_action})"
|
||||
|
||||
|
||||
def make_dir(path):
|
||||
try:
|
||||
original_umask = os.umask(0)
|
||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||
finally:
|
||||
os.umask(original_umask)
|
||||
|
||||
|
||||
def split_by_crlf(s):
|
||||
return [v for v in s.splitlines() if v]
|
||||
|
||||
53
project/views/oauth.py
Normal file
53
project/views/oauth.py
Normal file
@ -0,0 +1,53 @@
|
||||
from authlib.oauth2 import OAuth2Error
|
||||
from project import app
|
||||
from project.api import scopes
|
||||
from flask_security import current_user
|
||||
from flask import redirect, request, url_for, render_template
|
||||
from project.forms.security import AuthorizeForm
|
||||
from project.oauth2 import authorization
|
||||
|
||||
|
||||
@app.route("/oauth/authorize", methods=["GET", "POST"])
|
||||
def authorize():
|
||||
user = current_user
|
||||
|
||||
if not user or not user.is_authenticated:
|
||||
return redirect(url_for("security.login", next=request.url))
|
||||
|
||||
form = AuthorizeForm()
|
||||
|
||||
if form.validate_on_submit():
|
||||
grant_user = user if form.allow.data else None
|
||||
return authorization.create_authorization_response(grant_user=grant_user)
|
||||
else:
|
||||
try:
|
||||
grant = authorization.validate_consent_request(end_user=user)
|
||||
except OAuth2Error as error:
|
||||
return error.error
|
||||
|
||||
grant_scopes = grant.request.scope.split(" ")
|
||||
filtered_scopes = {k: scopes[k] for k in grant_scopes}
|
||||
return render_template(
|
||||
"security/authorize.html",
|
||||
form=form,
|
||||
scopes=filtered_scopes,
|
||||
user=user,
|
||||
grant=grant,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/oauth/token", methods=["POST"])
|
||||
def issue_token():
|
||||
return authorization.create_token_response()
|
||||
|
||||
|
||||
@app.route("/oauth/revoke", methods=["POST"])
|
||||
def revoke_token():
|
||||
return authorization.create_endpoint_response("revocation")
|
||||
|
||||
|
||||
@app.route("/oauth2-redirect.html")
|
||||
def swagger_oauth2_redirect():
|
||||
return redirect(
|
||||
url_for("flask-apispec.static", filename="oauth2-redirect.html", **request.args)
|
||||
)
|
||||
125
project/views/oauth2_client.py
Normal file
125
project/views/oauth2_client.py
Normal file
@ -0,0 +1,125 @@
|
||||
from project import app, db
|
||||
from flask import render_template, redirect, flash, url_for
|
||||
from flask_babelex import gettext
|
||||
from flask_security import permissions_required, current_user
|
||||
from project.models import OAuth2Client
|
||||
from project.views.utils import (
|
||||
get_pagination_urls,
|
||||
handleSqlError,
|
||||
flash_errors,
|
||||
non_match_for_deletion,
|
||||
)
|
||||
from project.forms.oauth2_client import (
|
||||
CreateOAuth2ClientForm,
|
||||
UpdateOAuth2ClientForm,
|
||||
DeleteOAuth2ClientForm,
|
||||
)
|
||||
from project.services.oauth2_client import complete_oauth2_client
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from project.access import owner_access_or_401
|
||||
|
||||
|
||||
@app.route("/oauth2_client/create", methods=("GET", "POST"))
|
||||
@permissions_required("oauth2_client:create")
|
||||
def oauth2_client_create():
|
||||
form = CreateOAuth2ClientForm()
|
||||
|
||||
if form.validate_on_submit():
|
||||
oauth2_client = OAuth2Client()
|
||||
form.populate_obj(oauth2_client)
|
||||
oauth2_client.user_id = current_user.id
|
||||
complete_oauth2_client(oauth2_client)
|
||||
|
||||
try:
|
||||
db.session.add(oauth2_client)
|
||||
db.session.commit()
|
||||
flash(gettext("OAuth2 client successfully created"), "success")
|
||||
return redirect(url_for("oauth2_client", id=oauth2_client.id))
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(handleSqlError(e), "danger")
|
||||
else:
|
||||
flash_errors(form)
|
||||
|
||||
return render_template("oauth2_client/create.html", form=form)
|
||||
|
||||
|
||||
@app.route("/oauth2_client/<int:id>/update", methods=("GET", "POST"))
|
||||
@permissions_required("oauth2_client:update")
|
||||
def oauth2_client_update(id):
|
||||
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||
owner_access_or_401(oauth2_client.user_id)
|
||||
|
||||
form = UpdateOAuth2ClientForm(obj=oauth2_client)
|
||||
|
||||
if form.validate_on_submit():
|
||||
form.populate_obj(oauth2_client)
|
||||
complete_oauth2_client(oauth2_client)
|
||||
|
||||
try:
|
||||
db.session.commit()
|
||||
flash(gettext("OAuth2 client successfully updated"), "success")
|
||||
return redirect(url_for("oauth2_client", id=oauth2_client.id))
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(handleSqlError(e), "danger")
|
||||
else:
|
||||
flash_errors(form)
|
||||
|
||||
return render_template(
|
||||
"oauth2_client/update.html", form=form, oauth2_client=oauth2_client
|
||||
)
|
||||
|
||||
|
||||
@app.route("/oauth2_client/<int:id>/delete", methods=("GET", "POST"))
|
||||
@permissions_required("oauth2_client:delete")
|
||||
def oauth2_client_delete(id):
|
||||
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||
owner_access_or_401(oauth2_client.user_id)
|
||||
|
||||
form = DeleteOAuth2ClientForm()
|
||||
|
||||
if form.validate_on_submit():
|
||||
if non_match_for_deletion(form.name.data, oauth2_client.client_name):
|
||||
flash(gettext("Entered name does not match OAuth2 client name"), "danger")
|
||||
else:
|
||||
try:
|
||||
db.session.delete(oauth2_client)
|
||||
db.session.commit()
|
||||
flash(gettext("OAuth2 client successfully deleted"), "success")
|
||||
return redirect(url_for("oauth2_clients"))
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(handleSqlError(e), "danger")
|
||||
else:
|
||||
flash_errors(form)
|
||||
|
||||
return render_template(
|
||||
"oauth2_client/delete.html", form=form, oauth2_client=oauth2_client
|
||||
)
|
||||
|
||||
|
||||
@app.route("/oauth2_client/<int:id>")
|
||||
@permissions_required("oauth2_client:read")
|
||||
def oauth2_client(id):
|
||||
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||
owner_access_or_401(oauth2_client.user_id)
|
||||
|
||||
return render_template(
|
||||
"oauth2_client/read.html",
|
||||
oauth2_client=oauth2_client,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/oauth2_clients")
|
||||
@permissions_required("oauth2_client:read")
|
||||
def oauth2_clients():
|
||||
oauth2_clients = OAuth2Client.query.filter(
|
||||
OAuth2Client.user_id == current_user.id
|
||||
).paginate()
|
||||
|
||||
return render_template(
|
||||
"oauth2_client/list.html",
|
||||
oauth2_clients=oauth2_clients.items,
|
||||
pagination=get_pagination_urls(oauth2_clients),
|
||||
)
|
||||
53
project/views/oauth2_token.py
Normal file
53
project/views/oauth2_token.py
Normal file
@ -0,0 +1,53 @@
|
||||
from project import app, db
|
||||
from flask import render_template, redirect, flash, url_for
|
||||
from flask_babelex import gettext
|
||||
from flask_security import current_user
|
||||
from project.models import OAuth2Token
|
||||
from project.views.utils import (
|
||||
get_pagination_urls,
|
||||
handleSqlError,
|
||||
flash_errors,
|
||||
)
|
||||
from project.forms.oauth2_token import RevokeOAuth2TokenForm
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from project.access import owner_access_or_401
|
||||
|
||||
|
||||
@app.route("/oauth2_token/<int:id>/revoke", methods=("GET", "POST"))
|
||||
def oauth2_token_revoke(id):
|
||||
oauth2_token = OAuth2Token.query.get_or_404(id)
|
||||
owner_access_or_401(oauth2_token.user_id)
|
||||
|
||||
if oauth2_token.revoked:
|
||||
return redirect(url_for("oauth2_tokens"))
|
||||
|
||||
form = RevokeOAuth2TokenForm()
|
||||
|
||||
if form.validate_on_submit():
|
||||
try:
|
||||
oauth2_token.revoked = True
|
||||
db.session.commit()
|
||||
flash(gettext("OAuth2 token successfully revoked"), "success")
|
||||
return redirect(url_for("oauth2_tokens"))
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
flash(handleSqlError(e), "danger")
|
||||
else:
|
||||
flash_errors(form)
|
||||
|
||||
return render_template(
|
||||
"oauth2_token/revoke.html", form=form, oauth2_token=oauth2_token
|
||||
)
|
||||
|
||||
|
||||
@app.route("/oauth2_tokens")
|
||||
def oauth2_tokens():
|
||||
oauth2_tokens = OAuth2Token.query.filter(
|
||||
OAuth2Token.user_id == current_user.id
|
||||
).paginate()
|
||||
|
||||
return render_template(
|
||||
"oauth2_token/list.html",
|
||||
oauth2_tokens=oauth2_tokens.items,
|
||||
pagination=get_pagination_urls(oauth2_tokens),
|
||||
)
|
||||
@ -5,6 +5,7 @@ apispec-webframeworks==0.5.2
|
||||
appdirs==1.4.4
|
||||
argh==0.26.2
|
||||
attrs==20.3.0
|
||||
Authlib==0.15.3
|
||||
Babel==2.9.0
|
||||
bcrypt==3.2.0
|
||||
beautifulsoup4==4.9.3
|
||||
@ -18,6 +19,7 @@ click==7.1.2
|
||||
colour==0.1.5
|
||||
coverage==5.3
|
||||
coveralls==2.2.0
|
||||
cryptography==3.3.1
|
||||
distlib==0.3.1
|
||||
dnspython==2.0.0
|
||||
docopt==0.6.2
|
||||
|
||||
55
tests/api/test___init__.py
Normal file
55
tests/api/test___init__.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from project.api import RestApi
|
||||
|
||||
|
||||
class Psycog2Error(object):
|
||||
def __init__(self, pgcode):
|
||||
self.pgcode = pgcode
|
||||
|
||||
|
||||
def test_handle_error_unique(app):
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||
|
||||
orig = Psycog2Error(UNIQUE_VIOLATION)
|
||||
error = IntegrityError("Select", list(), orig)
|
||||
|
||||
api = RestApi(app)
|
||||
(data, code) = api.handle_error(error)
|
||||
assert code == 400
|
||||
assert data["name"] == "Unique Violation"
|
||||
|
||||
|
||||
def test_handle_error_httpException(app):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
error = InternalServerError()
|
||||
|
||||
api = RestApi(app)
|
||||
(data, code) = api.handle_error(error)
|
||||
assert code == 500
|
||||
|
||||
|
||||
def test_handle_error_unprocessableEntity(app):
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
from marshmallow import ValidationError
|
||||
|
||||
args = {"name": ["Required"]}
|
||||
validation_error = ValidationError(args)
|
||||
|
||||
error = UnprocessableEntity()
|
||||
error.exc = validation_error
|
||||
|
||||
api = RestApi(app)
|
||||
(data, code) = api.handle_error(error)
|
||||
assert code == 422
|
||||
assert data["errors"][0]["field"] == "name"
|
||||
assert data["errors"][0]["message"] == "Required"
|
||||
|
||||
|
||||
def test_handle_error_unspecificRaises(app):
|
||||
error = Exception()
|
||||
api = RestApi(app)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
api.handle_error(error)
|
||||
47
tests/api/test_fields.py
Normal file
47
tests/api/test_fields.py
Normal file
@ -0,0 +1,47 @@
|
||||
def test_numeric_str_serialize(client, seeder, utils):
|
||||
from project.api.location.schemas import LocationSchema
|
||||
from project.models import Location
|
||||
|
||||
location = Location()
|
||||
location.street = "Markt 7"
|
||||
location.postalCode = "38640"
|
||||
location.city = "Goslar"
|
||||
location.latitude = 51.9077888
|
||||
location.longitude = 10.4333312
|
||||
|
||||
schema = LocationSchema()
|
||||
data = schema.dump(location)
|
||||
|
||||
assert data["latitude"] == "51.9077888"
|
||||
assert data["longitude"] == "10.4333312"
|
||||
|
||||
|
||||
def test_numeric_str_deserialize(client, seeder, utils):
|
||||
from project.api.location.schemas import LocationPostRequestLoadSchema
|
||||
|
||||
data = {
|
||||
"latitude": "51.9077888",
|
||||
"longitude": "10.4333312",
|
||||
}
|
||||
|
||||
schema = LocationPostRequestLoadSchema()
|
||||
location = schema.load(data)
|
||||
|
||||
assert location.latitude == 51.9077888
|
||||
assert location.longitude == 10.4333312
|
||||
|
||||
|
||||
def test_numeric_str_deserialize_invalid(client, seeder, utils):
|
||||
from project.api.location.schemas import LocationPostRequestLoadSchema
|
||||
import pytest
|
||||
from marshmallow import ValidationError
|
||||
|
||||
data = {
|
||||
"latitude": "Quatsch",
|
||||
"longitude": "Quatsch",
|
||||
}
|
||||
|
||||
schema = LocationPostRequestLoadSchema()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
schema.load(data)
|
||||
@ -1,6 +0,0 @@
|
||||
def test_read(client, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_base()
|
||||
image_id = seeder.upsert_default_image()
|
||||
|
||||
url = utils.get_url("api_v1_image", id=image_id)
|
||||
utils.get_ok(url)
|
||||
@ -1,20 +0,0 @@
|
||||
def test_read(client, app, db, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_base()
|
||||
|
||||
with app.app_context():
|
||||
from project.models import Location
|
||||
|
||||
location = Location()
|
||||
location.street = "Markt 7"
|
||||
location.postalCode = "38640"
|
||||
location.city = "Goslar"
|
||||
location.latitude = 51.9077888
|
||||
location.longitude = 10.4333312
|
||||
|
||||
db.session.add(location)
|
||||
db.session.commit()
|
||||
location_id = location.id
|
||||
|
||||
url = utils.get_url("api_v1_location", id=location_id)
|
||||
response = utils.get_ok(url)
|
||||
assert response.json["latitude"] == "51.9077888000000000"
|
||||
@ -46,6 +46,25 @@ def test_places(client, seeder, utils):
|
||||
utils.get_ok(url)
|
||||
|
||||
|
||||
def test_places_post(client, seeder, utils, app):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
|
||||
url = utils.get_url("api_v1_organization_place_list", id=admin_unit_id, name="crew")
|
||||
response = utils.post_json(url, {"name": "Neuer Ort"})
|
||||
utils.assert_response_created(response)
|
||||
assert "id" in response.json
|
||||
|
||||
with app.app_context():
|
||||
from project.models import EventPlace
|
||||
|
||||
place = (
|
||||
EventPlace.query.filter(EventPlace.admin_unit_id == admin_unit_id)
|
||||
.filter(EventPlace.name == "Neuer Ort")
|
||||
.first()
|
||||
)
|
||||
assert place is not None
|
||||
|
||||
|
||||
def test_references_incoming(client, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_base()
|
||||
(
|
||||
|
||||
@ -4,3 +4,65 @@ def test_read(client, app, db, seeder, utils):
|
||||
|
||||
url = utils.get_url("api_v1_place", id=place_id)
|
||||
utils.get_ok(url)
|
||||
|
||||
|
||||
def test_put(client, seeder, utils, app):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||
|
||||
url = utils.get_url("api_v1_place", id=place_id)
|
||||
response = utils.put_json(url, {"name": "Neuer Name"})
|
||||
utils.assert_response_no_content(response)
|
||||
|
||||
with app.app_context():
|
||||
from project.models import EventPlace
|
||||
|
||||
place = EventPlace.query.get(place_id)
|
||||
assert place.name == "Neuer Name"
|
||||
|
||||
|
||||
def test_put_nonActiveReturnsUnauthorized(client, seeder, db, utils, app):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||
|
||||
with app.app_context():
|
||||
from project.models import User
|
||||
|
||||
user = User.query.get(user_id)
|
||||
user.active = False
|
||||
db.session.commit()
|
||||
|
||||
url = utils.get_url("api_v1_place", id=place_id)
|
||||
response = utils.put_json(url, {"name": "Neuer Name"})
|
||||
utils.assert_response_unauthorized(response)
|
||||
|
||||
|
||||
def test_patch(client, seeder, utils, app):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||
|
||||
url = utils.get_url("api_v1_place", id=place_id)
|
||||
response = utils.patch_json(url, {"description": "Klasse"})
|
||||
utils.assert_response_no_content(response)
|
||||
|
||||
with app.app_context():
|
||||
from project.models import EventPlace
|
||||
|
||||
place = EventPlace.query.get(place_id)
|
||||
assert place.name == "Meine Crew"
|
||||
assert place.description == "Klasse"
|
||||
|
||||
|
||||
def test_delete(client, seeder, utils, app):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||
|
||||
url = utils.get_url("api_v1_place", id=place_id)
|
||||
response = utils.delete(url)
|
||||
utils.assert_response_no_content(response)
|
||||
|
||||
with app.app_context():
|
||||
from project.models import EventPlace
|
||||
|
||||
place = EventPlace.query.get(place_id)
|
||||
assert place is None
|
||||
|
||||
@ -8,6 +8,7 @@ def pytest_generate_tests(metafunc):
|
||||
os.environ["DATABASE_URL"] = os.environ.get(
|
||||
"TEST_DATABASE_URL", "postgresql://postgres@localhost/gsevpt_tests"
|
||||
)
|
||||
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -4,9 +4,10 @@ class Seeder(object):
|
||||
self._db = db
|
||||
self._utils = utils
|
||||
|
||||
def setup_base(self, admin=False):
|
||||
def setup_base(self, admin=False, log_in=True):
|
||||
user_id = self.create_user(admin=admin)
|
||||
self._utils.login()
|
||||
if log_in:
|
||||
self._utils.login()
|
||||
admin_unit_id = self.create_admin_unit(user_id)
|
||||
return (user_id, admin_unit_id)
|
||||
|
||||
@ -126,6 +127,45 @@ class Seeder(object):
|
||||
|
||||
return organizer_id
|
||||
|
||||
def insert_default_oauth2_client(self, user_id):
|
||||
from project.api import scope_list
|
||||
from project.models import OAuth2Client
|
||||
from project.services.oauth2_client import complete_oauth2_client
|
||||
|
||||
with self._app.app_context():
|
||||
client = OAuth2Client()
|
||||
client.user_id = user_id
|
||||
complete_oauth2_client(client)
|
||||
|
||||
metadata = dict()
|
||||
metadata["client_name"] = "Mein Client"
|
||||
metadata["scope"] = " ".join(scope_list)
|
||||
metadata["grant_types"] = ["authorization_code", "refresh_token"]
|
||||
metadata["response_types"] = ["code"]
|
||||
metadata["token_endpoint_auth_method"] = "client_secret_post"
|
||||
client.set_client_metadata(metadata)
|
||||
|
||||
self._db.session.add(client)
|
||||
self._db.session.commit()
|
||||
client_id = client.id
|
||||
|
||||
return client_id
|
||||
|
||||
def setup_api_access(self):
|
||||
user_id, admin_unit_id = self.setup_base(admin=True)
|
||||
oauth2_client_id = self.insert_default_oauth2_client(user_id)
|
||||
|
||||
with self._app.app_context():
|
||||
from project.models import OAuth2Client
|
||||
|
||||
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||
client_id = oauth2_client.client_id
|
||||
client_secret = oauth2_client.client_secret
|
||||
scope = oauth2_client.scope
|
||||
|
||||
self._utils.authorize(client_id, client_secret, scope)
|
||||
return (user_id, admin_unit_id)
|
||||
|
||||
def create_event(self, admin_unit_id, recurrence_rule=None):
|
||||
from project.models import Event
|
||||
from project.services.event import insert_event, upsert_event_category
|
||||
|
||||
@ -21,3 +21,16 @@ def test_event_category(client, app, db, seeder):
|
||||
db.session.commit()
|
||||
|
||||
assert event.category is None
|
||||
|
||||
|
||||
def test_oauth2_token(client, app):
|
||||
from project.models import OAuth2Token
|
||||
|
||||
token = OAuth2Token()
|
||||
token.revoked = True
|
||||
assert not token.is_refresh_token_active()
|
||||
|
||||
token.revoked = False
|
||||
token.issued_at = 0
|
||||
token.expires_in = 0
|
||||
assert not token.is_refresh_token_active()
|
||||
|
||||
7
tests/test_oauth2.py
Normal file
7
tests/test_oauth2.py
Normal file
@ -0,0 +1,7 @@
|
||||
def test_authorization_code(seeder):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
|
||||
|
||||
def test_refresh_token(seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
utils.refresh_token()
|
||||
161
tests/utils.py
161
tests/utils.py
@ -2,12 +2,17 @@ import re
|
||||
from flask import g, url_for
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import urlsplit, parse_qs
|
||||
|
||||
|
||||
class UtilActions(object):
|
||||
def __init__(self, client, app):
|
||||
self._client = client
|
||||
self._app = app
|
||||
self._access_token = None
|
||||
self._refresh_token = None
|
||||
self._client_id = None
|
||||
self._client_secret = None
|
||||
|
||||
def register(self, email="test@test.de", password="MeinPasswortIstDasBeste"):
|
||||
response = self._client.get("/register")
|
||||
@ -77,14 +82,56 @@ class UtilActions(object):
|
||||
form = Form(soup.find("form"))
|
||||
return form.fill(values)
|
||||
|
||||
def post_form(self, url, response, values: dict):
|
||||
data = self.create_form_data(response, values)
|
||||
def post_form_data(self, url, data: dict):
|
||||
return self._client.post(url, data=data)
|
||||
|
||||
def post_form(self, url, response, values: dict):
|
||||
data = self.create_form_data(response, values)
|
||||
return self.post_form_data(url, data=data)
|
||||
|
||||
def get_headers(self):
|
||||
headers = dict()
|
||||
|
||||
if self._access_token:
|
||||
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||
|
||||
return headers
|
||||
|
||||
def log_request(self, url):
|
||||
print(url)
|
||||
|
||||
def log_json_request(self, url, data: dict):
|
||||
self.log_request(url)
|
||||
print(data)
|
||||
|
||||
def log_response(self, response):
|
||||
print(response.status_code)
|
||||
print(response.data)
|
||||
print(response.json)
|
||||
|
||||
def post_json(self, url, data: dict):
|
||||
response = self._client.post(url, json=data)
|
||||
assert response.content_type == "application/json"
|
||||
return response.json
|
||||
self.log_json_request(url, data)
|
||||
response = self._client.post(url, json=data, headers=self.get_headers())
|
||||
self.log_response(response)
|
||||
return response
|
||||
|
||||
def put_json(self, url, data: dict):
|
||||
self.log_json_request(url, data)
|
||||
response = self._client.put(url, json=data, headers=self.get_headers())
|
||||
self.log_response(response)
|
||||
return response
|
||||
|
||||
def patch_json(self, url, data: dict):
|
||||
self.log_json_request(url, data)
|
||||
response = self._client.patch(url, json=data, headers=self.get_headers())
|
||||
self.log_response(response)
|
||||
return response
|
||||
|
||||
def delete(self, url):
|
||||
self.log_request(url)
|
||||
response = self._client.delete(url, headers=self.get_headers())
|
||||
self.log_response(response)
|
||||
return response
|
||||
|
||||
def mock_db_commit(self, mocker, orig=None):
|
||||
mocked_commit = mocker.patch("project.db.session.commit")
|
||||
@ -105,14 +152,23 @@ class UtilActions(object):
|
||||
url = url_for(endpoint, **values, _external=False)
|
||||
return url
|
||||
|
||||
def get(self, url):
|
||||
return self._client.get(url)
|
||||
|
||||
def get_ok(self, url):
|
||||
response = self._client.get(url)
|
||||
response = self.get(url)
|
||||
self.assert_response_ok(response)
|
||||
return response
|
||||
|
||||
def assert_response_ok(self, response):
|
||||
assert response.status_code == 200
|
||||
|
||||
def assert_response_created(self, response):
|
||||
assert response.status_code == 201
|
||||
|
||||
def assert_response_no_content(self, response):
|
||||
assert response.status_code == 204
|
||||
|
||||
def get_unauthorized(self, url):
|
||||
response = self._client.get(url)
|
||||
self.assert_response_unauthorized(response)
|
||||
@ -146,3 +202,96 @@ class UtilActions(object):
|
||||
|
||||
def assert_response_permission_missing(self, response, endpoint, **values):
|
||||
self.assert_response_redirect(response, endpoint, **values)
|
||||
|
||||
def parse_query_parameters(self, url):
|
||||
query = urlsplit(url).query
|
||||
params = parse_qs(query)
|
||||
return {k: v[0] for k, v in params.items()}
|
||||
|
||||
def authorize(self, client_id, client_secret, scope):
|
||||
# Authorize-Seite öffnen
|
||||
redirect_uri = self.get_url("swagger_oauth2_redirect")
|
||||
url = self.get_url(
|
||||
"authorize",
|
||||
response_type="code",
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
response = self.get_ok(url)
|
||||
|
||||
# Authorisieren
|
||||
response = self.post_form(
|
||||
url,
|
||||
response,
|
||||
{},
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert redirect_uri in response.headers["Location"]
|
||||
|
||||
# Code aus der Redirect-Antwort lesen
|
||||
params = self.parse_query_parameters(response.headers["Location"])
|
||||
assert "code" in params
|
||||
code = params["code"]
|
||||
|
||||
# Mit dem Code den Access-Token abfragen
|
||||
token_url = self.get_url("issue_token")
|
||||
response = self.post_form_data(
|
||||
token_url,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": "authorization_code",
|
||||
"scope": scope,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
self.assert_response_ok(response)
|
||||
assert response.content_type == "application/json"
|
||||
assert "access_token" in response.json
|
||||
assert "expires_in" in response.json
|
||||
assert "refresh_token" in response.json
|
||||
assert response.json["scope"] == scope
|
||||
assert response.json["token_type"] == "Bearer"
|
||||
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
self._access_token = response.json["access_token"]
|
||||
self._refresh_token = response.json["refresh_token"]
|
||||
|
||||
def refresh_token(self):
|
||||
token_url = self.get_url("issue_token")
|
||||
response = self.post_form_data(
|
||||
token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self._refresh_token,
|
||||
"client_id": self._client_id,
|
||||
"client_secret": self._client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
self.assert_response_ok(response)
|
||||
assert response.content_type == "application/json"
|
||||
assert response.json["token_type"] == "Bearer"
|
||||
assert "access_token" in response.json
|
||||
assert "expires_in" in response.json
|
||||
|
||||
self._access_token = response.json["access_token"]
|
||||
|
||||
def revoke_token(self):
|
||||
url = self.get_url("revoke_token")
|
||||
response = self.post_form_data(
|
||||
url,
|
||||
data={
|
||||
"token": self._access_token,
|
||||
"token_type_hint": "access_token",
|
||||
"client_id": self._client_id,
|
||||
"client_secret": self._client_secret,
|
||||
},
|
||||
)
|
||||
|
||||
self.assert_response_ok(response)
|
||||
|
||||
@ -430,7 +430,7 @@ def test_delete_nameDoesNotMatch(client, seeder, utils, app, mocker):
|
||||
|
||||
def test_rrule(client, seeder, utils, app):
|
||||
url = utils.get_url("event_rrule")
|
||||
json = utils.post_json(
|
||||
response = utils.post_json(
|
||||
url,
|
||||
{
|
||||
"year": 2020,
|
||||
@ -440,6 +440,7 @@ def test_rrule(client, seeder, utils, app):
|
||||
"start": 0,
|
||||
},
|
||||
)
|
||||
json = response.json
|
||||
|
||||
assert json["batch"]["batch_size"] == 10
|
||||
|
||||
|
||||
25
tests/views/test_oauth.py
Normal file
25
tests/views/test_oauth.py
Normal file
@ -0,0 +1,25 @@
|
||||
def test_authorize_unauthorizedRedirects(seeder, utils):
|
||||
url = utils.get_url("authorize")
|
||||
response = utils.get(url)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "login" in response.headers["Location"]
|
||||
|
||||
|
||||
def test_authorize_validateThrowsError(seeder, utils):
|
||||
seeder.setup_base()
|
||||
url = utils.get_url("authorize")
|
||||
response = utils.get(url)
|
||||
|
||||
utils.assert_response_error_message(response, b"invalid_grant")
|
||||
|
||||
|
||||
def test_revoke_token(seeder, utils):
|
||||
seeder.setup_api_access()
|
||||
utils.revoke_token()
|
||||
|
||||
|
||||
def test_swagger_redirect(utils):
|
||||
url = utils.get_url("swagger_oauth2_redirect")
|
||||
response = utils.get(url)
|
||||
assert response.status_code == 302
|
||||
136
tests/views/test_oauth2_client.py
Normal file
136
tests/views/test_oauth2_client.py
Normal file
@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_read(client, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_base(True)
|
||||
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||
|
||||
url = utils.get_url("oauth2_client", id=oauth2_client_id)
|
||||
utils.get_ok(url)
|
||||
|
||||
|
||||
def test_read_notOwner(client, seeder, utils):
|
||||
user_id = seeder.create_user(email="other@other.de", admin=True)
|
||||
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||
|
||||
seeder.setup_base(True)
|
||||
url = utils.get_url("oauth2_client", id=oauth2_client_id)
|
||||
utils.get_unauthorized(url)
|
||||
|
||||
|
||||
def test_list(client, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_base(True)
|
||||
|
||||
url = utils.get_url("oauth2_clients")
|
||||
utils.get_ok(url)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("db_error", [True, False])
|
||||
def test_create_authorization_code(client, app, utils, seeder, mocker, db_error):
|
||||
from project.api import scope_list
|
||||
|
||||
user_id, admin_unit_id = seeder.setup_base(True)
|
||||
|
||||
url = utils.get_url("oauth2_client_create")
|
||||
response = utils.get_ok(url)
|
||||
|
||||
if db_error:
|
||||
utils.mock_db_commit(mocker)
|
||||
|
||||
response = utils.post_form(
|
||||
url,
|
||||
response,
|
||||
{
|
||||
"client_name": "Mein Client",
|
||||
"scope": scope_list,
|
||||
},
|
||||
)
|
||||
|
||||
if db_error:
|
||||
utils.assert_response_db_error(response)
|
||||
return
|
||||
|
||||
with app.app_context():
|
||||
from project.models import OAuth2Client
|
||||
|
||||
oauth2_client = OAuth2Client.query.filter(
|
||||
OAuth2Client.user_id == user_id
|
||||
).first()
|
||||
assert oauth2_client is not None
|
||||
client_id = oauth2_client.id
|
||||
|
||||
utils.assert_response_redirect(response, "oauth2_client", id=client_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("db_error", [True, False])
|
||||
def test_update(client, seeder, utils, app, mocker, db_error):
|
||||
user_id, admin_unit_id = seeder.setup_base(True)
|
||||
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||
|
||||
url = utils.get_url("oauth2_client_update", id=oauth2_client_id)
|
||||
response = utils.get_ok(url)
|
||||
|
||||
if db_error:
|
||||
utils.mock_db_commit(mocker)
|
||||
|
||||
response = utils.post_form(
|
||||
url,
|
||||
response,
|
||||
{
|
||||
"client_name": "Neuer Name",
|
||||
},
|
||||
)
|
||||
|
||||
if db_error:
|
||||
utils.assert_response_db_error(response)
|
||||
return
|
||||
|
||||
utils.assert_response_redirect(response, "oauth2_client", id=oauth2_client_id)
|
||||
|
||||
with app.app_context():
|
||||
from project.models import OAuth2Client
|
||||
|
||||
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||
assert oauth2_client.client_name == "Neuer Name"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("db_error", [True, False])
|
||||
@pytest.mark.parametrize("non_match", [True, False])
|
||||
def test_delete(client, seeder, utils, app, mocker, db_error, non_match):
|
||||
user_id, admin_unit_id = seeder.setup_base(True)
|
||||
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||
|
||||
url = utils.get_url("oauth2_client_delete", id=oauth2_client_id)
|
||||
response = utils.get_ok(url)
|
||||
|
||||
if db_error:
|
||||
utils.mock_db_commit(mocker)
|
||||
|
||||
form_name = "Mein Client"
|
||||
|
||||
if non_match:
|
||||
form_name = "Falscher Name"
|
||||
|
||||
response = utils.post_form(
|
||||
url,
|
||||
response,
|
||||
{
|
||||
"name": form_name,
|
||||
},
|
||||
)
|
||||
|
||||
if non_match:
|
||||
utils.assert_response_error_message(response)
|
||||
return
|
||||
|
||||
if db_error:
|
||||
utils.assert_response_db_error(response)
|
||||
return
|
||||
|
||||
utils.assert_response_redirect(response, "oauth2_clients")
|
||||
|
||||
with app.app_context():
|
||||
from project.models import OAuth2Client
|
||||
|
||||
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||
assert oauth2_client is None
|
||||
47
tests/views/test_oauth2_token.py
Normal file
47
tests/views/test_oauth2_token.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_list(client, seeder, utils):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
|
||||
url = utils.get_url("oauth2_tokens")
|
||||
utils.get_ok(url)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("db_error", [True, False])
|
||||
def test_revoke(client, seeder, utils, app, mocker, db_error):
|
||||
user_id, admin_unit_id = seeder.setup_api_access()
|
||||
|
||||
with app.app_context():
|
||||
from project.models import OAuth2Token
|
||||
|
||||
oauth2_token = OAuth2Token.query.filter(OAuth2Token.user_id == user_id).first()
|
||||
oauth2_token_id = oauth2_token.id
|
||||
|
||||
url = utils.get_url("oauth2_token_revoke", id=oauth2_token_id)
|
||||
response = utils.get_ok(url)
|
||||
|
||||
if db_error:
|
||||
utils.mock_db_commit(mocker)
|
||||
|
||||
response = utils.post_form(
|
||||
url,
|
||||
response,
|
||||
{},
|
||||
)
|
||||
|
||||
if db_error:
|
||||
utils.assert_response_db_error(response)
|
||||
return
|
||||
|
||||
utils.assert_response_redirect(response, "oauth2_tokens")
|
||||
|
||||
with app.app_context():
|
||||
from project.models import OAuth2Token
|
||||
|
||||
oauth2_token = OAuth2Token.query.get(oauth2_token_id)
|
||||
assert oauth2_token.revoked
|
||||
|
||||
# Kann nicht zweimal revoked werden
|
||||
response = utils.get(url)
|
||||
utils.assert_response_redirect(response, "oauth2_tokens")
|
||||
Loading…
x
Reference in New Issue
Block a user