diff --git a/project/api/__init__.py b/project/api/__init__.py index 4ca00d0..7662634 100644 --- a/project/api/__init__.py +++ b/project/api/__init__.py @@ -9,6 +9,8 @@ from flask_marshmallow import Marshmallow from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask_apispec.extension import FlaskApiSpec +from flask import url_for +from apispec.exceptions import DuplicateComponentNameError class RestApi(Api): @@ -139,6 +141,28 @@ def add_api_resource(resource, url, endpoint): api_docs.register(resource, endpoint=endpoint) +def add_oauth2_scheme_with_transport(insecure: bool): + if insecure: + authorizationUrl = url_for("authorize", _external=True) + tokenUrl = url_for("issue_token", _external=True) + else: + authorizationUrl = url_for("authorize", _external=True, _scheme="https") + tokenUrl = url_for("issue_token", _external=True, _scheme="https") + + oauth2_scheme = { + "type": "oauth2", + "authorizationUrl": authorizationUrl, + "tokenUrl": tokenUrl, + "flow": "accessCode", + "scopes": scopes, + } + + try: + api_docs.spec.components.security_scheme("oauth2", oauth2_scheme) + except DuplicateComponentNameError: # pragma: no cover + pass + + marshmallow_plugin.converter.add_attribute_function(enum_to_properties) import project.api.event.resources diff --git a/project/init_data.py b/project/init_data.py index 28c6500..020ff27 100644 --- a/project/init_data.py +++ b/project/init_data.py @@ -1,27 +1,17 @@ from project import app, db -from project.api import api_docs, scopes +from project.api import add_oauth2_scheme_with_transport from project.services.user import upsert_user_role from project.services.admin_unit import upsert_admin_unit_member_role from project.services.event import upsert_event_category from project.models import Location -from flask import url_for -from apispec.exceptions import DuplicateComponentNameError +import os @app.before_first_request def add_oauth2_scheme(): - oauth2_scheme = { - "type": "oauth2", - "authorizationUrl": url_for("authorize", _external=True), - "tokenUrl": url_for("issue_token", _external=True), - "flow": "accessCode", - "scopes": scopes, - } - - try: - api_docs.spec.components.security_scheme("oauth2", oauth2_scheme) - except DuplicateComponentNameError: # pragma: no cover - pass + # At some sites the https scheme is not set yet + insecure = os.getenv("AUTHLIB_INSECURE_TRANSPORT", "False").lower() in ["true", "1"] + add_oauth2_scheme_with_transport(insecure) @app.before_first_request diff --git a/tests/api/test___init__.py b/tests/api/test___init__.py index 07f3d71..dd83f51 100644 --- a/tests/api/test___init__.py +++ b/tests/api/test___init__.py @@ -81,3 +81,11 @@ def test_handle_error_unspecificRaises(app): with pytest.raises(Exception): api.handle_error(error) + + +def test_add_oauth2_scheme(app, utils): + from project.api import add_oauth2_scheme_with_transport + + app.config["SERVER_NAME"] = "127.0.0.1" + with app.app_context(): + add_oauth2_scheme_with_transport(False) diff --git a/tests/conftest.py b/tests/conftest.py index 7572e7e..5a2cb77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ def pytest_generate_tests(metafunc): def app(): from project import app + app.config["SERVER_NAME"] = None app.config["TESTING"] = True app.testing = True