import json

from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from django.shortcuts import render
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework.decorators import api_view
from rest_framework import status
from protocol_version_manage.models import (CurrentDevVersion, AllDevCmdDefineAndVersion,
                                            AllProtocolDefinAndVersion, AllProtocolVersion)
from .models import (TableAllDevCmdDefine, TableDevCmdNamePoll, 
                     TableSoftLimitAngle, TableXproAllDevinfo)
from .serializers import (TableAllDevCmdDefineSerializer, TableDevCmdNamePollSerializer, 
                          TableSoftLimitAngleSerializer, TableXproAllDevinfoSerializer)
from .utils import tree_data


class TableAllDevCmdDefineView_1(ModelViewSet):
    queryset = TableAllDevCmdDefine.objects.all()
    serializer_class = TableAllDevCmdDefineSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = tree_data(serializer.data, 'cmd_name')
        return Response(data)

    def perform_create(self, serializer):
        """
        新增指令字段，给 TableAllDevCmdDefine 表创建记录时，同时给 AllDevCmdDefineAndVersion 表创建
        """
        # 给表创建值
        super().perform_create(serializer)

        current_obj = CurrentDevVersion.objects.filter(protocol_name=serializer.validated_data['protocol_name']).first()
        assert current_obj is not None, "当前协议不存在"

        current_version = current_obj.version
        serializer.validated_data['version'] = current_version
        serializer.validated_data.pop('protocol_name')
        AllDevCmdDefineAndVersion.objects.create(**serializer.validated_data)
    
    def perform_destroy(self, instance):
        """
        删除指令字段，给 TableAllDevCmdDefine 表删除记录时，同时给 AllDevCmdDefineAndVersion 表删除
        这里由于不能直接获取到 protocol_name，所以需要先通过 TableDevCmdNamePoll 获取到 protocol_name，然后再删除
        """
        cmd_name = instance.cmd_name

        cmd_obj = TableDevCmdNamePoll.objects.filter(cmd_name=cmd_name).first()
        assert cmd_obj is not None, "当前指令不存在"

        current_protocol_name = cmd_obj.protocol_name
        current_obj = CurrentDevVersion.objects.filter(protocol_name=current_protocol_name).first()
        assert current_obj is not None, "当前协议不存在"

        current_version = current_obj.version
        AllDevCmdDefineAndVersion.objects.filter(cmd_name=instance.cmd_name, 
                                                 fieldname=instance.fieldname, 
                                                 version=current_version).delete()
        # 更新 fieldindex
        super().perform_destroy(instance)
        fields = self.get_queryset().filter(cmd_name=cmd_name).all()
        for i in range(len(fields)):
            fields[i].fieldindex = i + 1
            fields[i].save()

        fields_ = AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd_name, version=current_version).all()
        for i in range(len(fields_)):
            fields_[i].fieldindex = i + 1
            fields_[i].save()

    
    def perform_update(self, serializer):
        """
        更新指令字段，给 TableAllDevCmdDefine 表更新记录时，同时给 AllDevCmdDefineAndVersion 表更新
        """
        super().perform_update(serializer)

        current_obj = CurrentDevVersion.objects.filter(protocol_name=serializer.validated_data['protocol_name']).first()
        assert current_obj is not None, "当前协议不存在"

        current_version = current_obj.version
        serializer.validated_data['version'] = current_version
        serializer.validated_data.pop('protocol_name')

        AllDevCmdDefineAndVersion.objects.filter(cmd_name=serializer.validated_data['cmd_name'],
                                                 fieldname=serializer.validated_data['fieldname'],
                                                 version=current_version).update(**serializer.validated_data)


