diff --git a/aleksis/core/util/celery_progress.py b/aleksis/core/util/celery_progress.py index 378f63b8f55290bd1a140f508e06d2b5e9e1f4e8..ac48c7c56001a51696c8099950ed799c6a2a9cea 100644 --- a/aleksis/core/util/celery_progress.py +++ b/aleksis/core/util/celery_progress.py @@ -21,9 +21,9 @@ class ProgressRecorder(AbstractProgressRecorder): from django.contrib import messages - from aleksis.core.util.celery_progress import ProgressRecorder + from aleksis.core.util.celery_progress import recorded_task - @ProgressRecorder.recorded_task + @recorded_task def do_something(foo, bar, recorder, baz=None): # ... recorder.total = len(list_with_data) @@ -98,16 +98,19 @@ class ProgressRecorder(AbstractProgressRecorder): self.messages.append((level, message)) self.set_progress(self.current) - @classmethod - def recorded_task(cls, orig: Callable) -> app.Task: - """Create a Celery task that receives a ProgressRecorder. - Returns a Task object with a wrapper that passes the recorder instance - as the recorder keyword argument. - """ - @wraps(orig) - def _inject_recorder(task, *args, **kwargs): - recorder = ProgressRecorder(task) - return orig(*args, **kwargs, recorder=recorder) +def recorded_task(orig: Callable, *args, **kwargs) -> app.Task: + """Create a Celery task that receives a ProgressRecorder. + + Returns a Task object with a wrapper that passes the recorder instance + as the recorder keyword argument. + """ + + @wraps(orig) + def _inject_recorder(task, *args, **kwargs): + recorder = ProgressRecorder(task) + return orig(*args, **kwargs, recorder=recorder) - return app.task(_inject_recorder, bind=True) + # Force bind to True because _inject_recorder needs the Task object + kwargs["bind"] = True + return app.task(_inject_recorder, *args, **kwargs)