From 7788cfda6f07f60a1d1385d0956baea506fdd63e Mon Sep 17 00:00:00 2001 From: jingboyitiji Date: Tue, 27 May 2025 18:08:52 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9F=B3=E6=B2=B9=E7=84=A6=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_juxiting.py | 136 +++++++++ config_shiyoujiao_lvyong.py | 13 +- config_shiyoujiao_lvyong_yuedu.py | 20 +- config_shiyoujiao_lvyong_zhoudu.py | 12 +- lib/dataread.py | 3 + main_shiyoujiao_lvyong_yuedu.py | 442 +++++++++++++++++++++++++++++ main_shiyoujiao_lvyong_zhoudu.py | 3 +- models/nerulforcastmodels.py | 6 +- 8 files changed, 610 insertions(+), 25 deletions(-) create mode 100644 main_shiyoujiao_lvyong_yuedu.py diff --git a/config_juxiting.py b/config_juxiting.py index df36f2b..50c79d2 100644 --- a/config_juxiting.py +++ b/config_juxiting.py @@ -137,6 +137,142 @@ data = { ClassifyId = 1161 + +# 变量定义--线上环境 +server_host = '10.200.32.39' +login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" +upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" +upload_warning_url = "http://10.200.32.39/jingbo-api/api/basicBuiness/crudeOilWarning/save" +query_data_list_item_nos_url = f"http://{server_host}/jingbo-api/api/warehouse/dwDataItem/queryDataListItemNos" +# 上传数据项值 +push_data_value_list_url = f"http://{server_host}/jingbo-api/api/dw/dataValue/pushDataValueList" +上传停更数据到市场信息平台 +push_waring_data_value_list_url = f"http://{server_host}:8080/jingbo-dev/api/basicBuiness/crudeOilWarning/crudeSaveOrupdate" +获取预警数据中取消订阅指标ID +get_waring_data_value_list_url = f"http://{server_host}:8080/jingbo-dev/api/basicBuiness/crudeOilWarning/dataList" + + +login_data = { + "data": { + "account": "api_dev", + "password": "ZTEwYWRjMzk0OWJhNTlhYmJlNTZlMDU3ZjIwZjg4M2U=", + "tenantHashCode": "8a4577dbd919675758d57999a1e891fe", + "terminal": "API" + }, + "funcModule": "API", + "funcOperation": "获取token" +} + + +upload_data = { + "funcModule":'研究报告信息', + "funcOperation":'上传聚烯烃PP价格预测报告', + "data":{ +    "groupNo":'000211' # 用户组编号 +        "ownerAccount":'36541', #报告所属用户账号  36541 - 贾青雪 +        "reportType":'OIL_PRICE_FORECAST', # 报告类型,固定为OIL_PRICE_FORECAST +        "fileName": '', #文件名称 +        "fileBase64": '' ,#文件内容base64 +        "categoryNo":'jxtjgycbg', # 研究报告分类编码 +        "smartBusinessClassCode":'JXTJGYCBG', #分析报告分类编码 +        "reportEmployeeCode":"E40482" ,# 报告人  E40482  - 管理员  0000027663 - 刘小朋   +        "reportDeptCode" :"JXTJGYCBG", # 报告部门 - 002000621000  SH期货研究部   +        "productGroupCode":"RAW_MATERIAL"  # 商品分类 +  } +} + +warning_data = { + "funcModule": '原油特征停更预警', + "funcOperation": '原油特征停更预警', + "data": { + "groupNo": "000211", + 'WARNING_TYPE_NAME': '特征数据停更预警', + 'WARNING_CONTENT': '', + 'WARNING_DATE': '' + } +} + +query_data_list_item_nos_data = { + "funcModule": "数据项", + "funcOperation": "查询", + "data": { + "dateStart":"20200101", + "dateEnd":"20241231", + "dataItemNoList":["Brentzdj","Brentzgj"] # 数据项编码,代表 brent最低价和最高价 + } +} + + +push_data_value_list_data = { + "funcModule": "数据表信息列表", + "funcOperation": "新增", + "data": [ + {"dataItemNo": "91230600716676129", + "dataDate": "20230113", + "dataStatus": "add", + "dataValue": 100.11 + }, + {"dataItemNo": "91230600716676129P|ETHYL_BEN|CAPACITY", + "dataDate": "20230113", + "dataStatus": "add", + "dataValue": 100.55 + }, + {"dataItemNo": "91230600716676129P|ETHYL_BEN|CAPACITY", + "dataDate": "20230113", + "dataStatus": "add", + "dataValue": 100.55 + } + ] +} + + +push_waring_data_value_list_data = { + "data": { + "crudeOilWarningDtoList": [ + { + "lastUpdateDate": "20240501", + "updateSuspensionCycle": 1, + "dataSource": "9", + "frequency": "1", + "indicatorName": "美元指数", + "indicatorId": "myzs001", + "warningDate": "2024-05-13" + } + ], + "dataSource": "9" + }, + "funcModule": "商品数据同步", + "funcOperation": "同步" +} + + +get_waring_data_value_list_data = { + "data": "9", "funcModule": "商品数据同步", "funcOperation": "同步"} + + +# 八大维度数据项编码 +bdwd_items = { + 'ciri': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE', + 'benzhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE01', + 'cizhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE02', + 'gezhou': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE03', + 'ciyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE04', + 'cieryue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE05', + 'cisanyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE06', + 'cisiyue': '91371600MAC3TYFN6M|LSBM00007|FORECAST_PRICE07', +} + + +# 生产环境数据库 +host = 'rm-2zehj3r1n60ttz9x5.mysql.rds.aliyuncs.com' +port = 3306 +dbusername ='jingbo' +password = 'shihua@123' +dbname = 'jingbo' +table_name = 'v_tbl_crude_oil_warning' + + + # 变量定义--测试环境 server_host = '192.168.100.53' # 内网 # server_host = '183.242.74.28' # 外网 diff --git a/config_shiyoujiao_lvyong.py b/config_shiyoujiao_lvyong.py index cc230e6..4b7e454 100644 --- a/config_shiyoujiao_lvyong.py +++ b/config_shiyoujiao_lvyong.py @@ -295,16 +295,17 @@ push_data_value_list_data = { } ] } + # 八大维度数据项编码 bdwd_items = { 'ciri': 'syjlyycbdwdcr', 'benzhou': 'syjlyycbdwdbz', - 'cizhou': 'syjlyycbdwdcz', - 'gezhou': 'syjlyycbdwdgz', - 'ciyue': 'syjlyycbdwdcy', - 'cieryue': 'syjlyycbdwdcey', - 'cisanyue': 'syjlyycbdwdcsy', - 'cisiyue': 'syjlyycbdwdcsiy', + 'cizhou': 'syj|nextweek|price', + 'gezhou': 'syj|next-one-week|price', + 'ciyue': 'syj|next-month|price', + 'cieryue': 'syj|next-one-month|price', + 'cisanyue': 'yj|next-two-month|price', + 'cisiyue': 'yj|next-three-month|price', } # 北京环境数据库 diff --git a/config_shiyoujiao_lvyong_yuedu.py b/config_shiyoujiao_lvyong_yuedu.py index 804fc24..5570610 100644 --- a/config_shiyoujiao_lvyong_yuedu.py +++ b/config_shiyoujiao_lvyong_yuedu.py @@ -466,21 +466,23 @@ push_data_value_list_data = { } ] } + # 八大维度数据项编码 bdwd_items = { 'ciri': 'syjlyycbdwdcr', 'benzhou': 'syjlyycbdwdbz', - 'cizhou': 'syjlyycbdwdcz', - 'gezhou': 'syjlyycbdwdgz', - 'ciyue': 'syjlyycbdwdcy', - 'cieryue': 'syjlyycbdwdcey', - 'cisanyue': 'syjlyycbdwdcsy', - 'cisiyue': 'syjlyycbdwdcsiy', + 'cizhou': 'syj|nextweek|price', + 'gezhou': 'syj|next-one-week|price', + 'ciyue': 'syj|next-month|price', + 'cieryue': 'syj|next-one-month|price', + 'cisanyue': 'yj|next-two-month|price', + 'cisiyue': 'yj|next-three-month|price', } # 报告中八大维度数据项重命名 -columnsrename={'syjlyycbdwdbz': '本周', 'syjlyycbdwdcey': '次二月', 'syjlyycbdwdcr': '次日', 'syjlyycbdwdcsiy': '次四月', - 'syjlyycbdwdcsy': '次三月', 'syjlyycbdwdcy': '次月', 'syjlyycbdwdcz': '次周', 'syjlyycbdwdgz': '隔周', } +columnsrename={ 'syjlyycbdwdcr': '次日', 'syjlyycbdwdbz': '本周', + 'syj|nextweek|price': '次周', 'syj|next-one-week|price': '隔周', + 'syj|next-month|price': '次月', 'syj|next-one-month|price': '次二月', 'yj|next-two-month|price': '次三月', 'yj|next-three-month|price': '次四月'} # 北京环境数据库 host = '192.168.101.27' port = 3306 @@ -542,7 +544,7 @@ avg_cols = [ offsite = 80 offsite_col = [] horizon = 4 # 预测的步长 -input_size = 16 # 输入序列长度 +input_size = 8 # 输入序列长度 train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 val_check_steps = 30 # 评估频率 early_stop_patience_steps = 5 # 早停的耐心步数 diff --git a/config_shiyoujiao_lvyong_zhoudu.py b/config_shiyoujiao_lvyong_zhoudu.py index 029d1e3..fd16269 100644 --- a/config_shiyoujiao_lvyong_zhoudu.py +++ b/config_shiyoujiao_lvyong_zhoudu.py @@ -326,12 +326,12 @@ push_data_value_list_data = { bdwd_items = { 'ciri': 'syjlyycbdwdcr', 'benzhou': 'syjlyycbdwdbz', - 'cizhou': 'syjlyycbdwdcz', - 'gezhou': 'syjlyycbdwdgz', - 'ciyue': 'syjlyycbdwdcy', - 'cieryue': 'syjlyycbdwdcey', - 'cisanyue': 'syjlyycbdwdcsy', - 'cisiyue': 'syjlyycbdwdcsiy', + 'cizhou': 'syj|nextweek|price', + 'gezhou': 'syj|next-one-week|price', + 'ciyue': 'syj|next-month|price', + 'cieryue': 'syj|next-one-month|price', + 'cisanyue': 'yj|next-two-month|price', + 'cisiyue': 'yj|next-three-month|price', } # 北京环境数据库 diff --git a/lib/dataread.py b/lib/dataread.py index dc798f6..f8475c0 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -79,6 +79,9 @@ global_config = { 'upload_warning_url': None, # 预警数据上传地址 'upload_warning_data': None, # 预警数据结构 + # 报告上传 + 'upload_data': None, # 报告数据结构 + # 查询接口 'query_data_list_item_nos_url': None, # 数据项查询地址 'query_data_list_item_nos_data': None, # 数据项查询参数 diff --git a/main_shiyoujiao_lvyong_yuedu.py b/main_shiyoujiao_lvyong_yuedu.py new file mode 100644 index 0000000..d2baba9 --- /dev/null +++ b/main_shiyoujiao_lvyong_yuedu.py @@ -0,0 +1,442 @@ +# 读取配置 + +from lib.dataread import * +from config_shiyoujiao_lvyong_yuedu import * +from lib.tools import SendMail, exception_logger +from models.nerulforcastmodels import ex_Model, model_losss, shiyoujiao_lvyong_export_pdf +import datetime +import torch +torch.set_float32_matmul_precision("high") + +global_config.update({ + # 核心参数 + 'logger': logger, + 'dataset': dataset, + 'y': y, + 'is_debug': is_debug, + 'is_train': is_train, + 'is_fivemodels': is_fivemodels, + 'is_update_report': is_update_report, + 'settings': settings, + 'weight_dict': weight_dict, + 'baichuanidnamedict': baichuanidnamedict, + 'bdwdname': bdwdname, + + + # 模型参数 + 'data_set': data_set, + 'input_size': input_size, + 'horizon': horizon, + 'train_steps': train_steps, + 'val_check_steps': val_check_steps, + 'val_size': val_size, + 'test_size': test_size, + 'modelsindex': modelsindex, + 'rote': rote, + 'bdwd_items': bdwd_items, + 'baichuanidnamedict': baichuanidnamedict, + + # 特征工程开关 + 'is_del_corr': is_del_corr, + 'is_del_tow_month': is_del_tow_month, + 'is_eta': is_eta, + 'is_update_eta': is_update_eta, + 'is_fivemodels': is_fivemodels, + 'is_update_predict_value': is_update_predict_value, + 'early_stop_patience_steps': early_stop_patience_steps, + + # 时间参数 + 'start_year': start_year, + 'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"), + 'freq': freq, # 保持列表结构 + + # 接口配置 + 'login_pushreport_url': login_pushreport_url, + 'login_data': login_data, + 'upload_url': upload_url, + 'upload_data': upload_data, + 'upload_warning_url': upload_warning_url, + 'warning_data': warning_data, + + # 查询接口 + 'query_data_list_item_nos_url': query_data_list_item_nos_url, + 'query_data_list_item_nos_data': query_data_list_item_nos_data, + + # 上传数据项 + 'push_data_value_list_url': push_data_value_list_url, + 'push_data_value_list_data': push_data_value_list_data, + + # eta 配置 + 'APPID': APPID, + 'SECRET': SECRET, + 'etadata': data, + 'edbcodelist': edbcodelist, + 'ClassifyId': ClassifyId, + 'edbcodedataurl': edbcodedataurl, + 'classifyidlisturl': classifyidlisturl, + 'edbdatapushurl': edbdatapushurl, + 'edbdeleteurl': edbdeleteurl, + 'edbbusinessurl': edbbusinessurl, + 'edbcodenamedict': edbcodenamedict, + 'ClassifyId': ClassifyId, + 'classifylisturl': classifylisturl, + + # 数据库配置 + 'sqlitedb': sqlitedb, + 'is_bdwd': is_bdwd, + 'columnsrename':columnsrename, + 'db_mysql': db_mysql, + 'baichuan_table_name': baichuan_table_name, +}) + + +def push_market_value(): + logger.info('发送预测结果到市场信息平台') + # 读取预测数据和模型评估数据 + predict_file_path = os.path.join(config.dataset, 'predict.csv') + model_eval_file_path = os.path.join(config.dataset, 'model_evaluation.csv') + try: + predictdata_df = pd.read_csv(predict_file_path) + top_models_df = pd.read_csv(model_eval_file_path) + except FileNotFoundError as e: + logger.error(f"文件未找到: {e}") + return + + predictdata = predictdata_df.copy() + + # 取模型前十 + top_models = top_models_df['模型(Model)'].head(10).tolist() + # 去掉FDBformer + if 'FEDformer' in top_models: + top_models.remove('FEDformer') + # 计算前十模型的均值 + predictdata_df['top_models_mean'] = predictdata_df[top_models].mean(axis=1) + + # 打印日期和前十模型均值 + print(predictdata_df[['ds', 'top_models_mean']]) + + # 准备要推送的数据 + ciyue_mean = predictdata_df['top_models_mean'].iloc[0] + cieryue_mean = predictdata_df['top_models_mean'].iloc[1] + cisanyue_mean = predictdata_df['top_models_mean'].iloc[2] + cisieryue_mean = predictdata_df['top_models_mean'].iloc[3] + # 保留两位小数 + ciyue_mean = round(ciyue_mean, 2) + cieryue_mean = round(cieryue_mean, 2) + cisanyue_mean = round(cisanyue_mean, 2) + cisieryue_mean = round(cisieryue_mean, 2) + + predictdata = [ + { + "dataItemNo": global_config['bdwd_items']['ciyue'], + "dataDate": global_config['end_time'].replace('-', ''), + "dataStatus": "add", + "dataValue": ciyue_mean + }, + { + "dataItemNo": global_config['bdwd_items']['cieryue'], + "dataDate": global_config['end_time'].replace('-', ''), + "dataStatus": "add", + "dataValue": cieryue_mean + }, + { + "dataItemNo": global_config['bdwd_items']['cisanyue'], + "dataDate": global_config['end_time'].replace('-', ''), + "dataStatus": "add", + "dataValue": cisanyue_mean + }, + { + "dataItemNo": global_config['bdwd_items']['cisiyue'], + "dataDate": global_config['end_time'].replace('-', ''), + "dataStatus": "add", + "dataValue": cisieryue_mean + } + ] + + print(predictdata) + + # 推送数据到市场信息平台 + try: + push_market_data(predictdata) + except Exception as e: + logger.error(f"推送数据失败: {e}") + + +def predict_main(): + """ + 主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。 + + 参数: + signature (BinanceAPI): Binance API 实例。 + etadata (EtaReader): ETA 数据读取器实例。 + is_eta (bool): 是否从 ETA 获取数据。 + data_set (str): 数据集名称。 + dataset (str): 数据集路径。 + add_kdj (bool): 是否添加 KDJ 指标。 + is_timefurture (bool): 是否添加时间衍生特征。 + end_time (str): 结束时间。 + is_edbnamelist (bool): 是否使用 EDB 名称列表。 + edbnamelist (list): EDB 名称列表。 + y (str): 预测目标列名。 + sqlitedb (SQLiteDB): SQLite 数据库实例。 + is_corr (bool): 是否进行相关性分析。 + horizon (int): 预测时域。 + input_size (int): 输入数据大小。 + train_steps (int): 训练步数。 + val_check_steps (int): 验证检查步数。 + early_stop_patience_steps (int): 早停耐心步数。 + is_debug (bool): 是否调试模式。 + dataset (str): 数据集名称。 + is_train (bool): 是否训练模型。 + is_fivemodels (bool): 是否使用五个模型。 + val_size (float): 验证集大小。 + test_size (float): 测试集大小。 + settings (dict): 模型设置。 + now (str): 当前时间。 + etadata (EtaReader): ETA 数据读取器实例。 + modelsindex (list): 模型索引列表。 + data (str): 数据类型。 + is_eta (bool): 是否从 ETA 获取数据。 + + 返回: + None + """ + + end_time = global_config['end_time'] + # 获取数据 + if is_eta: + logger.info('从eta获取数据...') + signature = BinanceAPI(APPID, SECRET) + etadata = EtaReader(signature=signature, + classifylisturl=global_config['classifylisturl'], + classifyidlisturl=global_config['classifyidlisturl'], + edbcodedataurl=global_config['edbcodedataurl'], + edbcodelist=global_config['edbcodelist'], + edbdatapushurl=global_config['edbdatapushurl'], + edbdeleteurl=global_config['edbdeleteurl'], + edbbusinessurl=global_config['edbbusinessurl'], + classifyId=global_config['ClassifyId'], + ) + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_shiyoujiao_lvyong_data( + data_set=data_set, dataset=dataset) # 原始数据,未处理 + + if is_market: + logger.info('从市场信息平台获取数据...') + try: + # 如果是测试环境,最高价最低价取excel文档 + if server_host == '192.168.100.53': + logger.info('从excel文档获取最高价最低价') + df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) + else: + logger.info('从市场信息平台获取数据') + df_zhibiaoshuju = get_market_data( + end_time, df_zhibiaoshuju) + + except: + logger.info('最高最低价拼接失败') + + if len(global_config['baichuanidnamedict']) > 0: + logger.info('从市场数据库获取百川数据...') + baichuandf = get_baichuan_data(global_config['baichuanidnamedict']) + df_zhibiaoshuju = pd.merge( + df_zhibiaoshuju, baichuandf, on='date', how='outer') + # 指标列表添加百川数据 + df_baichuanliebiao = pd.DataFrame( + global_config['baichuanidnamedict'].items(), columns=['指标id', '指标名称']) + df_baichuanliebiao['指标分类'] = '石油焦对标炼厂价格' + df_baichuanliebiao['频度'] = '其他' + df_zhibiaoliebiao = pd.concat( + [df_zhibiaoliebiao, df_baichuanliebiao], axis=0) + + # 保存到xlsx文件的sheet表 + with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: + df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) + df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) + + # 数据处理 + df = zhoududatachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) + + else: + # 读取数据 + logger.info('读取本地数据:' + os.path.join(dataset, data_set)) + df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + + # 更改预测列名称 + df.rename(columns={y: 'y'}, inplace=True) + + if is_edbnamelist: + df = df[edbnamelist] + df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) + # 保存最新日期的y值到数据库 + # 取第一行数据存储到数据库中 + first_row = df[['ds', 'y']].tail(1) + # 判断y的类型是否为float + if not isinstance(first_row['y'].values[0], float): + logger.info(f'{end_time}预测目标数据为空,跳过') + return None + + # 将最新真实值保存到数据库 + if not sqlitedb.check_table_exists('trueandpredict'): + first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) + else: + for row in first_row.itertuples(index=False): + row_dict = row._asdict() + config.logger.info(f'要保存的真实值:{row_dict}') + # 判断ds是否为字符串类型,如果不是则转换为字符串类型 + if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): + row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + elif not isinstance(row_dict['ds'], str): + try: + row_dict['ds'] = pd.to_datetime( + row_dict['ds']).strftime('%Y-%m-%d') + except: + logger.warning(f"无法解析的时间格式: {row_dict['ds']}") + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') + check_query = sqlitedb.select_data( + 'trueandpredict', where_condition=f"ds = '{row.ds}'") + if len(check_query) > 0: + set_clause = ", ".join( + [f"{key} = '{value}'" for key, value in row_dict.items()]) + sqlitedb.update_data( + 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") + continue + sqlitedb.insert_data('trueandpredict', tuple( + row_dict.values()), columns=row_dict.keys()) + + # 更新accuracy表的y值 + if not sqlitedb.check_table_exists('accuracy'): + pass + else: + update_y = sqlitedb.select_data( + 'accuracy', where_condition="y is null") + if len(update_y) > 0: + logger.info('更新accuracy表的y值') + # 找到update_y 中ds且df中的y的行 + update_y = update_y[update_y['ds'] <= end_time] + logger.info(f'要更新y的信息:{update_y}') + # try: + for row in update_y.itertuples(index=False): + try: + row_dict = row._asdict() + yy = df[df['ds'] == row_dict['ds']]['y'].values[0] + LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] + HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] + sqlitedb.update_data( + 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") + except: + logger.info(f'更新accuracy表的y值失败:{row_dict}') + # except Exception as e: + # logger.info(f'更新accuracy表的y值失败:{e}') + + # 判断当前日期是不是周一 + is_weekday = datetime.datetime.now().weekday() == 0 + if is_weekday: + logger.info('今天是周一,更新预测模型') + # 计算最近60天预测残差最低的模型名称 + model_results = sqlitedb.select_data( + 'trueandpredict', order_by="ds DESC", limit="60") + # 删除空值率为90%以上的列 + if len(model_results) > 10: + model_results = model_results.dropna( + thresh=len(model_results)*0.1, axis=1) + # 删除空行 + model_results = model_results.dropna() + modelnames = model_results.columns.to_list()[2:-2] + for col in model_results[modelnames].select_dtypes(include=['object']).columns: + model_results[col] = model_results[col].astype(np.float32) + # 计算每个预测值与真实值之间的偏差率 + for model in modelnames: + model_results[f'{model}_abs_error_rate'] = abs( + model_results['y'] - model_results[model]) / model_results['y'] + # 获取每行对应的最小偏差率值 + min_abs_error_rate_values = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) + # 获取每行对应的最小偏差率值对应的列名 + min_abs_error_rate_column_name = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) + # 将列名索引转换为列名 + min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( + lambda x: x.split('_')[0]) + # 取出现次数最多的模型名称 + most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() + logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") + # 保存结果到数据库 + if not sqlitedb.check_table_exists('most_model'): + sqlitedb.create_table( + 'most_model', columns="ds datetime, most_common_model TEXT") + sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) + + if is_corr: + df = corr_feature(df=df) + + df1 = df.copy() # 备份一下,后面特征筛选完之后加入ds y 列用 + logger.info(f"开始训练模型...") + row, col = df.shape + + now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ex_Model(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], + now=now, + etadata=etadata, + modelsindex=global_config['modelsindex'], + data=data, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], + ) + + logger.info('模型训练完成') + + logger.info('训练数据绘图ing') + model_results3 = model_losss(sqlitedb, end_time=end_time) + logger.info('训练数据绘图end') + + # 模型报告 + logger.info('制作报告ing') + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = '石油焦大模型铝用渠道.pdf' # 报告文件名 + # reportname = f'石油焦铝用大模型月度预测--{end_time}.pdf' # 报告文件名 + # reportname = reportname.replace(':', '-') # 替换冒号 + shiyoujiao_lvyong_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), + + logger.info('制作报告end') + logger.info('模型训练完成') + + push_market_value() + + # 发送邮件 + # m = SendMail( + # username=username, + # passwd=passwd, + # recv=recv, + # title=title, + # content=content, + # file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), + # ssl=ssl, + # ) + # m.send_mail() + + +if __name__ == '__main__': + # global end_time + # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 + # for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'): + # end_time = i_time.strftime('%Y-%m-%d') + # predict_main() + + predict_main() diff --git a/main_shiyoujiao_lvyong_zhoudu.py b/main_shiyoujiao_lvyong_zhoudu.py index 66b75f1..f500bf4 100644 --- a/main_shiyoujiao_lvyong_zhoudu.py +++ b/main_shiyoujiao_lvyong_zhoudu.py @@ -54,6 +54,7 @@ global_config.update({ 'login_pushreport_url': login_pushreport_url, 'login_data': login_data, 'upload_url': upload_url, + 'upload_data': upload_data, 'upload_warning_url': upload_warning_url, 'warning_data': warning_data, @@ -373,7 +374,7 @@ def predict_main(): test_size=global_config['test_size'], settings=global_config['settings'], now=now, - etadata=global_config['etadata'], + etadata=etadata, modelsindex=global_config['modelsindex'], data=data, is_eta=global_config['is_eta'], diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index ce91f3f..45ed060 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -2851,10 +2851,10 @@ def shiyoujiao_lvyong_export_pdf(num_indicators=475, num_models=21, num_dayindic if config.is_update_report: with open(os.path.join(config.dataset, reportname), 'rb') as f: base64_data = base64.b64encode(f.read()).decode('utf-8') - upload_data["data"]["fileBase64"] = base64_data - upload_data["data"]["fileName"] = reportname + config.upload_data["data"]["fileBase64"] = base64_data + config.upload_data["data"]["fileName"] = reportname token = get_head_auth_report() - upload_report_data(token, upload_data) + upload_report_data(token, config.upload_data) except TimeoutError as e: print(f"请求超时: {e}")