import json
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework import status
from rest_framework.viewsets import ModelViewSet

from .models import (SimulateDeviceCommunicationParameter, DeviceCommunicationParameter)
from protocol_version_manage.models import (AllProtocolVersion, AllProtocolDefinAndVersion,
                                            AllDevCmdDefineAndVersion)
from device_data_op.models import TableXproAllDevinfo
from .serializers import (SimulateDeviceCommunicationParameterSerializer,
                          DeviceCommunicationParameterSerializer)


@api_view(['GET'])
def get_protocol_names(request):
    """
    获取所有协议，并返回
    """
    protocol_names = AllProtocolVersion.objects.all()
    res_data = ([{'value': protocol_name.protocol_name, 'label': protocol_name.protocol_name} 
                 for protocol_name in protocol_names])

    return Response(data=res_data, status=status.HTTP_200_OK)


@api_view(['GET'])
def get_protocol_field_names(request, protocol_name):
    """
    获取指定协议下的指令集
    """
    # 获取该协议下的指令集合
    cmds = AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name, cmd_type='RX').all()
    cmd_set = set([cmd.cmd_name for cmd in cmds])

    # 获取指令集合下的字段
    fields = AllDevCmdDefineAndVersion.objects.filter(cmd_name__in=cmd_set, cmd_type='RX').all()
    field_name_set = set([field.fieldname for field in fields])

    res_data = ([field_name for field_name in field_name_set])
    return Response(data=res_data, status=status.HTTP_200_OK)


@api_view(['GET'])
def get_checked_field_names(request, type, protocol_name):
    """
    获取指定指令下选择监控的字段
    """
    # 获取该协议下的指令集合
    if type == '1':
        manager = DeviceCommunicationParameter.objects
    else:
        manager = SimulateDeviceCommunicationParameter.objects

    cmds = manager.filter(protocol_name=protocol_name).all()
    cmd_set = set()
    for cmd in cmds:
        performance_fields = json.loads(cmd.performance_fields)
        cmd_set.update(performance_fields)
    return Response(data=cmd_set, status=status.HTTP_200_OK)


@api_view(['POST'])
def set_communication_to_devinfo_table(request):
    """
    将设备通信参数或模拟设备通信参数设置到设备信息表
    """
    type = request.data.get('type')
    if type is None or type == '':
        return Response(status=status.HTTP_400_BAD_REQUEST)
    
    if type == 'simulate_communicate':
        communications = SimulateDeviceCommunicationParameter.objects.all()
    else:
        communications = DeviceCommunicationParameter.objects.all()
    
    # 清空设备信息表
    # TODO：实际开始用的时候，需要将这个注释打开
    # TableXproAllDevinfo.objects.all().delete()

    TableXproAllDevinfo.objects.bulk_create(
        [TableXproAllDevinfo(sta_id=communication.station_id,
                             dev_id=communication.device_id,
                             dev_name=communication.device_name,
                             dev_name_chn=communication.device_name_chn,
                             protocol_name=communication.protocol_name, 
                             cmd_excel_path="null",
                             comunitate_mode=communication.communicate_mode,
                             tcp_ip=communication.tcp_ip,
                             tcp_port=communication.tcp_port,
                             udp_ip_src=communication.udp_ip_src,
                             udp_port_src=communication.udp_port_src,
                             udp_ip_dst=communication.udp_ip_dst,
                             udp_port_dst=communication.udp_port_dst,
                             udpmc_ip="",
                             udpmc_ip_tx="",
                             udpmc_port_tx=0,
                             udpmc_ip_rx="",
                             udpmc_port_rx=0,
                             remarks="")
         for communication in communications])

    return Response(status=status.HTTP_200_OK)


class DeviceCommunicationParameterViewSet(ModelViewSet):
    queryset = DeviceCommunicationParameter.objects.all()
    serializer_class = DeviceCommunicationParameterSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = sorted(serializer.data, key=lambda item: (item['station_id'], item['device_id']))
        return Response(data)

    def perform_create(self, serializer):
        super().perform_create(serializer)
        
        device_infos = self.queryset.filter(station_id=serializer.instance.station_id).all()
        for i in range(len(device_infos)):
            device_infos[i].device_id = i + 1
            device_infos[i].save()

            if serializer.instance.id == device_infos[i].id:
                serializer.instance.device_id = device_infos[i].device_id

    def perform_destroy(self, instance):
        """
        删除某个记录之后，更新 device_id
        """
        super().perform_destroy(instance)
        communications = self.get_queryset().filter(station_id=instance.station_id).all()

        for i in range(len(communications)):
            communications[i].device_id = i + 1
            communications[i].save()


class SimulateDeviceCommunicationParameterViewSet(ModelViewSet):
    queryset = SimulateDeviceCommunicationParameter.objects.all()
    serializer_class = SimulateDeviceCommunicationParameterSerializer

    def list(self, request):
        serializer = self.get_serializer(self.get_queryset(), many=True)
        data = sorted(serializer.data, key=lambda item: (item['station_id'], item['device_id']))
        return Response(data)
    
    def perform_create(self, serializer):
        super().perform_create(serializer)
        
        device_infos = self.queryset.filter(station_id=serializer.instance.station_id).all()
        for i in range(len(device_infos)):
            device_infos[i].device_id = i + 1
            device_infos[i].save()

            if serializer.instance.id == device_infos[i].id:
                serializer.instance.device_id = device_infos[i].device_id
    
    def perform_destroy(self, instance):
        """
        删除某个记录之后，更新 device_id
        """
        super().perform_destroy(instance)
        communications = self.get_queryset().filter(station_id=instance.station_id).all()

        for i in range(len(communications)):
            communications[i].device_id = i + 1
            communications[i].save()
