import requests
import json
import functools
from fastapi import FastAPI, HTTPException, Body,Request
from fastapi.middleware.cors import CORSMiddleware
from requests_ntlm import HttpNtlmAuth
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
from typing import Dict
import mysql.connector
from datetime import datetime


# host = 'rm-2zehj3r1n60ttz9x5.mysql.rds.aliyuncs.com'  # 服务器访问使用
# database = 'jingbo_test'  # 服务器访问使用
host = 'rm-2zehj3r1n60ttz9x5ko.mysql.rds.aliyuncs.com'  # 北京访问使用
database = 'jingbo-test'  # 北京访问使用

# 配置数据库连接信息,根据实际情况修改
config = {
    "user": "jingbo",
    "password": "shihua@123",
    "host": host,
    "database": database
}

# GraphQL API eg:  url = 'http://10.88.14.86/AspenTech/AspenUnified/api/v1/model/Chambroad20241205/graphql'
graphql_host = 'http://10.88.14.86'
graphql_path = '/AspenTech/AspenUnified/api/v1/model/Chambroad20241205/graphql'
url = graphql_host + graphql_path
graphql_username = "bw19382"
graphql_password = "Fudong3!"
auth = HttpNtlmAuth(f'{graphql_username}', f'{graphql_password}')

# 请求头设置
headers = {'content-type': 'application/json;charset=UTF-8'}
def insert_api_log(request_time, request_url, request_method, request_params, response_content, response_time):
    '''
    请求日志表 v_tbl_aup_api_log 写入
    '''
    try:
        # 建立数据库连接
        global cnx
        if cnx is None:
            cnx = mysql.connector.connect(**config)
        cursor = cnx.cursor()
        # 先查询表中已有记录的数量,用于生成新记录的ID
        # count_query = "SELECT max(ID) FROM v_tbl_aup_api_log"
        # cursor.execute(count_query)
        # result = cursor.fetchone()
        # new_id = int(result[0]) + 1 if result else 1  # 如果表为空,ID设为1,否则数量加1
        # 插入数据的SQL语句
        # insert_query = """
        # INSERT INTO v_tbl_aup_api_log (ID,REQUEST_TIME, REQUEST_URL, REQUEST_METHOD, REQUEST_PARAMS, RESPONSE_CONTENT, RESPONSE_TIME)
        # VALUES (%s,%s, %s, %s, %s, %s, %s)
        # """
        insert_query = """
        INSERT INTO v_tbl_aup_api_log (REQUEST_TIME, REQUEST_URL, REQUEST_METHOD, REQUEST_PARAMS, RESPONSE_CONTENT, RESPONSE_TIME)
        VALUES (%s, %s, %s, %s, %s, %s)
        """
        # 准备要插入的数据,注意数据顺序要和SQL语句中的占位符顺序一致
        # data = (new_id,request_time, request_url, request_method, request_params, response_content, response_time)
        data = (request_time, request_url, request_method, request_params.encode('utf-8'), response_content.encode('utf-8'), response_time)
        # 执行插入操作
        cursor.execute(insert_query, data)
        # 提交事务,使插入生效
        cnx.commit()
    except mysql.connector.Error as err:
        print(f"Error: {err}")
    except UnboundLocalError as err:
        print(f"Error: {err}")
    finally:
        # 关闭游标和连接
        try:
            if cursor:
                cursor.close()
        except UnboundLocalError:
            pass


cnx = None


tags_metadata = [
    {
        "name": "get_cases",
        "description": "获取所有cases",
    },
    {
        "name": "generate_graphql_query",
        "description": "生成Graphql查询语句,并接收查询结果",
        
    },
]
app = FastAPI(
    title="AUP数据集成信息化接口",
    version="0.0.1",
    openapi_tags=tags_metadata,
    # openapi_url=""
)

# 允许跨域请求
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


