From 791928134243b9fd04f9861ab8396a7861d6cd9d Mon Sep 17 00:00:00 2001
From: Jonathan Weth <git@jonathanweth.de>
Date: Sat, 2 Nov 2024 13:12:18 +0100
Subject: [PATCH] Support additional filter in with_occurences

---
 django_pg_rrule/managers.py | 47 +++++++++++++++++++++++++------------
 1 file changed, 32 insertions(+), 15 deletions(-)

diff --git a/django_pg_rrule/managers.py b/django_pg_rrule/managers.py
index aab4256..d3436db 100644
--- a/django_pg_rrule/managers.py
+++ b/django_pg_rrule/managers.py
@@ -14,7 +14,12 @@ class RecurrenceManager(CTEManager):
     datetime_end_field = "datetime_end"
     until_field = "rrule_until"
 
-    def with_occurrences(self, start: datetime | None, end: datetime | None):
+    def with_occurrences(
+        self,
+        start: datetime | None = None,
+        end: datetime | None = None,
+        additional_filter: Q | None = None,
+    ):
         """Evaluate rrules and annotate all occurrences."""
 
         # Annotate occurrences of datetimes
@@ -30,6 +35,17 @@ class RecurrenceManager(CTEManager):
                 )
             )
 
+        with_qs_rdatetimes = self.filter(
+            **{
+                f"{self.datetime_start_field}__isnull": False,
+                "rdatetimes__isnull": False,
+            }
+        )
+
+        if additional_filter:
+            with_qs_datetimes = with_qs_datetimes.filter(additional_filter)
+            with_qs_rdatetimes = with_qs_rdatetimes.filter(additional_filter)
+
         cte_datetimes = With(
             with_qs_datetimes.only("id")
             # Get occurrences
@@ -46,14 +62,9 @@ class RecurrenceManager(CTEManager):
             )
             # Combine with rdatetimes
             .union(
-                self.filter(
-                    **{
-                        f"{self.datetime_start_field}__isnull": False,
-                        "rdatetimes__isnull": False,
-                    }
+                with_qs_rdatetimes.only("id").annotate(
+                    odatetime=Func(F("rdatetimes"), function="UNNEST")
                 )
-                .only("id")
-                .annotate(odatetime=Func(F("rdatetimes"), function="UNNEST"))
             ),
             name="qodatetimes",
         )
@@ -71,6 +82,17 @@ class RecurrenceManager(CTEManager):
                 )
             )
 
+        with_qs_rdates = self.filter(
+            **{
+                f"{self.date_start_field}__isnull": False,
+                "rdates__isnull": False,
+            }
+        )
+
+        if additional_filter:
+            with_qs_dates = with_qs_dates.filter(additional_filter)
+            with_qs_rdates = with_qs_rdates.filter(additional_filter)
+
         cte_dates = With(
             with_qs_dates.only("id")
             # Get occurrences
@@ -87,14 +109,9 @@ class RecurrenceManager(CTEManager):
             )
             # Combine with rdates
             .union(
-                self.filter(
-                    **{
-                        f"{self.date_start_field}__isnull": False,
-                        "rdates__isnull": False,
-                    }
+                with_qs_rdates.only("id").annotate(
+                    odate=Func(F("rdates"), function="UNNEST")
                 )
-                .only("id")
-                .annotate(odate=Func(F("rdates"), function="UNNEST"))
             ),
             name="qodates",
         )
-- 
GitLab