diff --git a/config_juxiting.py b/config_juxiting.py index f84c822..9c68abd 100644 --- a/config_juxiting.py +++ b/config_juxiting.py @@ -431,15 +431,15 @@ DEFAULT_CONFIG = { # 开关 is_train = True # 是否训练 is_debug = False # 是否调试 -is_eta = False # 是否使用eta接口 +is_eta = True # 是否使用eta接口 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 is_edbnamelist = False # 自定义特征,对应上面的edbnamelist is_update_eta = False # 预测结果上传到eta -is_update_report = True # 是否上传报告 -is_update_warning_data = True # 是否上传预警数据 +is_update_report = False # 是否上传报告 +is_update_warning_data = False # 是否上传预警数据 is_update_predict_value = True # 是否上传预测值到市场信息平台 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 is_del_tow_month = True # 是否删除两个月不更新的特征 diff --git a/config_juxiting_yuedu.py b/config_juxiting_yuedu.py index 9900153..1be7581 100644 --- a/config_juxiting_yuedu.py +++ b/config_juxiting_yuedu.py @@ -375,7 +375,8 @@ query_data_list_item_nos_data = { "data": { "dateStart": "20200101", "dateEnd": "20241231", - "dataItemNoList": ["Brentzdj", "Brentzgj"] # 数据项编码,代表 brent最低价和最高价 + # 数据项编码,代表 PP期货 价格 + "dataItemNoList": ["MAIN_CONFT_SETTLE_PRICE"] } } @@ -442,8 +443,8 @@ DEFAULT_CONFIG = { # 开关 is_train = True # 是否训练 is_debug = False # 是否调试 -is_eta = False # 是否使用eta接口 -is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 +is_eta = True # 是否使用eta接口 +is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 @@ -466,7 +467,7 @@ print("数据库连接成功", host, dbname, dbusername) # 数据截取日期 start_year = 2000 # 数据开始年份 -end_time = '2025-07-22' # 数据截取日期 +end_time = '' # 数据截取日期 freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 diff --git a/config_juxiting_zhoudu.py b/config_juxiting_zhoudu.py index b7d58cb..ccac039 100644 --- a/config_juxiting_zhoudu.py +++ b/config_juxiting_zhoudu.py @@ -86,6 +86,11 @@ bdwdname = [ '次周', '隔周', ] +# 数据库预测结果表八大维度列名 +price_columns = [ + 'day_price', 'week_price', 'second_week_price', 'next_week_price', + 'next_month_price', 'next_february_price', 'next_march_price', 'next_april_price' +] modelsindex = [{ "NHITS": "SELF0000231", "Informer": "SELF0000232", diff --git a/lib/dataread.py b/lib/dataread.py index 6253b03..a76247a 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -1101,7 +1101,7 @@ def datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_t df = calculate_kdj(df) calculate_correlation(df=df) - featureAnalysis(df, dataset=dataset, y=y) + featureAnalysis(df, dataset=dataset, y='y') return df diff --git a/lib/tools.py b/lib/tools.py index 24f5e64..0188f49 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -782,6 +782,7 @@ def get_model_id_name_dict(global_config): sql = f'select model_name,id from {tb} ' modelsname = global_config['db_mysql'].execute_query(sql) model_id_name_dict = {row['id']: row['model_name'] for row in modelsname} + global_config['logger'].info(f'模型id-name: {model_id_name_dict}') return model_id_name_dict diff --git a/main_juxiting_yuedu.py b/main_juxiting_yuedu.py index c4c2edf..bcad998 100644 --- a/main_juxiting_yuedu.py +++ b/main_juxiting_yuedu.py @@ -488,7 +488,7 @@ def predict_main(): sql_inset_predict(global_config) - 模型报告 + # 模型报告 logger.info('制作报告ing') title = f'{settings}--{end_time}-预测报告' # 报告标题 reportname = f'聚烯烃PP大模型月度预测--{end_time}.pdf' # 报告文件名 @@ -497,7 +497,6 @@ def predict_main(): reportname=reportname, sqlitedb=sqlitedb), logger.info('制作报告end') - logger.info('模型训练完成') # 图片报告 logger.info('图片报告ing') @@ -529,12 +528,13 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'): - # try: - # global_config['end_time'] = i_time.strftime('%Y-%m-%d') - # predict_main() - # except Exception as e: - # logger.info(f'预测失败:{e}') - # continue + for i_time in pd.date_range('2025-7-28', '2025-7-29', freq='B'): + try: + global_config['end_time'] = i_time.strftime('%Y-%m-%d') + global_config['db_mysql'].connect() + predict_main() + except Exception as e: + logger.info(f'预测失败:{e}') + continue - predict_main() + # predict_main() diff --git a/main_juxiting_zhoudu.py b/main_juxiting_zhoudu.py index 95930c3..16f5c17 100644 --- a/main_juxiting_zhoudu.py +++ b/main_juxiting_zhoudu.py @@ -3,7 +3,7 @@ from lib.dataread import * from config_juxiting_zhoudu import * from lib.tools import SendMail, exception_logger, convert_df_to_pydantic_pp, exception_logger, get_modelsname -from models.nerulforcastmodels import ex_Model_Juxiting, model_losss_juxiting, pp_export_pdf +from models.nerulforcastmodels import ex_Model_Juxiting, model_losss_juxiting, pp_bdwd_png, pp_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -23,6 +23,8 @@ global_config.update({ 'is_update_report': is_update_report, 'settings': settings, 'bdwdname': bdwdname, + 'columnsrename': columnsrename, + 'price_columns': price_columns, # 模型参数 @@ -470,6 +472,9 @@ def predict_main(): # sqlitedb, end_time=global_config['end_time'], is_fivemodels=global_config['is_fivemodels']) # logger.info('训练数据绘图end') + push_market_value() + sql_inset_predict(global_config) + # # # 模型报告 # logger.info('制作报告ing') # title = f'{settings}--{end_time}-预测报告' # 报告标题 @@ -480,8 +485,10 @@ def predict_main(): # logger.info('制作报告end') - push_market_value() - sql_inset_predict(global_config) + # 图片报告 + logger.info('图片报告ing') + pp_bdwd_png(global_config=global_config) + logger.info('图片报告end') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) @@ -508,16 +515,16 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2025-7-18', '2025-7-23', freq='B'): - try: - global_config['end_time'] = i_time.strftime('%Y-%m-%d') - global_config['db_mysql'].connect() - predict_main() - except Exception as e: - logger.info(f'预测失败:{e}') - continue + # for i_time in pd.date_range('2025-7-18', '2025-7-23', freq='B'): + # try: + # global_config['end_time'] = i_time.strftime('%Y-%m-%d') + # global_config['db_mysql'].connect() + # predict_main() + # except Exception as e: + # logger.info(f'预测失败:{e}') + # continue - # predict_main() + predict_main() # push_market_value() # sql_inset_predict(global_config) diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index b803cb4..f613641 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -2048,7 +2048,11 @@ def model_losss_juxiting(sqlitedb, end_time, is_fivemodels): if row[r] <= config.rote: columns.append(r.split('-')[0]) return pd.Series([columns], index=['columns']) - names_df['columns'] = names_df.apply(add_rote_column, axis=1) + try: + names_df['columns'] = names_df.apply(add_rote_column, axis=1) + except ValueError as e: + print(e) + def add_upper_lower_bound(row): print(row['columns']) @@ -3329,21 +3333,25 @@ def pp_export_pdf(num_indicators=475, num_models=21, num_dayindicator=202, input df3['ds'] = df4['ds'] for col in fivemodels_list: df3[col] = round(abs(df4[col] - df4['y']) / df4['y'] * 100, 2) - # 找出决定系数前五的偏差率 - df3 = df3[['ds']+fivemodels_list.tolist()][-inputsize:] - # 找出上一预测区间的时间 - stime = df3['ds'].iloc[0] - etime = df3['ds'].iloc[-1] - # 添加偏差率表格 - fivemodels = '、'.join(eval_df['模型(Model)'].values[:5]) # 字符串形式,后面写入字符串使用 - content.append(Graphs.draw_text( - f'预测使用了{num_models}个模型进行训练,使用评估结果MAE前五的模型分别是 {fivemodels} ,模型上一预测区间 {stime} -- {etime}的偏差率(%)分别是:')) - # # 添加偏差率表格 - df3 = df3.T - df3 = df3.reset_index() - data = df3.values.tolist() - col_width = 500/len(df3.columns) - content.append(Graphs.draw_table(col_width, *data)) + + try: + # 找出决定系数前五的偏差率 + df3 = df3[['ds']+fivemodels_list.tolist()][-inputsize:] + # 找出上一预测区间的时间 + stime = df3['ds'].iloc[0] + etime = df3['ds'].iloc[-1] + # 添加偏差率表格 + fivemodels = '、'.join(eval_df['模型(Model)'].values[:5]) # 字符串形式,后面写入字符串使用 + content.append(Graphs.draw_text( + f'预测使用了{num_models}个模型进行训练,使用评估结果MAE前五的模型分别是 {fivemodels} ,模型上一预测区间 {stime} -- {etime}的偏差率(%)分别是:')) + # # 添加偏差率表格 + df3 = df3.T + df3 = df3.reset_index() + data = df3.values.tolist() + col_width = 500/len(df3.columns) + content.append(Graphs.draw_table(col_width, *data)) + except: + print('偏差率计算错误,跳过') content.append(Graphs.draw_little_title('三、预测过程解析:')) # 特征、模型、参数配置 diff --git a/up_week_dates.csv b/up_week_dates.csv new file mode 100644 index 0000000..919e1c3 --- /dev/null +++ b/up_week_dates.csv @@ -0,0 +1,17 @@ +ds,NHITS,Informer,LSTM,iTransformer,TSMixer,TSMixerx,PatchTST,RNN,GRU,TCN,BiTCN,DilatedRNN,MLP,DLinear,NLinear,TFT,StemGNN,MLPMultivariate,TiDE,DeepNPTS,y,min_within_quantile,max_within_quantile,min_price,max_price,id,CREAT_DATE,序号,LOW_PRICE,HIGH_PRICE,ACCURACY +2024-12-09,71.63069,72.14122,70.19943,69.644196,71.80898,70.52471,71.53064,67.38501,70.75815,73.546684,71.92033,70.327484,71.95594,71.18323,71.882935,71.90951,73.94965,72.28691,71.25138,73.34018,72.13999938964844,67.38501,73.94965,67.38501,73.94965,51,2024-12-06,5.0,70.92,72.65,1.0 +2024-12-10,71.14343,71.9462,70.405106,69.48242,71.70601,70.66241,71.57484,67.23587,70.46323,73.37324,71.720894,70.45846,71.88132,72.34705,71.75997,72.41326,73.74943,72.31887,71.47953,72.78831,72.19000244140625,67.23587,73.74943,67.23587,73.74943,52,2024-12-06,4.0,70.73,71.77,1.0 +2024-12-11,71.71588,72.31544,70.175125,69.58213,71.59609,70.91783,71.54794,67.433334,70.518196,73.76477,71.84062,70.746284,72.27111,71.85789,70.77939,72.912704,73.91716,72.42111,71.47695,72.61624,73.5199966430664,67.433334,73.91716,67.433334,73.91716,53,2024-12-06,3.0,72.15,73.75,1.0 +2024-12-12,72.46348,71.87648,70.26041,69.922165,71.65103,70.689384,71.72716,67.54506,70.99872,73.52567,71.78495,70.777115,72.34328,72.756325,70.9607,73.391495,73.944244,72.465836,71.445244,71.69109,73.88999938964844,67.54506,73.944244,67.54506,73.944244,54,2024-12-06,2.0,72.42,74.0,0.9647113924050618 +2024-12-13,72.85073,72.36679,70.489136,69.759766,71.78641,70.69935,71.60861,67.47295,70.81146,73.85618,71.966835,70.923485,72.63866,72.29209,71.11011,73.777534,73.82516,72.58803,71.807915,72.6083,,67.47295,73.85618,67.47295,73.85618,55,2024-12-06,1.0,73.3,74.59,0.43114728682170156 +2024-12-10,71.97169,73.11586,71.41545,70.59598,71.81998,71.541016,72.397606,68.78458,71.30409,73.65467,71.78349,71.03456,71.8188,71.85696,72.014534,72.33824,73.77812,72.2061,71.36985,72.46154,72.19000244140625,68.78458,73.77812,68.78458,73.77812,56,2024-12-09,4.0,70.73,71.77,1.0 +2024-12-11,72.35509,73.087166,71.63927,70.430176,71.82658,71.76624,72.57018,68.59579,71.03884,73.45442,71.75791,71.15297,72.066956,72.28318,72.36303,72.75233,73.594864,72.24322,71.16944,72.39418,73.5199966430664,68.59579,73.594864,68.59579,73.594864,57,2024-12-09,3.0,72.15,73.75,0.9030400000000004 +2024-12-12,73.268654,73.31714,71.39498,70.65145,71.76794,71.6176,73.038315,68.85988,71.12751,73.84763,71.82094,71.41621,72.99199,72.42496,71.474464,73.17588,73.74838,72.31574,71.692444,72.48112,73.88999938964844,68.85988,73.84763,68.85988,73.84763,58,2024-12-09,2.0,72.42,74.0,0.9035632911392374 +2024-12-13,73.11824,73.70929,71.442184,70.72518,71.878975,71.82435,73.13541,68.760376,71.51967,73.62109,71.86448,71.44768,72.952484,72.107254,72.57947,73.61362,73.77317,72.3857,71.873566,72.40313,,68.760376,73.77317,68.760376,73.77317,59,2024-12-09,1.0,73.3,74.59,0.36679844961239827 +2024-12-11,72.372795,72.6546,71.37516,71.41169,71.977135,72.13211,72.774284,69.25488,71.62906,73.70365,72.27147,71.489716,72.77891,71.61932,72.090294,72.166916,73.72794,72.78662,71.960884,72.473854,73.5199966430664,69.25488,73.72794,69.25488,73.72794,61,2024-12-10,3.0,72.15,73.75,0.9862125000000024 +2024-12-12,72.91286,72.80857,71.612946,71.193184,72.05397,72.48625,72.85768,69.06833,71.35641,73.5444,72.19988,71.585815,73.47307,72.22734,72.219765,72.48997,73.544106,72.96339,71.66363,72.122116,73.88999938964844,69.06833,73.5444,69.06833,73.5444,62,2024-12-10,2.0,72.42,74.0,0.7116455696202503 +2024-12-13,73.01643,72.78796,71.3701,71.516174,72.21711,72.4036,73.66599,69.338615,71.467476,73.9247,72.2582,71.84913,73.643265,71.70761,72.41268,72.85244,73.69811,72.91225,71.75731,71.76489,,69.338615,73.9247,69.338615,73.9247,63,2024-12-10,1.0,73.3,74.59,0.4842635658914738 +2024-12-12,73.82699,73.58274,72.81162,73.18572,72.8827,72.958374,74.36039,71.09452,72.673904,74.025635,73.17438,72.72577,73.668365,72.48243,73.224655,74.286606,73.70824,73.668076,72.934685,73.16352,73.88999938964844,71.09452,74.36039,71.09452,74.36039,66,2024-12-11,2.0,72.42,74.0,1.0 +2024-12-13,74.025696,73.275696,73.0588,72.85848,72.719444,73.238945,74.05083,71.060295,72.455666,73.90173,72.966385,72.81345,73.77382,72.488365,73.78919,74.5482,73.52537,73.58226,72.8307,73.26384,,71.060295,74.5482,71.060295,74.5482,67,2024-12-11,1.0,73.3,74.59,0.9675968992247993 +2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,71,2024-12-12,1.0,73.3,74.59,0.877170542635658 +2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,76,2024-12-13,1.0,73.3,74.59,0.877170542635658