class GraphqlQueryTemplates:
    '''
    从GraphqlQuery 查询语句中找到模板,并保留配置示例
    '''
    def __init__(self):
        # 参数模板
        self.purchase_inputs_template = """
        name:"11月度计划"
        inputs:[
          {
            name:"CWT"
            inputs:[
              {
                field:Cost
                periodName:"1"
                value: 3100
              }
            ]
          },
                {
            name:"CWT"
            inputs:[
              {
                field:Cost
                periodName:"1"
                value: 3100
              }
            ]
          },
        ]
        """
        self.case_execution_input_template = """
      name: "Job2"
      cases: [
        {name: "11月度计划"}
        {name: "二催开工"}
        {name: "一焦化停工"}    
        {name: "焦化加工油浆"}    
        {name: "焦化加工低硫原油"}    
        {name: "焦化加工低硫渣油"}    

      ]
        """
        self.wait_for_case_stack_job_template = "waitForCaseStackJob(name: \"Job2\")"
        self.case_qurey = '''
                          query
                          {
                          cases
                          {
                            items
                            {
                              name
                            }
                          }
                          }
                          '''
        # 参数示例
        self.purchase_inputs = '''
        [{
            "name": "11月度计划",
            "inputs": [
                {
                    "name": "CWT",
                    "inputs": [
                        {
                            "field": "Cost",
                            "periodName": "1",
                            "value": 3100
                        }
                    ]
                },
                {
                    "name": "CWT",
                    "inputs": [
                        {
                            "field": "Cost",
                            "periodName": "1",
                            "value": 3100
                        }
                    ]
                }
            ]
        }]
        '''
        self.case_execution_input = '''
        {
            "name": "Job2",
            "cases": [
                {
                    "name": "11月度计划"
                },
                {
                    "name": "二催开工"
                },
                {
                    "name": "一焦化停工"
                },
                {
                    "name": "焦化加工油浆"
                },
                {
                    "name": "焦化加工低硫原油"
                },
                {
                    "name": "焦化加工低硫渣油"
                }
            ]
        }
        '''
        self.wait_for_case_stack_job_name = 'Job2'

templates = GraphqlQueryTemplates()


def generate_custom_graphql_query(
        purchase_inputs=None,
        case_execution_input=None,
        wait_for_case_stack_job_name=None
):
    base_query = """
mutation{
    purchases{
      update(inputs:[{
        name:"11月度计划"
        inputs:[
          {
            name:"CWT"
            inputs:[
              {
                field:Cost
                periodName:"1"
                value: 3100
              }
            ]
          },
                {
            name:"CWT"
            inputs:[
              {
                field:Cost
                periodName:"1"
                value: 3100
              }
            ]
          },
        ]
      }])
      }
  caseExecution {
  submitCaseStack(
    input:{
      name: "Job2"
      cases: [
        {name: "11月度计划"}
        {name: "二催开工"}
        {name: "一焦化停工"}    
        {name: "焦化加工油浆"}    
        {name: "焦化加工低硫原油"}    
        {name: "焦化加工低硫渣油"}    

      ]
    }
  )
  {id}
  waitForCaseStackJob(name: "Job2")
    {
      started
      submitted
      finished
      executionStatus
      cases{
        items{
          name
          objectiveValue
        }
      }
    }
}
}
    """

    # 检查purchase_inputs参数类型,如果不为None,需为列表类型,且列表元素需为字典类型
    if purchase_inputs is not None:
        if not isinstance(purchase_inputs, list):
            raise TypeError("purchase_inputs should be a list or None.")
        for input_data in purchase_inputs:
            if not isinstance(input_data, dict):
                raise TypeError("Elements in purchase_inputs should be dictionaries.")

    # 检查case_execution_input参数类型,如果不为None,需为字典类型
    if case_execution_input is not None:
        if not isinstance(case_execution_input, dict):
            raise TypeError("case_execution_input should be a dictionary or None.")

    # 检查wait_for_case_stack_job_name参数类型,如果不为None,需为字符串类型
    if wait_for_case_stack_job_name is not None:
        if not isinstance(wait_for_case_stack_job_name, str):
            raise TypeError("wait_for_case_stack_job_name should be a string or None.")

    if purchase_inputs:
        new_purchase_inputs_str = "["
        for input_data in purchase_inputs:
            input_str = f"""
            name: "{input_data['name']}"
            inputs: [
            """
            inner_inputs = input_data.get('inputs', [])
            for inner_input in inner_inputs:
                inner_str = f"""
                    name: "{inner_input['name']}"
                    inputs: [
                    """
                input_str += inner_str
                for input in inner_input['inputs']:
                    inner_str = f"""
                        {{
                            field: "{input['field']}"
                            periodName: "{input['periodName']}"
                            value: {input['value']}
                        }}
                    """
                    input_str += inner_str
                input_str += " ]"
            input_str += " ]"
            new_purchase_inputs_str += input_str

        base_query = base_query.replace(templates.purchase_inputs_template, new_purchase_inputs_str)

    if case_execution_input:
        input_dict_str = f"""
      name: "{case_execution_input['name']}"
      cases: [
        """
        for case in case_execution_input['cases']:
            case_str = f"""
            {{name: "{case['name']}"}}
            """
            input_dict_str += case_str
        input_dict_str += " ]"

        base_query = base_query.replace(templates.case_execution_input_template, input_dict_str)

    if wait_for_case_stack_job_name:
        new_wait_for_case_stack_job_str = f"waitForCaseStackJob(name: \"{wait_for_case_stack_job_name}\")"
        base_query = base_query.replace(templates.wait_for_case_stack_job_template, new_wait_for_case_stack_job_str)

    return base_query

