Coverage for sm / mixins.py: 29%
85 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-24 12:43 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-24 12:43 +0000
1from django.db.models import Q
2from django.contrib import messages
3from django.utils.translation import gettext as _
4from typing import Any, Optional
5from django.db.models.query import QuerySet
6from django.forms import ModelForm
7from django.contrib.auth.models import Group
8from django.db import transaction
11def get_tenant_model_counts(group: Optional[Group]) -> int:
12 """Helper function to count tenant items across all models for quota checking."""
13 if not group:
14 return 0
16 from server.models import Model as Server
17 from cluster.models import Model as Cluster
18 from domain.models import Model as Domain
19 from vendor.models import Model as Vendor
20 from operatingsystem.models import Model as OS
21 from status.models import Model as Status
22 from location.models import Model as Location
23 from patchtime.models import Model as Patchtime
24 from servermodel.models import Model as ServerModel
25 from clusterpackage.models import Model as ClusterPackage
26 from clustersoftware.models import Model as ClusterSoftware
27 from clusterpackagetype.models import Model as ClusterPackageType
29 return (
30 Server.objects.filter(group=group).count()
31 + Cluster.objects.filter(group=group).count()
32 + Domain.objects.filter(group=group).count()
33 + Vendor.objects.filter(group=group).count()
34 + OS.objects.filter(group=group).count()
35 + Status.objects.filter(group=group).count()
36 + Location.objects.filter(group=group).count()
37 + Patchtime.objects.filter(group=group).count()
38 + ServerModel.objects.filter(group=group).count()
39 + ClusterPackage.objects.filter(group=group).count()
40 + ClusterSoftware.objects.filter(group=group).count()
41 + ClusterPackageType.objects.filter(group=group).count()
42 )
45class MultiTenantMixin:
46 """
47 Mixin to filter querysets by user groups and auto-assign group on save.
48 Enforces item quotas per group.
49 """
51 def get_queryset(self) -> QuerySet:
52 queryset = super().get_queryset() # type: ignore
53 if self.request.user.is_superuser:
54 return queryset
56 selected_groups = self.request.session.get("selected_groups", [])
57 user_groups = self.request.user.groups.all()
59 if selected_groups:
60 group_ids = [int(g) for g in selected_groups if g.isdigit()]
61 if group_ids:
62 return queryset.filter(
63 Q(group__id__in=group_ids) | Q(group__isnull=True)
64 )
66 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True))
68 def check_quota(self, group: Optional[Group]) -> bool:
69 if not group or not hasattr(group, "profile"):
70 return True
72 profile = group.profile
73 max_items = profile.max_items
75 # Count items across models with transaction to prevent race conditions
76 with transaction.atomic():
77 # Lock the group profile to prevent concurrent modifications
78 GroupProfile = group.profile.__class__
79 GroupProfile.objects.select_for_update().get(pk=profile.pk)
81 count = get_tenant_model_counts(group)
83 return count < max_items
85 def form_valid(self, form: ModelForm) -> Any:
86 # Auto-assign first group if not set and not superuser
87 if not form.instance.group and not self.request.user.is_superuser:
88 user_groups = self.request.user.groups.all()
89 if user_groups.exists():
90 form.instance.group = user_groups.first()
92 # Check quota for NEW items
93 if not form.instance.pk:
94 if not self.check_quota(form.instance.group):
95 quota_limit = 0
96 if form.instance.group and hasattr(form.instance.group, "profile"):
97 quota_limit = form.instance.group.profile.max_items
98 messages.error(
99 self.request,
100 _("Quota exceeded for this group (%d items).") % quota_limit,
101 )
102 return self.form_invalid(form)
104 # Call super().form_valid(form) to let other mixins (like SuccessMessageMixin)
105 # or the base view handle the actual saving and response.
106 return super().form_valid(form) # type: ignore
109class APIMultiTenantMixin:
110 """
111 Mixin for DRF ViewSets to filter by user groups and auto-assign on create.
112 """
114 def get_queryset(self) -> QuerySet:
115 queryset = super().get_queryset() # type: ignore
116 if self.request.user.is_superuser:
117 return queryset
119 selected_groups = self.request.session.get("selected_groups", [])
120 user_groups = self.request.user.groups.all()
122 if selected_groups:
123 group_ids = [int(g) for g in selected_groups if g.isdigit()]
124 if group_ids:
125 return queryset.filter(
126 Q(group__id__in=group_ids) | Q(group__isnull=True)
127 )
129 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True))
131 def perform_create(self, serializer: Any) -> None:
132 user_groups = self.request.user.groups.all()
133 group = user_groups.first() if user_groups.exists() else None
135 if not self.request.user.is_superuser:
136 # Simple quota check for API with transaction
137 if group and hasattr(group, "profile"):
138 # Use transaction to ensure atomic count
139 with transaction.atomic():
140 # Lock the group profile to prevent concurrent modifications
141 GroupProfile = group.profile.__class__
142 GroupProfile.objects.select_for_update().get(pk=group.profile.pk)
144 count = get_tenant_model_counts(group)
145 if count >= group.profile.max_items:
146 from rest_framework.exceptions import ValidationError
148 raise ValidationError(_("Quota exceeded for this group."))
150 serializer.save(group=group)
151 else:
152 serializer.save()