From 29a65d7c704d3d3f14e8e500219ad4817410338c Mon Sep 17 00:00:00 2001 From: jingboyitiji Date: Fri, 7 Mar 2025 13:28:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8E=9F=E6=B2=B9=E6=9C=88=E5=BA=A6=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E6=95=B0=E6=8D=AE=E8=B0=83=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_jingbo_yuedu.py | 6 +++--- main_yuanyou_yuedu.py | 35 ++++++++++++++++++++--------------- models/nerulforcastmodels.py | 6 ++++-- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/config_jingbo_yuedu.py b/config_jingbo_yuedu.py index 6693283..1cdcfbe 100644 --- a/config_jingbo_yuedu.py +++ b/config_jingbo_yuedu.py @@ -158,7 +158,7 @@ table_name = 'v_tbl_crude_oil_warning' # 开关 is_train = False # 是否训练 -is_debug = True # 是否调试 +is_debug = False # 是否调试 is_eta = False # 是否使用eta接口 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 @@ -180,8 +180,8 @@ print("数据库连接成功", host, dbname, dbusername) # 数据截取日期 -start_year = 2005 # 数据开始年份 -end_time = '' # 数据截取日期 +start_year = 2000 # 数据开始年份 +end_time = '2023-3-1' # 数据截取日期 freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 "WW" 自定义周 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 diff --git a/main_yuanyou_yuedu.py b/main_yuanyou_yuedu.py index 1dee36d..d9b055b 100644 --- a/main_yuanyou_yuedu.py +++ b/main_yuanyou_yuedu.py @@ -102,7 +102,7 @@ def predict_main(): 返回: None """ - global end_time + end_time = global_config['end_time'] signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, classifylisturl=classifylisturl, @@ -332,22 +332,22 @@ def predict_main(): end_time=global_config['end_time'], ) - # logger.info('模型训练完成') + logger.info('模型训练完成') logger.info('训练数据绘图ing') model_results3 = model_losss(sqlitedb, end_time=end_time) logger.info('训练数据绘图end') # # 模型报告 - logger.info('制作报告ing') - title = f'{settings}--{end_time}-预测报告' # 报告标题 - reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名 - reportname = reportname.replace(':', '-') # 替换冒号 - brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, - reportname=reportname, sqlitedb=sqlitedb), + # logger.info('制作报告ing') + # title = f'{settings}--{end_time}-预测报告' # 报告标题 + # reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名 + # reportname = reportname.replace(':', '-') # 替换冒号 + # brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + # reportname=reportname, sqlitedb=sqlitedb), - logger.info('制作报告end') - logger.info('模型训练完成') + # logger.info('制作报告end') + # logger.info('模型训练完成') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) @@ -373,9 +373,14 @@ def predict_main(): if __name__ == '__main__': # global end_time - # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'): - # end_time = i_time.strftime('%Y-%m-%d') - # predict_main() + # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 + for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'): + try: + global_config['end_time'] = i_time.strftime('%Y-%m-%d') + predict_main() + except Exception as e: + logger.info(f'预测失败:{e}') + continue - predict_main() + + # predict_main() diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 7ce32d0..5824b6e 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -935,10 +935,10 @@ def model_losss(sqlitedb, end_time): try: df_combined = sqlitedb.select_data( 'accuracy', where_condition=f"created_dt <= '{end_time}'") - if len(df_combined) < 10: + if len(df_combined) < 100: len(df_combined) + '' except: - df_combined = loadcsv(os.path.join(dataset, "cross_validation.csv")) + df_combined = loadcsv(os.path.join(config.dataset, "cross_validation.csv")) df_combined = dateConvert(df_combined) df_combined['CREAT_DATE'] = df_combined['cutoff'] df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要 @@ -968,6 +968,8 @@ def model_losss(sqlitedb, end_time): modelnames.remove('y') if 'cutoff' in modelnames: modelnames.remove('cutoff') + if 'ds' in modelnames: + modelnames.remove('ds') df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要 # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE