from __future__ import annotations

from contextlib import contextmanager

from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Engine
from sqlalchemy.orm import sessionmaker, scoped_session, Session
from feature.logging.easy_logging import logger
from sqlalchemy import and_

class DBEngine:
    db_params = {}
    engines = {}
    def __init__(self, **params ):
        """
        初始化数据库连接
        :param params: 数据库连接参数
        举例 params = {
        'db_name1': {
            'user': 'root',
            'passwd': 'passwd',
            'host': '127.0.0.1',
            'port': 3306,
            'db': 'db_name1',
        },
        'db_name2': {
            'user': 'root',
            'passwd': 'passwd',
            'host': '127.0.0.1',
            'port': 3306,
            'db': 'db_name2',
        }
        }
        """
        self.db_params=params
        for k, v in self.db_params.iteritems():
            mysql_url = ("mysql+pymysql://{user}:{passwd}@{host}:{port}/{db}"
                         "?charset=utf8".format(**v))
            self.engines[k] = create_engine(mysql_url,
                                            pool_size=10,
                                            max_overflow=-1,
                                            pool_recycle=1000,
                                            echo=False)


class Dao:
    def __init__(self, engine:Engine ):
        """
        根据engine获取session
        :param engine:
        """
        # 线程安全的 session
        DBSession = scoped_session(sessionmaker(bind=engine))
        # 创建 session 对象:
        self.session = DBSession()

    @contextmanager
    def session_scope(self,commit=0):
        """
        事务管理
        :return:
        """
        session = self.session
        try:
            yield session
            if commit==0:
                session.commit()
        except Exception as e:
            session.rollback()
            logger.exception(e)
        finally:
            session.close()

    def close(self):
        """
        暂时不确定是否有自动管理的连接池，所以使用完手动关闭连接
        :return:
        """
        self.session.close()

    def add(self, model_value):
        """
        单个添加
        :param model_value: tables 对象
        :return:
        """
        with self.session_scope() as session:
            session.add(model_value)
            session.flush()
            session.expunge(model_value)


    def add_all(self, model_list):
        """
        批量添加
        :param model_list: tables 对象数组
        :return:
        """
        with self.session_scope() as session:
            session.add_all(model_list)


    def update(self, model_type, model_value, filter_params):
        """
        单个更新
        :param model_type: tables 类
        :param model_value: tables 对象
        :param filter_params: sql语句where的判断条件
        :return:
        """
        p = model_value.to_dict()
        for k, v in list(p.items()):
            """
            update 时去除None字段和主键字段
            对象需要有 getPK 方法返回 []string
            """
            if k in model_value.get_pk():
                del p[k]
            elif v == None:
                del p[k]
            with self.session_scope() as session:
                session.query(model_type).filter(filter_params).update(p, synchronize_session=False)

    def delete(self, model_type, filter_params):
        """
        单个删除
        :param model_type: tables 类
        :param filter_params: sql语句where的判断条件
        :return:
        """
        with self.session_scope() as session:
            session.query(model_type).filter(filter_params).delete()

    def select(self, model_type, filter_params):
        """
        单表查询
        :param model_type: tables 类
        :param filter_params: sql语句where的判断条件
        :return:
        """
        with self.session_scope(1) as session:
            rows = session.query(model_type).filter(filter_params).all()
        return rows

    def multi_select(self, model_type, join_model_type, filter_params):
        """
        双表查询
        :param model_type: tables 类
        :param join_model_type: tables 类 join使用
        :param filter_params: sql语句where的判断条件
        :return:
        """
        with self.session_scope(1)  as session:
            rows = session.query(model_type, join_model_type).join(join_model_type).filter(filter_params).all()
        return rows