@app.post("/generate_graphql_query",tags=['generate_graphql_query'])
async def generate_graphql_query(
        request: Request,
        purchase_inputs: list[dict] = Body(templates.purchase_inputs, embed=True,example_query=json.loads(templates.purchase_inputs)),
        case_execution_input: dict = Body(templates.case_execution_input, embed=True,example_query=json.loads(templates.case_execution_input)),
        wait_for_case_stack_job_name: str = Body(templates.wait_for_case_stack_job_name, embed=True,example_query=templates.wait_for_case_stack_job_name),
):
    try:
        custom_query = generate_custom_graphql_query(purchase_inputs, case_execution_input, wait_for_case_stack_job_name)
        payload_json = {
        "query": custom_query
        }
        request_time = datetime.now()
        full_path = str(request.url.path)
        session = requests.Session()
        try:
            response = session.post(url=url, headers=headers, json=payload_json, auth=auth, verify=False, timeout=300)
        except requests.exceptions.ConnectTimeout as e:
          # 构造符合错误情况的响应数据字典
          error_response_data = {
              "errors": [{"message": "连接超时,请检查网络或稍后重试"}],
              "data": {},
              "status_code": 503  # 使用合适的状态码,如503表示服务暂时不可用,可根据具体错误类型调整
          }
          response = error_response_data
          raise HTTPException(status_code=503, detail=response)  # 抛出合适状态码的HTTPException
        except requests.exceptions.RequestException as e:
            # 捕获其他请求相关的异常,统一处理
            error_response_data = {
                "errors": [{"message": "请求出现其他错误,请联系管理员"}],
                "data": {},
                "status_code": 500
            }
            response = error_response_data
            raise HTTPException(status_code=500, detail=response)
        
        finally:    
            response_time = datetime.now()
            try:
                res = response.json()
            except (UnboundLocalError,AttributeError):
                res = response
            # 调用插入日志的函数,将相关信息记录到数据库中(假设insert_api_log函数已正确定义且可访问)
            insert_api_log(
                request_time,
                full_path,
                request.method,
                json.dumps(payload_json),
                json.dumps(res),
                response_time
            )

        if response.status_code!= 200:
            raise HTTPException(status_code=response.status_code, detail=response.text)
        print(response.json())
        return response.json()
    except TypeError as e:
        return {"error": str(e)}