class TableDevCmdNamePollView_1(ModelViewSet):
    queryset = TableDevCmdNamePoll.objects.all()
    serializer_class = TableDevCmdNamePollSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = tree_data(serializer.data, 'protocol_name')
        return Response(data)

    def perform_create(self, serializer):
        """
        新增指令，给 TableDevCmdNamePollView 表创建记录时，同时给 AllProtocolDefineAndVersion 表创建
        如果当前协议不存在，则在 CurentDevVersion 表创建记录，同时在 AllProtocolVersion 表创建记录
        """
        super().perform_create(serializer)

        current_obj = CurrentDevVersion.objects.filter(protocol_name=serializer.validated_data['protocol_name']).first()
        if current_obj is None:
            current_obj = CurrentDevVersion.objects.create(protocol_name=serializer.validated_data['protocol_name'], 
                                                           version="init")
            current_version = "init"
            AllProtocolVersion.objects.create(protocol_name=serializer.validated_data['protocol_name'],
                                              version_paths=json.dumps([{"version": "init"}]))
        else:
            current_version = current_obj.version
        serializer.validated_data['version'] = current_version
        AllProtocolDefinAndVersion.objects.create(**serializer.validated_data)
    
    def perform_destroy(self, instance):
        """
        删除指令，给 TableDevCmdNamePollView 表删除记录时，同时给 AllProtocolDefineAndVersion 表删除
        同时要删除指令下的字段
        """
        protocol_name = instance.protocol_name
        cmd_name = instance.cmd_name

        # 删除指令下的字段
        TableAllDevCmdDefine.objects.filter(cmd_name=cmd_name).delete()

        current_obj = CurrentDevVersion.objects.filter(protocol_name=protocol_name).first()
        assert current_obj is not None, "当前协议不存在"
        AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name,
                                                  cmd_name=cmd_name,
                                                  version=current_obj.version).delete()
        AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd_name, version=current_obj.version).delete()

        super().perform_destroy(instance)
        # 当当前协议指令删完了之后，版本表里面就不存数据了
        if len(TableDevCmdNamePoll.objects.filter(protocol_name=protocol_name).all()) == 0:
            CurrentDevVersion.objects.filter(protocol_name=protocol_name).delete()
            AllProtocolVersion.objects.filter(protocol_name=protocol_name).delete()

    def perform_update(self, serializer):
        """
        更新指令，给 TableDevCmdNamePollView 表更新记录时，同时给 AllProtocolDefineAndVersion 表更新
        """
        current_obj = CurrentDevVersion.objects.filter(protocol_name=serializer.validated_data['protocol_name']).first()
        assert current_obj is not None, "当前协议不存在"
        serializer.validated_data['version'] = current_obj.version
        AllProtocolDefinAndVersion.objects.filter(protocol_name=serializer.validated_data['protocol_name'],
                                                  cmd_name=serializer.validated_data['cmd_name'],
                                                  version=current_obj.version).update(**serializer.validated_data)
        super().perform_update(serializer)


class TableAllDevCmdDefineView(ModelViewSet):
    queryset = TableAllDevCmdDefine.objects.all()
    serializer_class = TableAllDevCmdDefineSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = tree_data(serializer.data, 'cmd_name')
        return Response(data)
    
    def perform_destroy(self, instance):
        """
        删除某个字段，需要将字段的 index 更新
        """
        
        # 获取改字段的 cmd_name
        cmd_name = instance.cmd_name
        super().perform_destroy(instance)
        fields = self.get_queryset().filter(cmd_name=cmd_name).all()

        # 更新字段的 index
        for i in range(len(fields)):
            print(fields[i])
            fields[i].fieldindex = i + 1
            fields[i].save()


class TableDevCmdNamePollView(ModelViewSet):
    queryset = TableDevCmdNamePoll.objects.all()
    serializer_class = TableDevCmdNamePollSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = tree_data(serializer.data, 'protocol_name')
        return Response(data)
    

class TableSoftLimitAngleView(ModelViewSet):
    queryset = TableSoftLimitAngle.objects.all()
    serializer_class = TableSoftLimitAngleSerializer


class TableXproAllDevinfoView(ModelViewSet):
    queryset = TableXproAllDevinfo.objects.all()
    serializer_class = TableXproAllDevinfoSerializer


