2023-04-16 23:04:39 +02:00

87 lines
2.5 KiB
Python

import time
from authlib.integrations.sqla_oauth2 import (
OAuth2AuthorizationCodeMixin,
OAuth2ClientMixin,
OAuth2TokenMixin,
)
from flask import request
from sqlalchemy.orm import object_session
from project import db
# OAuth Server: Wir bieten an, dass sich ein Nutzer per OAuth2 auf unserer Seite anmeldet
class OAuth2Client(db.Model, OAuth2ClientMixin):
__tablename__ = "oauth2_client"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
@OAuth2ClientMixin.grant_types.getter
def grant_types(self):
return ["authorization_code", "refresh_token"]
@OAuth2ClientMixin.response_types.getter
def response_types(self):
return ["code"]
@OAuth2ClientMixin.token_endpoint_auth_method.getter
def token_endpoint_auth_method(self):
return ["client_secret_basic", "client_secret_post", "none"]
def check_redirect_uri(self, redirect_uri):
if redirect_uri.startswith(request.host_url): # pragma: no cover
return True
return super().check_redirect_uri(redirect_uri)
def check_token_endpoint_auth_method(self, method):
return method in self.token_endpoint_auth_method
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == "token":
return self.check_token_endpoint_auth_method(method)
return True
class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin):
__tablename__ = "oauth2_code"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
class OAuth2Token(db.Model, OAuth2TokenMixin):
__tablename__ = "oauth2_token"
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("user.id", ondelete="CASCADE"))
user = db.relationship("User")
@property
def client(self):
return (
object_session(self)
.query(OAuth2Client)
.filter(OAuth2Client.client_id == self.client_id)
.first()
)
@property
def expires_at(self):
return self.issued_at + self.expires_in
def is_refresh_token_active(self):
if self.is_revoked():
return False
return self.expires_at >= time.time()
def revoke_token(self):
self.access_token_revoked_at = int(time.time())