Coverage for sm / mixins.py: 31%
113 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 13:46 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 13:46 +0000
1from django.db.models import Q
2from django.contrib import messages
3from django.utils.translation import gettext as _
4from typing import Any, Optional
5from django.db.models.query import QuerySet
6from django.forms import ModelForm
7from django.contrib.auth.models import Group
8from django.db import transaction
11def get_tenant_model_counts(group: Optional[Group]) -> int:
12 """Helper function to count tenant items across all models for quota checking."""
13 if not group:
14 return 0
16 from server.models import Model as Server
17 from cluster.models import Model as Cluster
18 from domain.models import Model as Domain
19 from vendor.models import Model as Vendor
20 from operatingsystem.models import Model as OS
21 from status.models import Model as Status
22 from location.models import Model as Location
23 from patchtime.models import Model as Patchtime
24 from servermodel.models import Model as ServerModel
25 from clusterpackage.models import Model as ClusterPackage
26 from clustersoftware.models import Model as ClusterSoftware
27 from clusterpackagetype.models import Model as ClusterPackageType
29 return (
30 Server.objects.filter(group=group).count()
31 + Cluster.objects.filter(group=group).count()
32 + Domain.objects.filter(group=group).count()
33 + Vendor.objects.filter(group=group).count()
34 + OS.objects.filter(group=group).count()
35 + Status.objects.filter(group=group).count()
36 + Location.objects.filter(group=group).count()
37 + Patchtime.objects.filter(group=group).count()
38 + ServerModel.objects.filter(group=group).count()
39 + ClusterPackage.objects.filter(group=group).count()
40 + ClusterSoftware.objects.filter(group=group).count()
41 + ClusterPackageType.objects.filter(group=group).count()
42 )
45class MultiTenantMixin:
46 """
47 Mixin to filter querysets by user groups and auto-assign group on save.
48 Enforces item quotas per group.
49 """
51 def get_queryset(self) -> QuerySet:
52 # Check basic view permission for the model
53 model = getattr(self, "model", None)
54 if model and not self.request.user.is_superuser:
55 opts = model._meta
56 codename = f"view_{opts.model_name.lower()}"
57 if not self.request.user.has_perm(f"{opts.app_label}.{codename}"):
58 from django.core.exceptions import PermissionDenied
60 raise PermissionDenied
62 queryset = super().get_queryset() # type: ignore
63 if self.request.user.is_superuser:
64 return queryset
66 selected_groups = self.request.session.get("selected_groups", [])
67 user_groups = self.request.user.groups.all()
69 if selected_groups:
70 group_ids = [int(g) for g in selected_groups if g.isdigit()]
71 if group_ids:
72 return queryset.filter(
73 Q(group__id__in=group_ids) | Q(group__isnull=True)
74 )
76 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True))
78 def get_context_data(self, **kwargs: Any) -> Any:
79 context = super().get_context_data(**kwargs) # type: ignore
80 model = getattr(self, "model", None)
81 # Only add history to context for ListViews that have it configured
82 if model and hasattr(model, "history") and hasattr(self, "object_list"):
83 history_qs = model.history.all()
85 if not self.request.user.is_superuser:
86 selected_groups = self.request.session.get("selected_groups", [])
87 user_groups = self.request.user.groups.all()
89 if selected_groups:
90 group_ids = [int(g) for g in selected_groups if g.isdigit()]
91 history_qs = history_qs.filter(
92 Q(group_id__in=group_ids) | Q(group_id__isnull=True)
93 )
94 else:
95 history_qs = history_qs.filter(
96 Q(group_id__in=user_groups) | Q(group_id__isnull=True)
97 )
99 context["recent_history"] = history_qs.order_by("-history_date")[:10]
100 return context
102 def check_quota(self, group: Optional[Group]) -> bool:
103 if not group or not hasattr(group, "profile"):
104 return True
106 profile = group.profile
107 max_items = profile.max_items
109 # Count items across models with transaction to prevent race conditions
110 with transaction.atomic():
111 # Lock the group profile to prevent concurrent modifications
112 GroupProfile = group.profile.__class__
113 GroupProfile.objects.select_for_update().get(pk=profile.pk)
115 count = get_tenant_model_counts(group)
117 return count < max_items
119 def form_valid(self, form: ModelForm) -> Any:
120 # Auto-assign first group if not set and not superuser
121 if not form.instance.group and not self.request.user.is_superuser:
122 user_groups = self.request.user.groups.all()
123 if user_groups.exists():
124 form.instance.group = user_groups.first()
126 # Check quota for NEW items
127 if not form.instance.pk:
128 if not self.check_quota(form.instance.group):
129 quota_limit = 0
130 if form.instance.group and hasattr(form.instance.group, "profile"):
131 quota_limit = form.instance.group.profile.max_items
132 messages.error(
133 self.request,
134 _("Quota exceeded for this group (%d items).") % quota_limit,
135 )
136 return self.form_invalid(form)
138 # Call super().form_valid(form) to let other mixins (like SuccessMessageMixin)
139 # or the base view handle the actual saving and response.
140 return super().form_valid(form) # type: ignore
143class APIMultiTenantMixin:
144 """
145 Mixin for DRF ViewSets to filter by user groups and auto-assign on create.
146 """
148 def get_queryset(self) -> QuerySet:
149 # Check basic view permission for the model
150 model = getattr(self, "model", None)
151 if model and not self.request.user.is_superuser:
152 opts = model._meta
153 codename = f"view_{opts.model_name.lower()}"
154 if not self.request.user.has_perm(f"{opts.app_label}.{codename}"):
155 from django.core.exceptions import PermissionDenied
157 raise PermissionDenied
159 queryset = super().get_queryset() # type: ignore
160 if self.request.user.is_superuser:
161 return queryset
163 selected_groups = self.request.session.get("selected_groups", [])
164 user_groups = self.request.user.groups.all()
166 if selected_groups:
167 group_ids = [int(g) for g in selected_groups if g.isdigit()]
168 if group_ids:
169 return queryset.filter(
170 Q(group__id__in=group_ids) | Q(group__isnull=True)
171 )
173 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True))
175 def perform_create(self, serializer: Any) -> None:
176 user_groups = self.request.user.groups.all()
177 group = user_groups.first() if user_groups.exists() else None
179 if not self.request.user.is_superuser:
180 # Simple quota check for API with transaction
181 if group and hasattr(group, "profile"):
182 # Use transaction to ensure atomic count
183 with transaction.atomic():
184 # Lock the group profile to prevent concurrent modifications
185 GroupProfile = group.profile.__class__
186 GroupProfile.objects.select_for_update().get(pk=group.profile.pk)
188 count = get_tenant_model_counts(group)
189 if count >= group.profile.max_items:
190 from rest_framework.exceptions import ValidationError
192 raise ValidationError(_("Quota exceeded for this group."))
194 serializer.save(group=group)
195 else:
196 serializer.save()