@swagger_auto_schema(method='post', request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'cmds': openapi.Schema(type=openapi.TYPE_OBJECT, properties={
            'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
            'cmd_name': openapi.Schema(type=openapi.TYPE_STRING),
            'cmd_type': openapi.Schema(type=openapi.TYPE_STRING),
            'encode': openapi.Schema(type=openapi.TYPE_STRING),
            'timing_cmd_cycle_period': openapi.Schema(type=openapi.TYPE_INTEGER),
            'cmd_explain': openapi.Schema(type=openapi.TYPE_STRING),
            'fields': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_OBJECT, properties={
                "cmd_name": openapi.Schema(type=openapi.TYPE_STRING),
                "cmd_type": openapi.Schema(type=openapi.TYPE_STRING),
                "fieldindex": openapi.Schema(type=openapi.TYPE_INTEGER),
                "fieldname": openapi.Schema(type=openapi.TYPE_STRING),
                "fieldsize": openapi.Schema(type=openapi.TYPE_INTEGER),
                "value": openapi.Schema(type=openapi.TYPE_STRING),
                "minvalue": openapi.Schema(type=openapi.TYPE_STRING),
                "maxvalue": openapi.Schema(type=openapi.TYPE_STRING),
                "datatype": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operation_in": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operation_in_num": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operation_out": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operation_out_num": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operabo_in": openapi.Schema(type=openapi.TYPE_INTEGER),
                "operabo_out": openapi.Schema(type=openapi.TYPE_INTEGER),
                "lua_script_in": openapi.Schema(type=openapi.TYPE_STRING),
                "lua_script_out": openapi.Schema(type=openapi.TYPE_STRING)
            }))
        })
    }
))
@api_view(['POST'])
def test(request):
    protocol_cmd = TableDevCmdNamePollView()
    cmd_fields = TableAllDevCmdDefineView()
    # print(request.data)
    protocol_cmd.request = request
    protocol_cmd.format_kwarg = None  # 设置 format_kwarg 属性
    
    cmd_fields.request = request
    cmd_fields.format_kwarg = None  # 设置 format_kwarg 属性

    cmds = request.data.get('cmds')
    for cmd in cmds.values():
        # 将指令的字段属性从字典中弹出
        fields = cmd.pop('fields')

        # 创建协议指令
        protocol_cmd_serializer = protocol_cmd.get_serializer(data=cmd)
        protocol_cmd_serializer.is_valid(raise_exception=True)
        cmd_explain = protocol_cmd_serializer.validated_data.get('cmd_explain')
        try:
            json.loads(cmd_explain)
        except json.JSONDecodeError:
            cmd_explain_dict = {
                'explain': cmd_explain,
                'version': "20230101"
            }
            cmd_explain = json.dumps(cmd_explain_dict)
        protocol_cmd_serializer.validated_data['cmd_explain'] = cmd_explain
        protocol_cmd.perform_create(protocol_cmd_serializer)
        # 创建指令
        for field in fields:
            cmd_fields_serializer = cmd_fields.get_serializer(data=field)
            cmd_fields_serializer.is_valid(raise_exception=True)
            cmd_fields.perform_create(cmd_fields_serializer)
    return Response(status=status.HTTP_201_CREATED)


@api_view(['DELETE'])
def remove_protocol(request, protocol_name):
    """
    删除协议
    """
    # 删除协议
    cmds = TableDevCmdNamePoll.objects.filter(protocol_name=protocol_name).all()
    for cmd in cmds:
        cmd_name = cmd.cmd_name
        TableAllDevCmdDefine.objects.filter(cmd_name=cmd_name).delete()
        cmd.delete()
    
    # 删除协议版本
    CurrentDevVersion.objects.filter(protocol_name=protocol_name).delete()
    AllProtocolVersion.objects.filter(protocol_name=protocol_name).delete()

    # 删除管理的协议
    cmds = AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name).all()
    for cmd in cmds:
        cmd_name = cmd.cmd_name
        AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd_name).delete()
        cmd.delete()

    return Response(status=status.HTTP_204_NO_CONTENT)


