diff --git a/livesync/indico_livesync/initial.py b/livesync/indico_livesync/initial.py index 536d236..594bb88 100644 --- a/livesync/indico_livesync/initial.py +++ b/livesync/indico_livesync/initial.py @@ -41,7 +41,7 @@ def apply_acl_entry_strategy(rel, principal): def _get_excluded_category_filter(event_model=Event): - if excluded_category_ids := get_excluded_categories(): + if excluded_category_ids := get_excluded_categories(deep=True): return event_model.category_id.notin_(excluded_category_ids) return True diff --git a/livesync/indico_livesync/simplify.py b/livesync/indico_livesync/simplify.py index 48de859..30af2d5 100644 --- a/livesync/indico_livesync/simplify.py +++ b/livesync/indico_livesync/simplify.py @@ -135,7 +135,7 @@ def _process_cascaded_category_contents(records): :param records: queue records to process """ - excluded_categories = get_excluded_categories() + excluded_categories = get_excluded_categories(deep=True) excluded_categories_filter = Event.category_id.notin_(excluded_categories) if excluded_categories else True category_prot_records = {rec.category_id for rec in records if rec.type == EntryType.category diff --git a/livesync/indico_livesync/util.py b/livesync/indico_livesync/util.py index 8571a11..8555a61 100644 --- a/livesync/indico_livesync/util.py +++ b/livesync/indico_livesync/util.py @@ -9,8 +9,9 @@ from datetime import timedelta from werkzeug.datastructures import ImmutableDict +from indico.core.db import db from indico.modules.attachments.models.attachments import Attachment -from indico.modules.categories.models.categories import Category +from indico.modules.categories import Category from indico.modules.events import Event from indico.modules.events.contributions.models.contributions import Contribution from indico.modules.events.contributions.models.subcontributions import SubContribution @@ -78,7 +79,17 @@ def clean_old_entries(): @memoize_request -def get_excluded_categories(): - """Get excluded category IDs.""" +def get_excluded_categories(*, deep=False): + """Get excluded category IDs. + + :param deep: Whether to get all subcategory ids as well + """ from indico_livesync.plugin import LiveSyncPlugin - return {int(x['id']) for x in LiveSyncPlugin.settings.get('excluded_categories')} + ids = {int(x['id']) for x in LiveSyncPlugin.settings.get('excluded_categories')} + if not deep or not ids: + return ids + cte = Category.get_tree_cte() + query = (db.session.query(Category.id) + .join(cte, Category.id == cte.c.id) + .filter(cte.c.path.overlap(ids), ~cte.c.is_deleted)) + return {x.id for x in query} diff --git a/livesync/tests/util_test.py b/livesync/tests/util_test.py index 281ca0e..af80adc 100644 --- a/livesync/tests/util_test.py +++ b/livesync/tests/util_test.py @@ -11,7 +11,7 @@ from indico.util.date_time import now_utc from indico_livesync.models.queue import ChangeType, EntryType, LiveSyncQueueEntry from indico_livesync.plugin import LiveSyncPlugin -from indico_livesync.util import clean_old_entries +from indico_livesync.util import clean_old_entries, get_excluded_categories def test_clean_old_entries(dummy_event, db, dummy_agent): @@ -34,3 +34,17 @@ def test_clean_old_entries(dummy_event, db, dummy_agent): clean_old_entries() assert LiveSyncQueueEntry.query.filter_by(processed=False).count() == 10 assert LiveSyncQueueEntry.query.filter_by(processed=True).count() == 3 + + +def test_get_excluded_categories(dummy_category, create_category): + cat1 = create_category() + cat1a = create_category(parent=cat1) + cat1aa = create_category(parent=cat1a) + cat1b = create_category(parent=cat1) + cat2 = create_category() + LiveSyncPlugin.settings.set('excluded_categories', [{'id': cat1.id}]) + assert get_excluded_categories() == {cat1.id} + assert get_excluded_categories(deep=True) == {cat1.id, cat1a.id, cat1aa.id, cat1b.id} + LiveSyncPlugin.settings.set('excluded_categories', [{'id': 0}]) + assert get_excluded_categories() == {0} + assert get_excluded_categories(deep=True) == {cat1.id, cat1a.id, cat1aa.id, cat1b.id, cat2.id, dummy_category.id, 0}