diff --git a/aleksis/core/util/celery_progress.py b/aleksis/core/util/celery_progress.py index d3a4ecef5d8e121e9c0c4d513f4e02f902a90f29..abd44f16609b9a526aa760fcfed7ba50d6ea9a69 100644 --- a/aleksis/core/util/celery_progress.py +++ b/aleksis/core/util/celery_progress.py @@ -140,18 +140,24 @@ class ProgressRecorder(AbstractProgressRecorder): self.set_progress(description=message, level=level) -def recorded_task(orig: Callable, *args, **kwargs) -> app.Task: +def recorded_task(*args, **kwargs) -> Union[Callable, app.Task]: """Create a Celery task that receives a ProgressRecorder. Returns a Task object with a wrapper that passes the recorder instance as the recorder keyword argument. """ - @wraps(orig) - def _inject_recorder(task, *args, **kwargs): - recorder = ProgressRecorder(task) - return orig(*args, **kwargs, recorder=recorder) + def _real_decorator(orig: Callable) -> app.Task: + @wraps(orig) + def _inject_recorder(task, *args, **kwargs): + recorder = ProgressRecorder(task) + return orig(*args, **kwargs, recorder=recorder) - # Force bind to True because _inject_recorder needs the Task object - kwargs["bind"] = True - return app.task(_inject_recorder, *args, **kwargs) + # Force bind to True because _inject_recorder needs the Task object + kwargs["bind"] = True + return app.task(_inject_recorder, *args, **kwargs) + + if len(args) == 1 and isinstance(args[0], Callable) and not kwargs: + return _real_decorator(args[0]) + else: + return _real_decorator