From 996d7a9761dcc6665249867057444f634d7bf73c Mon Sep 17 00:00:00 2001
From: Hangzhi Yu <hangzhi@protonmail.com>
Date: Wed, 19 May 2021 10:34:52 +0200
Subject: [PATCH] Filtering by rules

Reformat
---
 aleksis/core/util/core_helpers.py | 59 ++-----------------------------
 aleksis/core/views.py             | 52 ++++++++++++++++-----------
 2 files changed, 33 insertions(+), 78 deletions(-)

diff --git a/aleksis/core/util/core_helpers.py b/aleksis/core/util/core_helpers.py
index 04319fa61..46680d08a 100644
--- a/aleksis/core/util/core_helpers.py
+++ b/aleksis/core/util/core_helpers.py
@@ -2,10 +2,10 @@ from datetime import datetime, timedelta
 from importlib import import_module, metadata
 from itertools import groupby
 from operator import itemgetter
-from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union
+from typing import Any, Callable, Optional, Sequence, Union
 
 from django.conf import settings
-from django.db.models import Model, Q, QuerySet
+from django.db.models import Model, QuerySet
 from django.http import HttpRequest
 from django.shortcuts import get_object_or_404
 from django.utils import timezone
@@ -13,11 +13,6 @@ from django.utils.functional import lazy
 
 from cache_memoize import cache_memoize
 
-if TYPE_CHECKING:
-    from django.contrib.auth import get_user_model
-
-    User = get_user_model()  # noqa
-
 
 def copyright_years(years: Sequence[int], seperator: str = ", ", joiner: str = "–") -> str:
     """Take a sequence of integegers and produces a string with ranges.
@@ -254,53 +249,3 @@ def queryset_rules_filter(
 def unread_notifications_badge(request: HttpRequest) -> int:
     """Generate badge content with the number of unread notifications."""
     return request.user.person.unread_notifications_count
-
-
-def get_persons_for_user(user: "User"):
-    """Get all persons the given user is allowed to view."""
-
-    from guardian.core import ObjectPermissionChecker
-
-    from ..models import Person
-    from .predicates import check_global_permission
-
-    checker = ObjectPermissionChecker(user)
-
-    allowed_persons = Person.objects.filter(is_active=True)
-    if not check_global_permission(user, "core.view_person"):
-        checker.prefetch_perms(allowed_persons)
-
-        obj_perm_persons = set()
-
-        for allowed_person in allowed_persons:
-            if checker.has_perm("core.view_person", allowed_person):
-                obj_perm_persons.add(allowed_person.pk)
-
-        allowed_persons = allowed_persons.filter(Q(pk__in=obj_perm_persons) | Q(pk=user.person.pk))
-
-    return allowed_persons
-
-
-def get_groups_for_user(user: "User"):
-    """Get all groups the given user is allowed to view."""
-
-    from guardian.core import ObjectPermissionChecker
-
-    from ..models import Group
-    from .predicates import check_global_permission
-
-    checker = ObjectPermissionChecker(user)
-
-    allowed_groups = Group.objects.all()
-    if not check_global_permission(user, "core.view_group"):
-        checker.prefetch_perms(allowed_groups)
-
-        obj_perm_groups = set()
-
-        for allowed_group in allowed_groups:
-            if checker.has_perm("core.view_group", allowed_group):
-                obj_perm_groups.add(allowed_group.pk)
-
-        allowed_groups = allowed_groups.filter(pk__in=obj_perm_groups)
-
-    return allowed_groups
diff --git a/aleksis/core/views.py b/aleksis/core/views.py
index d1d11023e..edc67b120 100644
--- a/aleksis/core/views.py
+++ b/aleksis/core/views.py
@@ -28,6 +28,7 @@ from guardian.shortcuts import get_objects_for_user
 from haystack.generic_views import SearchView
 from haystack.inputs import AutoQuery
 from haystack.query import SearchQuerySet
+from haystack.utils.loading import UnifiedIndex
 from health_check.views import MainView
 from oauth2_provider.models import Application
 from reversion import set_user
@@ -86,11 +87,10 @@ from .util import messages
 from .util.apps import AppConfig
 from .util.celery_progress import render_progress_page
 from .util.core_helpers import (
-    get_groups_for_user,
-    get_persons_for_user,
     get_site_preferences,
     has_person,
     objectgetter_optional,
+    queryset_rules_filter,
 )
 from .util.forms import PreferenceLayout
 from .util.pdf import render_pdf
@@ -565,18 +565,22 @@ def searchbar_snippets(request: HttpRequest) -> HttpResponse:
     query = request.GET.get("q", "")
     limit = int(request.GET.get("limit", "5"))
 
-    allowed_person_ids = [
-        f"core.person.{pk}"
-        for pk in get_persons_for_user(request.user).values_list("pk", flat=True)
-    ]
-    allowed_group_ids = [
-        f"core.group.{pk}" for pk in get_groups_for_user(request.user).values_list("pk", flat=True)
-    ]
+    indexed_models = UnifiedIndex().get_indexed_models()
+
+    allowed_object_ids = []
+
+    for model in indexed_models:
+        app_label = ContentType.objects.get_for_model(model).app_label
+        model_name = ContentType.objects.get_for_model(model).model
+        allowed_object_ids += [
+            f"{app_label}.{model_name}.{pk}"
+            for pk in queryset_rules_filter(
+                request, model.objects.all(), f"{app_label}.view_{model_name}"
+            ).values_list("pk", flat=True)
+        ]
 
     results = (
-        SearchQuerySet()
-        .filter(id__in=allowed_person_ids + allowed_group_ids)
-        .filter(text=AutoQuery(query))[:limit]
+        SearchQuerySet().filter(id__in=allowed_object_ids).filter(text=AutoQuery(query))[:limit]
     )
 
     context = {"results": results}
@@ -592,16 +596,22 @@ class PermissionSearchView(PermissionRequiredMixin, SearchView):
     def get_context_data(self, *, object_list=None, **kwargs):
         queryset = object_list if object_list is not None else self.object_list
 
-        allowed_person_ids = [
-            f"core.person.{pk}"
-            for pk in get_persons_for_user(self.request.user).values_list("pk", flat=True)
-        ]
-        allowed_group_ids = [
-            f"core.group.{pk}"
-            for pk in get_groups_for_user(self.request.user).values_list("pk", flat=True)
-        ]
+        indexed_models = UnifiedIndex().get_indexed_models()
+
+        allowed_object_ids = []
+
+        for model in indexed_models:
+
+            app_label = ContentType.objects.get_for_model(model).app_label
+            model_name = ContentType.objects.get_for_model(model).model
+            allowed_object_ids += [
+                f"{app_label}.{model_name}.{pk}"
+                for pk in queryset_rules_filter(
+                    self.request, model.objects.all(), f"{app_label}.view_{model_name}"
+                ).values_list("pk", flat=True)
+            ]
 
-        queryset = queryset.filter(id__in=allowed_person_ids + allowed_group_ids)
+        queryset = queryset.filter(id__in=allowed_object_ids)
 
         return super().get_context_data(object_list=queryset, **kwargs)
 
-- 
GitLab