Coverage for sm / mixins.py: 31%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 13:46 +0000

1from django.db.models import Q 

2from django.contrib import messages 

3from django.utils.translation import gettext as _ 

4from typing import Any, Optional 

5from django.db.models.query import QuerySet 

6from django.forms import ModelForm 

7from django.contrib.auth.models import Group 

8from django.db import transaction 

9 

10 

11def get_tenant_model_counts(group: Optional[Group]) -> int: 

12 """Helper function to count tenant items across all models for quota checking.""" 

13 if not group: 

14 return 0 

15 

16 from server.models import Model as Server 

17 from cluster.models import Model as Cluster 

18 from domain.models import Model as Domain 

19 from vendor.models import Model as Vendor 

20 from operatingsystem.models import Model as OS 

21 from status.models import Model as Status 

22 from location.models import Model as Location 

23 from patchtime.models import Model as Patchtime 

24 from servermodel.models import Model as ServerModel 

25 from clusterpackage.models import Model as ClusterPackage 

26 from clustersoftware.models import Model as ClusterSoftware 

27 from clusterpackagetype.models import Model as ClusterPackageType 

28 

29 return ( 

30 Server.objects.filter(group=group).count() 

31 + Cluster.objects.filter(group=group).count() 

32 + Domain.objects.filter(group=group).count() 

33 + Vendor.objects.filter(group=group).count() 

34 + OS.objects.filter(group=group).count() 

35 + Status.objects.filter(group=group).count() 

36 + Location.objects.filter(group=group).count() 

37 + Patchtime.objects.filter(group=group).count() 

38 + ServerModel.objects.filter(group=group).count() 

39 + ClusterPackage.objects.filter(group=group).count() 

40 + ClusterSoftware.objects.filter(group=group).count() 

41 + ClusterPackageType.objects.filter(group=group).count() 

42 ) 

43 

44 

45class MultiTenantMixin: 

46 """ 

47 Mixin to filter querysets by user groups and auto-assign group on save. 

48 Enforces item quotas per group. 

49 """ 

50 

51 def get_queryset(self) -> QuerySet: 

52 # Check basic view permission for the model 

53 model = getattr(self, "model", None) 

54 if model and not self.request.user.is_superuser: 

55 opts = model._meta 

56 codename = f"view_{opts.model_name.lower()}" 

57 if not self.request.user.has_perm(f"{opts.app_label}.{codename}"): 

58 from django.core.exceptions import PermissionDenied 

59 

60 raise PermissionDenied 

61 

62 queryset = super().get_queryset() # type: ignore 

63 if self.request.user.is_superuser: 

64 return queryset 

65 

66 selected_groups = self.request.session.get("selected_groups", []) 

67 user_groups = self.request.user.groups.all() 

68 

69 if selected_groups: 

70 group_ids = [int(g) for g in selected_groups if g.isdigit()] 

71 if group_ids: 

72 return queryset.filter( 

73 Q(group__id__in=group_ids) | Q(group__isnull=True) 

74 ) 

75 

76 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True)) 

77 

78 def get_context_data(self, **kwargs: Any) -> Any: 

79 context = super().get_context_data(**kwargs) # type: ignore 

80 model = getattr(self, "model", None) 

81 # Only add history to context for ListViews that have it configured 

82 if model and hasattr(model, "history") and hasattr(self, "object_list"): 

83 history_qs = model.history.all() 

84 

85 if not self.request.user.is_superuser: 

86 selected_groups = self.request.session.get("selected_groups", []) 

87 user_groups = self.request.user.groups.all() 

88 

89 if selected_groups: 

90 group_ids = [int(g) for g in selected_groups if g.isdigit()] 

91 history_qs = history_qs.filter( 

92 Q(group_id__in=group_ids) | Q(group_id__isnull=True) 

93 ) 

94 else: 

