From 1a0e026138e9139f68a1152e87f9de9d6933b748 Mon Sep 17 00:00:00 2001
From: Hangzhi Yu <hangzhi@protonmail.com>
Date: Sun, 16 May 2021 16:47:13 +0200
Subject: [PATCH] Implement permission filtering

---
 aleksis/core/util/core_helpers.py | 59 +++++++++++++++++++++++++++++--
 aleksis/core/views.py             | 42 ++++++++++++++++++++--
 2 files changed, 96 insertions(+), 5 deletions(-)

diff --git a/aleksis/core/util/core_helpers.py b/aleksis/core/util/core_helpers.py
index 46680d08a..04319fa61 100644
--- a/aleksis/core/util/core_helpers.py
+++ b/aleksis/core/util/core_helpers.py
@@ -2,10 +2,10 @@ from datetime import datetime, timedelta
 from importlib import import_module, metadata
 from itertools import groupby
 from operator import itemgetter
-from typing import Any, Callable, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union
 
 from django.conf import settings
-from django.db.models import Model, QuerySet
+from django.db.models import Model, Q, QuerySet
 from django.http import HttpRequest
 from django.shortcuts import get_object_or_404
 from django.utils import timezone
@@ -13,6 +13,11 @@ from django.utils.functional import lazy
 
 from cache_memoize import cache_memoize
 
+if TYPE_CHECKING:
+    from django.contrib.auth import get_user_model
+
+    User = get_user_model()  # noqa
+
 
 def copyright_years(years: Sequence[int], seperator: str = ", ", joiner: str = "–") -> str:
     """Take a sequence of integegers and produces a string with ranges.
@@ -249,3 +254,53 @@ def queryset_rules_filter(
 def unread_notifications_badge(request: HttpRequest) -> int:
     """Generate badge content with the number of unread notifications."""
     return request.user.person.unread_notifications_count
+
+
+def get_persons_for_user(user: "User"):
+    """Get all persons the given user is allowed to view."""
+
+    from guardian.core import ObjectPermissionChecker
+
+    from ..models import Person
+    from .predicates import check_global_permission
+
+    checker = ObjectPermissionChecker(user)
+
+    allowed_persons = Person.objects.filter(is_active=True)
+    if not check_global_permission(user, "core.view_person"):
+        checker.prefetch_perms(allowed_persons)
+
+        obj_perm_persons = set()
+
+        for allowed_person in allowed_persons:
+            if checker.has_perm("core.view_person", allowed_person):
+                obj_perm_persons.add(allowed_person.pk)
+
+        allowed_persons = allowed_persons.filter(Q(pk__in=obj_perm_persons) | Q(pk=user.person.pk))
+
+    return allowed_persons
+
+
+def get_groups_for_user(user: "User"):
+    """Get all groups the given user is allowed to view."""
+
+    from guardian.core import ObjectPermissionChecker
+
+    from ..models import Group
+    from .predicates import check_global_permission
+
+    checker = ObjectPermissionChecker(user)
+
+    allowed_groups = Group.objects.all()
+    if not check_global_permission(user, "core.view_group"):
+        checker.prefetch_perms(allowed_groups)
+
+        obj_perm_groups = set()
+
+        for allowed_group in allowed_groups:
+            if checker.has_perm("core.view_group", allowed_group):
+                obj_perm_groups.add(allowed_group.pk)
+
+        allowed_groups = allowed_groups.filter(pk__in=obj_perm_groups)
+
+    return allowed_groups
diff --git a/aleksis/core/views.py b/aleksis/core/views.py
index 5dd11a398..d1d11023e 100644
--- a/aleksis/core/views.py
+++ b/aleksis/core/views.py
@@ -23,10 +23,11 @@ from celery_progress.views import get_progress
 from django_celery_results.models import TaskResult
 from django_tables2 import RequestConfig, SingleTableView
 from dynamic_preferences.forms import preference_form_builder
+from guardian.core import ObjectPermissionChecker
 from guardian.shortcuts import get_objects_for_user
+from haystack.generic_views import SearchView
 from haystack.inputs import AutoQuery
 from haystack.query import SearchQuerySet
-from haystack.generic_views import SearchView
 from health_check.views import MainView
 from oauth2_provider.models import Application
 from reversion import set_user
@@ -84,7 +85,13 @@ from .tables import (
 from .util import messages
 from .util.apps import AppConfig
 from .util.celery_progress import render_progress_page
-from .util.core_helpers import get_site_preferences, has_person, objectgetter_optional
+from .util.core_helpers import (
+    get_groups_for_user,
+    get_persons_for_user,
+    get_site_preferences,
+    has_person,
+    objectgetter_optional,
+)
 from .util.forms import PreferenceLayout
 from .util.pdf import render_pdf
 
@@ -558,7 +565,20 @@ def searchbar_snippets(request: HttpRequest) -> HttpResponse:
     query = request.GET.get("q", "")
     limit = int(request.GET.get("limit", "5"))
 
-    results = SearchQuerySet().filter(text=AutoQuery(query))[:limit]
+    allowed_person_ids = [
+        f"core.person.{pk}"
+        for pk in get_persons_for_user(request.user).values_list("pk", flat=True)
+    ]
+    allowed_group_ids = [
+        f"core.group.{pk}" for pk in get_groups_for_user(request.user).values_list("pk", flat=True)
+    ]
+
+    results = (
+        SearchQuerySet()
+        .filter(id__in=allowed_person_ids + allowed_group_ids)
+        .filter(text=AutoQuery(query))[:limit]
+    )
+
     context = {"results": results}
 
     return render(request, "search/searchbar_snippets.html", context)
@@ -569,6 +589,22 @@ class PermissionSearchView(PermissionRequiredMixin, SearchView):
 
     permission_required = "core.search"
 
+    def get_context_data(self, *, object_list=None, **kwargs):
+        queryset = object_list if object_list is not None else self.object_list
+
+        allowed_person_ids = [
+            f"core.person.{pk}"
+            for pk in get_persons_for_user(self.request.user).values_list("pk", flat=True)
+        ]
+        allowed_group_ids = [
+            f"core.group.{pk}"
+            for pk in get_groups_for_user(self.request.user).values_list("pk", flat=True)
+        ]
+
+        queryset = queryset.filter(id__in=allowed_person_ids + allowed_group_ids)
+
+        return super().get_context_data(object_list=queryset, **kwargs)
+
 
 @never_cache
 def preferences(
-- 
GitLab