diff --git a/aleksis/core/tasks.py b/aleksis/core/tasks.py index f50300a9d52d3829065eca4bd14a12b88641941e..2ffa51ff55a340911a15acbad0cdcf66ff5b1942 100644 --- a/aleksis/core/tasks.py +++ b/aleksis/core/tasks.py @@ -1,8 +1,6 @@ from django.conf import settings from django.core import management -from celery.decorators import task - from .celery import app from .util.celery_progress import ProgressRecorder from .util.notifications import send_notification as _send_notification diff --git a/aleksis/core/util/celery_progress.py b/aleksis/core/util/celery_progress.py index 9de760eeb56b8ce535a0cd8e88f7380c5411f78d..d3a4ecef5d8e121e9c0c4d513f4e02f902a90f29 100644 --- a/aleksis/core/util/celery_progress.py +++ b/aleksis/core/util/celery_progress.py @@ -1,34 +1,157 @@ -from decimal import Decimal -from typing import Union +from functools import wraps +from numbers import Number +from typing import Callable, Generator, Iterable, Optional, Sequence, Union + +from django.contrib import messages from celery_progress.backend import PROGRESS_STATE, AbstractProgressRecorder +from ..celery import app + class ProgressRecorder(AbstractProgressRecorder): + """Track the progress of a Celery task and give data to the frontend. + + This recorder provides the functions `set_progress` and `add_message` + which can be used to track the status of a Celery task. + + How to use + ---------- + 1. Write a function and include tracking methods + + :: + + from django.contrib import messages + + from aleksis.core.util.celery_progress import recorded_task + + @recorded_task + def do_something(foo, bar, recorder, baz=None): + # ... + recorder.set_progress(total=len(list_with_data)) + + for i, item in enumerate(list_with_data): + # ... + recorder.set_progress(i+1) + # ... + + recorder.add_message(messages.SUCCESS, "All data were imported successfully.") + + You can also use `recorder.iterate` to simplify iterating and counting. + + 2. Track progress in view: + + :: + + def my_view(request): + context = {} + # ... + result = do_something.delay(foo, bar, baz=baz) + + context = { + "title": _("Progress: Import data"), + "back_url": reverse("index"), + "progress": { + "task_id": result.task_id, + "title": _("Import objects …"), + "success": _("The import was done successfully."), + "error": _("There was a problem while importing data."), + }, + } + + # Render progress view + return render(request, "core/progress.html", context) + """ + def __init__(self, task): self.task = task - self.messages = [] - self.total = 100 - self.current = 0 + self._messages = [] + self._current = 0 + self._total = 100 + + def iterate(self, data: Union[Iterable, Sequence], total: Optional[int] = None) -> Generator: + """Iterate over a sequence or iterable, updating progress on the move. + + :: + + @recorded_task + def do_something(long_list, recorder): + for item in recorder.iterate(long_list): + do_something_with(item) + + :param data: A sequence (tuple, list, set,...) or an iterable + :param total: Total number of items, in case data does not support len() + """ + if total is None and hasattr(data, "__len__"): + total = len(data) + else: + raise TypeError("No total value passed, and data does not support len()") + + for current, item in enumerate(data): + self.set_progress(current, total) + yield item + + def set_progress( + self, + current: Optional[Number] = None, + total: Optional[Number] = None, + description: Optional[str] = None, + level: int = messages.INFO, + ): + """Set the current progress in the frontend. + + The progress percentage is automatically calculated in relation to self.total. - def set_progress(self, current: Union[int, float], **kwargs): - self.current = current + :param current: The number of processed items; relative to total, default unchanged + :param total: The total number of items (or 100 if using a percentage), default unchanged + :param description: A textual description, routed to the frontend as an INFO message + """ + if current is not None: + self._current = current + if total is not None: + self._total = total percent = 0 - if self.total > 0: - percent = (Decimal(current) / Decimal(self.total)) * Decimal(100) - percent = float(round(percent, 2)) + if self._total > 0: + percent = self._current / self._total + + if description is not None: + self._messages.append((level, description)) self.task.update_state( state=PROGRESS_STATE, meta={ - "current": current, - "total": self.total, + "current": self._current, + "total": self._total, "percent": percent, - "messages": self.messages, + "messages": self._messages, }, ) - def add_message(self, level: int, message: str, **kwargs): - self.messages.append((level, message)) - self.set_progress(self.current) + def add_message(self, level: int, message: str) -> None: + """Show a message in the progress frontend. + + This method is a shortcut for set_progress with no new progress arguments, + passing only the message and level as description. + + :param level: The message level (default levels from django.contrib.messages) + :param message: The actual message (should be translated) + """ + self.set_progress(description=message, level=level) + + +def recorded_task(orig: Callable, *args, **kwargs) -> app.Task: + """Create a Celery task that receives a ProgressRecorder. + + Returns a Task object with a wrapper that passes the recorder instance + as the recorder keyword argument. + """ + + @wraps(orig) + def _inject_recorder(task, *args, **kwargs): + recorder = ProgressRecorder(task) + return orig(*args, **kwargs, recorder=recorder) + + # Force bind to True because _inject_recorder needs the Task object + kwargs["bind"] = True + return app.task(_inject_recorder, *args, **kwargs)