165 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			165 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# -*- coding: utf-8 -*-
 | 
						||
 | 
						||
"""
 | 
						||
@author: 猿小天
 | 
						||
@contact: QQ:1638245306
 | 
						||
@Created on: 2021/6/1 001 22:57
 | 
						||
@Remark: 自定义视图集
 | 
						||
"""
 | 
						||
import copy
 | 
						||
 | 
						||
from django.db import transaction
 | 
						||
from django_filters import DateTimeFromToRangeFilter
 | 
						||
from django_filters.rest_framework import FilterSet
 | 
						||
from drf_yasg import openapi
 | 
						||
from drf_yasg.utils import swagger_auto_schema
 | 
						||
from rest_framework.decorators import action
 | 
						||
from rest_framework.viewsets import ModelViewSet
 | 
						||
 | 
						||
from dvadmin.utils.filters import CoreModelFilterBankend, DataLevelPermissionMargeFilter
 | 
						||
from dvadmin.utils.import_export_mixin import ExportSerializerMixin, ImportSerializerMixin
 | 
						||
from dvadmin.utils.json_response import SuccessResponse, ErrorResponse, DetailResponse
 | 
						||
from dvadmin.utils.permission import CustomPermission
 | 
						||
from dvadmin.utils.models import get_custom_app_models, CoreModel
 | 
						||
from dvadmin.system.models import FieldPermission, MenuField
 | 
						||
from django_restql.mixins import QueryArgumentsMixin
 | 
						||
 | 
						||
 | 
						||
