From 98836d7a4d71116ae89b3ac8fb72aa6c4e8c1547 Mon Sep 17 00:00:00 2001 From: workpc Date: Fri, 11 Jul 2025 14:46:23 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=98=E5=8E=9F=E5=8E=9F=E6=B2=B9=E5=91=A8?= =?UTF-8?q?=E5=BA=A6=E6=89=A7=E8=A1=8C=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_juxiting.py | 67 +++++++++++---------- main_yuanyou_zhoudu.py | 130 +++++++++++++++++------------------------ 2 files changed, 87 insertions(+), 110 deletions(-) diff --git a/main_juxiting.py b/main_juxiting.py index 03887ef..4469ff6 100644 --- a/main_juxiting.py +++ b/main_juxiting.py @@ -436,42 +436,41 @@ def predict_main(): sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) - # try: - # if is_weekday: - if True: - logger.info('发送特征预警') - # 获取取消订阅的指标ID - quxiaodingyueidlist = get_waring_data() - # 上传预警信息到数据库 - warning_data_df = df_zhibiaoliebiao.copy() - warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ - '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] - # 重命名列名 - warning_data_df = warning_data_df.rename(columns={'指标名称': 'indicatorName', '指标id': 'indicatorId', '频度': 'frequency', - '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'updateSuspensionCycle'}) + try: + if True: + logger.info('发送特征预警') + # 获取取消订阅的指标ID + quxiaodingyueidlist = get_waring_data() + # 上传预警信息到数据库 + warning_data_df = df_zhibiaoliebiao.copy() + warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ + '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] + # 重命名列名 + warning_data_df = warning_data_df.rename(columns={'指标名称': 'indicatorName', '指标id': 'indicatorId', '频度': 'frequency', + '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'updateSuspensionCycle'}) - warning_data_df['warningDate'] = datetime.date.today().strftime( - "%Y-%m-%d %H:%M:%S") - warning_data_df['dataSource'] = 9 + warning_data_df['warningDate'] = datetime.date.today().strftime( + "%Y-%m-%d %H:%M:%S") + warning_data_df['dataSource'] = 9 - if len(quxiaodingyueidlist) > 0: - # 去掉取消订阅的指标 - print(warning_data_df.shape) - warning_data_df = warning_data_df[~warning_data_df['indicatorId'].isin( - quxiaodingyueidlist)] - print(warning_data_df.shape) - warning_data = warning_data_df.to_json( - orient='records', force_ascii=False) - warning_data = warning_data.replace('日度', '1') - warning_data = warning_data.replace('周度', '2') - warning_data = warning_data.replace('月度', '3') - warning_data = json.loads(warning_data) - push_waring_market_data( - warning_data, dataSource=warning_data_df['dataSource'].values[0]) - # if is_update_warning_data: - # upload_warning_info(len(warning_data_df)) - # except: - # logger.info('上传预警信息到数据库失败') + if len(quxiaodingyueidlist) > 0: + # 去掉取消订阅的指标 + print(warning_data_df.shape) + warning_data_df = warning_data_df[~warning_data_df['indicatorId'].isin( + quxiaodingyueidlist)] + print(warning_data_df.shape) + warning_data = warning_data_df.to_json( + orient='records', force_ascii=False) + warning_data = warning_data.replace('日度', '1') + warning_data = warning_data.replace('周度', '2') + warning_data = warning_data.replace('月度', '3') + warning_data = json.loads(warning_data) + push_waring_market_data( + warning_data, dataSource=warning_data_df['dataSource'].values[0]) + # if is_update_warning_data: + # upload_warning_info(len(warning_data_df)) + except: + logger.info('上传预警信息到数据库失败') if is_corr: df = corr_feature(df=df) diff --git a/main_yuanyou_zhoudu.py b/main_yuanyou_zhoudu.py index e3c94c3..ebb4d29 100644 --- a/main_yuanyou_zhoudu.py +++ b/main_yuanyou_zhoudu.py @@ -1,9 +1,9 @@ # 读取配置 from lib.dataread import * -from config_juxiting_zhoudu import * -from lib.tools import SendMail, exception_logger, convert_df_to_pydantic, exception_logger, get_modelsname -from models.nerulforcastmodels import ex_Model_Juxiting, model_losss_juxiting, pp_export_pdf +from config_jingbo_zhoudu import * +from lib.tools import SendMail, convert_df_to_pydantic, exception_logger, get_modelsname +from models.nerulforcastmodels import ex_Model, model_losss, brent_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -13,10 +13,6 @@ global_config.update({ 'logger': logger, 'dataset': dataset, 'y': y, - # 'offsite_col': offsite_col, - # 'avg_cols': avg_cols, - # 'offsite': offsite, - 'edbcodenamedict': edbcodenamedict, 'is_debug': is_debug, 'is_train': is_train, 'is_fivemodels': is_fivemodels, @@ -67,14 +63,6 @@ global_config.update({ 'push_data_value_list_url': push_data_value_list_url, 'push_data_value_list_data': push_data_value_list_data, - # 上传预警数据 - 'push_waring_data_value_list_url': push_waring_data_value_list_url, - 'push_waring_data_value_list_data': push_waring_data_value_list_data, - - # 获取取消订阅的数据 - 'get_waring_data_value_list_url': get_waring_data_value_list_url, - 'get_waring_data_value_list_data': get_waring_data_value_list_data, - # eta 配置 'APPID': APPID, 'SECRET': SECRET, @@ -91,7 +79,6 @@ global_config.update({ # 数据库配置 'sqlitedb': sqlitedb, - 'bdwd_items': bdwd_items, 'is_bdwd': is_bdwd, 'db_mysql': db_mysql, 'DEFAULT_CONFIG': DEFAULT_CONFIG, @@ -99,7 +86,7 @@ global_config.update({ def push_market_value(): - config.logger.info('发送预测结果到市场信息平台') + logger.info('发送预测结果到市场信息平台') # 读取预测数据和模型评估数据 predict_file_path = os.path.join(config.dataset, 'predict.csv') model_eval_file_path = os.path.join(config.dataset, 'model_evaluation.csv') @@ -107,7 +94,7 @@ def push_market_value(): predictdata_df = pd.read_csv(predict_file_path) top_models_df = pd.read_csv(model_eval_file_path) except FileNotFoundError as e: - config.logger.error(f"文件未找到: {e}") + logger.error(f"文件未找到: {e}") return predictdata = predictdata_df.copy() @@ -151,7 +138,7 @@ def push_market_value(): try: push_market_data(predictdata) except Exception as e: - config.logger.error(f"推送数据失败: {e}") + logger.error(f"推送数据失败: {e}") def sql_inset_predict(global_config): @@ -167,7 +154,7 @@ def sql_inset_predict(global_config): model_name_list, model_id_name_dict = get_modelsname(df, global_config) PRICE_COLUMNS = [ - 'day_price', 'week_price', 'second_week_price', 'next_week_price', + 'day_price', 'week_price', 'second_week_price', 'next_week_price', 'next_month_price', 'next_february_price', 'next_march_price', 'next_april_price' ] @@ -217,8 +204,8 @@ def sql_inset_predict(global_config): params = ( result.feature_factor_frequency, result.strategy_id, - global_config['DEFAULT_CONFIG']['oil_code'], - global_config['DEFAULT_CONFIG']['oil_name'], + result.oil_code, + result.oil_name, next_day_df['created_dt'].values[0], result.market_price, *price_values, @@ -279,6 +266,7 @@ def predict_main(): None """ end_time = global_config['end_time'] + signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, classifylisturl=global_config['classifylisturl'], @@ -294,7 +282,7 @@ def predict_main(): if is_eta: logger.info('从eta获取数据...') - df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( data_set=data_set, dataset=dataset) # 原始数据,未处理 if is_market: @@ -302,7 +290,7 @@ def predict_main(): try: # 如果是测试环境,最高价最低价取excel文档 if server_host == '192.168.100.53': - logger.info('从excel文档获取市场信息平台指标') + logger.info('从excel文档获取最高价最低价') df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) else: logger.info('从市场信息平台获取数据') @@ -310,7 +298,7 @@ def predict_main(): end_time, df_zhibiaoshuju) except: - logger.info('市场信息平台数据项-eta数据项 拼接失败') + logger.info('最高最低价拼接失败') # 保存到xlsx文件的sheet表 with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: @@ -318,14 +306,14 @@ def predict_main(): df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) # 数据处理 - df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - end_time=end_time) + df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) else: # 读取数据 logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 # 更改预测列名称 df.rename(columns={y: 'y'}, inplace=True) @@ -395,8 +383,8 @@ def predict_main(): # except Exception as e: # logger.info(f'更新accuracy表的y值失败:{e}') - # 判断当前日期是不是周一 - is_weekday = datetime.datetime.now().weekday() == 0 + # 判断当前日期是不是周一 预测目标周度许转换,暂注释 + # is_weekday = datetime.datetime.strptime(global_config['end_time'], "%Y-%m-%d").weekday() == 0 # if is_weekday: # logger.info('今天是周一,更新预测模型') # # 计算最近60天预测残差最低的模型名称 @@ -442,41 +430,45 @@ def predict_main(): row, col = df.shape now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - ex_Model_Juxiting(df, - horizon=global_config['horizon'], - input_size=global_config['input_size'], - train_steps=global_config['train_steps'], - val_check_steps=global_config['val_check_steps'], - early_stop_patience_steps=global_config['early_stop_patience_steps'], - is_debug=global_config['is_debug'], - dataset=global_config['dataset'], - is_train=global_config['is_train'], - is_fivemodels=global_config['is_fivemodels'], - val_size=global_config['val_size'], - test_size=global_config['test_size'], - settings=global_config['settings'], - now=now, - etadata=etadata, - modelsindex=global_config['modelsindex'], - data=data, - is_eta=global_config['is_eta'], - end_time=global_config['end_time'], - ) + ex_Model(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], + now=now, + etadata=etadata, + modelsindex=global_config['modelsindex'], + data=data, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], + ) logger.info('模型训练完成') logger.info('训练数据绘图ing') - model_results3 = model_losss_juxiting( - sqlitedb, end_time=global_config['end_time'], is_fivemodels=global_config['is_fivemodels']) + model_results3 = model_losss(sqlitedb, end_time=end_time) logger.info('训练数据绘图end') # # 模型报告 logger.info('制作报告ing') title = f'{settings}--{end_time}-预测报告' # 报告标题 - reportname = f'聚烯烃PP大模型周度预测--{end_time}.pdf' # 报告文件名 + reportname = f'Brent原油大模型周度预测--{end_time}.pdf' # 报告文件名 reportname = reportname.replace(':', '-') # 替换冒号 - pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, - reportname=reportname, sqlitedb=sqlitedb), + brent_export_pdf(dataset=dataset, + num_models=5 if is_fivemodels else 22, + time=end_time, + reportname=reportname, + inputsize=global_config['horizon'], + sqlitedb=sqlitedb + ), logger.info('制作报告end') logger.info('模型训练完成') @@ -484,15 +476,6 @@ def predict_main(): push_market_value() sql_inset_predict(global_config) - # # LSTM 单变量模型 - # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) - - # # lstm 多变量模型 - # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) - - # # GRU 模型 - # # ex_GRU(df) - # 发送邮件 # m = SendMail( # username=username, @@ -509,15 +492,10 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2025-4-14', '2025-4-15', freq='B'): - # try: - # global_config['end_time'] = i_time.strftime('%Y-%m-%d') - # predict_main() - # except Exception as e: - # logger.info(f'预测失败:{e}') - # continue + for i_time in pd.date_range('2025-6-23', '2025-6-30', freq='B'): + global_config['end_time'] = i_time.strftime('%Y-%m-%d') + global_config['db_mysql'].connect() + predict_main() - predict_main() - - # push_market_value() - # sql_inset_predict(global_config) + # predict_main() + # sql_inset_predict(global_config=global_config)