diff --git a/citadel/indico_citadel/backend.py b/citadel/indico_citadel/backend.py index 073b939..9109ec8 100644 --- a/citadel/indico_citadel/backend.py +++ b/citadel/indico_citadel/backend.py @@ -310,8 +310,8 @@ class LiveSyncCitadelBackend(LiveSyncBackendBase): query = super().get_data_query(model_cls, ids) return query.options(joinedload(model_cls.citadel_id_mapping)) - def process_queue(self, uploader): - super().process_queue(uploader) + def process_queue(self, uploader, allowed_categories=()): + super().process_queue(uploader, allowed_categories) uploader_name = type(uploader).__name__ self.plugin.logger.info(f'{uploader_name} starting file upload') total, errors, aborted = self.run_export_files(verbose=False) diff --git a/livesync/indico_livesync/base.py b/livesync/indico_livesync/base.py index 6274c45..2771dac 100644 --- a/livesync/indico_livesync/base.py +++ b/livesync/indico_livesync/base.py @@ -97,13 +97,16 @@ class LiveSyncBackendBase: return True, None return False, 'initial export not performed' - def fetch_records(self): + def fetch_records(self, allowed_categories=()): query = (self.agent.queue .filter(~LiveSyncQueueEntry.processed) .order_by(LiveSyncQueueEntry.timestamp)) if LiveSyncPlugin.settings.get('skip_category_changes'): LiveSyncPlugin.logger.warning('Category changes are currently being skipped') - query = query.filter(LiveSyncQueueEntry.type != EntryType.category) + whitelist_filter = False + if allowed_categories: + whitelist_filter = LiveSyncQueueEntry.category_id.in_(allowed_categories) + query = query.filter((LiveSyncQueueEntry.type != EntryType.category) | whitelist_filter) return query.all() def update_last_run(self): @@ -113,20 +116,20 @@ class LiveSyncBackendBase: """ self.agent.last_run = now_utc() - def process_queue(self, uploader): + def process_queue(self, uploader, allowed_categories=()): """Process queued entries during an export run.""" - records = self.fetch_records() + records = self.fetch_records(allowed_categories) LiveSyncPlugin.logger.info(f'Uploading %d records via {self.uploader.__name__}', len(records)) uploader.run(records) - def run(self, verbose=False, from_cli=False): + def run(self, verbose=False, from_cli=False, allowed_categories=()): """Runs the livesync export""" if self.uploader is None: # pragma: no cover raise NotImplementedError uploader = self.uploader(self, verbose=verbose, from_cli=from_cli) self._precache_categories() - self.process_queue(uploader) + self.process_queue(uploader, allowed_categories) self.update_last_run() def get_initial_query(self, model_cls, force): diff --git a/livesync/indico_livesync/cli.py b/livesync/indico_livesync/cli.py index 1eb8146..8c2a976 100644 --- a/livesync/indico_livesync/cli.py +++ b/livesync/indico_livesync/cli.py @@ -116,11 +116,20 @@ def initial_export(agent_id, batch, force, verbose, retry): @click.argument('agent_id', type=int, required=False) @click.option('--force', '-f', is_flag=True, help="Run even if initial export was not done") @click.option('--verbose', '-v', is_flag=True, help="Be more verbose (what this does is up to the backend)") -def run(agent_id, force, verbose): +@click.option('--allow-category', '-c', 'allowed_categories', multiple=True, type=int, + help="Process changes for the specified category id even if 'Skip category changes' is enabled. " + "This setting can be used multiple times.") +def run(agent_id, force, verbose, allowed_categories): """Runs the livesync agent""" from indico_livesync.plugin import LiveSyncPlugin + if LiveSyncPlugin.settings.get('disable_queue_runs'): print(cformat('%{yellow!}Queue runs are disabled%{reset}')) + if LiveSyncPlugin.settings.get('skip_category_changes'): + print(cformat('%{yellow!}Category changes are currently being skipped%{reset}')) + if allowed_categories: + print(cformat('Whitelisted categories: %{green}{}%{reset}') + .format(', '.join(map(str, sorted(allowed_categories))))) if agent_id is None: agent_list = LiveSyncAgent.query.all() @@ -142,7 +151,7 @@ def run(agent_id, force, verbose): continue print(cformat('Running agent: %{white!}{}%{reset}').format(agent.name)) try: - backend.run(verbose, from_cli=True) + backend.run(verbose, from_cli=True, allowed_categories=allowed_categories) db.session.commit() except Exception: db.session.rollback() diff --git a/livesync/tests/agent_test.py b/livesync/tests/agent_test.py index e0de6e8..31ee421 100644 --- a/livesync/tests/agent_test.py +++ b/livesync/tests/agent_test.py @@ -7,8 +7,11 @@ from unittest.mock import MagicMock +import pytest + from indico_livesync.base import LiveSyncBackendBase from indico_livesync.models.queue import ChangeType, EntryType, LiveSyncQueueEntry +from indico_livesync.plugin import LiveSyncPlugin class DummyBackend(LiveSyncBackendBase): @@ -54,3 +57,20 @@ def test_fetch_records(db, dummy_event, dummy_agent): dummy_agent.queue = queue db.session.flush() assert backend.fetch_records() == [queue[1]] + + +@pytest.mark.parametrize('disabled', (True, False)) +@pytest.mark.parametrize('whitelisted', (True, False)) +def test_fetch_records_categories_disabled(db, dummy_event, dummy_category, dummy_agent, disabled, whitelisted): + """Test if the correct records are fetched""" + backend = DummyBackend(dummy_agent) + queue = [ + LiveSyncQueueEntry(change=ChangeType.protection_changed, type=EntryType.category, category=dummy_category), + LiveSyncQueueEntry(change=ChangeType.created, type=EntryType.event, event=dummy_event) + ] + dummy_agent.queue = queue + LiveSyncPlugin.settings.set('skip_category_changes', disabled) + db.session.flush() + expected = queue[1:] if disabled and not whitelisted else queue + whitelist = (dummy_category.id,) if whitelisted else () + assert backend.fetch_records(whitelist) == expected diff --git a/livesync_debug/indico_livesync_debug/backend.py b/livesync_debug/indico_livesync_debug/backend.py index 904f185..cb17f2f 100644 --- a/livesync_debug/indico_livesync_debug/backend.py +++ b/livesync_debug/indico_livesync_debug/backend.py @@ -77,8 +77,8 @@ class LiveSyncDebugBackend(LiveSyncBackendBase): uploader = DebugUploader - def process_queue(self, uploader): - records = self.fetch_records() + def process_queue(self, uploader, allowed_categories=()): + records = self.fetch_records(allowed_categories) if not records: print(cformat('%{yellow!}No records%{reset}')) return