From 996d7a9761dcc6665249867057444f634d7bf73c Mon Sep 17 00:00:00 2001 From: Hangzhi Yu <hangzhi@protonmail.com> Date: Wed, 19 May 2021 10:34:52 +0200 Subject: [PATCH] Filtering by rules Reformat --- aleksis/core/util/core_helpers.py | 59 ++----------------------------- aleksis/core/views.py | 52 ++++++++++++++++----------- 2 files changed, 33 insertions(+), 78 deletions(-) diff --git a/aleksis/core/util/core_helpers.py b/aleksis/core/util/core_helpers.py index 04319fa61..46680d08a 100644 --- a/aleksis/core/util/core_helpers.py +++ b/aleksis/core/util/core_helpers.py @@ -2,10 +2,10 @@ from datetime import datetime, timedelta from importlib import import_module, metadata from itertools import groupby from operator import itemgetter -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union from django.conf import settings -from django.db.models import Model, Q, QuerySet +from django.db.models import Model, QuerySet from django.http import HttpRequest from django.shortcuts import get_object_or_404 from django.utils import timezone @@ -13,11 +13,6 @@ from django.utils.functional import lazy from cache_memoize import cache_memoize -if TYPE_CHECKING: - from django.contrib.auth import get_user_model - - User = get_user_model() # noqa - def copyright_years(years: Sequence[int], seperator: str = ", ", joiner: str = "–") -> str: """Take a sequence of integegers and produces a string with ranges. @@ -254,53 +249,3 @@ def queryset_rules_filter( def unread_notifications_badge(request: HttpRequest) -> int: """Generate badge content with the number of unread notifications.""" return request.user.person.unread_notifications_count - - -def get_persons_for_user(user: "User"): - """Get all persons the given user is allowed to view.""" - - from guardian.core import ObjectPermissionChecker - - from ..models import Person - from .predicates import check_global_permission - - checker = ObjectPermissionChecker(user) - - allowed_persons = Person.objects.filter(is_active=True) - if not check_global_permission(user, "core.view_person"): - checker.prefetch_perms(allowed_persons) - - obj_perm_persons = set() - - for allowed_person in allowed_persons: - if checker.has_perm("core.view_person", allowed_person): - obj_perm_persons.add(allowed_person.pk) - - allowed_persons = allowed_persons.filter(Q(pk__in=obj_perm_persons) | Q(pk=user.person.pk)) - - return allowed_persons - - -def get_groups_for_user(user: "User"): - """Get all groups the given user is allowed to view.""" - - from guardian.core import ObjectPermissionChecker - - from ..models import Group - from .predicates import check_global_permission - - checker = ObjectPermissionChecker(user) - - allowed_groups = Group.objects.all() - if not check_global_permission(user, "core.view_group"): - checker.prefetch_perms(allowed_groups) - - obj_perm_groups = set() - - for allowed_group in allowed_groups: - if checker.has_perm("core.view_group", allowed_group): - obj_perm_groups.add(allowed_group.pk) - - allowed_groups = allowed_groups.filter(pk__in=obj_perm_groups) - - return allowed_groups diff --git a/aleksis/core/views.py b/aleksis/core/views.py index d1d11023e..edc67b120 100644 --- a/aleksis/core/views.py +++ b/aleksis/core/views.py @@ -28,6 +28,7 @@ from guardian.shortcuts import get_objects_for_user from haystack.generic_views import SearchView from haystack.inputs import AutoQuery from haystack.query import SearchQuerySet +from haystack.utils.loading import UnifiedIndex from health_check.views import MainView from oauth2_provider.models import Application from reversion import set_user @@ -86,11 +87,10 @@ from .util import messages from .util.apps import AppConfig from .util.celery_progress import render_progress_page from .util.core_helpers import ( - get_groups_for_user, - get_persons_for_user, get_site_preferences, has_person, objectgetter_optional, + queryset_rules_filter, ) from .util.forms import PreferenceLayout from .util.pdf import render_pdf @@ -565,18 +565,22 @@ def searchbar_snippets(request: HttpRequest) -> HttpResponse: query = request.GET.get("q", "") limit = int(request.GET.get("limit", "5")) - allowed_person_ids = [ - f"core.person.{pk}" - for pk in get_persons_for_user(request.user).values_list("pk", flat=True) - ] - allowed_group_ids = [ - f"core.group.{pk}" for pk in get_groups_for_user(request.user).values_list("pk", flat=True) - ] + indexed_models = UnifiedIndex().get_indexed_models() + + allowed_object_ids = [] + + for model in indexed_models: + app_label = ContentType.objects.get_for_model(model).app_label + model_name = ContentType.objects.get_for_model(model).model + allowed_object_ids += [ + f"{app_label}.{model_name}.{pk}" + for pk in queryset_rules_filter( + request, model.objects.all(), f"{app_label}.view_{model_name}" + ).values_list("pk", flat=True) + ] results = ( - SearchQuerySet() - .filter(id__in=allowed_person_ids + allowed_group_ids) - .filter(text=AutoQuery(query))[:limit] + SearchQuerySet().filter(id__in=allowed_object_ids).filter(text=AutoQuery(query))[:limit] ) context = {"results": results} @@ -592,16 +596,22 @@ class PermissionSearchView(PermissionRequiredMixin, SearchView): def get_context_data(self, *, object_list=None, **kwargs): queryset = object_list if object_list is not None else self.object_list - allowed_person_ids = [ - f"core.person.{pk}" - for pk in get_persons_for_user(self.request.user).values_list("pk", flat=True) - ] - allowed_group_ids = [ - f"core.group.{pk}" - for pk in get_groups_for_user(self.request.user).values_list("pk", flat=True) - ] + indexed_models = UnifiedIndex().get_indexed_models() + + allowed_object_ids = [] + + for model in indexed_models: + + app_label = ContentType.objects.get_for_model(model).app_label + model_name = ContentType.objects.get_for_model(model).model + allowed_object_ids += [ + f"{app_label}.{model_name}.{pk}" + for pk in queryset_rules_filter( + self.request, model.objects.all(), f"{app_label}.view_{model_name}" + ).values_list("pk", flat=True) + ] - queryset = queryset.filter(id__in=allowed_person_ids + allowed_group_ids) + queryset = queryset.filter(id__in=allowed_object_ids) return super().get_context_data(object_list=queryset, **kwargs) -- GitLab