diff --git a/citadel/README.md b/citadel/README.md index 1f0552b..1966da0 100644 --- a/citadel/README.md +++ b/citadel/README.md @@ -5,6 +5,10 @@ to provide advanced search functionality using an Elasticsearch backend. ## Changelog +### 3.1 + +- Correctly handle remote groups whose capitalization changed at some point + ### 3.0 - Initial release diff --git a/citadel/indico_citadel/util.py b/citadel/indico_citadel/util.py index 990fe48..01b5293 100644 --- a/citadel/indico_citadel/util.py +++ b/citadel/indico_citadel/util.py @@ -8,11 +8,14 @@ import re import sys import threading +from collections import defaultdict from functools import wraps from flask import current_app from flask.globals import _app_ctx_stack +from indico.core.db import db +from indico.core.db.sqlalchemy.principals import PrincipalMixin, PrincipalPermissionsMixin, PrincipalType from indico.modules.groups import GroupProxy from indico.util.caching import memoize_redis @@ -174,6 +177,31 @@ def _flatten(obj, target_key='buckets', parent_key=''): yield from _flatten(value, target_key, f'{parent_key}_{key}' if parent_key else key) +@memoize_redis(86400) +def _get_alternative_group_names(): + """Get non-lowercase versions of group names.""" + classes = [sc for sc in [*PrincipalMixin.__subclasses__(), *PrincipalPermissionsMixin.__subclasses__()] + if hasattr(sc, 'query')] + alternatives = defaultdict(set) + for cls in classes: + res = (db.session.query(cls.multipass_group_provider, cls.multipass_group_name) + .distinct() + .filter(cls.type == PrincipalType.multipass_group, + cls.multipass_group_name != db.func.lower(cls.multipass_group_name)) + .all()) + for provider, name in res: + alternatives[(provider, name.lower())].add(name) + return dict(alternatives) + + +def _include_capitalized_groups(groups): + alternatives = _get_alternative_group_names() + for group in groups: + yield group.identifier + for alt_name in alternatives.get((group.provider, group.name.lower()), ()): + yield GroupProxy(alt_name, group.provider).identifier + + @memoize_redis(3600) def get_user_access(user, admin_override_enabled=False): if not user: @@ -183,6 +211,7 @@ def get_user_access(user, admin_override_enabled=False): access = [user.identifier] + [u.identifier for u in user.get_merged_from_users_recursive()] access += [GroupProxy(x.id, _group=x).identifier for x in user.local_groups] if user.can_get_all_multipass_groups: - access += [GroupProxy(x.name, x.provider.name, x).identifier - for x in user.iter_all_multipass_groups()] + multipass_groups = [GroupProxy(x.name, x.provider.name, x) + for x in user.iter_all_multipass_groups()] + access += _include_capitalized_groups(multipass_groups) return access