import os
import json
import urllib.parse

from django.conf import settings
from django.http import FileResponse
from rest_framework.decorators import api_view, parser_classes
from rest_framework.parsers import MultiPartParser
from rest_framework.response import Response
from rest_framework import status
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from rest_framework.mixins import ListModelMixin
from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi
from device_data_op.models import TableAllDevCmdDefine, TableDevCmdNamePoll
from .models import (AllDevCmdDefineAndVersion, AllProtocolDefinAndVersion, 
                     AllProtocolVersion, CurrentDevVersion)
from .serializers import (AllDevCmdDefineAndVersionSerializer, AllProtocolDefinAndVersionSerializer,
                         AllProtocolVersionSerializer, CurrentDevVersionSerializer)
from .services import (init_protocol_version_manage, update_device_protocol_and_cmds, 
                       add_protocol_version_manage, update_protocol_version_manage)


@swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
    }
))
@api_view(['POST'])
def init(request):
    """
    初始化协议版本信息（用的很少，大概）
    """
    protocol_name = request.data.get('protocol_name')    
    if protocol_name is None:
        return Response(status=status.HTTP_400_BAD_REQUEST)

    all_protocol_version = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
    try:
        if all_protocol_version is None:
            # 该协议为空，即没有协议版本信息
            init_protocol_version_manage(protocol_name)
            all_protocol_version = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
        # 该协议不为空，即有协议版本信息
        current_protocol_version = CurrentDevVersion.objects.filter(protocol_name=protocol_name).first()
        data = json.loads(all_protocol_version.version_paths)
        res_data = {
            'version_paths': data,
            'current_version': current_protocol_version.version
        }
    except:
        return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)

    return Response(data=res_data, status=status.HTTP_200_OK)


@swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'version': openapi.Schema(type=openapi.TYPE_STRING),
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
    }
))
@api_view(['POST'])
def change_protocol_version(request):
    version = request.data.get('version')
    protocol_name = request.data.get('protocol_name')
    if protocol_name is None or version is None:
        return Response(status=status.HTTP_400_BAD_REQUEST)
    
    return update_device_protocol_and_cmds(protocol_name, version)
    

@swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'version': openapi.Schema(type=openapi.TYPE_STRING),
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
        'cmds': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT, properties={
            'context': openapi.Schema(type=openapi.TYPE_STRING, default='内容为 AllProtocolDefinAndVersionSerializer 的内容'),
        }))
    }
))
@api_view(['POST'])
def add_protocol_version(request):
    version = request.data.get('version')
    protocol_name = request.data.get('protocol_name')
    cmds = request.data.get('cmds')
    if (protocol_name is None or 
        cmds is None or 
        version is None):
        return Response(status=status.HTTP_400_BAD_REQUEST)
    return add_protocol_version_manage(protocol_name, version, cmds)


@swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'version': openapi.Schema(type=openapi.TYPE_STRING),
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
        'cmds': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT, properties={
            'context': openapi.Schema(type=openapi.TYPE_STRING, default='内容为 AllProtocolDefinAndVersionSerializer 的内容'),
        }))
    }
))
@api_view(['POST'])
def update_protocol_version(request):
    version = request.data.get('version')
    protocol_name = request.data.get('protocol_name')
    cmds = request.data.get('cmds')
    if (protocol_name is None or
        version is None or
        cmds is None):
        return Response(status=status.HTTP_400_BAD_REQUEST)
    return update_protocol_version_manage(protocol_name, version, cmds)


class AllProtocolVersionViewSet(GenericViewSet, ListModelMixin):
    queryset = AllProtocolVersion.objects.all()
    serializer_class = AllProtocolVersionSerializer


# @swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
#     type=openapi.TYPE_OBJECT,
#     properties={
#         'file': openapi.Schema(type=openapi.TYPE_FILE),
#         'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
#         'version': openapi.Schema(type=openapi.TYPE_STRING)
#     }
# ))
@api_view(['POST'])
@parser_classes([MultiPartParser])
def raw_file_upload(request):
    file_obj = request.FILES.get('file')
    protocol_name = request.data.get('protocol_name')
    version = request.data.get('version')
    print(version, protocol_name)
    if protocol_name is None or version is None:
        return Response(status=status.HTTP_400_BAD_REQUEST)

    # 构建文件夹路径和文件路径
    folder_path = os.path.join(settings.BASE_DIR, 'protocol_raw_files', protocol_name, version)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    file_path = os.path.join(folder_path, file_obj.name)
    
    # 更新协议版本信息，将文件路径保存下来
    protocol_versions = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
    version_path_list = json.loads(protocol_versions.version_paths)
    for version_path in version_path_list:
        if version_path['version'] == version:
            if version_path.get('path', None) is not None:
                try:
                    os.remove(version_path['path'])
                except Exception as e:
                    print(e)
                    return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)
            version_path['path'] = file_path
            break

    # 将上传来的文件保存下来
    try:
        with open(file_path, 'wb') as f:
            for chunk in file_obj.chunks():
                f.write(chunk)
    except Exception as e:
        print(e) 
        return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)

    # 将路径保存下来
    protocol_versions.version_paths = json.dumps(version_path_list)
    protocol_versions.save()

    return Response(data={'path': file_path}, status=status.HTTP_200_OK)


