From 9de1b4a8570d95871bd4bc26ca0b87797901325f Mon Sep 17 00:00:00 2001 From: liurui Date: Wed, 18 Dec 2024 17:49:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BC=82=E5=B8=B8=E8=A3=85?= =?UTF-8?q?=E9=A5=B0=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_jingbo.py | 2 +- lib/tools.py | 13 ++++++ main_yuanyou.py | 18 ++++----- models/nerulforcastmodels.py | 77 +++++++++++++++++++++++++----------- 4 files changed, 78 insertions(+), 32 deletions(-) diff --git a/config_jingbo.py b/config_jingbo.py index 0fa4074..c47bfac 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -223,7 +223,7 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 -is_train = False # 是否训练 +is_train = True # 是否训练 is_debug = False # 是否调试 is_eta = False # 是否使用eta接口 is_timefurture = True # 是否使用时间特征 diff --git a/lib/tools.py b/lib/tools.py index 6a79bf8..5584f84 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -512,5 +512,18 @@ class MySQLDB: self.connection.close() logging.info("Database connection closed.") +def exception_logger(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + # 记录异常日志 + logging.error(f"An error occurred in function {func.__name__}: {str(e)}") + # 可以选择重新抛出异常,或者在这里处理异常 + raise e # 重新抛出异常 + return wrapper + + + if __name__ == '__main__': print('This is a tool, not a script.') \ No newline at end of file diff --git a/main_yuanyou.py b/main_yuanyou.py index 8e3788a..165ff37 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -1,6 +1,6 @@ # 读取配置 from lib.dataread import * -from lib.tools import SendMail +from lib.tools import SendMail,exception_logger from models.nerulforcastmodels import ex_Model,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting import glob @@ -206,14 +206,14 @@ def predict_main(end_time): logger.info('训练数据绘图end') # 模型报告 - logger.info('制作报告ing') - title = f'{settings}--{now}-预测报告' # 报告标题 + # logger.info('制作报告ing') + # title = f'{settings}--{now}-预测报告' # 报告标题 - brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time, - reportname=reportname,sqlitedb=sqlitedb), + # brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time, + # reportname=reportname,sqlitedb=sqlitedb), - logger.info('制作报告end') - logger.info('模型训练完成') + # logger.info('制作报告end') + # logger.info('模型训练完成') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) @@ -234,12 +234,12 @@ def predict_main(end_time): file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), ssl=ssl, ) - m.send_mail() + # m.send_mail() if __name__ == '__main__': # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2024-11-22', '2024-12-16', freq='B'): + for i_time in pd.date_range('2024-12-02', '2024-12-16', freq='B'): end_time = i_time.strftime('%Y-%m-%d') # print(e_time) predict_main(end_time) \ No newline at end of file diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 6fa1757..44e4c2a 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -6,7 +6,7 @@ import seaborn as sns import matplotlib.pyplot as plt import matplotlib.dates as mdates import datetime -from lib.tools import Graphs,mse,rmse,mae +from lib.tools import Graphs,mse,rmse,mae,exception_logger from lib.dataread import * from neuralforecast import NeuralForecast from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer @@ -36,7 +36,7 @@ from reportlab.lib.units import cm # 单位:cm pdfmetrics.registerFont(TTFont('SimSun', 'SimSun.ttf')) - +@exception_logger def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patience_steps, is_debug,dataset,is_train,is_fivemodels,val_size,test_size,settings,now, etadata,modelsindex,data,is_eta): @@ -222,6 +222,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien # 原油计算预测评估指数 +@exception_logger def model_losss(sqlitedb,end_time): global dataset global rote @@ -411,7 +412,15 @@ def model_losss(sqlitedb,end_time): df_predict2['id'] = range(1, 1 + len(df_predict2)) # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time df_predict2['CREAT_DATE'] = end_time - df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + def get_common_columns(df1, df2): + # 获取两个DataFrame的公共列名 + return list(set(df1.columns).intersection(df2.columns)) + + common_columns = get_common_columns(df_predict2, existing_data) + try: + df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + except: + df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) # 更新accuracy表中的y值 update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null') @@ -425,15 +434,21 @@ def model_losss(sqlitedb,end_time): logger.error(f'更新accuracy表中的y值失败,row={row}') # 上周准确率计算 predict_y = sqlitedb.select_data(table_name = "accuracy") - ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist() - # 模型评估前五最大最小 + # ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist() + ids = predict_y['id'].tolist() + # 准确率基准与绘图上下界逻辑一致 # predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']] - # 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1) + # 模型评估前五均值 # predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1 # predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1 - # 模型评估前十均值 - predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1) - predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1) + # 模型评估前十均值 + # predict_y['min_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) -1 + # predict_y['max_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) +1 + # 模型评估前十最大最小 + # allmodelnames 和 predict_y 列 重复的 + allmodelnames = [col for col in allmodelnames if col in predict_y.columns] + predict_y['min_price'] = predict_y[allmodelnames[0:10]].min(axis=1) + predict_y['max_price'] = predict_y[allmodelnames[0:10]].max(axis=1) for id in ids: row = predict_y[predict_y['id'] == id] try: @@ -448,25 +463,37 @@ def model_losss(sqlitedb,end_time): df = pd.merge(predict_y,df2,on=['ds'],how='left') df['ds'] = pd.to_datetime(df['ds']) df = df.reindex() + + # 判断预测值在不在布伦特最高最低价范围内,准确率为1,否则为0 + def is_within_range(row): + for model in allmodelnames: + if row['LOW_PRICE'] <= row[col] <= row['HIGH_PRICE']: + return 1 + else: + return 0 + + # 比较真实最高最低,和预测最高最低 计算准确率 def calculate_accuracy(row): - # 子集情况: - if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \ - (row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']): + # 全子集情况: + if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \ + (row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']): return 1 + # 无交集情况: + if row['max_price'] < row['LOW_PRICE'] or \ + row['min_price'] > row['HIGH_PRICE']: + return 0 # 有交集情况: - if row['HIGH_PRICE'] > row['min_price'] or \ - row['max_price'] > row['LOW_PRICE']: + else: sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']]) middle_diff = sorted_prices[2] - sorted_prices[1] price_range = row['HIGH_PRICE'] - row['LOW_PRICE'] accuracy = middle_diff / price_range return accuracy - # 无交集情况: - else: - return 0 + columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price'] df[columns] = df[columns].astype(float) df['ACCURACY'] = df.apply(calculate_accuracy, axis=1) + # df['ACCURACY'] = df.apply(is_within_range, axis=1) # 取结束日期上一周的日期 endtime = end_time endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d') @@ -477,12 +504,15 @@ def model_losss(sqlitedb,end_time): df3 = df.copy() df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)] df3 = df3[df3['ds'].isin(up_week_dates)] + # df3.to_csv(os.path.join(dataset,f'accuracy_{endtime}.csv'),index=False) total = len(df3) accuracy_rote = 0 + # 设置权重字典 + weight_dict = [0.4,0.15,0.1,0.1,0.25] for i,group in df3.groupby('ds'): - print('权重:',round(len(group)/total,2)) - print('准确率:',group['ACCURACY'].sum()/(len(group)/total)) - accuracy_rote += group['ACCURACY'].sum()/(len(group)/total) + print('权重:',weight_dict[len(group)-1]) + print('准确率:',(group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]) + accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1] df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率']) df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote} print(df4) @@ -598,6 +628,7 @@ def model_losss(sqlitedb,end_time): # 聚烯烃计算预测评估指数 +@exception_logger def model_losss_juxiting(sqlitedb): global dataset global rote @@ -857,7 +888,7 @@ def model_losss_juxiting(sqlitedb): import matplotlib.dates as mdates - +@exception_logger def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'): global y # 创建内容对应的空列表 @@ -1138,6 +1169,7 @@ def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inpu except TimeoutError as e: print(f"请求超时: {e}") +@exception_logger def pp_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'): global y # 创建内容对应的空列表 @@ -1748,7 +1780,8 @@ def pp_export_pdf_v1(num_indicators=475,num_models=21, num_dayindicator=202,inpu upload_report_data(token, upload_data) except TimeoutError as e: print(f"请求超时: {e}") - + +@exception_logger def tansuanli_export_pdf(num_indicators=475,num_models=22, num_dayindicator=202,inputsize=5,dataset='dataset',y='电碳价格',end_time='2024-07-30',reportname='tansuanli.pdf'): # 创建内容对应的空列表 content = list()