95 history_qs = history_qs.filter( 

96 Q(group_id__in=user_groups) | Q(group_id__isnull=True) 

97 ) 

98 

99 context["recent_history"] = history_qs.order_by("-history_date")[:10] 

100 return context 

101 

102 def check_quota(self, group: Optional[Group]) -> bool: 

103 if not group or not hasattr(group, "profile"): 

104 return True 

105 

106 profile = group.profile 

107 max_items = profile.max_items 

108 

109 # Count items across models with transaction to prevent race conditions 

110 with transaction.atomic(): 

111 # Lock the group profile to prevent concurrent modifications 

112 GroupProfile = group.profile.__class__ 

113 GroupProfile.objects.select_for_update().get(pk=profile.pk) 

114 

115 count = get_tenant_model_counts(group) 

116 

117 return count < max_items 

118 

119 def form_valid(self, form: ModelForm) -> Any: 

120 # Auto-assign first group if not set and not superuser 

121 if not form.instance.group and not self.request.user.is_superuser: 

122 user_groups = self.request.user.groups.all() 

123 if user_groups.exists(): 

124 form.instance.group = user_groups.first() 

125 

126 # Check quota for NEW items 

127 if not form.instance.pk: 

128 if not self.check_quota(form.instance.group): 

129 quota_limit = 0 

130 if form.instance.group and hasattr(form.instance.group, "profile"): 

131 quota_limit = form.instance.group.profile.max_items 

132 messages.error( 

133 self.request, 

134 _("Quota exceeded for this group (%d items).") % quota_limit, 

135 ) 

136 return self.form_invalid(form) 

137 

138 # Call super().form_valid(form) to let other mixins (like SuccessMessageMixin) 

139 # or the base view handle the actual saving and response. 

140 return super().form_valid(form) # type: ignore 

141 

142 

143class APIMultiTenantMixin: 

144 """ 

145 Mixin for DRF ViewSets to filter by user groups and auto-assign on create. 

146 """ 

147 

148 def get_queryset(self) -> QuerySet: 

149 # Check basic view permission for the model 

150 model = getattr(self, "model", None) 

151 if model and not self.request.user.is_superuser: 

152 opts = model._meta 

153 codename = f"view_{opts.model_name.lower()}" 

154 if not self.request.user.has_perm(f"{opts.app_label}.{codename}"): 

155 from django.core.exceptions import PermissionDenied 

156 

157 raise PermissionDenied 

158 

159 queryset = super().get_queryset() # type: ignore 

160 if self.request.user.is_superuser: 

161 return queryset 

162 

163 selected_groups = self.request.session.get("selected_groups", []) 

164 user_groups = self.request.user.groups.all() 

165 

166 if selected_groups: 

167 group_ids = [int(g) for g in selected_groups if g.isdigit()] 

168 if group_ids: 

169 return queryset.filter( 

170 Q(group__id__in=group_ids) | Q(group__isnull=True) 

171 ) 

172 

173 return queryset.filter(Q(group__in=user_groups) | Q(group__isnull=True)) 

174 

175 def perform_create(self, serializer: Any) -> None: 

176 user_groups = self.request.user.groups.all() 

177 group = user_groups.first() if user_groups.exists() else None 

178 

179 if not self.request.user.is_superuser: 

180 # Simple quota check for API with transaction 

181 if group and hasattr(group, "profile"): 

182 # Use transaction to ensure atomic count 

183 with transaction.atomic(): 

184 # Lock the group profile to prevent concurrent modifications 

185 GroupProfile = group.profile.__class__ 

186 GroupProfile.objects.select_for_update().get(pk=group.profile.pk) 

187 

188 count = get_tenant_model_counts(group) 

189 if count >= group.profile.max_items: 

190 from rest_framework.exceptions import ValidationError 

191 

192 raise ValidationError(_("Quota exceeded for this group.")) 

193 

194 serializer.save(group=group) 

195 else: 

196 serializer.save()