From 4c5a1d1c6ea572055a29c3b17386b9dd1a2ce7fe Mon Sep 17 00:00:00 2001 From: workpc Date: Mon, 7 Jul 2025 16:48:55 +0800 Subject: [PATCH] =?UTF-8?q?=E8=81=9A=E7=83=AF=E7=83=83=E5=91=A8=E5=BA=A6?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=85=AB=E5=A4=A7=E7=BB=B4=E5=BA=A6=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E4=B8=8A=E4=BC=A0=E5=B8=82=E5=9C=BA=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_juxiting_zhoudu.py | 241 +++++++++++++++++++++++++++++++++----- main_juxiting_zhoudu.py | 193 ++++++++++++++++++++++-------- main_yuanyou_zhoudu.py | 130 +++++++++++--------- 3 files changed, 435 insertions(+), 129 deletions(-) diff --git a/config_juxiting_zhoudu.py b/config_juxiting_zhoudu.py index f2e18e0..d298eb2 100644 --- a/config_juxiting_zhoudu.py +++ b/config_juxiting_zhoudu.py @@ -1,3 +1,4 @@ +from decimal import Decimal import logging import os import logging.handlers @@ -159,16 +160,157 @@ data = { ClassifyId = 1161 -# 变量定义--测试环境 -server_host = '192.168.100.53:8080' # 内网 -# server_host = '183.242.74.28' # 外网 +# # 变量定义--线上环境 +# server_host = '10.200.32.39' +# login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" +# upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" +# upload_warning_url = "http://10.200.32.39/jingbo-api/api/basicBuiness/crudeOilWarning/save" +# query_data_list_item_nos_url = f"http://{server_host}/jingbo-api/api/warehouse/dwDataItem/queryDataListItemNos" +# # 上传数据项值 +# push_data_value_list_url = f"http://{server_host}/jingbo-api/api/dw/dataValue/pushDataValueList" +# # 上传停更数据到市场信息平台 +# push_waring_data_value_list_url = f"http://{server_host}/jingbo-api/api/basicBuiness/crudeOilWarning/crudeSaveOrupdate" +# # 获取预警数据中取消订阅指标ID +# get_waring_data_value_list_url = f"http://{server_host}/jingbo-api/api/basicBuiness/crudeOilWarning/dataList" -login_pushreport_url = f"http://{server_host}/jingbo-dev/api/server/login" -upload_url = f"http://{server_host}/jingbo-dev/api/analysis/reportInfo/researchUploadReportSave" -upload_warning_url = f"http://{server_host}/jingbo-dev/api/basicBuiness/crudeOilWarning/save" -query_data_list_item_nos_url = f"http://{server_host}/jingbo-dev/api/warehouse/dwDataItem/queryDataListItemNos" + +# login_data = { +# "data": { +# "account": "api_dev", +# "password": "ZTEwYWRjMzk0OWJhNTlhYmJlNTZlMDU3ZjIwZjg4M2U=", +# "tenantHashCode": "8a4577dbd919675758d57999a1e891fe", +# "terminal": "API" +# }, +# "funcModule": "API", +# "funcOperation": "获取token" +# } + + +# upload_data = { +# "funcModule": '研究报告信息', +# "funcOperation": '上传聚烯烃PP价格预测报告', +# "data": { +# "groupNo": '000211', # 用户组编号 +# "ownerAccount": '36541', # 报告所属用户账号  36541 - 贾青雪 +# "reportType": 'OIL_PRICE_FORECAST', # 报告类型,固定为OIL_PRICE_FORECAST +# "fileName": '', # 文件名称 +# "fileBase64": '', # 文件内容base64 +# "categoryNo": 'jxtjgycbg', # 研究报告分类编码 +# "smartBusinessClassCode": 'JXTJGYCBG', # 分析报告分类编码 +# "reportEmployeeCode": "E40482", # 报告人  E40482  - 管理员  0000027663 - 刘小朋   +# "reportDeptCode": "JXTJGYCBG", # 报告部门 - 002000621000  SH期货研究部   +# "productGroupCode": "RAW_MATERIAL" # 商品分类 +# } +# } + +# warning_data = { +# "funcModule": '原油特征停更预警', +# "funcOperation": '原油特征停更预警', +# "data": { +# "groupNo": "000211", +# 'WARNING_TYPE_NAME': '特征数据停更预警', +# 'WARNING_CONTENT': '', +# 'WARNING_DATE': '' +# } +# } + +# query_data_list_item_nos_data = { +# "funcModule": "数据项", +# "funcOperation": "查询", +# "data": { +# "dateStart": "20200101", +# "dateEnd": "", +# # 数据项编码,代表 PP期货 价格 +# "dataItemNoList": ["MAIN_CONFT_SETTLE_PRICE"] +# } +# } + + +# push_data_value_list_data = { +# "funcModule": "数据表信息列表", +# "funcOperation": "新增", +# "data": [ +# {"dataItemNo": "91230600716676129", +# "dataDate": "20230113", +# "dataStatus": "add", +# "dataValue": 100.11 +# }, +# {"dataItemNo": "91230600716676129P|ETHYL_BEN|CAPACITY", +# "dataDate": "20230113", +# "dataStatus": "add", +# "dataValue": 100.55 +# }, +# {"dataItemNo": "91230600716676129P|ETHYL_BEN|CAPACITY", +# "dataDate": "20230113", +# "dataStatus": "add", +# "dataValue": 100.55 +# } +# ] +# } + + +# push_waring_data_value_list_data = { +# "data": { +# "crudeOilWarningDtoList": [ +# { +# "lastUpdateDate": "20240501", +# "updateSuspensionCycle": 1, +# "dataSource": "9", +# "frequency": "1", +# "indicatorName": "美元指数", +# "indicatorId": "myzs001", +# "warningDate": "2024-05-13" +# } +# ], +# "dataSource": "9" +# }, +# "funcModule": "商品数据同步", +# "funcOperation": "同步" +# } + + +# get_waring_data_value_list_data = { +# "data": "9", "funcModule": "商品数据同步", "funcOperation": "同步"} + + +# # 八大维度数据项编码 +# bdwd_items = { +# 'ciri': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE', +# 'benzhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE01', +# 'cizhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE02', +# 'gezhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE03', +# 'ciyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE04', +# 'cieryue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE05', +# 'cisanyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE06', +# 'cisiyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE07', +# } + + +# # 生产环境数据库 +# host = 'rm-2zehj3r1n60ttz9x5.mysql.rds.aliyuncs.com' +# port = 3306 +# dbusername = 'jingbo' +# password = 'shihua@123' +# dbname = 'jingbo' +# table_name = 'v_tbl_crude_oil_warning' + + +# 变量定义--测试环境 +server_host = '192.168.100.53' # 内网 +# server_host = '183.242.74.28' # 外网 +login_pushreport_url = f"http://{server_host}:8080/jingbo-dev/api/server/login" +# 上传报告 +upload_url = f"http://{server_host}:8080/jingbo-dev/api/analysis/reportInfo/researchUploadReportSave" +# 停更预警 +upload_warning_url = f"http://{server_host}:8080/jingbo-dev/api/basicBuiness/crudeOilWarning/save" +# 查询数据项编码 +query_data_list_item_nos_url = f"http://{server_host}:8080/jingbo-dev/api/warehouse/dwDataItem/queryDataListItemNos" # 上传数据项值 -push_data_value_list_url = f"http://{server_host}/jingbo-dev/api/dw/dataValue/pushDataValueList" +push_data_value_list_url = f"http://{server_host}:8080/jingbo-dev/api/dw/dataValue/pushDataValueList" +# 上传停更数据到市场信息平台 +push_waring_data_value_list_url = f"http://{server_host}:8080/jingbo-dev/api/basicBuiness/crudeOilWarning/crudeSaveOrupdate" +# 获取预警数据中取消订阅指标ID +get_waring_data_value_list_url = f"http://{server_host}:8080/jingbo-dev/api/basicBuiness/crudeOilWarning/dataList" login_data = { "data": { @@ -186,6 +328,7 @@ upload_data = { "funcModule": '研究报告信息', "funcOperation": '上传聚烯烃PP价格预测报告', "data": { + "groupNo": "000127", "ownerAccount": 'arui', # 报告所属用户账号 "reportType": 'OIL_PRICE_FORECAST', # 报告类型,固定为OIL_PRICE_FORECAST "fileName": '2000-40-5-50--100-原油指标数据.xlsx-Brent活跃合约--2024-09-06-15-01-29-预测报告.pdf', # 文件名称 @@ -198,11 +341,12 @@ upload_data = { } } - +# 已弃用 warning_data = { "funcModule": '原油特征停更预警', "funcOperation": '原油特征停更预警', "data": { + "groupNo": "000127", 'WARNING_TYPE_NAME': '特征数据停更预警', 'WARNING_CONTENT': '', 'WARNING_DATE': '' @@ -214,8 +358,9 @@ query_data_list_item_nos_data = { "funcOperation": "查询", "data": { "dateStart": "20200101", - "dateEnd": "20241231", - "dataItemNoList": ["Brentzdj", "Brentzgj"] # 数据项编码,代表 brent最低价和最高价 + "dateEnd": "", + # 数据项编码,代表 PP期货 价格 + "dataItemNoList": ["MAIN_CONFT_SETTLE_PRICE"] } } @@ -241,6 +386,31 @@ push_data_value_list_data = { ] } + +push_waring_data_value_list_data = { + "data": { + "crudeOilWarningDtoList": [ + { + "lastUpdateDate": "20240501", + "updateSuspensionCycle": 1, + "dataSource": "9", + "frequency": "1", + "indicatorName": "美元指数", + "indicatorId": "myzs001", + "warningDate": "2024-05-13" + } + ], + "dataSource": "9" + }, + "funcModule": "商品数据同步", + "funcOperation": "同步" +} + + +get_waring_data_value_list_data = { + "data": "9", "funcModule": "商品数据同步", "funcOperation": "同步"} + + # 八大维度数据项编码 bdwd_items = { 'ciri': 'jxtppbdwdcr', @@ -262,22 +432,36 @@ password = '123456' dbname = 'jingbo_test' table_name = 'v_tbl_crude_oil_warning' +DEFAULT_CONFIG = { + 'feature_factor_frequency': 'D', + 'strategy_id': 2, + 'model_evaluation_id': 1, + 'tenant_code': '', + 'version_num': Decimal(1), + 'delete_flag': '0', + 'create_user': 'admin', + 'create_date': datetime.datetime.now(), + 'update_user': 'admin', + 'update_date': datetime.datetime.now(), + 'oil_code': 'PP', + 'oil_name': 'PP期货', +} # 开关 is_train = True # 是否训练 is_debug = False # 是否调试 -is_eta = True # 是否使用eta接口 -is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 +is_eta = False # 是否使用eta接口 +is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 is_edbnamelist = False # 自定义特征,对应上面的edbnamelist is_update_eta = True # 预测结果上传到eta -is_update_report = True # 是否上传报告 -is_update_warning_data = False # 是否上传预警数据 +is_update_report = False # 是否上传报告 +is_update_warning_data = True # 是否上传预警数据 is_update_predict_value = True # 是否上传预测值到市场信息平台 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 -is_del_tow_month = False # 是否删除两个月不更新的特征 +is_del_tow_month = True # 是否删除两个月不更新的特征 is_bdwd = False # 是否使用八大维度 @@ -299,15 +483,17 @@ if add_kdj and is_edbnamelist: edbnamelist = edbnamelist+['K', 'D', 'J'] # 模型参数 -y = 'AVG-金能大唐久泰青州' -avg_cols = [ - 'PP:拉丝:1102K:出厂价:青州:国家能源宁煤(日)', - 'PP:拉丝:L5E89:出厂价:华北(第二区域):内蒙古久泰新材料(日)', - 'PP:拉丝:L5E89:出厂价:河北、鲁北:大唐内蒙多伦(日)', - 'PP:拉丝:HP550J:市场价:青岛:金能化学(日)' -] -offsite = 80 -offsite_col = ['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'] +# y = 'AVG-金能大唐久泰青州' +# avg_cols = [ +# 'PP:拉丝:1102K:出厂价:青州:国家能源宁煤(日)', +# 'PP:拉丝:L5E89:出厂价:华北(第二区域):内蒙古久泰新材料(日)', +# 'PP:拉丝:L5E89:出厂价:河北、鲁北:大唐内蒙多伦(日)', +# 'PP:拉丝:HP550J:市场价:青岛:金能化学(日)' +# ] +# offsite = 80 +# offsite_col = ['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'] + +y = 'MAIN_CONFT_SETTLE_PRICE' horizon = 2 # 预测的步长 input_size = 14 # 输入序列长度 train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 @@ -330,9 +516,6 @@ weight_dict = [0.4, 0.15, 0.1, 0.1, 0.25] # 权重 data_set = 'PP指标数据.xlsx' # 数据集文件 dataset = 'juxitingzhoududataset' # 数据集文件夹 -print("当前工作目录:", os.getcwd()) -print("数据库路径:", os.path.abspath('juxitingzhoududataset/jbsh_juxiting_zhoudu.db')) - # 数据库名称 db_name = os.path.join(dataset, 'jbsh_juxiting_zhoudu.db') sqlitedb = SQLiteHandler(db_name) @@ -374,7 +557,7 @@ logger.setLevel(logging.INFO) file_handler = logging.handlers.RotatingFileHandler(os.path.join( log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5) file_handler.setFormatter(logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')) # 配置控制台处理器,将日志打印到控制台 console_handler = logging.StreamHandler() diff --git a/main_juxiting_zhoudu.py b/main_juxiting_zhoudu.py index 5048ad6..3ff723d 100644 --- a/main_juxiting_zhoudu.py +++ b/main_juxiting_zhoudu.py @@ -2,8 +2,8 @@ from lib.dataread import * from config_juxiting_zhoudu import * -from lib.tools import SendMail, exception_logger -from models.nerulforcastmodels import ex_Model, model_losss_juxiting, tansuanli_export_pdf, pp_export_pdf +from lib.tools import SendMail, exception_logger, convert_df_to_pydantic, exception_logger, get_modelsname +from models.nerulforcastmodels import ex_Model_Juxiting, model_losss_juxiting, pp_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -13,9 +13,9 @@ global_config.update({ 'logger': logger, 'dataset': dataset, 'y': y, - 'offsite_col': offsite_col, - 'avg_cols': avg_cols, - 'offsite': offsite, + # 'offsite_col': offsite_col, + # 'avg_cols': avg_cols, + # 'offsite': offsite, 'edbcodenamedict': edbcodenamedict, 'is_debug': is_debug, 'is_train': is_train, @@ -67,6 +67,14 @@ global_config.update({ 'push_data_value_list_url': push_data_value_list_url, 'push_data_value_list_data': push_data_value_list_data, + # 上传预警数据 + 'push_waring_data_value_list_url': push_waring_data_value_list_url, + 'push_waring_data_value_list_data': push_waring_data_value_list_data, + + # 获取取消订阅的数据 + 'get_waring_data_value_list_url': get_waring_data_value_list_url, + 'get_waring_data_value_list_data': get_waring_data_value_list_data, + # eta 配置 'APPID': APPID, 'SECRET': SECRET, @@ -83,12 +91,15 @@ global_config.update({ # 数据库配置 'sqlitedb': sqlitedb, + 'bdwd_items': bdwd_items, 'is_bdwd': is_bdwd, + 'db_mysql': db_mysql, + 'DEFAULT_CONFIG': DEFAULT_CONFIG, }) def push_market_value(): - logger.info('发送预测结果到市场信息平台') + config.logger.info('发送预测结果到市场信息平台') # 读取预测数据和模型评估数据 predict_file_path = os.path.join(config.dataset, 'predict.csv') model_eval_file_path = os.path.join(config.dataset, 'model_evaluation.csv') @@ -96,7 +107,7 @@ def push_market_value(): predictdata_df = pd.read_csv(predict_file_path) top_models_df = pd.read_csv(model_eval_file_path) except FileNotFoundError as e: - logger.error(f"文件未找到: {e}") + config.logger.error(f"文件未找到: {e}") return predictdata = predictdata_df.copy() @@ -140,7 +151,92 @@ def push_market_value(): try: push_market_data(predictdata) except Exception as e: - logger.error(f"推送数据失败: {e}") + config.logger.error(f"推送数据失败: {e}") + + +def sql_inset_predict(global_config): + df = pd.read_csv(os.path.join(config.dataset, 'predict.csv')) + df['created_dt'] = pd.to_datetime(df['created_dt']) + df['ds'] = pd.to_datetime(df['ds']) + # 获取次日预测结果 + next_day_df = df[df['ds'] == df['ds'].min()] + # 获取本周预测结果 + this_week_df = df[df['ds'] == df['ds'].max()] + + wd = ['day_price', 'week_price'] + model_name_list, model_id_name_dict = get_modelsname(df, global_config) + + PRICE_COLUMNS = [ + 'day_price', 'week_price', 'second_week_price', 'next_week_price', + 'next_month_price', 'next_february_price', 'next_march_price', 'next_april_price' + ] + + params_list = [] + for df, price_type in zip([next_day_df, this_week_df], wd): + + update_columns = [ + "feature_factor_frequency = VALUES(feature_factor_frequency)", + "oil_code = VALUES(oil_code)", + "oil_name = VALUES(oil_name)", + "data_date = VALUES(data_date)", + "market_price = VALUES(market_price)", + f"{price_type} = VALUES({price_type})", + "model_evaluation_id = VALUES(model_evaluation_id)", + "tenant_code = VALUES(tenant_code)", + "version_num = VALUES(version_num)", + "delete_flag = VALUES(delete_flag)", + "update_user = VALUES(update_user)", + "update_date = VALUES(update_date)" + ] + + insert_query = f""" + INSERT INTO v_tbl_predict_prediction_results ( + feature_factor_frequency, strategy_id, oil_code, oil_name, data_date, + market_price, day_price, week_price, second_week_price, next_week_price, + next_month_price, next_february_price, next_march_price, next_april_price, + model_evaluation_id, model_id, tenant_code, version_num, delete_flag, + create_user, create_date, update_user, update_date + ) VALUES ( + %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s + ) + ON DUPLICATE KEY UPDATE + {', '.join(update_columns)} + """ + + next_day_df = df[['ds', 'created_dt'] + model_name_list] + + pydantic_results = convert_df_to_pydantic( + next_day_df, model_id_name_dict, global_config) + if pydantic_results: + + for result in pydantic_results: + price_values = [None] * len(PRICE_COLUMNS) + price_index = PRICE_COLUMNS.index(price_type) + price_values[price_index] = next_day_df[model_id_name_dict[result.model_id]].values[0] + + params = ( + result.feature_factor_frequency, + result.strategy_id, + global_config['DEFAULT_CONFIG']['oil_code'], + global_config['DEFAULT_CONFIG']['oil_name'], + next_day_df['created_dt'].values[0], + result.market_price, + *price_values, + result.model_evaluation_id, + result.model_id, + result.tenant_code, + 1, + '0', + result.create_user, + result.create_date, + result.update_user, + result.update_date + ) + params_list.append(params) + affected_rows = config.db_mysql.execute_batch_insert( + insert_query, params_list) + config.logger.info(f"成功插入或更新 {affected_rows} 条记录") + config.db_mysql.close() def predict_main(): @@ -206,7 +302,7 @@ def predict_main(): try: # 如果是测试环境,最高价最低价取excel文档 if server_host == '192.168.100.53': - logger.info('从excel文档获取最高价最低价') + logger.info('从excel文档获取市场信息平台指标') df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) else: logger.info('从市场信息平台获取数据') @@ -214,7 +310,7 @@ def predict_main(): end_time, df_zhibiaoshuju) except: - logger.info('最高最低价拼接失败') + logger.info('市场信息平台数据项-eta数据项 拼接失败') # 保存到xlsx文件的sheet表 with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: @@ -346,26 +442,26 @@ def predict_main(): row, col = df.shape now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - ex_Model(df, - horizon=global_config['horizon'], - input_size=global_config['input_size'], - train_steps=global_config['train_steps'], - val_check_steps=global_config['val_check_steps'], - early_stop_patience_steps=global_config['early_stop_patience_steps'], - is_debug=global_config['is_debug'], - dataset=global_config['dataset'], - is_train=global_config['is_train'], - is_fivemodels=global_config['is_fivemodels'], - val_size=global_config['val_size'], - test_size=global_config['test_size'], - settings=global_config['settings'], - now=now, - etadata=etadata, - modelsindex=global_config['modelsindex'], - data=data, - is_eta=global_config['is_eta'], - end_time=global_config['end_time'], - ) + ex_Model_Juxiting(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], + now=now, + etadata=etadata, + modelsindex=global_config['modelsindex'], + data=data, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], + ) logger.info('模型训练完成') @@ -375,17 +471,18 @@ def predict_main(): logger.info('训练数据绘图end') # # 模型报告 - # logger.info('制作报告ing') - # title = f'{settings}--{end_time}-预测报告' # 报告标题 - # reportname = f'聚烯烃PP大模型周度预测--{end_time}.pdf' # 报告文件名 - # reportname = reportname.replace(':', '-') # 替换冒号 - # pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, - # reportname=reportname, sqlitedb=sqlitedb), + logger.info('制作报告ing') + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = f'聚烯烃PP大模型周度预测--{end_time}.pdf' # 报告文件名 + reportname = reportname.replace(':', '-') # 替换冒号 + pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), - # logger.info('制作报告end') - # logger.info('模型训练完成') + logger.info('制作报告end') + logger.info('模型训练完成') - # push_market_value() + push_market_value() + sql_inset_predict(global_config) # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) @@ -412,11 +509,15 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2025-3-1', '2025-5-26', freq='W'): - try: - global_config['end_time'] = i_time.strftime('%Y-%m-%d') - predict_main() - except Exception as e: - logger.info(f'预测失败:{e}') - continue + # for i_time in pd.date_range('2025-4-14', '2025-4-15', freq='B'): + # try: + # global_config['end_time'] = i_time.strftime('%Y-%m-%d') + # predict_main() + # except Exception as e: + # logger.info(f'预测失败:{e}') + # continue + # predict_main() + + # push_market_value() + sql_inset_predict(global_config) diff --git a/main_yuanyou_zhoudu.py b/main_yuanyou_zhoudu.py index ebb4d29..e3c94c3 100644 --- a/main_yuanyou_zhoudu.py +++ b/main_yuanyou_zhoudu.py @@ -1,9 +1,9 @@ # 读取配置 from lib.dataread import * -from config_jingbo_zhoudu import * -from lib.tools import SendMail, convert_df_to_pydantic, exception_logger, get_modelsname -from models.nerulforcastmodels import ex_Model, model_losss, brent_export_pdf +from config_juxiting_zhoudu import * +from lib.tools import SendMail, exception_logger, convert_df_to_pydantic, exception_logger, get_modelsname +from models.nerulforcastmodels import ex_Model_Juxiting, model_losss_juxiting, pp_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -13,6 +13,10 @@ global_config.update({ 'logger': logger, 'dataset': dataset, 'y': y, + # 'offsite_col': offsite_col, + # 'avg_cols': avg_cols, + # 'offsite': offsite, + 'edbcodenamedict': edbcodenamedict, 'is_debug': is_debug, 'is_train': is_train, 'is_fivemodels': is_fivemodels, @@ -63,6 +67,14 @@ global_config.update({ 'push_data_value_list_url': push_data_value_list_url, 'push_data_value_list_data': push_data_value_list_data, + # 上传预警数据 + 'push_waring_data_value_list_url': push_waring_data_value_list_url, + 'push_waring_data_value_list_data': push_waring_data_value_list_data, + + # 获取取消订阅的数据 + 'get_waring_data_value_list_url': get_waring_data_value_list_url, + 'get_waring_data_value_list_data': get_waring_data_value_list_data, + # eta 配置 'APPID': APPID, 'SECRET': SECRET, @@ -79,6 +91,7 @@ global_config.update({ # 数据库配置 'sqlitedb': sqlitedb, + 'bdwd_items': bdwd_items, 'is_bdwd': is_bdwd, 'db_mysql': db_mysql, 'DEFAULT_CONFIG': DEFAULT_CONFIG, @@ -86,7 +99,7 @@ global_config.update({ def push_market_value(): - logger.info('发送预测结果到市场信息平台') + config.logger.info('发送预测结果到市场信息平台') # 读取预测数据和模型评估数据 predict_file_path = os.path.join(config.dataset, 'predict.csv') model_eval_file_path = os.path.join(config.dataset, 'model_evaluation.csv') @@ -94,7 +107,7 @@ def push_market_value(): predictdata_df = pd.read_csv(predict_file_path) top_models_df = pd.read_csv(model_eval_file_path) except FileNotFoundError as e: - logger.error(f"文件未找到: {e}") + config.logger.error(f"文件未找到: {e}") return predictdata = predictdata_df.copy() @@ -138,7 +151,7 @@ def push_market_value(): try: push_market_data(predictdata) except Exception as e: - logger.error(f"推送数据失败: {e}") + config.logger.error(f"推送数据失败: {e}") def sql_inset_predict(global_config): @@ -154,7 +167,7 @@ def sql_inset_predict(global_config): model_name_list, model_id_name_dict = get_modelsname(df, global_config) PRICE_COLUMNS = [ - 'day_price', 'week_price', 'second_week_price', 'next_week_price', + 'day_price', 'week_price', 'second_week_price', 'next_week_price', 'next_month_price', 'next_february_price', 'next_march_price', 'next_april_price' ] @@ -204,8 +217,8 @@ def sql_inset_predict(global_config): params = ( result.feature_factor_frequency, result.strategy_id, - result.oil_code, - result.oil_name, + global_config['DEFAULT_CONFIG']['oil_code'], + global_config['DEFAULT_CONFIG']['oil_name'], next_day_df['created_dt'].values[0], result.market_price, *price_values, @@ -266,7 +279,6 @@ def predict_main(): None """ end_time = global_config['end_time'] - signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, classifylisturl=global_config['classifylisturl'], @@ -282,7 +294,7 @@ def predict_main(): if is_eta: logger.info('从eta获取数据...') - df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( data_set=data_set, dataset=dataset) # 原始数据,未处理 if is_market: @@ -290,7 +302,7 @@ def predict_main(): try: # 如果是测试环境,最高价最低价取excel文档 if server_host == '192.168.100.53': - logger.info('从excel文档获取最高价最低价') + logger.info('从excel文档获取市场信息平台指标') df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) else: logger.info('从市场信息平台获取数据') @@ -298,7 +310,7 @@ def predict_main(): end_time, df_zhibiaoshuju) except: - logger.info('最高最低价拼接失败') + logger.info('市场信息平台数据项-eta数据项 拼接失败') # 保存到xlsx文件的sheet表 with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: @@ -306,14 +318,14 @@ def predict_main(): df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) # 数据处理 - df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - end_time=end_time) + df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) else: # 读取数据 logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 # 更改预测列名称 df.rename(columns={y: 'y'}, inplace=True) @@ -383,8 +395,8 @@ def predict_main(): # except Exception as e: # logger.info(f'更新accuracy表的y值失败:{e}') - # 判断当前日期是不是周一 预测目标周度许转换,暂注释 - # is_weekday = datetime.datetime.strptime(global_config['end_time'], "%Y-%m-%d").weekday() == 0 + # 判断当前日期是不是周一 + is_weekday = datetime.datetime.now().weekday() == 0 # if is_weekday: # logger.info('今天是周一,更新预测模型') # # 计算最近60天预测残差最低的模型名称 @@ -430,45 +442,41 @@ def predict_main(): row, col = df.shape now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - ex_Model(df, - horizon=global_config['horizon'], - input_size=global_config['input_size'], - train_steps=global_config['train_steps'], - val_check_steps=global_config['val_check_steps'], - early_stop_patience_steps=global_config['early_stop_patience_steps'], - is_debug=global_config['is_debug'], - dataset=global_config['dataset'], - is_train=global_config['is_train'], - is_fivemodels=global_config['is_fivemodels'], - val_size=global_config['val_size'], - test_size=global_config['test_size'], - settings=global_config['settings'], - now=now, - etadata=etadata, - modelsindex=global_config['modelsindex'], - data=data, - is_eta=global_config['is_eta'], - end_time=global_config['end_time'], - ) + ex_Model_Juxiting(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], + now=now, + etadata=etadata, + modelsindex=global_config['modelsindex'], + data=data, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], + ) logger.info('模型训练完成') logger.info('训练数据绘图ing') - model_results3 = model_losss(sqlitedb, end_time=end_time) + model_results3 = model_losss_juxiting( + sqlitedb, end_time=global_config['end_time'], is_fivemodels=global_config['is_fivemodels']) logger.info('训练数据绘图end') # # 模型报告 logger.info('制作报告ing') title = f'{settings}--{end_time}-预测报告' # 报告标题 - reportname = f'Brent原油大模型周度预测--{end_time}.pdf' # 报告文件名 + reportname = f'聚烯烃PP大模型周度预测--{end_time}.pdf' # 报告文件名 reportname = reportname.replace(':', '-') # 替换冒号 - brent_export_pdf(dataset=dataset, - num_models=5 if is_fivemodels else 22, - time=end_time, - reportname=reportname, - inputsize=global_config['horizon'], - sqlitedb=sqlitedb - ), + pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), logger.info('制作报告end') logger.info('模型训练完成') @@ -476,6 +484,15 @@ def predict_main(): push_market_value() sql_inset_predict(global_config) + # # LSTM 单变量模型 + # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) + + # # lstm 多变量模型 + # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) + + # # GRU 模型 + # # ex_GRU(df) + # 发送邮件 # m = SendMail( # username=username, @@ -492,10 +509,15 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2025-6-23', '2025-6-30', freq='B'): - global_config['end_time'] = i_time.strftime('%Y-%m-%d') - global_config['db_mysql'].connect() - predict_main() + # for i_time in pd.date_range('2025-4-14', '2025-4-15', freq='B'): + # try: + # global_config['end_time'] = i_time.strftime('%Y-%m-%d') + # predict_main() + # except Exception as e: + # logger.info(f'预测失败:{e}') + # continue - # predict_main() - # sql_inset_predict(global_config=global_config) + predict_main() + + # push_market_value() + # sql_inset_predict(global_config)