class CustomModelViewSet(ModelViewSet, ImportSerializerMixin, ExportSerializerMixin, QueryArgumentsMixin):
 | 
						||
    """
 | 
						||
    自定义的ModelViewSet:
 | 
						||
    统一标准的返回格式;新增,查询,修改可使用不同序列化器
 | 
						||
    (1)ORM性能优化, 尽可能使用values_queryset形式
 | 
						||
    (2)xxx_serializer_class 某个方法下使用的序列化器(xxx=create|update|list|retrieve|destroy)
 | 
						||
    (3)filter_fields = '__all__' 默认支持全部model中的字段查询(除json字段外)
 | 
						||
    (4)import_field_dict={} 导入时的字段字典 {model值: model的label}
 | 
						||
    (5)export_field_label = [] 导出时的字段
 | 
						||
    """
 | 
						||
    values_queryset = None
 | 
						||
    ordering_fields = '__all__'
 | 
						||
    create_serializer_class = None
 | 
						||
    update_serializer_class = None
 | 
						||
    filter_fields = '__all__'
 | 
						||
    search_fields = ()
 | 
						||
    extra_filter_class = [CoreModelFilterBankend,DataLevelPermissionMargeFilter]
 | 
						||
    permission_classes = [CustomPermission]
 | 
						||
    import_field_dict = {}
 | 
						||
    export_field_label = {}
 | 
						||
 | 
						||
    def filter_queryset(self, queryset):
 | 
						||
        for backend in set(set(self.filter_backends) | set(self.extra_filter_class or [])):
 | 
						||
            queryset = backend().filter_queryset(self.request, queryset, self)
 | 
						||
        return queryset
 | 
						||
 | 
						||
    def get_queryset(self):
 | 
						||
        if getattr(self, 'values_queryset', None):
 | 
						||
            return self.values_queryset
 | 
						||
        return super().get_queryset()
 | 
						||
 | 
						||
    def get_serializer_class(self):
 | 
						||
        action_serializer_name = f"{self.action}_serializer_class"
 | 
						||
        action_serializer_class = getattr(self, action_serializer_name, None)
 | 
						||
        if action_serializer_class:
 | 
						||
            return action_serializer_class
 | 
						||
        return super().get_serializer_class()
 | 
						||
 | 
						||
    # 通过many=True直接改造原有的API,使其可以批量创建
 | 
						||
    def get_serializer(self, *args, **kwargs):
 | 
						||
        serializer_class = self.get_serializer_class()
 | 
						||
        kwargs.setdefault('context', self.get_serializer_context())
 | 
						||
        # 全部以可见字段为准
 | 
						||
        can_see = self.get_menu_field(serializer_class)
 | 
						||
        # 排除掉序列化器级的字段(排除字段权限中未授权的字段)
 | 
						||
        # if not self.request.user.is_superuser:
 | 
						||
        #     exclude_set = set(serializer_class._declared_fields.keys()) - set(can_see)
 | 
						||
        #     for field in exclude_set:
 | 
						||
        #         serializer_class._declared_fields.pop(field)
 | 
						||
        #     meta = copy.deepcopy(serializer_class.Meta)
 | 
						||
        #     meta.fields = list(can_see)
 | 
						||
        #     serializer_class.Meta = meta
 | 
						||
        # 在分页器中使用
 | 
						||
        self.request.permission_fields = can_see
 | 
						||
        if isinstance(self.request.data, list):
 | 
						||
            with transaction.atomic():
 | 
						||
                return serializer_class(many=True, *args, **kwargs)
 | 
						||
        else:
 | 
						||
            return serializer_class(*args, **kwargs)
 | 
						||
 | 
						||
    def get_menu_field(self, serializer_class):
 | 
						||
        """获取字段权限"""
 | 
						||
 | 
						||
        if not any(model['object'] is serializer_class.Meta.model for model in get_custom_app_models()):
 | 
						||
            return []
 | 
						||
 | 
						||
        # 匿名用户没有角色
 | 
						||
        ret = FieldPermission.objects.filter(field__model=serializer_class.Meta.model.__name__)
 | 
						||
        if hasattr(self.request.user, 'role'):
 | 
						||
            roles = self.request.user.role.values_list('id', flat=True)
 | 
						||
            ret = ret.filter(is_query=True, role__in=roles)
 | 
						||
 | 
						||
        return ret.values_list('field__field_name', flat=True)
 | 
						||
 | 
						||
    def create(self, request, *args, **kwargs):
 | 
						||
        serializer = self.get_serializer(data=request.data, request=request)
 | 
						||
        serializer.is_valid(raise_exception=True)
 | 
						||
        self.perform_create(serializer)
 | 
						||
        return DetailResponse(data=serializer.data, msg="新增成功")
 | 
						||
 | 
						||
    def list(self, request, *args, **kwargs):
 | 
						||
        queryset = self.filter_queryset(self.get_queryset())
 | 
						||
        page = self.paginate_queryset(queryset)
 | 
						||
        if page is not None:
 | 
						||
            serializer = self.get_serializer(page, many=True, request=request)
 | 
						||
            return self.get_paginated_response(serializer.data)
 | 
						||
        serializer = self.get_serializer(queryset, many=True, request=request)
 | 
						||
        return SuccessResponse(data=serializer.data, msg="获取成功")
 | 
						||
 | 
						||
    def retrieve(self, request, *args, **kwargs):
 | 
						||
        instance = self.get_object()
 | 
						||
        serializer = self.get_serializer(instance)
 | 
						||
        return DetailResponse(data=serializer.data, msg="获取成功")
 | 
						||
 | 
						||
    def update(self, request, *args, **kwargs):
 | 
						||
        partial = kwargs.pop('partial', False)
 | 
						||
        instance = self.get_object()
 | 
						||
        serializer = self.get_serializer(instance, data=request.data, request=request, partial=partial)
 | 
						||
        serializer.is_valid(raise_exception=True)
 | 
						||
        self.perform_update(serializer)
 | 
						||
 | 
						||
        if getattr(instance, '_prefetched_objects_cache', None):
 | 
						||
            # If 'prefetch_related' has been applied to a queryset, we need to
 | 
						||
            # forcibly invalidate the prefetch cache on the instance.
 | 
						||
            instance._prefetched_objects_cache = {}
 | 
						||
        return DetailResponse(data=serializer.data, msg="更新成功")
 | 
						||
 | 
						||
    def destroy(self, request, *args, **kwargs):
 | 
						||
        instance = self.get_object()
 | 
						||
        instance.delete()
 | 
						||
        return DetailResponse(data=[], msg="删除成功")
 | 
						||
 | 
						||
    keys = openapi.Schema(description='主键列表', type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING))
 | 
						||
    @swagger_auto_schema(request_body=openapi.Schema(
 | 
						||
        type=openapi.TYPE_OBJECT,
 | 
						||
        required=['keys'],
 | 
						||
        properties={'keys': keys}
 | 
						||
    ), operation_summary='批量删除')
 | 
						||
    @action(methods=['delete'], detail=False)
 | 
						||
    def multiple_delete(self, request, *args, **kwargs):
 | 
						||
        request_data = request.data
 | 
						||
        keys = request_data.get('keys', None)
 | 
						||
        if keys:
 | 
						||
            self.get_queryset().filter(id__in=keys).delete()
 | 
						||
            return SuccessResponse(data=[], msg="删除成功")
 | 
						||
        else:
 | 
						||
            return ErrorResponse(msg="未获取到keys字段")
 | 
						||
 | 
						||
    @action(methods=['post'], detail=False)
 | 
						||
    def get_by_ids(self, request):
 | 
						||
        """通过IDS列表获取数据"""
 | 
						||
        ids = request.data.get('ids', [])
 | 
						||
        if ids and ids != ['']:
 | 
						||
            queryset = self.get_queryset().filter(id__in=ids)
 | 
						||
            serializer = self.get_serializer(queryset, many=True)
 | 
						||
            return DetailResponse(data=serializer.data)
 | 
						||
        return DetailResponse(data=None)
 |