diff --git a/project/api/__init__.py b/project/api/__init__.py index 3de94e2..d383c50 100644 --- a/project/api/__init__.py +++ b/project/api/__init__.py @@ -2,6 +2,8 @@ from apispec import APISpec from apispec.exceptions import DuplicateComponentNameError from apispec.ext.marshmallow import MarshmallowPlugin from flask import url_for +from flask.globals import current_app +from flask.signals import got_request_exception from flask_apispec.extension import FlaskApiSpec from flask_babelex import gettext from flask_marshmallow import Marshmallow @@ -73,11 +75,12 @@ class RestApi(Api): self.fill_validation_data(err, data) # Call default error handler that propagates error further - try: - super().handle_error(err) - except Exception: - if not schema: - raise + if code >= 500: + try: + super().handle_error(err) + except Exception: + if not schema: + raise if data and "message" in data: data["message"] = gettext(data["message"]) diff --git a/tests/api/test___init__.py b/tests/api/test___init__.py index ffcf8ec..b5096f6 100644 --- a/tests/api/test___init__.py +++ b/tests/api/test___init__.py @@ -8,10 +8,12 @@ def test_handle_error_unique(app): error = make_unique_violation() - api = RestApi(app) - (data, code) = api.handle_error(error) - assert code == 400 - assert data["name"] == "Unique Violation" + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 400 + assert data["name"] == "Unique Violation" def test_handle_error_checkViolation(app): @@ -19,10 +21,12 @@ def test_handle_error_checkViolation(app): error = make_check_violation() - api = RestApi(app) - (data, code) = api.handle_error(error) - assert code == 400 - assert data["name"] == "Check Violation" + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 400 + assert data["name"] == "Check Violation" def test_handle_error_integrity(app): @@ -30,10 +34,12 @@ def test_handle_error_integrity(app): error = make_integrity_error("custom") - api = RestApi(app) - (data, code) = api.handle_error(error) - assert code == 400 - assert data["name"] == "Integrity Error" + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 400 + assert data["name"] == "Integrity Error" def test_handle_error_httpException(app): @@ -41,9 +47,11 @@ def test_handle_error_httpException(app): error = InternalServerError() - api = RestApi(app) - (data, code) = api.handle_error(error) - assert code == 500 + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 500 def test_handle_error_unprocessableEntity(app): @@ -56,11 +64,13 @@ def test_handle_error_unprocessableEntity(app): error = UnprocessableEntity() error.exc = validation_error - api = RestApi(app) - (data, code) = api.handle_error(error) - assert code == 422 - assert data["errors"][0]["field"] == "name" - assert data["errors"][0]["message"] == "Required" + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(error) + assert code == 422 + assert data["errors"][0]["field"] == "name" + assert data["errors"][0]["message"] == "Required" def test_handle_error_validationError(app): @@ -69,19 +79,24 @@ def test_handle_error_validationError(app): args = {"name": ["Required"]} validation_error = ValidationError(args) - api = RestApi(app) - (data, code) = api.handle_error(validation_error) - assert code == 422 - assert data["errors"][0]["field"] == "name" - assert data["errors"][0]["message"] == "Required" + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + (data, code) = api.handle_error(validation_error) + assert code == 422 + assert data["errors"][0]["field"] == "name" + assert data["errors"][0]["message"] == "Required" def test_handle_error_unspecificRaises(app): error = Exception() - api = RestApi(app) - with pytest.raises(Exception): - api.handle_error(error) + with app.app_context(): + app.config["PROPAGATE_EXCEPTIONS"] = False + api = RestApi(app) + + with pytest.raises(Exception): + api.handle_error(error) def test_add_oauth2_scheme(app, utils):