@app.get("/get_cases",tags=['get_cases'])
async def get_cases_query_async(request: Request):
    payload_json2 = {
        "query": templates.case_qurey
    }
    full_path = str(request.url.path)
    request_time = datetime.now()
    session = requests.Session()
    try:
        response = session.post(url=url, headers=headers, json=payload_json2, auth=auth, verify=False)
        # 将JSON字符串解析为Python字典对象
        res = response.json()
        # # 提取name列表
        # name_list = [item["name"] for item in res["data"]["cases"]["items"]]
        # res['name_list'] = name_list
    except requests.exceptions.ConnectTimeout as e:
          # 构造符合错误情况的响应数据字典
          error_response_data = {
              "errors": [{"message": "连接超时,请检查网络或稍后重试"}],
              "data": {},
              "status_code": 503  # 使用合适的状态码,如503表示服务暂时不可用,可根据具体错误类型调整
          }
          res = error_response_data
          raise HTTPException(status_code=503, detail=res)  # 抛出合适状态码的HTTPException
    except requests.exceptions.RequestException as e:
        # 捕获其他请求相关的异常,统一处理
        error_response_data = {
            "errors": [{"message": "请求出现其他错误,请联系管理员"}],
            "data": {},
            "status_code": 500
        }
        
        res = error_response_data
        raise HTTPException(status_code=500, detail=res)
    finally:
        response_time = datetime.now()
        # 调用插入日志的函数,将相关信息记录到数据库中(假设insert_api_log函数已正确定义且可访问)
        insert_api_log(
            request_time,
            full_path,
            request.method,
            json.dumps(payload_json2),
            json.dumps(res),
            response_time
        )

    if response.status_code!= 200:
        raise HTTPException(status_code=response.status_code, detail=response.text)
    
    return res

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8003)


# query = """
# mutation{
#     purchases{
#       update(inputs:[{
#         name:"11月度计划"
#         inputs:[
#           {
#             name:"CWT"
#             inputs:[
#               {
#                 field:Cost
#                 periodName:"1"
#                 value: 3100
#               }
#             ]
#           },
#                 {
#             name:"CWT"
#             inputs:[
#               {
#                 field:Cost
#                 periodName:"1"
#                 value: 3100
#               }
#             ]
#           },
#         ]
#       }])
#       }
#   caseExecution {
#   submitCaseStack(
#     input:{
#       name: "Job2"
#       cases: [
#         {name: "11月度计划"}
#         {name: "二催开工"}
#         {name: "一焦化停工"}	
#         {name: "焦化加工油浆"}	
#         {name: "焦化加工低硫原油"}	
#         {name: "焦化加工低硫渣油"}	

#       ]
#     }
#   )
#   {id}
#   waitForCaseStackJob(name: "Job2")
#     {
#       started
#       submitted
#       finished
#       executionStatus
#       cases{
#         items{
#           name
#           objectiveValue
#         }
#       }
#     }
# }
# }
# """

# payload_json = {
#     "query": query,
#     "operationName": ""
# }



# query2 = '''
# query
# {
#  cases
#  {
#    items
#    {
#      name
#    }
#  }
# }
# '''

# payload_json2 = {
#     "query": query2,
#     "operationName": ""
# }




# @app.post("/graphql")
# async def post_execute_graphql_query(request: Request,
#                                      query:str = Body(query,example_query=query)
#                                      ):
#     payload_json = {
#         "query": query
#     }
#     request_time = datetime.now()
#     full_path = str(request.url.path)
#     session = requests.Session()
#     response = session.post(url=url, headers=headers, json=payload_json, auth=auth, verify=False)
#     response_time = datetime.now()

#     # 调用插入日志的函数,将相关信息记录到数据库中(假设insert_api_log函数已正确定义且可访问)
#     insert_api_log(
#         request_time,
#         full_path,
#         'POST',
#         json.dumps(payload_json),
#         json.dumps(response.json()),
#         response_time
#     )


#     if response.status_code!= 200:
#         raise HTTPException(status_code=response.status_code, detail=response.text)
#     return response.json()

# def insert_api_log(request_time, request_url, request_method, request_params, response_content, response_time):