# @swagger_auto_schema(methods=['GET'])
@api_view(['GET'])
@parser_classes([MultiPartParser])
def raw_file_download(request, protocol_name, version):
    try:
        protocol_versions = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
        for version_path in json.loads(protocol_versions.version_paths):
            if version_path['version'] == version:
                file_path = version_path['path']
                break
        file_name = urllib.parse.quote(os.path.basename(file_path))
        response = FileResponse(open(file_path, 'rb'))
    except Exception as e:
        print(e)
        return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)
    
    response['Content-Disposition'] = 'attachment; filename=' + file_name
    response['Content-Type'] = 'multipart/form-data'
    response['filename'] = file_name
    return response


class CurrentDevVersionViewSet(GenericViewSet, ListModelMixin):
    """
    获取所有协议当前的版本
    """
    """
    获取所有协议当前的版本
    """
    queryset = CurrentDevVersion.objects.all()
    serializer_class = CurrentDevVersionSerializer


@swagger_auto_schema(methods=['POST'], request_body=openapi.Schema(
    type=openapi.TYPE_OBJECT,
    properties={
        'protocol_name': openapi.Schema(type=openapi.TYPE_STRING),
        'version': openapi.Schema(type=openapi.TYPE_STRING),
    }
))
@api_view(['POST'])
def delete_protocol_vesrion(request):
    """
    删除特定版本的协议
    """
    protocol_name = request.data.get('protocol_name')
    version = request.data.get('version')
    print(version, protocol_name)
    if protocol_name is None or version is None:
        return Response(status=status.HTTP_400_BAD_REQUEST)
    
    all_cmds = AllProtocolDefinAndVersion.objects.filter(protocol_name=protocol_name).all()
    for cmd in all_cmds:
        cmd_vesrions = json.loads(cmd.version)
        if len(cmd_vesrions) == 1 and version in cmd_vesrions:
            # 这个命令是这个版本独有的
            cmd.delete()
        elif version in cmd_vesrions:
            cmd_vesrions.remove(version)
            cmd.version = json.dumps(cmd_vesrions)
            cmd.save()
        else:
            # 这个命令不在这个版本中
            continue

        all_fields = AllDevCmdDefineAndVersion.objects.filter(cmd_name=cmd.cmd_name).all()   
        for field in all_fields:
            field_versions = json.loads(field.version)
            if len(field_versions) == 1 and version in field_versions:
                # 这个字段是这个版本独有的
                field.delete()
            elif version in field_versions:
                field_versions.remove(version)
                field.version = json.dumps(field_versions)
                field.save()
            else:
                # 这个字段不在这个版本中
                continue

    protocol_vesrions = AllProtocolVersion.objects.filter(protocol_name=protocol_name).first()
    protocol_vesrions_l: list = json.loads(protocol_vesrions.version_paths)
    vesrions_l: list = [version_paths['version'] for version_paths in protocol_vesrions_l]
    if len(vesrions_l) == 1:
        # 这个版本是这个协议独有的
        protocol_vesrions.delete()
    elif version in vesrions_l:
        protocol_vesrions_l.pop(vesrions_l.index(version))
        protocol_vesrions.version_paths = json.dumps(protocol_vesrions_l)
        protocol_vesrions.save()

    return Response(status=status.HTTP_200_OK)


class AllDevCmdDefineAndVersionViewSet(ModelViewSet):
    queryset = AllDevCmdDefineAndVersion.objects.all()
    serializer_class = AllDevCmdDefineAndVersionSerializer

    def perform_create(self, serializer):
        serializer.validated_data['version'] = json.dumps([serializer.validated_data['version']])
        super().perform_create(serializer)
        serializer.validated_data.pop('version')
        TableAllDevCmdDefine.objects.create(**serializer.validated_data)