@swagger_auto_schema(method='post', request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
        'version_name': openapi.Schema(type=openapi.TYPE_STRING),
        'is_extend': openapi.Schema(type=openapi.TYPE_BOOLEAN),
        'extend_version_name': openapi.Schema(type=openapi.TYPE_STRING),
    }
))
@api_view(['POST'])
def protocol_add_version(request):
    """
    协议添加版本
    """
    protocol_name = request.data.get('protocol_name')
    version_name = request.data.get('version_name')
    is_extend = request.data.get('is_extend')
    extend_version_name = request.data.get('extend_version_name')

    if version_name is None or protocol_name is None:
        return Response(status=status.HTTP_400_BAD_REQUEST)
    
    if is_extend:
        # 如果是继承版本，则需要判断继承的版本是否存在
        if extend_version_name is None:
            return Response(status=status.HTTP_400_BAD_REQUEST)
        
        # current_version = CurrentDevVersion.objects.filter(protocol_name=protocol_name).first().version
        # if current_version != extend_version_name:
        #     # 继承的版本和当前的版本不一样，需要将当前版本的指令和字段删除
        #     cmds = TableDevCmdNamePoll.objects.filter(protocol_name=protocol_name).all()
        #     for cmd in cmds:
        #         TableAllDevCmdDefine.objects.filter(cmd_name=cmd.cmd_name).delete()
        #         cmd.delete()
        
        #     # 获取继承版本的字段
        #     cmds = AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name,
        #                                                      version=extend_version_name).all()
        #     for cmd in cmds:
        #         TableDevCmdNamePoll(protocol_name=_cmd.protocol_name,
        #                             cmd_name=_cmd.cmd_name,
        #                             cmd_type=_cmd.cmd_type,
        #                             encode=_cmd.encode,
        #                             timing_cmd_cycle_period=_cmd.timing_cmd_cycle_period,
        #                             cmd_explain=_cmd.cmd_explain).save()
                

        #         fields = AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd.cmd_name,
        #                                                           version=extend_version_name).all()
        #         # 新增字段
        #         for field in fields:
        #             TableAllDevCmdDefine(cmd_name=_field.cmd_name,
        #                                  cmd_type=_field.cmd_type,
        #                                  fieldindex=_field.fieldindex,
        #                                  fieldname=_field.fieldname,
        #                                  fieldsize=_field.fieldsize,
        #                                  value=_field.value,
        #                                  minvalue=_field.minvalue,
        #                                  maxvalue=_field.maxvalue,
        #                                  datatype=_field.datatype,
        #                                  operation_in=_field.operation_in,
        #                                  operation_in_num=_field.operation_in_num,
        #                                  operation_out=_field.operation_out,
        #                                  operation_out_num=_field.operation_out_num,
        #                                  operabo_in=_field.operabo_in,
        #                                  operabo_out=_field.operabo_out,
        #                                  lua_script_in=_field.lua_script_in,
        #                                  lua_script_out=_field.lua_script_out).save()
        
        # 获取继承版本的字段
        cmds = AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name,
                                                         version=extend_version_name).all()
        for cmd in cmds:
            # 给当前版本的指令和字段创建记录
            _cmd = AllProtocolDefinAndVersion(protocol_name=protocol_name,
                                              cmd_name=cmd.cmd_name,
                                              cmd_type=cmd.cmd_type,
                                              encode=cmd.encode,
                                              timing_cmd_cycle_period=cmd.timing_cmd_cycle_period,
                                              cmd_explain=cmd.cmd_explain,
                                              version=version_name)
            _cmd.save()
            fields = AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd.cmd_name,
                                                              version=extend_version_name).all()
            # 新增字段
            for field in fields:
                _field = AllDevCmdDefineAndVersion(cmd_name=field.cmd_name,
                                                   cmd_type=field.cmd_type,
                                                   fieldindex=field.fieldindex,
                                                   fieldname=field.fieldname,
                                                   fieldsize=field.fieldsize,
                                                   value=field.value,
                                                   minvalue=field.minvalue,
                                                   maxvalue=field.maxvalue,
                                                   datatype=field.datatype,
                                                   operation_in=field.operation_in,
                                                   operation_in_num=field.operation_in_num,
                                                   operation_out=field.operation_out,
                                                   operation_out_num=field.operation_out_num,
                                                   operabo_in=field.operabo_in,
                                                   operabo_out=field.operabo_out,
                                                   lua_script_in=field.lua_script_in,
                                                   lua_script_out=field.lua_script_out,
                                                   version=version_name)
                _field.save()
    else:
        # 如果不用继承版本，先直接删除两张表上所有的记录
        cmds = TableDevCmdNamePoll.objects.filter(protocol_name=protocol_name).all()
        for cmd in cmds:
            cmd_name = cmd.cmd_name
            TableAllDevCmdDefine.objects.filter(cmd_name=cmd_name).delete()
            cmd.delete()

    # 添加版本
    protoocl_version = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
    version = json.loads(protoocl_version.version_paths) 
    version.append({"version": version_name})
    protoocl_version.version_paths = json.dumps(version)
    protoocl_version.save()

    return Response(status=status.HTTP_200_OK)
