mirror of
https://github.com/lucaspalomodevelop/eventcally.git
synced 2026-03-13 00:07:22 +00:00
API Write Access with OAuth2 #104
This commit is contained in:
parent
bc9d2aae3c
commit
6c2384e678
@ -1,2 +0,0 @@
|
|||||||
[run]
|
|
||||||
relative_files = True
|
|
||||||
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@ -5,7 +5,7 @@
|
|||||||
"python.linting.pylintEnabled": false,
|
"python.linting.pylintEnabled": false,
|
||||||
"python.linting.flake8Enabled": true,
|
"python.linting.flake8Enabled": true,
|
||||||
"python.testing.pytestArgs": [
|
"python.testing.pytestArgs": [
|
||||||
"tests"
|
"tests", "--capture=sys"
|
||||||
],
|
],
|
||||||
"python.testing.unittestEnabled": false,
|
"python.testing.unittestEnabled": false,
|
||||||
"python.testing.nosetestsEnabled": false,
|
"python.testing.nosetestsEnabled": false,
|
||||||
|
|||||||
91
migrations/versions/ddb85cb1c21e_.py
Normal file
91
migrations/versions/ddb85cb1c21e_.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
"""empty message
|
||||||
|
|
||||||
|
Revision ID: ddb85cb1c21e
|
||||||
|
Revises: b1a6e7630185
|
||||||
|
Create Date: 2021-02-02 15:17:21.988363
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlalchemy_utils
|
||||||
|
from project import dbtypes
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "ddb85cb1c21e"
|
||||||
|
down_revision = "b1a6e7630185"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"oauth2_client",
|
||||||
|
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||||
|
sa.Column("client_secret", sa.String(length=120), nullable=True),
|
||||||
|
sa.Column("client_id_issued_at", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("client_secret_expires_at", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("client_metadata", sa.Text(), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth2_client_client_id"), "oauth2_client", ["client_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth2_code",
|
||||||
|
sa.Column("code", sa.String(length=120), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||||
|
sa.Column("redirect_uri", sa.Text(), nullable=True),
|
||||||
|
sa.Column("response_type", sa.Text(), nullable=True),
|
||||||
|
sa.Column("scope", sa.Text(), nullable=True),
|
||||||
|
sa.Column("nonce", sa.Text(), nullable=True),
|
||||||
|
sa.Column("auth_time", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.Text(), nullable=True),
|
||||||
|
sa.Column("code_challenge_method", sa.String(length=48), nullable=True),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("code"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth2_token",
|
||||||
|
sa.Column("client_id", sa.String(length=48), nullable=True),
|
||||||
|
sa.Column("token_type", sa.String(length=40), nullable=True),
|
||||||
|
sa.Column("access_token", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("refresh_token", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("scope", sa.Text(), nullable=True),
|
||||||
|
sa.Column("revoked", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column("issued_at", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("expires_in", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("access_token"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth2_token_refresh_token"),
|
||||||
|
"oauth2_token",
|
||||||
|
["refresh_token"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_unique_constraint(
|
||||||
|
"eventplace_name_admin_unit_id", "eventplace", ["name", "admin_unit_id"]
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint("eventplace_name_admin_unit_id", "eventplace", type_="unique")
|
||||||
|
op.drop_index(op.f("ix_oauth2_token_refresh_token"), table_name="oauth2_token")
|
||||||
|
op.drop_table("oauth2_token")
|
||||||
|
op.drop_table("oauth2_code")
|
||||||
|
op.drop_index(op.f("ix_oauth2_client_client_id"), table_name="oauth2_client")
|
||||||
|
op.drop_table("oauth2_client")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from flask import Flask
|
from flask import Flask, url_for, redirect, request, jsonify
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from flask_security import (
|
from flask_security import (
|
||||||
Security,
|
Security,
|
||||||
@ -10,12 +10,8 @@ from flask_cors import CORS
|
|||||||
from flask_qrcode import QRcode
|
from flask_qrcode import QRcode
|
||||||
from flask_mail import Mail, email_dispatched
|
from flask_mail import Mail, email_dispatched
|
||||||
from flask_migrate import Migrate
|
from flask_migrate import Migrate
|
||||||
from flask_marshmallow import Marshmallow
|
|
||||||
from flask_restful import Api
|
|
||||||
from apispec import APISpec
|
|
||||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
|
||||||
from flask_apispec.extension import FlaskApiSpec
|
|
||||||
from flask_gzip import Gzip
|
from flask_gzip import Gzip
|
||||||
|
from webargs import flaskparser
|
||||||
|
|
||||||
# Create app
|
# Create app
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
@ -59,25 +55,6 @@ babel = Babel(app)
|
|||||||
# cors
|
# cors
|
||||||
cors = CORS(app, resources={r"/api/*", "/swagger/"})
|
cors = CORS(app, resources={r"/api/*", "/swagger/"})
|
||||||
|
|
||||||
# API
|
|
||||||
rest_api = Api(app, "/api/v1")
|
|
||||||
marshmallow = Marshmallow(app)
|
|
||||||
marshmallow_plugin = MarshmallowPlugin()
|
|
||||||
app.config.update(
|
|
||||||
{
|
|
||||||
"APISPEC_SPEC": APISpec(
|
|
||||||
title="Oveda API",
|
|
||||||
version="0.1.0",
|
|
||||||
plugins=[marshmallow_plugin],
|
|
||||||
openapi_version="2.0",
|
|
||||||
info=dict(
|
|
||||||
description="This API provides endpoints to interact with the Oveda data. At the moment, there is no authorization needed."
|
|
||||||
),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
api_docs = FlaskApiSpec(app)
|
|
||||||
|
|
||||||
# Mail
|
# Mail
|
||||||
mail_server = os.getenv("MAIL_SERVER")
|
mail_server = os.getenv("MAIL_SERVER")
|
||||||
|
|
||||||
@ -108,6 +85,9 @@ if app.config["MAIL_SUPPRESS_SEND"]:
|
|||||||
db = SQLAlchemy(app)
|
db = SQLAlchemy(app)
|
||||||
migrate = Migrate(app, db)
|
migrate = Migrate(app, db)
|
||||||
|
|
||||||
|
# API
|
||||||
|
from project.api import RestApi
|
||||||
|
|
||||||
# qr code
|
# qr code
|
||||||
QRcode(app)
|
QRcode(app)
|
||||||
|
|
||||||
@ -123,6 +103,13 @@ from project.forms.security import ExtendedRegisterForm
|
|||||||
user_datastore = SQLAlchemySessionUserDatastore(db.session, User, Role)
|
user_datastore = SQLAlchemySessionUserDatastore(db.session, User, Role)
|
||||||
security = Security(app, user_datastore, register_form=ExtendedRegisterForm)
|
security = Security(app, user_datastore, register_form=ExtendedRegisterForm)
|
||||||
|
|
||||||
|
# OAuth2
|
||||||
|
from project.oauth2 import config_oauth
|
||||||
|
|
||||||
|
config_oauth(app)
|
||||||
|
|
||||||
|
# Init misc modules
|
||||||
|
|
||||||
from project import i10n
|
from project import i10n
|
||||||
from project import jinja_filters
|
from project import jinja_filters
|
||||||
from project import init_data
|
from project import init_data
|
||||||
@ -142,6 +129,9 @@ from project.views import (
|
|||||||
image,
|
image,
|
||||||
manage,
|
manage,
|
||||||
organizer,
|
organizer,
|
||||||
|
oauth,
|
||||||
|
oauth2_client,
|
||||||
|
oauth2_token,
|
||||||
planing,
|
planing,
|
||||||
reference,
|
reference,
|
||||||
reference_request,
|
reference_request,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from flask import abort
|
from flask import abort
|
||||||
|
from flask_login import login_user
|
||||||
from flask_security import current_user
|
from flask_security import current_user
|
||||||
from flask_security.utils import FsPermNeed
|
from flask_security.utils import FsPermNeed
|
||||||
from flask_principal import Permission
|
from flask_principal import Permission
|
||||||
@ -12,6 +13,20 @@ def has_current_user_permission(permission):
|
|||||||
return user_perm.can()
|
return user_perm.can()
|
||||||
|
|
||||||
|
|
||||||
|
def has_owner_access(user_id):
|
||||||
|
return user_id == current_user.id
|
||||||
|
|
||||||
|
|
||||||
|
def owner_access_or_401(user_id):
|
||||||
|
if not has_owner_access(user_id):
|
||||||
|
abort(401)
|
||||||
|
|
||||||
|
|
||||||
|
def login_api_user_or_401(user):
|
||||||
|
if not login_user(user):
|
||||||
|
abort(401)
|
||||||
|
|
||||||
|
|
||||||
def has_admin_unit_member_role(admin_unit_member, role_name):
|
def has_admin_unit_member_role(admin_unit_member, role_name):
|
||||||
for role in admin_unit_member.roles:
|
for role in admin_unit_member.roles:
|
||||||
if role.name == role_name:
|
if role.name == role_name:
|
||||||
|
|||||||
@ -1,4 +1,107 @@
|
|||||||
from project import rest_api, api_docs
|
from flask_restful import Api
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||||
|
from werkzeug.exceptions import HTTPException, UnprocessableEntity
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
from project.utils import get_localized_scope
|
||||||
|
from project import app
|
||||||
|
from flask_marshmallow import Marshmallow
|
||||||
|
from apispec import APISpec
|
||||||
|
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||||
|
from flask_apispec.extension import FlaskApiSpec
|
||||||
|
|
||||||
|
|
||||||
|
class RestApi(Api):
|
||||||
|
def handle_error(self, err):
|
||||||
|
from project.api.schemas import (
|
||||||
|
ErrorResponseSchema,
|
||||||
|
UnprocessableEntityResponseSchema,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = None
|
||||||
|
data = {}
|
||||||
|
code = 500
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(err, IntegrityError)
|
||||||
|
and err.orig
|
||||||
|
and err.orig.pgcode == UNIQUE_VIOLATION
|
||||||
|
):
|
||||||
|
data["name"] = "Unique Violation"
|
||||||
|
data[
|
||||||
|
"message"
|
||||||
|
] = "An entry with the entered values already exists. Duplicate entries are not allowed."
|
||||||
|
code = 400
|
||||||
|
schema = ErrorResponseSchema()
|
||||||
|
elif isinstance(err, HTTPException):
|
||||||
|
data["name"] = err.name
|
||||||
|
data["message"] = err.description
|
||||||
|
code = err.code
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(err, UnprocessableEntity)
|
||||||
|
and err.exc
|
||||||
|
and isinstance(err.exc, ValidationError)
|
||||||
|
):
|
||||||
|
data["name"] = err.name
|
||||||
|
data["message"] = err.description
|
||||||
|
code = err.code
|
||||||
|
schema = UnprocessableEntityResponseSchema()
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(err.exc, "args", None)
|
||||||
|
and isinstance(err.exc.args, tuple)
|
||||||
|
and len(err.exc.args) > 0
|
||||||
|
):
|
||||||
|
arg = err.exc.args[0]
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
errors = []
|
||||||
|
for field, messages in arg.items():
|
||||||
|
if isinstance(messages, list):
|
||||||
|
for message in messages:
|
||||||
|
error = {"field": field, "message": message}
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
|
if len(errors) > 0:
|
||||||
|
data["errors"] = errors
|
||||||
|
else:
|
||||||
|
schema = ErrorResponseSchema()
|
||||||
|
|
||||||
|
# Call default error handler that propagates error further
|
||||||
|
try:
|
||||||
|
super().handle_error(err)
|
||||||
|
except Exception:
|
||||||
|
if not schema:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return schema.dump(data), code
|
||||||
|
|
||||||
|
|
||||||
|
scope_list = [
|
||||||
|
"organizer:write",
|
||||||
|
"place:write",
|
||||||
|
"event:write",
|
||||||
|
]
|
||||||
|
scopes = {k: get_localized_scope(k) for v, k in enumerate(scope_list)}
|
||||||
|
|
||||||
|
rest_api = RestApi(app, "/api/v1", catch_all_404s=True)
|
||||||
|
marshmallow = Marshmallow(app)
|
||||||
|
marshmallow_plugin = MarshmallowPlugin()
|
||||||
|
app.config.update(
|
||||||
|
{
|
||||||
|
"APISPEC_SPEC": APISpec(
|
||||||
|
title="Oveda API",
|
||||||
|
version="0.1.0",
|
||||||
|
plugins=[marshmallow_plugin],
|
||||||
|
openapi_version="2.0",
|
||||||
|
info=dict(
|
||||||
|
description="This API provides endpoints to interact with the Oveda data."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
api_docs = FlaskApiSpec(app)
|
||||||
|
|
||||||
|
|
||||||
def enum_to_properties(self, field, **kwargs):
|
def enum_to_properties(self, field, **kwargs):
|
||||||
@ -17,8 +120,6 @@ def add_api_resource(resource, url, endpoint):
|
|||||||
api_docs.register(resource, endpoint=endpoint)
|
api_docs.register(resource, endpoint=endpoint)
|
||||||
|
|
||||||
|
|
||||||
from project import marshmallow_plugin
|
|
||||||
|
|
||||||
marshmallow_plugin.converter.add_attribute_function(enum_to_properties)
|
marshmallow_plugin.converter.add_attribute_function(enum_to_properties)
|
||||||
|
|
||||||
import project.api.event.resources
|
import project.api.event.resources
|
||||||
@ -26,8 +127,6 @@ import project.api.event_category.resources
|
|||||||
import project.api.event_date.resources
|
import project.api.event_date.resources
|
||||||
import project.api.event_reference.resources
|
import project.api.event_reference.resources
|
||||||
import project.api.dump.resources
|
import project.api.dump.resources
|
||||||
import project.api.image.resources
|
|
||||||
import project.api.location.resources
|
|
||||||
import project.api.organization.resources
|
import project.api.organization.resources
|
||||||
import project.api.organizer.resources
|
import project.api.organizer.resources
|
||||||
import project.api.place.resources
|
import project.api.place.resources
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project.api.event.schemas import EventDumpSchema
|
from project.api.event.schemas import EventDumpSchema
|
||||||
from project.api.place.schemas import PlaceDumpSchema
|
from project.api.place.schemas import PlaceDumpSchema
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from marshmallow import fields, validate
|
from marshmallow import fields, validate
|
||||||
from marshmallow_enum import EnumField
|
from marshmallow_enum import EnumField
|
||||||
from project.models import (
|
from project.models import (
|
||||||
@ -10,7 +10,7 @@ from project.models import (
|
|||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
from project.api.organization.schemas import OrganizationRefSchema
|
from project.api.organization.schemas import OrganizationRefSchema
|
||||||
from project.api.organizer.schemas import OrganizerRefSchema
|
from project.api.organizer.schemas import OrganizerRefSchema
|
||||||
from project.api.image.schemas import ImageRefSchema
|
from project.api.image.schemas import ImageSchema
|
||||||
from project.api.place.schemas import PlaceRefSchema, PlaceSearchItemSchema
|
from project.api.place.schemas import PlaceRefSchema, PlaceSearchItemSchema
|
||||||
from project.api.event_category.schemas import (
|
from project.api.event_category.schemas import (
|
||||||
EventCategoryRefSchema,
|
EventCategoryRefSchema,
|
||||||
@ -55,7 +55,7 @@ class EventSchema(EventBaseSchema):
|
|||||||
organization = fields.Nested(OrganizationRefSchema, attribute="admin_unit")
|
organization = fields.Nested(OrganizationRefSchema, attribute="admin_unit")
|
||||||
organizer = fields.Nested(OrganizerRefSchema)
|
organizer = fields.Nested(OrganizerRefSchema)
|
||||||
place = fields.Nested(PlaceRefSchema, attribute="event_place")
|
place = fields.Nested(PlaceRefSchema, attribute="event_place")
|
||||||
photo = fields.Nested(ImageRefSchema)
|
photo = fields.Nested(ImageSchema)
|
||||||
categories = fields.List(fields.Nested(EventCategoryRefSchema))
|
categories = fields.List(fields.Nested(EventCategoryRefSchema))
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ class EventSearchItemSchema(EventRefSchema):
|
|||||||
start = marshmallow.auto_field()
|
start = marshmallow.auto_field()
|
||||||
end = marshmallow.auto_field()
|
end = marshmallow.auto_field()
|
||||||
recurrence_rule = marshmallow.auto_field()
|
recurrence_rule = marshmallow.auto_field()
|
||||||
photo = fields.Nested(ImageRefSchema)
|
photo = fields.Nested(ImageSchema)
|
||||||
place = fields.Nested(PlaceSearchItemSchema, attribute="event_place")
|
place = fields.Nested(PlaceSearchItemSchema, attribute="event_place")
|
||||||
status = EnumField(EventStatus)
|
status = EnumField(EventStatus)
|
||||||
booked_up = marshmallow.auto_field()
|
booked_up = marshmallow.auto_field()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import EventCategory
|
from project.models import EventCategory
|
||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project.models import EventDate
|
from project.models import EventDate
|
||||||
from project.api.event.schemas import (
|
from project.api.event.schemas import (
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import EventReference
|
from project.models import EventReference
|
||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
from project.api.event.schemas import EventRefSchema
|
from project.api.event.schemas import EventRefSchema
|
||||||
|
|||||||
15
project/api/fields.py
Normal file
15
project/api/fields.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from marshmallow import fields, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class NumericStr(fields.String):
|
||||||
|
def _serialize(self, value, attr, obj, **kwargs):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def _deserialize(self, value, attr, data, **kwargs):
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError as error:
|
||||||
|
raise ValidationError("Must be a numeric value.") from error
|
||||||
@ -1,15 +0,0 @@
|
|||||||
from project.api import add_api_resource
|
|
||||||
from flask_apispec import marshal_with, doc
|
|
||||||
from project.api.resources import BaseResource
|
|
||||||
from project.api.image.schemas import ImageSchema
|
|
||||||
from project.models import Image
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResource(BaseResource):
|
|
||||||
@doc(summary="Get image", tags=["Images"])
|
|
||||||
@marshal_with(ImageSchema)
|
|
||||||
def get(self, id):
|
|
||||||
return Image.query.get_or_404(id)
|
|
||||||
|
|
||||||
|
|
||||||
add_api_resource(ImageResource, "/images/<int:id>", "api_v1_image")
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import Image
|
from project.models import Image
|
||||||
|
|
||||||
|
|
||||||
@ -10,8 +10,6 @@ class ImageIdSchema(marshmallow.SQLAlchemySchema):
|
|||||||
|
|
||||||
|
|
||||||
class ImageBaseSchema(ImageIdSchema):
|
class ImageBaseSchema(ImageIdSchema):
|
||||||
created_at = marshmallow.auto_field()
|
|
||||||
updated_at = marshmallow.auto_field()
|
|
||||||
copyright_text = marshmallow.auto_field()
|
copyright_text = marshmallow.auto_field()
|
||||||
|
|
||||||
|
|
||||||
@ -27,13 +25,3 @@ class ImageSchema(ImageBaseSchema):
|
|||||||
|
|
||||||
class ImageDumpSchema(ImageBaseSchema):
|
class ImageDumpSchema(ImageBaseSchema):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ImageRefSchema(ImageIdSchema):
|
|
||||||
image_url = marshmallow.URLFor(
|
|
||||||
"image",
|
|
||||||
values=dict(id="<id>", s=500),
|
|
||||||
metadata={
|
|
||||||
"description": "Append query arguments w for width, h for height or s for size(width and height)."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,15 +0,0 @@
|
|||||||
from project.api import add_api_resource
|
|
||||||
from flask_apispec import marshal_with, doc
|
|
||||||
from project.api.resources import BaseResource
|
|
||||||
from project.api.location.schemas import LocationSchema
|
|
||||||
from project.models import Location
|
|
||||||
|
|
||||||
|
|
||||||
class LocationResource(BaseResource):
|
|
||||||
@doc(summary="Get location", tags=["Locations"])
|
|
||||||
@marshal_with(LocationSchema)
|
|
||||||
def get(self, id):
|
|
||||||
return Location.query.get_or_404(id)
|
|
||||||
|
|
||||||
|
|
||||||
add_api_resource(LocationResource, "/locations/<int:id>", "api_v1_location")
|
|
||||||
@ -1,38 +1,65 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields, validate
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import Location
|
from project.models import Location
|
||||||
|
from project.api.fields import NumericStr
|
||||||
|
|
||||||
|
|
||||||
class LocationIdSchema(marshmallow.SQLAlchemySchema):
|
class LocationIdSchema(marshmallow.SQLAlchemySchema):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Location
|
model = Location
|
||||||
|
|
||||||
id = marshmallow.auto_field()
|
|
||||||
|
|
||||||
|
|
||||||
class LocationSchema(LocationIdSchema):
|
class LocationSchema(LocationIdSchema):
|
||||||
created_at = marshmallow.auto_field()
|
|
||||||
updated_at = marshmallow.auto_field()
|
|
||||||
street = marshmallow.auto_field()
|
street = marshmallow.auto_field()
|
||||||
postalCode = marshmallow.auto_field()
|
postalCode = marshmallow.auto_field()
|
||||||
city = marshmallow.auto_field()
|
city = marshmallow.auto_field()
|
||||||
state = marshmallow.auto_field()
|
state = marshmallow.auto_field()
|
||||||
country = marshmallow.auto_field()
|
country = marshmallow.auto_field()
|
||||||
longitude = fields.Str()
|
longitude = NumericStr()
|
||||||
latitude = fields.Str()
|
latitude = NumericStr()
|
||||||
|
|
||||||
|
|
||||||
class LocationDumpSchema(LocationSchema):
|
class LocationDumpSchema(LocationSchema):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LocationRefSchema(LocationIdSchema):
|
class LocationSearchItemSchema(LocationSchema):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LocationSearchItemSchema(LocationRefSchema):
|
class LocationPostRequestSchema(marshmallow.SQLAlchemySchema):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Location
|
model = Location
|
||||||
|
|
||||||
longitude = fields.Str()
|
street = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||||
latitude = fields.Str()
|
postalCode = fields.Str(validate=validate.Length(max=10), missing=None)
|
||||||
|
city = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||||
|
state = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||||
|
country = fields.Str(validate=validate.Length(max=255), missing=None)
|
||||||
|
longitude = NumericStr(validate=validate.Range(-180, 180), missing=None)
|
||||||
|
latitude = NumericStr(validate=validate.Range(-90, 90), missing=None)
|
||||||
|
|
||||||
|
|
||||||
|
class LocationPostRequestLoadSchema(LocationPostRequestSchema):
|
||||||
|
class Meta:
|
||||||
|
model = Location
|
||||||
|
load_instance = True
|
||||||
|
|
||||||
|
|
||||||
|
class LocationPatchRequestSchema(marshmallow.SQLAlchemySchema):
|
||||||
|
class Meta:
|
||||||
|
model = Location
|
||||||
|
|
||||||
|
street = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||||
|
postalCode = fields.Str(validate=validate.Length(max=10), allow_none=True)
|
||||||
|
city = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||||
|
state = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||||
|
country = fields.Str(validate=validate.Length(max=255), allow_none=True)
|
||||||
|
longitude = NumericStr(validate=validate.Range(-180, 180), allow_none=True)
|
||||||
|
latitude = NumericStr(validate=validate.Range(-90, 90), allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LocationPatchRequestLoadSchema(LocationPatchRequestSchema):
|
||||||
|
class Meta:
|
||||||
|
model = Location
|
||||||
|
load_instance = True
|
||||||
|
|||||||
@ -27,7 +27,13 @@ from project.services.reference import (
|
|||||||
get_reference_incoming_query,
|
get_reference_incoming_query,
|
||||||
get_reference_outgoing_query,
|
get_reference_outgoing_query,
|
||||||
)
|
)
|
||||||
from project.api.place.schemas import PlaceListRequestSchema, PlaceListResponseSchema
|
from project.api.place.schemas import (
|
||||||
|
PlaceListRequestSchema,
|
||||||
|
PlaceListResponseSchema,
|
||||||
|
PlaceIdSchema,
|
||||||
|
PlacePostRequestSchema,
|
||||||
|
PlacePostRequestLoadSchema,
|
||||||
|
)
|
||||||
from project.services.event import get_event_dates_query, get_events_query
|
from project.services.event import get_event_dates_query, get_events_query
|
||||||
from project.services.event_search import EventSearchParams
|
from project.services.event_search import EventSearchParams
|
||||||
from project.services.admin_unit import (
|
from project.services.admin_unit import (
|
||||||
@ -35,6 +41,14 @@ from project.services.admin_unit import (
|
|||||||
get_organizer_query,
|
get_organizer_query,
|
||||||
get_place_query,
|
get_place_query,
|
||||||
)
|
)
|
||||||
|
from project.oauth2 import require_oauth
|
||||||
|
from authlib.integrations.flask_oauth2 import current_token
|
||||||
|
from project import db
|
||||||
|
from project.access import (
|
||||||
|
access_or_401,
|
||||||
|
get_admin_unit_for_manage_or_404,
|
||||||
|
login_api_user_or_401,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrganizationResource(BaseResource):
|
class OrganizationResource(BaseResource):
|
||||||
@ -113,6 +127,26 @@ class OrganizationPlaceListResource(BaseResource):
|
|||||||
pagination = get_place_query(admin_unit.id, name).paginate()
|
pagination = get_place_query(admin_unit.id, name).paginate()
|
||||||
return pagination
|
return pagination
|
||||||
|
|
||||||
|
@doc(
|
||||||
|
summary="Add new place",
|
||||||
|
tags=["Organizations", "Places"],
|
||||||
|
security=[{"oauth2": ["place:write"]}],
|
||||||
|
)
|
||||||
|
@use_kwargs(PlacePostRequestSchema, location="json")
|
||||||
|
@marshal_with(PlaceIdSchema, 201)
|
||||||
|
@require_oauth("place:write")
|
||||||
|
def post(self, id, **kwargs):
|
||||||
|
login_api_user_or_401(current_token.user)
|
||||||
|
admin_unit = get_admin_unit_for_manage_or_404(id)
|
||||||
|
access_or_401(admin_unit, "place:create")
|
||||||
|
|
||||||
|
place = PlacePostRequestLoadSchema().load(kwargs, session=db.session)
|
||||||
|
place.admin_unit_id = admin_unit.id
|
||||||
|
db.session.add(place)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return place, 201
|
||||||
|
|
||||||
|
|
||||||
class OrganizationIncomingEventReferenceListResource(BaseResource):
|
class OrganizationIncomingEventReferenceListResource(BaseResource):
|
||||||
@doc(
|
@doc(
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import AdminUnit
|
from project.models import AdminUnit
|
||||||
from project.api.location.schemas import LocationRefSchema
|
from project.api.location.schemas import LocationSchema
|
||||||
from project.api.image.schemas import ImageRefSchema
|
from project.api.image.schemas import ImageSchema
|
||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +25,8 @@ class OrganizationBaseSchema(OrganizationIdSchema):
|
|||||||
|
|
||||||
|
|
||||||
class OrganizationSchema(OrganizationBaseSchema):
|
class OrganizationSchema(OrganizationBaseSchema):
|
||||||
location = fields.Nested(LocationRefSchema)
|
location = fields.Nested(LocationSchema)
|
||||||
logo = fields.Nested(ImageRefSchema)
|
logo = fields.Nested(ImageSchema)
|
||||||
|
|
||||||
|
|
||||||
class OrganizationDumpSchema(OrganizationBaseSchema):
|
class OrganizationDumpSchema(OrganizationBaseSchema):
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import EventOrganizer
|
from project.models import EventOrganizer
|
||||||
from project.api.location.schemas import LocationRefSchema
|
from project.api.location.schemas import LocationSchema
|
||||||
from project.api.image.schemas import ImageRefSchema
|
from project.api.image.schemas import ImageSchema
|
||||||
from project.api.organization.schemas import OrganizationRefSchema
|
from project.api.organization.schemas import OrganizationRefSchema
|
||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
|
|
||||||
@ -25,8 +25,8 @@ class OrganizerBaseSchema(OrganizerIdSchema):
|
|||||||
|
|
||||||
|
|
||||||
class OrganizerSchema(OrganizerBaseSchema):
|
class OrganizerSchema(OrganizerBaseSchema):
|
||||||
location = fields.Nested(LocationRefSchema)
|
location = fields.Nested(LocationSchema)
|
||||||
logo = fields.Nested(ImageRefSchema)
|
logo = fields.Nested(ImageSchema)
|
||||||
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,19 @@
|
|||||||
from project.api import add_api_resource
|
from project.api import add_api_resource
|
||||||
from flask_apispec import marshal_with, doc
|
from flask import make_response
|
||||||
|
from flask_apispec import marshal_with, doc, use_kwargs
|
||||||
from project.api.resources import BaseResource
|
from project.api.resources import BaseResource
|
||||||
from project.api.place.schemas import PlaceSchema
|
from project.api.place.schemas import (
|
||||||
|
PlaceSchema,
|
||||||
|
PlacePostRequestSchema,
|
||||||
|
PlacePostRequestLoadSchema,
|
||||||
|
PlacePatchRequestSchema,
|
||||||
|
PlacePatchRequestLoadSchema,
|
||||||
|
)
|
||||||
from project.models import EventPlace
|
from project.models import EventPlace
|
||||||
|
from project.oauth2 import require_oauth
|
||||||
|
from authlib.integrations.flask_oauth2 import current_token
|
||||||
|
from project import db
|
||||||
|
from project.access import access_or_401, login_api_user_or_401
|
||||||
|
|
||||||
|
|
||||||
class PlaceResource(BaseResource):
|
class PlaceResource(BaseResource):
|
||||||
@ -11,5 +22,54 @@ class PlaceResource(BaseResource):
|
|||||||
def get(self, id):
|
def get(self, id):
|
||||||
return EventPlace.query.get_or_404(id)
|
return EventPlace.query.get_or_404(id)
|
||||||
|
|
||||||
|
@doc(
|
||||||
|
summary="Update place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
|
||||||
|
)
|
||||||
|
@use_kwargs(PlacePostRequestSchema, location="json")
|
||||||
|
@marshal_with(None, 204)
|
||||||
|
@require_oauth("place:write")
|
||||||
|
def put(self, id, **kwargs):
|
||||||
|
login_api_user_or_401(current_token.user)
|
||||||
|
place = EventPlace.query.get_or_404(id)
|
||||||
|
access_or_401(place.adminunit, "place:update")
|
||||||
|
|
||||||
|
place = PlacePostRequestLoadSchema().load(
|
||||||
|
kwargs, session=db.session, instance=place
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return make_response("", 204)
|
||||||
|
|
||||||
|
@doc(summary="Patch place", tags=["Places"], security=[{"oauth2": ["place:write"]}])
|
||||||
|
@use_kwargs(PlacePatchRequestSchema, location="json")
|
||||||
|
@marshal_with(None, 204)
|
||||||
|
@require_oauth("place:write")
|
||||||
|
def patch(self, id, **kwargs):
|
||||||
|
login_api_user_or_401(current_token.user)
|
||||||
|
place = EventPlace.query.get_or_404(id)
|
||||||
|
access_or_401(place.adminunit, "place:update")
|
||||||
|
|
||||||
|
place = PlacePatchRequestLoadSchema().load(
|
||||||
|
kwargs, session=db.session, instance=place
|
||||||
|
)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return make_response("", 204)
|
||||||
|
|
||||||
|
@doc(
|
||||||
|
summary="Delete place", tags=["Places"], security=[{"oauth2": ["place:write"]}]
|
||||||
|
)
|
||||||
|
@marshal_with(None, 204)
|
||||||
|
@require_oauth("place:write")
|
||||||
|
def delete(self, id):
|
||||||
|
login_api_user_or_401(current_token.user)
|
||||||
|
place = EventPlace.query.get_or_404(id)
|
||||||
|
access_or_401(place.adminunit, "place:delete")
|
||||||
|
|
||||||
|
db.session.delete(place)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return make_response("", 204)
|
||||||
|
|
||||||
|
|
||||||
add_api_resource(PlaceResource, "/places/<int:id>", "api_v1_place")
|
add_api_resource(PlaceResource, "/places/<int:id>", "api_v1_place")
|
||||||
|
|||||||
@ -1,8 +1,15 @@
|
|||||||
from marshmallow import fields
|
from marshmallow import fields, validate
|
||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from project.models import EventPlace
|
from project.models import EventPlace
|
||||||
from project.api.image.schemas import ImageRefSchema
|
from project.api.image.schemas import ImageSchema
|
||||||
from project.api.location.schemas import LocationRefSchema, LocationSearchItemSchema
|
from project.api.location.schemas import (
|
||||||
|
LocationSchema,
|
||||||
|
LocationSearchItemSchema,
|
||||||
|
LocationPostRequestSchema,
|
||||||
|
LocationPostRequestLoadSchema,
|
||||||
|
LocationPatchRequestSchema,
|
||||||
|
LocationPatchRequestLoadSchema,
|
||||||
|
)
|
||||||
from project.api.organization.schemas import OrganizationRefSchema
|
from project.api.organization.schemas import OrganizationRefSchema
|
||||||
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
from project.api.schemas import PaginationRequestSchema, PaginationResponseSchema
|
||||||
|
|
||||||
@ -23,8 +30,8 @@ class PlaceBaseSchema(PlaceIdSchema):
|
|||||||
|
|
||||||
|
|
||||||
class PlaceSchema(PlaceBaseSchema):
|
class PlaceSchema(PlaceBaseSchema):
|
||||||
location = fields.Nested(LocationRefSchema)
|
location = fields.Nested(LocationSchema)
|
||||||
photo = fields.Nested(ImageRefSchema)
|
photo = fields.Nested(ImageSchema)
|
||||||
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
organization = fields.Nested(OrganizationRefSchema, attribute="adminunit")
|
||||||
|
|
||||||
|
|
||||||
@ -55,3 +62,41 @@ class PlaceListResponseSchema(PaginationResponseSchema):
|
|||||||
items = fields.List(
|
items = fields.List(
|
||||||
fields.Nested(PlaceRefSchema), metadata={"description": "Places"}
|
fields.Nested(PlaceRefSchema), metadata={"description": "Places"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlacePostRequestSchema(marshmallow.SQLAlchemySchema):
|
||||||
|
class Meta:
|
||||||
|
model = EventPlace
|
||||||
|
|
||||||
|
name = fields.Str(required=True, validate=validate.Length(min=3, max=255))
|
||||||
|
url = fields.Str(validate=[validate.URL(), validate.Length(max=255)], missing=None)
|
||||||
|
description = fields.Str(missing=None)
|
||||||
|
location = fields.Nested(LocationPostRequestSchema, missing=None)
|
||||||
|
|
||||||
|
|
||||||
|
class PlacePostRequestLoadSchema(PlacePostRequestSchema):
|
||||||
|
class Meta:
|
||||||
|
model = EventPlace
|
||||||
|
load_instance = True
|
||||||
|
|
||||||
|
location = fields.Nested(LocationPostRequestLoadSchema, missing=None)
|
||||||
|
|
||||||
|
|
||||||
|
class PlacePatchRequestSchema(marshmallow.SQLAlchemySchema):
|
||||||
|
class Meta:
|
||||||
|
model = EventPlace
|
||||||
|
|
||||||
|
name = fields.Str(validate=validate.Length(min=3, max=255), allow_none=True)
|
||||||
|
url = fields.Str(
|
||||||
|
validate=[validate.URL(), validate.Length(max=255)], allow_none=True
|
||||||
|
)
|
||||||
|
description = fields.Str(allow_none=True)
|
||||||
|
location = fields.Nested(LocationPatchRequestSchema, allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PlacePatchRequestLoadSchema(PlacePatchRequestSchema):
|
||||||
|
class Meta:
|
||||||
|
model = EventPlace
|
||||||
|
load_instance = True
|
||||||
|
|
||||||
|
location = fields.Nested(LocationPatchRequestLoadSchema, allow_none=True)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
|
from flask_apispec import marshal_with
|
||||||
from flask_apispec.views import MethodResource
|
from flask_apispec.views import MethodResource
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from project.api.schemas import ErrorResponseSchema, UnprocessableEntityResponseSchema
|
||||||
|
|
||||||
|
|
||||||
def etag_cache(func):
|
def etag_cache(func):
|
||||||
@ -13,5 +15,7 @@ def etag_cache(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@marshal_with(ErrorResponseSchema, 400, "Bad Request")
|
||||||
|
@marshal_with(UnprocessableEntityResponseSchema, 422, "Unprocessable Entity")
|
||||||
class BaseResource(MethodResource):
|
class BaseResource(MethodResource):
|
||||||
decorators = [etag_cache]
|
decorators = [etag_cache]
|
||||||
|
|||||||
@ -1,7 +1,21 @@
|
|||||||
from project import marshmallow
|
from project.api import marshmallow
|
||||||
from marshmallow import fields, validate
|
from marshmallow import fields, validate
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponseSchema(marshmallow.Schema):
|
||||||
|
name = fields.Str()
|
||||||
|
message = fields.Str()
|
||||||
|
|
||||||
|
|
||||||
|
class UnprocessableEntityErrorSchema(marshmallow.Schema):
|
||||||
|
field = fields.Str()
|
||||||
|
message = fields.Str()
|
||||||
|
|
||||||
|
|
||||||
|
class UnprocessableEntityResponseSchema(ErrorResponseSchema):
|
||||||
|
errors = fields.List(fields.Nested(UnprocessableEntityErrorSchema))
|
||||||
|
|
||||||
|
|
||||||
class PaginationRequestSchema(marshmallow.Schema):
|
class PaginationRequestSchema(marshmallow.Schema):
|
||||||
page = fields.Integer(
|
page = fields.Integer(
|
||||||
required=False,
|
required=False,
|
||||||
|
|||||||
92
project/forms/oauth2_client.py
Normal file
92
project/forms/oauth2_client.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from flask_wtf import FlaskForm
|
||||||
|
from flask_babelex import lazy_gettext
|
||||||
|
from wtforms import StringField, TextAreaField, SubmitField, SelectField
|
||||||
|
from wtforms.validators import Optional, DataRequired
|
||||||
|
from project.forms.widgets import MultiCheckboxField
|
||||||
|
from project.api import scopes
|
||||||
|
from project.utils import split_by_crlf
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOAuth2ClientForm(FlaskForm):
|
||||||
|
client_name = StringField(lazy_gettext("Client name"), validators=[DataRequired()])
|
||||||
|
redirect_uris = TextAreaField(
|
||||||
|
lazy_gettext("Redirect URIs"), validators=[Optional()]
|
||||||
|
)
|
||||||
|
grant_types = MultiCheckboxField(
|
||||||
|
lazy_gettext("Grant types"),
|
||||||
|
validators=[DataRequired()],
|
||||||
|
choices=[
|
||||||
|
("authorization_code", lazy_gettext("Authorization Code")),
|
||||||
|
("refresh_token", lazy_gettext("Refresh Token")),
|
||||||
|
],
|
||||||
|
default=["authorization_code", "refresh_token"],
|
||||||
|
)
|
||||||
|
response_types = MultiCheckboxField(
|
||||||
|
lazy_gettext("Response types"),
|
||||||
|
validators=[DataRequired()],
|
||||||
|
choices=[
|
||||||
|
("code", "code"),
|
||||||
|
],
|
||||||
|
default=["code"],
|
||||||
|
)
|
||||||
|
scope = MultiCheckboxField(
|
||||||
|
lazy_gettext("Scopes"),
|
||||||
|
validators=[DataRequired()],
|
||||||
|
choices=[(k, k) for k, v in scopes.items()],
|
||||||
|
)
|
||||||
|
token_endpoint_auth_method = SelectField(
|
||||||
|
lazy_gettext("Token endpoint auth method"),
|
||||||
|
validators=[DataRequired()],
|
||||||
|
choices=[
|
||||||
|
("client_secret_post", lazy_gettext("Client secret post")),
|
||||||
|
("client_secret_basic", lazy_gettext("Client secret basic")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
submit = SubmitField(lazy_gettext("Save"))
|
||||||
|
|
||||||
|
def populate_obj(self, obj):
|
||||||
|
meta_keys = [
|
||||||
|
"client_name",
|
||||||
|
"client_uri",
|
||||||
|
"grant_types",
|
||||||
|
"redirect_uris",
|
||||||
|
"response_types",
|
||||||
|
"scope",
|
||||||
|
"token_endpoint_auth_method",
|
||||||
|
]
|
||||||
|
metadata = dict()
|
||||||
|
for name, field in self._fields.items():
|
||||||
|
if name in meta_keys:
|
||||||
|
if name == "redirect_uris":
|
||||||
|
metadata[name] = split_by_crlf(field.data)
|
||||||
|
elif name == "scope":
|
||||||
|
metadata[name] = " ".join(field.data)
|
||||||
|
else:
|
||||||
|
metadata[name] = field.data
|
||||||
|
else:
|
||||||
|
field.populate_obj(obj, name)
|
||||||
|
obj.set_client_metadata(metadata)
|
||||||
|
|
||||||
|
def process(self, formdata=None, obj=None, data=None, **kwargs):
|
||||||
|
super().process(formdata, obj, data, **kwargs)
|
||||||
|
|
||||||
|
if not obj:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.redirect_uris.data = os.linesep.join(obj.redirect_uris)
|
||||||
|
self.scope.data = obj.scope.split(" ")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateOAuth2ClientForm(BaseOAuth2ClientForm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateOAuth2ClientForm(BaseOAuth2ClientForm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteOAuth2ClientForm(FlaskForm):
|
||||||
|
submit = SubmitField(lazy_gettext("Delete OAuth2 client"))
|
||||||
|
name = StringField(lazy_gettext("Name"), validators=[DataRequired()])
|
||||||
7
project/forms/oauth2_token.py
Normal file
7
project/forms/oauth2_token.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from flask_wtf import FlaskForm
|
||||||
|
from flask_babelex import lazy_gettext
|
||||||
|
from wtforms import SubmitField
|
||||||
|
|
||||||
|
|
||||||
|
class RevokeOAuth2TokenForm(FlaskForm):
|
||||||
|
submit = SubmitField(lazy_gettext("Revoke OAuth2 token"))
|
||||||
@ -1,7 +1,9 @@
|
|||||||
from flask_security.forms import RegisterForm, EqualTo, get_form_field_label
|
from flask_security.forms import RegisterForm, EqualTo, get_form_field_label
|
||||||
from wtforms import BooleanField, PasswordField
|
from wtforms import BooleanField, PasswordField, SubmitField
|
||||||
from wtforms.validators import DataRequired
|
from wtforms.validators import DataRequired
|
||||||
from project.forms.common import get_accept_tos_markup
|
from project.forms.common import get_accept_tos_markup
|
||||||
|
from flask_wtf import FlaskForm
|
||||||
|
from flask_babelex import lazy_gettext
|
||||||
|
|
||||||
|
|
||||||
class ExtendedRegisterForm(RegisterForm):
|
class ExtendedRegisterForm(RegisterForm):
|
||||||
@ -20,3 +22,8 @@ class ExtendedRegisterForm(RegisterForm):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ExtendedRegisterForm, self).__init__(*args, **kwargs)
|
super(ExtendedRegisterForm, self).__init__(*args, **kwargs)
|
||||||
self._fields["accept_tos"].label.text = get_accept_tos_markup()
|
self._fields["accept_tos"].label.text = get_accept_tos_markup()
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizeForm(FlaskForm):
|
||||||
|
allow = SubmitField(lazy_gettext("Allow"))
|
||||||
|
deny = SubmitField(lazy_gettext("Deny"))
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -39,3 +39,8 @@ def print_dynamic_texts():
|
|||||||
gettext("EventReviewStatus.inbox")
|
gettext("EventReviewStatus.inbox")
|
||||||
gettext("EventReviewStatus.verified")
|
gettext("EventReviewStatus.verified")
|
||||||
gettext("EventReviewStatus.rejected")
|
gettext("EventReviewStatus.rejected")
|
||||||
|
gettext("read")
|
||||||
|
gettext("write")
|
||||||
|
gettext("Event")
|
||||||
|
gettext("Organizer")
|
||||||
|
gettext("Place")
|
||||||
|
|||||||
@ -1,8 +1,27 @@
|
|||||||
from project import app, db
|
from project import app, db
|
||||||
|
from project.api import api_docs, scopes
|
||||||
from project.services.user import upsert_user_role
|
from project.services.user import upsert_user_role
|
||||||
from project.services.admin_unit import upsert_admin_unit_member_role
|
from project.services.admin_unit import upsert_admin_unit_member_role
|
||||||
from project.services.event import upsert_event_category
|
from project.services.event import upsert_event_category
|
||||||
from project.models import Location
|
from project.models import Location
|
||||||
|
from flask import url_for
|
||||||
|
from apispec.exceptions import DuplicateComponentNameError
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_first_request
|
||||||
|
def add_oauth2_scheme():
|
||||||
|
oauth2_scheme = {
|
||||||
|
"type": "oauth2",
|
||||||
|
"authorizationUrl": url_for("authorize", _external=True),
|
||||||
|
"tokenUrl": url_for("issue_token", _external=True),
|
||||||
|
"flow": "accessCode",
|
||||||
|
"scopes": scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_docs.spec.components.security_scheme("oauth2", oauth2_scheme)
|
||||||
|
except DuplicateComponentNameError: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@app.before_first_request
|
@app.before_first_request
|
||||||
@ -37,12 +56,23 @@ def create_initial_data():
|
|||||||
"reference_request:delete",
|
"reference_request:delete",
|
||||||
"reference_request:verify",
|
"reference_request:verify",
|
||||||
]
|
]
|
||||||
|
early_adopter_permissions = [
|
||||||
|
"oauth2_client:create",
|
||||||
|
"oauth2_client:read",
|
||||||
|
"oauth2_client:update",
|
||||||
|
"oauth2_client:delete",
|
||||||
|
"oauth2_token:create",
|
||||||
|
"oauth2_token:read",
|
||||||
|
"oauth2_token:update",
|
||||||
|
"oauth2_token:delete",
|
||||||
|
]
|
||||||
|
|
||||||
upsert_admin_unit_member_role("admin", "Administrator", admin_permissions)
|
upsert_admin_unit_member_role("admin", "Administrator", admin_permissions)
|
||||||
upsert_admin_unit_member_role("event_verifier", "Event expert", event_permissions)
|
upsert_admin_unit_member_role("event_verifier", "Event expert", event_permissions)
|
||||||
|
|
||||||
upsert_user_role("admin", "Administrator", admin_permissions)
|
upsert_user_role("admin", "Administrator", admin_permissions)
|
||||||
upsert_user_role("event_verifier", "Event expert", event_permissions)
|
upsert_user_role("event_verifier", "Event expert", event_permissions)
|
||||||
|
upsert_user_role("early_adopter", "Early Adopter", early_adopter_permissions)
|
||||||
|
|
||||||
Location.update_coordinates()
|
Location.update_coordinates()
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
from project import app
|
from project import app
|
||||||
from project.utils import get_event_category_name, get_localized_enum_name
|
from project.utils import (
|
||||||
|
get_event_category_name,
|
||||||
|
get_localized_enum_name,
|
||||||
|
get_localized_scope,
|
||||||
|
)
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -8,10 +12,16 @@ def env_override(value, key):
|
|||||||
return os.getenv(key, value)
|
return os.getenv(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def is_list(value):
|
||||||
|
return isinstance(value, list)
|
||||||
|
|
||||||
|
|
||||||
app.jinja_env.filters["event_category_name"] = lambda u: get_event_category_name(u)
|
app.jinja_env.filters["event_category_name"] = lambda u: get_event_category_name(u)
|
||||||
app.jinja_env.filters["loc_enum"] = lambda u: get_localized_enum_name(u)
|
app.jinja_env.filters["loc_enum"] = lambda u: get_localized_enum_name(u)
|
||||||
|
app.jinja_env.filters["loc_scope"] = lambda s: get_localized_scope(s)
|
||||||
app.jinja_env.filters["env_override"] = env_override
|
app.jinja_env.filters["env_override"] = env_override
|
||||||
app.jinja_env.filters["quote_plus"] = lambda u: quote_plus(u)
|
app.jinja_env.filters["quote_plus"] = lambda u: quote_plus(u)
|
||||||
|
app.jinja_env.filters["is_list"] = is_list
|
||||||
|
|
||||||
|
|
||||||
@app.context_processor
|
@app.context_processor
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from project import db
|
from project import db
|
||||||
from sqlalchemy.ext.declarative import declared_attr
|
from sqlalchemy.ext.declarative import declared_attr
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import relationship, backref, deferred
|
from sqlalchemy.orm import relationship, backref, deferred, object_session
|
||||||
from sqlalchemy.schema import CheckConstraint
|
from sqlalchemy.schema import CheckConstraint
|
||||||
from sqlalchemy.event import listens_for
|
from sqlalchemy.event import listens_for
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
@ -24,6 +24,12 @@ import datetime
|
|||||||
from project.dbtypes import IntegerEnum
|
from project.dbtypes import IntegerEnum
|
||||||
from geoalchemy2 import Geometry
|
from geoalchemy2 import Geometry
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
from authlib.integrations.sqla_oauth2 import (
|
||||||
|
OAuth2ClientMixin,
|
||||||
|
OAuth2AuthorizationCodeMixin,
|
||||||
|
OAuth2TokenMixin,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
# Base
|
# Base
|
||||||
|
|
||||||
@ -156,6 +162,12 @@ class User(db.Model, UserMixin):
|
|||||||
"Role", secondary="roles_users", backref=backref("users", lazy="dynamic")
|
"Role", secondary="roles_users", backref=backref("users", lazy="dynamic")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_user_id(self):
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
|
||||||
|
# OAuth Consumer: Wenn wir OAuth consumen und sich ein Nutzer per Google oder Facebook anmelden möchte
|
||||||
|
|
||||||
|
|
||||||
class OAuth(OAuthConsumerMixin, db.Model):
|
class OAuth(OAuthConsumerMixin, db.Model):
|
||||||
provider_user_id = Column(String(256), unique=True, nullable=False)
|
provider_user_id = Column(String(256), unique=True, nullable=False)
|
||||||
@ -163,6 +175,51 @@ class OAuth(OAuthConsumerMixin, db.Model):
|
|||||||
user = db.relationship("User")
|
user = db.relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
# OAuth Server: Wir bieten an, dass sich ein Nutzer per OAuth2 auf unserer Seite anmeldet
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Client(db.Model, OAuth2ClientMixin):
|
||||||
|
__tablename__ = "oauth2_client"
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||||
|
user = db.relationship("User")
|
||||||
|
|
||||||
|
def check_redirect_uri(self, redirect_uri):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin):
|
||||||
|
__tablename__ = "oauth2_code"
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||||
|
user = db.relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Token(db.Model, OAuth2TokenMixin):
|
||||||
|
__tablename__ = "oauth2_token"
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
|
||||||
|
user = db.relationship("User")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
return (
|
||||||
|
object_session(self)
|
||||||
|
.query(OAuth2Client)
|
||||||
|
.filter(OAuth2Client.client_id == self.client_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_refresh_token_active(self):
|
||||||
|
if self.revoked:
|
||||||
|
return False
|
||||||
|
expires_at = self.issued_at + self.expires_in * 2
|
||||||
|
return expires_at >= time.time()
|
||||||
|
|
||||||
|
|
||||||
# Admin Unit
|
# Admin Unit
|
||||||
|
|
||||||
|
|
||||||
@ -298,6 +355,7 @@ def update_location_coordinate(mapper, connect, self):
|
|||||||
# Events
|
# Events
|
||||||
class EventPlace(db.Model, TrackableMixin):
|
class EventPlace(db.Model, TrackableMixin):
|
||||||
__tablename__ = "eventplace"
|
__tablename__ = "eventplace"
|
||||||
|
__table_args__ = (UniqueConstraint("name", "admin_unit_id"),)
|
||||||
id = Column(Integer(), primary_key=True)
|
id = Column(Integer(), primary_key=True)
|
||||||
name = Column(Unicode(255), nullable=False)
|
name = Column(Unicode(255), nullable=False)
|
||||||
location_id = db.Column(db.Integer, db.ForeignKey("location.id"))
|
location_id = db.Column(db.Integer, db.ForeignKey("location.id"))
|
||||||
|
|||||||
114
project/oauth2.py
Normal file
114
project/oauth2.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from authlib.integrations.flask_oauth2 import (
|
||||||
|
AuthorizationServer,
|
||||||
|
ResourceProtector,
|
||||||
|
)
|
||||||
|
from authlib.integrations.sqla_oauth2 import (
|
||||||
|
create_query_client_func,
|
||||||
|
create_save_token_func,
|
||||||
|
create_bearer_token_validator,
|
||||||
|
create_query_token_func,
|
||||||
|
)
|
||||||
|
from authlib.oauth2.rfc6749 import grants
|
||||||
|
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||||
|
from project import db
|
||||||
|
from project.models import User, OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||||
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||||||
|
"client_secret_basic",
|
||||||
|
"client_secret_post",
|
||||||
|
"none",
|
||||||
|
]
|
||||||
|
|
||||||
|
def save_authorization_code(self, code, request):
|
||||||
|
code_challenge = request.data.get("code_challenge")
|
||||||
|
code_challenge_method = request.data.get("code_challenge_method")
|
||||||
|
auth_code = OAuth2AuthorizationCode(
|
||||||
|
code=code,
|
||||||
|
client_id=request.client.client_id,
|
||||||
|
redirect_uri=request.redirect_uri,
|
||||||
|
scope=request.scope,
|
||||||
|
user_id=request.user.id,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
)
|
||||||
|
db.session.add(auth_code)
|
||||||
|
db.session.commit()
|
||||||
|
return auth_code
|
||||||
|
|
||||||
|
def query_authorization_code(self, code, client):
|
||||||
|
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
||||||
|
code=code, client_id=client.client_id
|
||||||
|
).first()
|
||||||
|
if auth_code and not auth_code.is_expired():
|
||||||
|
return auth_code
|
||||||
|
|
||||||
|
def delete_authorization_code(self, authorization_code):
|
||||||
|
db.session.delete(authorization_code)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def authenticate_user(self, authorization_code):
|
||||||
|
return User.query.get(authorization_code.user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||||
|
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
|
||||||
|
|
||||||
|
def authenticate_refresh_token(self, refresh_token):
|
||||||
|
token = OAuth2Token.query.filter_by(refresh_token=refresh_token).first()
|
||||||
|
if token and token.is_refresh_token_active():
|
||||||
|
return token
|
||||||
|
|
||||||
|
def authenticate_user(self, credential):
|
||||||
|
return User.query.get(credential.user_id)
|
||||||
|
|
||||||
|
def revoke_old_credential(self, credential):
|
||||||
|
credential.revoked = True
|
||||||
|
db.session.add(credential)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
query_client = create_query_client_func(db.session, OAuth2Client)
|
||||||
|
save_token = create_save_token_func(db.session, OAuth2Token)
|
||||||
|
authorization = AuthorizationServer(
|
||||||
|
query_client=query_client,
|
||||||
|
save_token=save_token,
|
||||||
|
)
|
||||||
|
require_oauth = ResourceProtector()
|
||||||
|
|
||||||
|
|
||||||
|
def create_revocation_endpoint(session, token_model):
|
||||||
|
from authlib.oauth2.rfc7009 import RevocationEndpoint
|
||||||
|
|
||||||
|
query_token = create_query_token_func(session, token_model)
|
||||||
|
|
||||||
|
class _RevocationEndpoint(RevocationEndpoint):
|
||||||
|
CLIENT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
|
||||||
|
|
||||||
|
def query_token(self, token, token_type_hint, client):
|
||||||
|
return query_token(token, token_type_hint, client)
|
||||||
|
|
||||||
|
def revoke_token(self, token):
|
||||||
|
token.revoked = True
|
||||||
|
session.add(token)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return _RevocationEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
def config_oauth(app):
|
||||||
|
app.config["OAUTH2_REFRESH_TOKEN_GENERATOR"] = True
|
||||||
|
authorization.init_app(app)
|
||||||
|
|
||||||
|
# support grants
|
||||||
|
authorization.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
|
||||||
|
authorization.register_grant(RefreshTokenGrant)
|
||||||
|
|
||||||
|
# support revocation
|
||||||
|
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token)
|
||||||
|
authorization.register_endpoint(revocation_cls)
|
||||||
|
|
||||||
|
# protect resource
|
||||||
|
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
|
||||||
|
require_oauth.register_token_validator(bearer_cls())
|
||||||
12
project/services/oauth2_client.py
Normal file
12
project/services/oauth2_client.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from project.models import OAuth2Client
|
||||||
|
import time
|
||||||
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
|
||||||
|
def complete_oauth2_client(oauth2_client: OAuth2Client) -> None:
|
||||||
|
if not oauth2_client.id:
|
||||||
|
oauth2_client.client_id = gen_salt(24)
|
||||||
|
oauth2_client.client_id_issued_at = int(time.time())
|
||||||
|
|
||||||
|
if oauth2_client.client_secret is None:
|
||||||
|
oauth2_client.client_secret = gen_salt(48)
|
||||||
@ -17,7 +17,7 @@ def add_roles_to_user(email, roles):
|
|||||||
|
|
||||||
|
|
||||||
def add_admin_roles_to_user(email):
|
def add_admin_roles_to_user(email):
|
||||||
add_roles_to_user(email, ["admin", "event_verifier"])
|
add_roles_to_user(email, ["admin", "event_verifier", "early_adopter"])
|
||||||
|
|
||||||
|
|
||||||
def remove_roles_from_user(email, roles):
|
def remove_roles_from_user(email, roles):
|
||||||
|
|||||||
@ -160,6 +160,21 @@
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
{% endmacro %}
|
{% endmacro %}
|
||||||
|
|
||||||
|
{% macro render_kv_begin() %}
|
||||||
|
<dl class="row">
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
|
{% macro render_kv_end() %}
|
||||||
|
</dl>
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
|
{% macro render_kv_prop(prop, label_key = None) %}
|
||||||
|
{% if prop %}
|
||||||
|
<dt class="col-sm-3">{{ _(label_key) }}</dt>
|
||||||
|
<dd class="col-sm-9">{% if prop|is_list %}{{ prop|join(', ') }}{% else %}{{ prop }}{% endif %}</dd>
|
||||||
|
{% endif %}
|
||||||
|
{% endmacro %}
|
||||||
|
|
||||||
{% macro render_string_prop(prop, icon = None, label_key = None) %}
|
{% macro render_string_prop(prop, icon = None, label_key = None) %}
|
||||||
{% if prop %}
|
{% if prop %}
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
26
project/templates/oauth2_client/create.html
Normal file
26
project/templates/oauth2_client/create.html
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||||
|
{% block title %}
|
||||||
|
{{ _('Create OAuth2 client') }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<h1>{{ _('Create OAuth2 client') }}</h1>
|
||||||
|
|
||||||
|
<form action="" method="POST">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
|
||||||
|
<div class="card mb-4">
|
||||||
|
<div class="card-body">
|
||||||
|
{{ render_field_with_errors(form.client_name) }}
|
||||||
|
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
|
||||||
|
{{ render_field_with_errors(form.redirect_uris) }}
|
||||||
|
|
||||||
|
|
||||||
|
{{ render_field(form.submit) }}
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
24
project/templates/oauth2_client/delete.html
Normal file
24
project/templates/oauth2_client/delete.html
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<h1>{{ _('Delete OAuth2 client') }} "{{ oauth2_client.client_name }}"</h1>
|
||||||
|
|
||||||
|
<form action="" method="POST">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
|
||||||
|
<div class="card mb-4">
|
||||||
|
<div class="card-header">
|
||||||
|
{{ _('OAuth2 client') }}
|
||||||
|
</div>
|
||||||
|
<div class="card-body">
|
||||||
|
{{ render_field_with_errors(form.name) }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ render_field(form.submit) }}
|
||||||
|
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
44
project/templates/oauth2_client/list.html
Normal file
44
project/templates/oauth2_client/list.html
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_pagination %}
|
||||||
|
{% block title %}
|
||||||
|
{{ _('OAuth2 clients') }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<nav aria-label="breadcrumb">
|
||||||
|
<ol class="breadcrumb">
|
||||||
|
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||||
|
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 clients') }}</li>
|
||||||
|
</ol>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
{% if current_user.has_permission('oauth2_client:create') %}
|
||||||
|
<div class="my-4">
|
||||||
|
<a class="btn btn-outline-secondary my-1" href="{{ url_for('oauth2_client_create') }}" role="button"><i class="fa fa-plus"></i> {{ _('Create OAuth2 client') }}</a>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<div class="table-responsive">
|
||||||
|
<table class="table table-sm table-bordered table-hover table-striped">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>{{ _('Name') }}</th>
|
||||||
|
<th></th>
|
||||||
|
<th></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for oauth2_client in oauth2_clients %}
|
||||||
|
<tr>
|
||||||
|
<td><a href="{{ url_for('oauth2_client', id=oauth2_client.id) }}">{{ oauth2_client.client_name }}</a></td>
|
||||||
|
<td><a href="{{ url_for('oauth2_client_update', id=oauth2_client.id) }}">{{ _('Edit') }}</a></td>
|
||||||
|
<td><a href="{{ url_for('oauth2_client_delete', id=oauth2_client.id) }}">{{ _('Delete') }}</a></td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="my-4">{{ render_pagination(pagination) }}</div>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
34
project/templates/oauth2_client/read.html
Normal file
34
project/templates/oauth2_client/read.html
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_kv_begin, render_kv_end, render_kv_prop %}
|
||||||
|
{% block title %}
|
||||||
|
{{ oauth2_client.client_name }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block header %}
|
||||||
|
<script type="application/ld+json">
|
||||||
|
{{ structured_data | safe }}
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<nav aria-label="breadcrumb">
|
||||||
|
<ol class="breadcrumb">
|
||||||
|
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||||
|
<li class="breadcrumb-item"><a href="{{ url_for('oauth2_clients') }}">{{ _('OAuth2 clients') }}</a></li>
|
||||||
|
<li class="breadcrumb-item active" aria-current="page">{{ oauth2_client.client_name }}</li>
|
||||||
|
</ol>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
<div class="w-normal">
|
||||||
|
{{ render_kv_begin() }}
|
||||||
|
{{ render_kv_prop(oauth2_client.client_id, 'Client ID') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.client_secret, 'Client secret') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.client_uri, 'Client URI') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.grant_types, 'Grant types') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.redirect_uris, 'Redirect URIs') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.response_types, 'Response types') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.scope, 'Scope') }}
|
||||||
|
{{ render_kv_prop(oauth2_client.token_endpoint_auth_method, 'Token endpoint auth method') }}
|
||||||
|
{{ render_kv_end() }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
27
project/templates/oauth2_client/update.html
Normal file
27
project/templates/oauth2_client/update.html
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||||
|
{% block title %}
|
||||||
|
{{ _('Update OAuth2 client') }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<h1>{{ _('Update OAuth2 client') }}</h1>
|
||||||
|
|
||||||
|
<form action="" method="POST">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
|
||||||
|
<div class="card mb-4">
|
||||||
|
<div class="card-body">
|
||||||
|
{{ render_field_with_errors(form.client_name) }}
|
||||||
|
{{ render_field_with_errors(form.grant_types, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.response_types, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.scope, ri="multicheckbox") }}
|
||||||
|
{{ render_field_with_errors(form.token_endpoint_auth_method) }}
|
||||||
|
{{ render_field_with_errors(form.redirect_uris) }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ render_field(form.submit) }}
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
40
project/templates/oauth2_token/list.html
Normal file
40
project/templates/oauth2_token/list.html
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_pagination %}
|
||||||
|
{% block title %}
|
||||||
|
{{ _('OAuth2 tokens') }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<nav aria-label="breadcrumb">
|
||||||
|
<ol class="breadcrumb">
|
||||||
|
<li class="breadcrumb-item"><a href="{{ url_for('profile') }}">{{ _('Profile') }}</a></li>
|
||||||
|
<li class="breadcrumb-item active" aria-current="page">{{ _('OAuth2 tokens') }}</li>
|
||||||
|
</ol>
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
<div class="table-responsive">
|
||||||
|
<table class="table table-sm table-bordered table-hover table-striped">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>{{ _('Client') }}</th>
|
||||||
|
<th>{{ _('Scopes') }}</th>
|
||||||
|
<th>{{ _('Status') }}</th>
|
||||||
|
<th></th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for oauth2_token in oauth2_tokens %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ oauth2_token.client.client_name }}</td>
|
||||||
|
<td>{{ oauth2_token.client.scope }}</td>
|
||||||
|
<td>{% if oauth2_token.revoked %}{{ _('Revoked') }}{% else %}{{ _('Active') }}{% endif %}</td>
|
||||||
|
<td>{% if not oauth2_token.revoked %}<a href="{{ url_for('oauth2_token_revoke', id=oauth2_token.id) }}">{{ _('Revoke') }}</a>{% endif %}</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="my-4">{{ render_pagination(pagination) }}</div>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
16
project/templates/oauth2_token/revoke.html
Normal file
16
project/templates/oauth2_token/revoke.html
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_field_with_errors, render_field %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<h1>{{ _('Revoke OAuth2 token') }}</h1>
|
||||||
|
<h2>{{ oauth2_token.client.client_name }}</h2>
|
||||||
|
|
||||||
|
<form action="" method="POST">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
|
||||||
|
{{ render_field(form.submit) }}
|
||||||
|
|
||||||
|
</form>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
@ -8,7 +8,23 @@
|
|||||||
<h1>{{ current_user.email }}</h1>
|
<h1>{{ current_user.email }}</h1>
|
||||||
|
|
||||||
<h2>{{ _('Profile') }}</h2>
|
<h2>{{ _('Profile') }}</h2>
|
||||||
<p><a href="{{ url_for_security('change_password') }}">{{ _fsdomain('Change password') }}</a></p>
|
|
||||||
|
<div class="list-group">
|
||||||
|
<a href="{{ url_for('security.change_password') }}" class="list-group-item">
|
||||||
|
{{ _fsdomain('Change password') }}
|
||||||
|
<i class="fa fa-caret-right"></i>
|
||||||
|
</a>
|
||||||
|
{% if current_user.has_permission('oauth2_client:read') %}
|
||||||
|
<a href="{{ url_for('oauth2_clients') }}" class="list-group-item">
|
||||||
|
{{ _('OAuth2 clients') }}
|
||||||
|
<i class="fa fa-caret-right"></i>
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
|
<a href="{{ url_for('oauth2_tokens') }}" class="list-group-item">
|
||||||
|
{{ _('OAuth2 tokens') }}
|
||||||
|
<i class="fa fa-caret-right"></i>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
{% if invitations %}
|
{% if invitations %}
|
||||||
<h2>{{ _('Invitations') }}</h2>
|
<h2>{{ _('Invitations') }}</h2>
|
||||||
|
|||||||
40
project/templates/security/authorize.html
Normal file
40
project/templates/security/authorize.html
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
{% extends "layout.html" %}
|
||||||
|
{% from "_macros.html" import render_field %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
|
||||||
|
<div class="w-normal d-flex flex-column">
|
||||||
|
|
||||||
|
<div class="card mx-auto">
|
||||||
|
<div class="card-body">
|
||||||
|
<h5>{{ _('"%(client_name)s" wants to access your account', client_name=grant.client.client_name) }}</h5>
|
||||||
|
|
||||||
|
<p class="text-center"><strong>{{ user.email }}</strong></p>
|
||||||
|
|
||||||
|
<p class="mb-1">{{ _('This will allow "%(client_name)s" to:', client_name=grant.client.client_name) }}</p>
|
||||||
|
|
||||||
|
<ul>
|
||||||
|
{% for key, value in scopes.items() %}
|
||||||
|
<li>{{ value }}</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<form action="" method="POST">
|
||||||
|
{{ form.hidden_tag() }}
|
||||||
|
|
||||||
|
<div class="d-flex flex-column mt-5">
|
||||||
|
<div class="mx-auto">
|
||||||
|
{{ render_field(form.allow, class="btn btn-success mx-auto") }}
|
||||||
|
</div>
|
||||||
|
<div class="mx-auto">
|
||||||
|
{{ render_field(form.deny, class="btn btn-light mx-auto") }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
@ -4,7 +4,8 @@
|
|||||||
{% block content %}
|
{% block content %}
|
||||||
|
|
||||||
<h1>{{ _fsdomain('Login') }}</h1>
|
<h1>{{ _fsdomain('Login') }}</h1>
|
||||||
<form action="{{ url_for_security('login', next='manage') }}" method="POST" name="login_user_form">
|
{% set next = request.args['next'] if 'next' in request.args and 'authorize' in request.args['next'] else 'manage' %}
|
||||||
|
<form action="{{ url_for_security('login', next=next) }}" method="POST" name="login_user_form">
|
||||||
{{ login_user_form.hidden_tag() }}
|
{{ login_user_form.hidden_tag() }}
|
||||||
{{ render_field_with_errors(login_user_form.email) }}
|
{{ render_field_with_errors(login_user_form.email) }}
|
||||||
{{ render_field_with_errors(login_user_form.password) }}
|
{{ render_field_with_errors(login_user_form.password) }}
|
||||||
|
|||||||
@ -11,9 +11,20 @@ def get_localized_enum_name(enum):
|
|||||||
return lazy_gettext(enum.__class__.__name__ + "." + enum.name)
|
return lazy_gettext(enum.__class__.__name__ + "." + enum.name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_localized_scope(scope: str) -> str:
|
||||||
|
type_name, action = scope.split(":")
|
||||||
|
loc_lazy_gettext = lazy_gettext(type_name.capitalize())
|
||||||
|
loc_action = lazy_gettext(action)
|
||||||
|
return f"{loc_lazy_gettext} ({loc_action})"
|
||||||
|
|
||||||
|
|
||||||
def make_dir(path):
|
def make_dir(path):
|
||||||
try:
|
try:
|
||||||
original_umask = os.umask(0)
|
original_umask = os.umask(0)
|
||||||
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
||||||
finally:
|
finally:
|
||||||
os.umask(original_umask)
|
os.umask(original_umask)
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_crlf(s):
|
||||||
|
return [v for v in s.splitlines() if v]
|
||||||
|
|||||||
53
project/views/oauth.py
Normal file
53
project/views/oauth.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from authlib.oauth2 import OAuth2Error
|
||||||
|
from project import app
|
||||||
|
from project.api import scopes
|
||||||
|
from flask_security import current_user
|
||||||
|
from flask import redirect, request, url_for, render_template
|
||||||
|
from project.forms.security import AuthorizeForm
|
||||||
|
from project.oauth2 import authorization
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth/authorize", methods=["GET", "POST"])
|
||||||
|
def authorize():
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
if not user or not user.is_authenticated:
|
||||||
|
return redirect(url_for("security.login", next=request.url))
|
||||||
|
|
||||||
|
form = AuthorizeForm()
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
grant_user = user if form.allow.data else None
|
||||||
|
return authorization.create_authorization_response(grant_user=grant_user)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
grant = authorization.validate_consent_request(end_user=user)
|
||||||
|
except OAuth2Error as error:
|
||||||
|
return error.error
|
||||||
|
|
||||||
|
grant_scopes = grant.request.scope.split(" ")
|
||||||
|
filtered_scopes = {k: scopes[k] for k in grant_scopes}
|
||||||
|
return render_template(
|
||||||
|
"security/authorize.html",
|
||||||
|
form=form,
|
||||||
|
scopes=filtered_scopes,
|
||||||
|
user=user,
|
||||||
|
grant=grant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth/token", methods=["POST"])
|
||||||
|
def issue_token():
|
||||||
|
return authorization.create_token_response()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth/revoke", methods=["POST"])
|
||||||
|
def revoke_token():
|
||||||
|
return authorization.create_endpoint_response("revocation")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2-redirect.html")
|
||||||
|
def swagger_oauth2_redirect():
|
||||||
|
return redirect(
|
||||||
|
url_for("flask-apispec.static", filename="oauth2-redirect.html", **request.args)
|
||||||
|
)
|
||||||
125
project/views/oauth2_client.py
Normal file
125
project/views/oauth2_client.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from project import app, db
|
||||||
|
from flask import render_template, redirect, flash, url_for
|
||||||
|
from flask_babelex import gettext
|
||||||
|
from flask_security import permissions_required, current_user
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
from project.views.utils import (
|
||||||
|
get_pagination_urls,
|
||||||
|
handleSqlError,
|
||||||
|
flash_errors,
|
||||||
|
non_match_for_deletion,
|
||||||
|
)
|
||||||
|
from project.forms.oauth2_client import (
|
||||||
|
CreateOAuth2ClientForm,
|
||||||
|
UpdateOAuth2ClientForm,
|
||||||
|
DeleteOAuth2ClientForm,
|
||||||
|
)
|
||||||
|
from project.services.oauth2_client import complete_oauth2_client
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from project.access import owner_access_or_401
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_client/create", methods=("GET", "POST"))
|
||||||
|
@permissions_required("oauth2_client:create")
|
||||||
|
def oauth2_client_create():
|
||||||
|
form = CreateOAuth2ClientForm()
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
oauth2_client = OAuth2Client()
|
||||||
|
form.populate_obj(oauth2_client)
|
||||||
|
oauth2_client.user_id = current_user.id
|
||||||
|
complete_oauth2_client(oauth2_client)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.add(oauth2_client)
|
||||||
|
db.session.commit()
|
||||||
|
flash(gettext("OAuth2 client successfully created"), "success")
|
||||||
|
return redirect(url_for("oauth2_client", id=oauth2_client.id))
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(handleSqlError(e), "danger")
|
||||||
|
else:
|
||||||
|
flash_errors(form)
|
||||||
|
|
||||||
|
return render_template("oauth2_client/create.html", form=form)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_client/<int:id>/update", methods=("GET", "POST"))
|
||||||
|
@permissions_required("oauth2_client:update")
|
||||||
|
def oauth2_client_update(id):
|
||||||
|
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||||
|
owner_access_or_401(oauth2_client.user_id)
|
||||||
|
|
||||||
|
form = UpdateOAuth2ClientForm(obj=oauth2_client)
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
form.populate_obj(oauth2_client)
|
||||||
|
complete_oauth2_client(oauth2_client)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.session.commit()
|
||||||
|
flash(gettext("OAuth2 client successfully updated"), "success")
|
||||||
|
return redirect(url_for("oauth2_client", id=oauth2_client.id))
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(handleSqlError(e), "danger")
|
||||||
|
else:
|
||||||
|
flash_errors(form)
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_client/update.html", form=form, oauth2_client=oauth2_client
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_client/<int:id>/delete", methods=("GET", "POST"))
|
||||||
|
@permissions_required("oauth2_client:delete")
|
||||||
|
def oauth2_client_delete(id):
|
||||||
|
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||||
|
owner_access_or_401(oauth2_client.user_id)
|
||||||
|
|
||||||
|
form = DeleteOAuth2ClientForm()
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
if non_match_for_deletion(form.name.data, oauth2_client.client_name):
|
||||||
|
flash(gettext("Entered name does not match OAuth2 client name"), "danger")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
db.session.delete(oauth2_client)
|
||||||
|
db.session.commit()
|
||||||
|
flash(gettext("OAuth2 client successfully deleted"), "success")
|
||||||
|
return redirect(url_for("oauth2_clients"))
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(handleSqlError(e), "danger")
|
||||||
|
else:
|
||||||
|
flash_errors(form)
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_client/delete.html", form=form, oauth2_client=oauth2_client
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_client/<int:id>")
|
||||||
|
@permissions_required("oauth2_client:read")
|
||||||
|
def oauth2_client(id):
|
||||||
|
oauth2_client = OAuth2Client.query.get_or_404(id)
|
||||||
|
owner_access_or_401(oauth2_client.user_id)
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_client/read.html",
|
||||||
|
oauth2_client=oauth2_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_clients")
|
||||||
|
@permissions_required("oauth2_client:read")
|
||||||
|
def oauth2_clients():
|
||||||
|
oauth2_clients = OAuth2Client.query.filter(
|
||||||
|
OAuth2Client.user_id == current_user.id
|
||||||
|
).paginate()
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_client/list.html",
|
||||||
|
oauth2_clients=oauth2_clients.items,
|
||||||
|
pagination=get_pagination_urls(oauth2_clients),
|
||||||
|
)
|
||||||
53
project/views/oauth2_token.py
Normal file
53
project/views/oauth2_token.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from project import app, db
|
||||||
|
from flask import render_template, redirect, flash, url_for
|
||||||
|
from flask_babelex import gettext
|
||||||
|
from flask_security import current_user
|
||||||
|
from project.models import OAuth2Token
|
||||||
|
from project.views.utils import (
|
||||||
|
get_pagination_urls,
|
||||||
|
handleSqlError,
|
||||||
|
flash_errors,
|
||||||
|
)
|
||||||
|
from project.forms.oauth2_token import RevokeOAuth2TokenForm
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from project.access import owner_access_or_401
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_token/<int:id>/revoke", methods=("GET", "POST"))
|
||||||
|
def oauth2_token_revoke(id):
|
||||||
|
oauth2_token = OAuth2Token.query.get_or_404(id)
|
||||||
|
owner_access_or_401(oauth2_token.user_id)
|
||||||
|
|
||||||
|
if oauth2_token.revoked:
|
||||||
|
return redirect(url_for("oauth2_tokens"))
|
||||||
|
|
||||||
|
form = RevokeOAuth2TokenForm()
|
||||||
|
|
||||||
|
if form.validate_on_submit():
|
||||||
|
try:
|
||||||
|
oauth2_token.revoked = True
|
||||||
|
db.session.commit()
|
||||||
|
flash(gettext("OAuth2 token successfully revoked"), "success")
|
||||||
|
return redirect(url_for("oauth2_tokens"))
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
db.session.rollback()
|
||||||
|
flash(handleSqlError(e), "danger")
|
||||||
|
else:
|
||||||
|
flash_errors(form)
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_token/revoke.html", form=form, oauth2_token=oauth2_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/oauth2_tokens")
|
||||||
|
def oauth2_tokens():
|
||||||
|
oauth2_tokens = OAuth2Token.query.filter(
|
||||||
|
OAuth2Token.user_id == current_user.id
|
||||||
|
).paginate()
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
"oauth2_token/list.html",
|
||||||
|
oauth2_tokens=oauth2_tokens.items,
|
||||||
|
pagination=get_pagination_urls(oauth2_tokens),
|
||||||
|
)
|
||||||
@ -5,6 +5,7 @@ apispec-webframeworks==0.5.2
|
|||||||
appdirs==1.4.4
|
appdirs==1.4.4
|
||||||
argh==0.26.2
|
argh==0.26.2
|
||||||
attrs==20.3.0
|
attrs==20.3.0
|
||||||
|
Authlib==0.15.3
|
||||||
Babel==2.9.0
|
Babel==2.9.0
|
||||||
bcrypt==3.2.0
|
bcrypt==3.2.0
|
||||||
beautifulsoup4==4.9.3
|
beautifulsoup4==4.9.3
|
||||||
@ -18,6 +19,7 @@ click==7.1.2
|
|||||||
colour==0.1.5
|
colour==0.1.5
|
||||||
coverage==5.3
|
coverage==5.3
|
||||||
coveralls==2.2.0
|
coveralls==2.2.0
|
||||||
|
cryptography==3.3.1
|
||||||
distlib==0.3.1
|
distlib==0.3.1
|
||||||
dnspython==2.0.0
|
dnspython==2.0.0
|
||||||
docopt==0.6.2
|
docopt==0.6.2
|
||||||
|
|||||||
55
tests/api/test___init__.py
Normal file
55
tests/api/test___init__.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import pytest
|
||||||
|
from project.api import RestApi
|
||||||
|
|
||||||
|
|
||||||
|
class Psycog2Error(object):
|
||||||
|
def __init__(self, pgcode):
|
||||||
|
self.pgcode = pgcode
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_error_unique(app):
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from psycopg2.errorcodes import UNIQUE_VIOLATION
|
||||||
|
|
||||||
|
orig = Psycog2Error(UNIQUE_VIOLATION)
|
||||||
|
error = IntegrityError("Select", list(), orig)
|
||||||
|
|
||||||
|
api = RestApi(app)
|
||||||
|
(data, code) = api.handle_error(error)
|
||||||
|
assert code == 400
|
||||||
|
assert data["name"] == "Unique Violation"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_error_httpException(app):
|
||||||
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
error = InternalServerError()
|
||||||
|
|
||||||
|
api = RestApi(app)
|
||||||
|
(data, code) = api.handle_error(error)
|
||||||
|
assert code == 500
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_error_unprocessableEntity(app):
|
||||||
|
from werkzeug.exceptions import UnprocessableEntity
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
args = {"name": ["Required"]}
|
||||||
|
validation_error = ValidationError(args)
|
||||||
|
|
||||||
|
error = UnprocessableEntity()
|
||||||
|
error.exc = validation_error
|
||||||
|
|
||||||
|
api = RestApi(app)
|
||||||
|
(data, code) = api.handle_error(error)
|
||||||
|
assert code == 422
|
||||||
|
assert data["errors"][0]["field"] == "name"
|
||||||
|
assert data["errors"][0]["message"] == "Required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_error_unspecificRaises(app):
|
||||||
|
error = Exception()
|
||||||
|
api = RestApi(app)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
api.handle_error(error)
|
||||||
47
tests/api/test_fields.py
Normal file
47
tests/api/test_fields.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
def test_numeric_str_serialize(client, seeder, utils):
|
||||||
|
from project.api.location.schemas import LocationSchema
|
||||||
|
from project.models import Location
|
||||||
|
|
||||||
|
location = Location()
|
||||||
|
location.street = "Markt 7"
|
||||||
|
location.postalCode = "38640"
|
||||||
|
location.city = "Goslar"
|
||||||
|
location.latitude = 51.9077888
|
||||||
|
location.longitude = 10.4333312
|
||||||
|
|
||||||
|
schema = LocationSchema()
|
||||||
|
data = schema.dump(location)
|
||||||
|
|
||||||
|
assert data["latitude"] == "51.9077888"
|
||||||
|
assert data["longitude"] == "10.4333312"
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_str_deserialize(client, seeder, utils):
|
||||||
|
from project.api.location.schemas import LocationPostRequestLoadSchema
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"latitude": "51.9077888",
|
||||||
|
"longitude": "10.4333312",
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = LocationPostRequestLoadSchema()
|
||||||
|
location = schema.load(data)
|
||||||
|
|
||||||
|
assert location.latitude == 51.9077888
|
||||||
|
assert location.longitude == 10.4333312
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_str_deserialize_invalid(client, seeder, utils):
|
||||||
|
from project.api.location.schemas import LocationPostRequestLoadSchema
|
||||||
|
import pytest
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"latitude": "Quatsch",
|
||||||
|
"longitude": "Quatsch",
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = LocationPostRequestLoadSchema()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
schema.load(data)
|
||||||
@ -1,6 +0,0 @@
|
|||||||
def test_read(client, seeder, utils):
|
|
||||||
user_id, admin_unit_id = seeder.setup_base()
|
|
||||||
image_id = seeder.upsert_default_image()
|
|
||||||
|
|
||||||
url = utils.get_url("api_v1_image", id=image_id)
|
|
||||||
utils.get_ok(url)
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
def test_read(client, app, db, seeder, utils):
|
|
||||||
user_id, admin_unit_id = seeder.setup_base()
|
|
||||||
|
|
||||||
with app.app_context():
|
|
||||||
from project.models import Location
|
|
||||||
|
|
||||||
location = Location()
|
|
||||||
location.street = "Markt 7"
|
|
||||||
location.postalCode = "38640"
|
|
||||||
location.city = "Goslar"
|
|
||||||
location.latitude = 51.9077888
|
|
||||||
location.longitude = 10.4333312
|
|
||||||
|
|
||||||
db.session.add(location)
|
|
||||||
db.session.commit()
|
|
||||||
location_id = location.id
|
|
||||||
|
|
||||||
url = utils.get_url("api_v1_location", id=location_id)
|
|
||||||
response = utils.get_ok(url)
|
|
||||||
assert response.json["latitude"] == "51.9077888000000000"
|
|
||||||
@ -46,6 +46,25 @@ def test_places(client, seeder, utils):
|
|||||||
utils.get_ok(url)
|
utils.get_ok(url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_places_post(client, seeder, utils, app):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
|
||||||
|
url = utils.get_url("api_v1_organization_place_list", id=admin_unit_id, name="crew")
|
||||||
|
response = utils.post_json(url, {"name": "Neuer Ort"})
|
||||||
|
utils.assert_response_created(response)
|
||||||
|
assert "id" in response.json
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import EventPlace
|
||||||
|
|
||||||
|
place = (
|
||||||
|
EventPlace.query.filter(EventPlace.admin_unit_id == admin_unit_id)
|
||||||
|
.filter(EventPlace.name == "Neuer Ort")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert place is not None
|
||||||
|
|
||||||
|
|
||||||
def test_references_incoming(client, seeder, utils):
|
def test_references_incoming(client, seeder, utils):
|
||||||
user_id, admin_unit_id = seeder.setup_base()
|
user_id, admin_unit_id = seeder.setup_base()
|
||||||
(
|
(
|
||||||
|
|||||||
@ -4,3 +4,65 @@ def test_read(client, app, db, seeder, utils):
|
|||||||
|
|
||||||
url = utils.get_url("api_v1_place", id=place_id)
|
url = utils.get_url("api_v1_place", id=place_id)
|
||||||
utils.get_ok(url)
|
utils.get_ok(url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_put(client, seeder, utils, app):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||||
|
|
||||||
|
url = utils.get_url("api_v1_place", id=place_id)
|
||||||
|
response = utils.put_json(url, {"name": "Neuer Name"})
|
||||||
|
utils.assert_response_no_content(response)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import EventPlace
|
||||||
|
|
||||||
|
place = EventPlace.query.get(place_id)
|
||||||
|
assert place.name == "Neuer Name"
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_nonActiveReturnsUnauthorized(client, seeder, db, utils, app):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import User
|
||||||
|
|
||||||
|
user = User.query.get(user_id)
|
||||||
|
user.active = False
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
url = utils.get_url("api_v1_place", id=place_id)
|
||||||
|
response = utils.put_json(url, {"name": "Neuer Name"})
|
||||||
|
utils.assert_response_unauthorized(response)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch(client, seeder, utils, app):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||||
|
|
||||||
|
url = utils.get_url("api_v1_place", id=place_id)
|
||||||
|
response = utils.patch_json(url, {"description": "Klasse"})
|
||||||
|
utils.assert_response_no_content(response)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import EventPlace
|
||||||
|
|
||||||
|
place = EventPlace.query.get(place_id)
|
||||||
|
assert place.name == "Meine Crew"
|
||||||
|
assert place.description == "Klasse"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(client, seeder, utils, app):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
place_id = seeder.upsert_default_event_place(admin_unit_id)
|
||||||
|
|
||||||
|
url = utils.get_url("api_v1_place", id=place_id)
|
||||||
|
response = utils.delete(url)
|
||||||
|
utils.assert_response_no_content(response)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import EventPlace
|
||||||
|
|
||||||
|
place = EventPlace.query.get(place_id)
|
||||||
|
assert place is None
|
||||||
|
|||||||
@ -8,6 +8,7 @@ def pytest_generate_tests(metafunc):
|
|||||||
os.environ["DATABASE_URL"] = os.environ.get(
|
os.environ["DATABASE_URL"] = os.environ.get(
|
||||||
"TEST_DATABASE_URL", "postgresql://postgres@localhost/gsevpt_tests"
|
"TEST_DATABASE_URL", "postgresql://postgres@localhost/gsevpt_tests"
|
||||||
)
|
)
|
||||||
|
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@ -4,9 +4,10 @@ class Seeder(object):
|
|||||||
self._db = db
|
self._db = db
|
||||||
self._utils = utils
|
self._utils = utils
|
||||||
|
|
||||||
def setup_base(self, admin=False):
|
def setup_base(self, admin=False, log_in=True):
|
||||||
user_id = self.create_user(admin=admin)
|
user_id = self.create_user(admin=admin)
|
||||||
self._utils.login()
|
if log_in:
|
||||||
|
self._utils.login()
|
||||||
admin_unit_id = self.create_admin_unit(user_id)
|
admin_unit_id = self.create_admin_unit(user_id)
|
||||||
return (user_id, admin_unit_id)
|
return (user_id, admin_unit_id)
|
||||||
|
|
||||||
@ -126,6 +127,45 @@ class Seeder(object):
|
|||||||
|
|
||||||
return organizer_id
|
return organizer_id
|
||||||
|
|
||||||
|
def insert_default_oauth2_client(self, user_id):
|
||||||
|
from project.api import scope_list
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
from project.services.oauth2_client import complete_oauth2_client
|
||||||
|
|
||||||
|
with self._app.app_context():
|
||||||
|
client = OAuth2Client()
|
||||||
|
client.user_id = user_id
|
||||||
|
complete_oauth2_client(client)
|
||||||
|
|
||||||
|
metadata = dict()
|
||||||
|
metadata["client_name"] = "Mein Client"
|
||||||
|
metadata["scope"] = " ".join(scope_list)
|
||||||
|
metadata["grant_types"] = ["authorization_code", "refresh_token"]
|
||||||
|
metadata["response_types"] = ["code"]
|
||||||
|
metadata["token_endpoint_auth_method"] = "client_secret_post"
|
||||||
|
client.set_client_metadata(metadata)
|
||||||
|
|
||||||
|
self._db.session.add(client)
|
||||||
|
self._db.session.commit()
|
||||||
|
client_id = client.id
|
||||||
|
|
||||||
|
return client_id
|
||||||
|
|
||||||
|
def setup_api_access(self):
|
||||||
|
user_id, admin_unit_id = self.setup_base(admin=True)
|
||||||
|
oauth2_client_id = self.insert_default_oauth2_client(user_id)
|
||||||
|
|
||||||
|
with self._app.app_context():
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
|
||||||
|
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||||
|
client_id = oauth2_client.client_id
|
||||||
|
client_secret = oauth2_client.client_secret
|
||||||
|
scope = oauth2_client.scope
|
||||||
|
|
||||||
|
self._utils.authorize(client_id, client_secret, scope)
|
||||||
|
return (user_id, admin_unit_id)
|
||||||
|
|
||||||
def create_event(self, admin_unit_id, recurrence_rule=None):
|
def create_event(self, admin_unit_id, recurrence_rule=None):
|
||||||
from project.models import Event
|
from project.models import Event
|
||||||
from project.services.event import insert_event, upsert_event_category
|
from project.services.event import insert_event, upsert_event_category
|
||||||
|
|||||||
@ -21,3 +21,16 @@ def test_event_category(client, app, db, seeder):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
assert event.category is None
|
assert event.category is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth2_token(client, app):
|
||||||
|
from project.models import OAuth2Token
|
||||||
|
|
||||||
|
token = OAuth2Token()
|
||||||
|
token.revoked = True
|
||||||
|
assert not token.is_refresh_token_active()
|
||||||
|
|
||||||
|
token.revoked = False
|
||||||
|
token.issued_at = 0
|
||||||
|
token.expires_in = 0
|
||||||
|
assert not token.is_refresh_token_active()
|
||||||
|
|||||||
7
tests/test_oauth2.py
Normal file
7
tests/test_oauth2.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
def test_authorization_code(seeder):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_token(seeder, utils):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
utils.refresh_token()
|
||||||
161
tests/utils.py
161
tests/utils.py
@ -2,12 +2,17 @@ import re
|
|||||||
from flask import g, url_for
|
from flask import g, url_for
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from urllib.parse import urlsplit, parse_qs
|
||||||
|
|
||||||
|
|
||||||
class UtilActions(object):
|
class UtilActions(object):
|
||||||
def __init__(self, client, app):
|
def __init__(self, client, app):
|
||||||
self._client = client
|
self._client = client
|
||||||
self._app = app
|
self._app = app
|
||||||
|
self._access_token = None
|
||||||
|
self._refresh_token = None
|
||||||
|
self._client_id = None
|
||||||
|
self._client_secret = None
|
||||||
|
|
||||||
def register(self, email="test@test.de", password="MeinPasswortIstDasBeste"):
|
def register(self, email="test@test.de", password="MeinPasswortIstDasBeste"):
|
||||||
response = self._client.get("/register")
|
response = self._client.get("/register")
|
||||||
@ -77,14 +82,56 @@ class UtilActions(object):
|
|||||||
form = Form(soup.find("form"))
|
form = Form(soup.find("form"))
|
||||||
return form.fill(values)
|
return form.fill(values)
|
||||||
|
|
||||||
def post_form(self, url, response, values: dict):
|
def post_form_data(self, url, data: dict):
|
||||||
data = self.create_form_data(response, values)
|
|
||||||
return self._client.post(url, data=data)
|
return self._client.post(url, data=data)
|
||||||
|
|
||||||
|
def post_form(self, url, response, values: dict):
|
||||||
|
data = self.create_form_data(response, values)
|
||||||
|
return self.post_form_data(url, data=data)
|
||||||
|
|
||||||
|
def get_headers(self):
|
||||||
|
headers = dict()
|
||||||
|
|
||||||
|
if self._access_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self._access_token}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def log_request(self, url):
|
||||||
|
print(url)
|
||||||
|
|
||||||
|
def log_json_request(self, url, data: dict):
|
||||||
|
self.log_request(url)
|
||||||
|
print(data)
|
||||||
|
|
||||||
|
def log_response(self, response):
|
||||||
|
print(response.status_code)
|
||||||
|
print(response.data)
|
||||||
|
print(response.json)
|
||||||
|
|
||||||
def post_json(self, url, data: dict):
|
def post_json(self, url, data: dict):
|
||||||
response = self._client.post(url, json=data)
|
self.log_json_request(url, data)
|
||||||
assert response.content_type == "application/json"
|
response = self._client.post(url, json=data, headers=self.get_headers())
|
||||||
return response.json
|
self.log_response(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def put_json(self, url, data: dict):
|
||||||
|
self.log_json_request(url, data)
|
||||||
|
response = self._client.put(url, json=data, headers=self.get_headers())
|
||||||
|
self.log_response(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def patch_json(self, url, data: dict):
|
||||||
|
self.log_json_request(url, data)
|
||||||
|
response = self._client.patch(url, json=data, headers=self.get_headers())
|
||||||
|
self.log_response(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def delete(self, url):
|
||||||
|
self.log_request(url)
|
||||||
|
response = self._client.delete(url, headers=self.get_headers())
|
||||||
|
self.log_response(response)
|
||||||
|
return response
|
||||||
|
|
||||||
def mock_db_commit(self, mocker, orig=None):
|
def mock_db_commit(self, mocker, orig=None):
|
||||||
mocked_commit = mocker.patch("project.db.session.commit")
|
mocked_commit = mocker.patch("project.db.session.commit")
|
||||||
@ -105,14 +152,23 @@ class UtilActions(object):
|
|||||||
url = url_for(endpoint, **values, _external=False)
|
url = url_for(endpoint, **values, _external=False)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
def get(self, url):
|
||||||
|
return self._client.get(url)
|
||||||
|
|
||||||
def get_ok(self, url):
|
def get_ok(self, url):
|
||||||
response = self._client.get(url)
|
response = self.get(url)
|
||||||
self.assert_response_ok(response)
|
self.assert_response_ok(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def assert_response_ok(self, response):
|
def assert_response_ok(self, response):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def assert_response_created(self, response):
|
||||||
|
assert response.status_code == 201
|
||||||
|
|
||||||
|
def assert_response_no_content(self, response):
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
def get_unauthorized(self, url):
|
def get_unauthorized(self, url):
|
||||||
response = self._client.get(url)
|
response = self._client.get(url)
|
||||||
self.assert_response_unauthorized(response)
|
self.assert_response_unauthorized(response)
|
||||||
@ -146,3 +202,96 @@ class UtilActions(object):
|
|||||||
|
|
||||||
def assert_response_permission_missing(self, response, endpoint, **values):
|
def assert_response_permission_missing(self, response, endpoint, **values):
|
||||||
self.assert_response_redirect(response, endpoint, **values)
|
self.assert_response_redirect(response, endpoint, **values)
|
||||||
|
|
||||||
|
def parse_query_parameters(self, url):
|
||||||
|
query = urlsplit(url).query
|
||||||
|
params = parse_qs(query)
|
||||||
|
return {k: v[0] for k, v in params.items()}
|
||||||
|
|
||||||
|
def authorize(self, client_id, client_secret, scope):
|
||||||
|
# Authorize-Seite öffnen
|
||||||
|
redirect_uri = self.get_url("swagger_oauth2_redirect")
|
||||||
|
url = self.get_url(
|
||||||
|
"authorize",
|
||||||
|
response_type="code",
|
||||||
|
client_id=client_id,
|
||||||
|
scope=scope,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
)
|
||||||
|
response = self.get_ok(url)
|
||||||
|
|
||||||
|
# Authorisieren
|
||||||
|
response = self.post_form(
|
||||||
|
url,
|
||||||
|
response,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert redirect_uri in response.headers["Location"]
|
||||||
|
|
||||||
|
# Code aus der Redirect-Antwort lesen
|
||||||
|
params = self.parse_query_parameters(response.headers["Location"])
|
||||||
|
assert "code" in params
|
||||||
|
code = params["code"]
|
||||||
|
|
||||||
|
# Mit dem Code den Access-Token abfragen
|
||||||
|
token_url = self.get_url("issue_token")
|
||||||
|
response = self.post_form_data(
|
||||||
|
token_url,
|
||||||
|
data={
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"scope": scope,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_response_ok(response)
|
||||||
|
assert response.content_type == "application/json"
|
||||||
|
assert "access_token" in response.json
|
||||||
|
assert "expires_in" in response.json
|
||||||
|
assert "refresh_token" in response.json
|
||||||
|
assert response.json["scope"] == scope
|
||||||
|
assert response.json["token_type"] == "Bearer"
|
||||||
|
|
||||||
|
self._client_id = client_id
|
||||||
|
self._client_secret = client_secret
|
||||||
|
self._access_token = response.json["access_token"]
|
||||||
|
self._refresh_token = response.json["refresh_token"]
|
||||||
|
|
||||||
|
def refresh_token(self):
|
||||||
|
token_url = self.get_url("issue_token")
|
||||||
|
response = self.post_form_data(
|
||||||
|
token_url,
|
||||||
|
data={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": self._refresh_token,
|
||||||
|
"client_id": self._client_id,
|
||||||
|
"client_secret": self._client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_response_ok(response)
|
||||||
|
assert response.content_type == "application/json"
|
||||||
|
assert response.json["token_type"] == "Bearer"
|
||||||
|
assert "access_token" in response.json
|
||||||
|
assert "expires_in" in response.json
|
||||||
|
|
||||||
|
self._access_token = response.json["access_token"]
|
||||||
|
|
||||||
|
def revoke_token(self):
|
||||||
|
url = self.get_url("revoke_token")
|
||||||
|
response = self.post_form_data(
|
||||||
|
url,
|
||||||
|
data={
|
||||||
|
"token": self._access_token,
|
||||||
|
"token_type_hint": "access_token",
|
||||||
|
"client_id": self._client_id,
|
||||||
|
"client_secret": self._client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_response_ok(response)
|
||||||
|
|||||||
@ -430,7 +430,7 @@ def test_delete_nameDoesNotMatch(client, seeder, utils, app, mocker):
|
|||||||
|
|
||||||
def test_rrule(client, seeder, utils, app):
|
def test_rrule(client, seeder, utils, app):
|
||||||
url = utils.get_url("event_rrule")
|
url = utils.get_url("event_rrule")
|
||||||
json = utils.post_json(
|
response = utils.post_json(
|
||||||
url,
|
url,
|
||||||
{
|
{
|
||||||
"year": 2020,
|
"year": 2020,
|
||||||
@ -440,6 +440,7 @@ def test_rrule(client, seeder, utils, app):
|
|||||||
"start": 0,
|
"start": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
json = response.json
|
||||||
|
|
||||||
assert json["batch"]["batch_size"] == 10
|
assert json["batch"]["batch_size"] == 10
|
||||||
|
|
||||||
|
|||||||
25
tests/views/test_oauth.py
Normal file
25
tests/views/test_oauth.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
def test_authorize_unauthorizedRedirects(seeder, utils):
|
||||||
|
url = utils.get_url("authorize")
|
||||||
|
response = utils.get(url)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert "login" in response.headers["Location"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_authorize_validateThrowsError(seeder, utils):
|
||||||
|
seeder.setup_base()
|
||||||
|
url = utils.get_url("authorize")
|
||||||
|
response = utils.get(url)
|
||||||
|
|
||||||
|
utils.assert_response_error_message(response, b"invalid_grant")
|
||||||
|
|
||||||
|
|
||||||
|
def test_revoke_token(seeder, utils):
|
||||||
|
seeder.setup_api_access()
|
||||||
|
utils.revoke_token()
|
||||||
|
|
||||||
|
|
||||||
|
def test_swagger_redirect(utils):
|
||||||
|
url = utils.get_url("swagger_oauth2_redirect")
|
||||||
|
response = utils.get(url)
|
||||||
|
assert response.status_code == 302
|
||||||
136
tests/views/test_oauth2_client.py
Normal file
136
tests/views/test_oauth2_client.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_read(client, seeder, utils):
|
||||||
|
user_id, admin_unit_id = seeder.setup_base(True)
|
||||||
|
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_client", id=oauth2_client_id)
|
||||||
|
utils.get_ok(url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_notOwner(client, seeder, utils):
|
||||||
|
user_id = seeder.create_user(email="other@other.de", admin=True)
|
||||||
|
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||||
|
|
||||||
|
seeder.setup_base(True)
|
||||||
|
url = utils.get_url("oauth2_client", id=oauth2_client_id)
|
||||||
|
utils.get_unauthorized(url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list(client, seeder, utils):
|
||||||
|
user_id, admin_unit_id = seeder.setup_base(True)
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_clients")
|
||||||
|
utils.get_ok(url)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("db_error", [True, False])
|
||||||
|
def test_create_authorization_code(client, app, utils, seeder, mocker, db_error):
|
||||||
|
from project.api import scope_list
|
||||||
|
|
||||||
|
user_id, admin_unit_id = seeder.setup_base(True)
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_client_create")
|
||||||
|
response = utils.get_ok(url)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.mock_db_commit(mocker)
|
||||||
|
|
||||||
|
response = utils.post_form(
|
||||||
|
url,
|
||||||
|
response,
|
||||||
|
{
|
||||||
|
"client_name": "Mein Client",
|
||||||
|
"scope": scope_list,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.assert_response_db_error(response)
|
||||||
|
return
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
|
||||||
|
oauth2_client = OAuth2Client.query.filter(
|
||||||
|
OAuth2Client.user_id == user_id
|
||||||
|
).first()
|
||||||
|
assert oauth2_client is not None
|
||||||
|
client_id = oauth2_client.id
|
||||||
|
|
||||||
|
utils.assert_response_redirect(response, "oauth2_client", id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("db_error", [True, False])
|
||||||
|
def test_update(client, seeder, utils, app, mocker, db_error):
|
||||||
|
user_id, admin_unit_id = seeder.setup_base(True)
|
||||||
|
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_client_update", id=oauth2_client_id)
|
||||||
|
response = utils.get_ok(url)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.mock_db_commit(mocker)
|
||||||
|
|
||||||
|
response = utils.post_form(
|
||||||
|
url,
|
||||||
|
response,
|
||||||
|
{
|
||||||
|
"client_name": "Neuer Name",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.assert_response_db_error(response)
|
||||||
|
return
|
||||||
|
|
||||||
|
utils.assert_response_redirect(response, "oauth2_client", id=oauth2_client_id)
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
|
||||||
|
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||||
|
assert oauth2_client.client_name == "Neuer Name"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("db_error", [True, False])
|
||||||
|
@pytest.mark.parametrize("non_match", [True, False])
|
||||||
|
def test_delete(client, seeder, utils, app, mocker, db_error, non_match):
|
||||||
|
user_id, admin_unit_id = seeder.setup_base(True)
|
||||||
|
oauth2_client_id = seeder.insert_default_oauth2_client(user_id)
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_client_delete", id=oauth2_client_id)
|
||||||
|
response = utils.get_ok(url)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.mock_db_commit(mocker)
|
||||||
|
|
||||||
|
form_name = "Mein Client"
|
||||||
|
|
||||||
|
if non_match:
|
||||||
|
form_name = "Falscher Name"
|
||||||
|
|
||||||
|
response = utils.post_form(
|
||||||
|
url,
|
||||||
|
response,
|
||||||
|
{
|
||||||
|
"name": form_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if non_match:
|
||||||
|
utils.assert_response_error_message(response)
|
||||||
|
return
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.assert_response_db_error(response)
|
||||||
|
return
|
||||||
|
|
||||||
|
utils.assert_response_redirect(response, "oauth2_clients")
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import OAuth2Client
|
||||||
|
|
||||||
|
oauth2_client = OAuth2Client.query.get(oauth2_client_id)
|
||||||
|
assert oauth2_client is None
|
||||||
47
tests/views/test_oauth2_token.py
Normal file
47
tests/views/test_oauth2_token.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_list(client, seeder, utils):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_tokens")
|
||||||
|
utils.get_ok(url)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("db_error", [True, False])
|
||||||
|
def test_revoke(client, seeder, utils, app, mocker, db_error):
|
||||||
|
user_id, admin_unit_id = seeder.setup_api_access()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import OAuth2Token
|
||||||
|
|
||||||
|
oauth2_token = OAuth2Token.query.filter(OAuth2Token.user_id == user_id).first()
|
||||||
|
oauth2_token_id = oauth2_token.id
|
||||||
|
|
||||||
|
url = utils.get_url("oauth2_token_revoke", id=oauth2_token_id)
|
||||||
|
response = utils.get_ok(url)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.mock_db_commit(mocker)
|
||||||
|
|
||||||
|
response = utils.post_form(
|
||||||
|
url,
|
||||||
|
response,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_error:
|
||||||
|
utils.assert_response_db_error(response)
|
||||||
|
return
|
||||||
|
|
||||||
|
utils.assert_response_redirect(response, "oauth2_tokens")
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from project.models import OAuth2Token
|
||||||
|
|
||||||
|
oauth2_token = OAuth2Token.query.get(oauth2_token_id)
|
||||||
|
assert oauth2_token.revoked
|
||||||
|
|
||||||
|
# Kann nicht zweimal revoked werden
|
||||||
|
response = utils.get(url)
|
||||||
|
utils.assert_response_redirect(response, "oauth2_tokens")
|
||||||
Loading…
x
Reference in New Issue
Block a user