添加异常装饰器

This commit is contained in:
liurui 2024-12-18 17:49:23 +08:00
parent 34c4a9e205
commit 9de1b4a857
4 changed files with 78 additions and 32 deletions

View File

@ -223,7 +223,7 @@ table_name = 'v_tbl_crude_oil_warning'
### 开关
is_train = False # 是否训练
is_train = True # 是否训练
is_debug = False # 是否调试
is_eta = False # 是否使用eta接口
is_timefurture = True # 是否使用时间特征

View File

@ -512,5 +512,18 @@ class MySQLDB:
self.connection.close()
logging.info("Database connection closed.")
def exception_logger(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
# 记录异常日志
logging.error(f"An error occurred in function {func.__name__}: {str(e)}")
# 可以选择重新抛出异常,或者在这里处理异常
raise e # 重新抛出异常
return wrapper
if __name__ == '__main__':
print('This is a tool, not a script.')

View File

@ -1,6 +1,6 @@
# 读取配置
from lib.dataread import *
from lib.tools import SendMail
from lib.tools import SendMail,exception_logger
from models.nerulforcastmodels import ex_Model,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting
import glob
@ -206,14 +206,14 @@ def predict_main(end_time):
logger.info('训练数据绘图end')
# 模型报告
logger.info('制作报告ing')
title = f'{settings}--{now}-预测报告' # 报告标题
# logger.info('制作报告ing')
# title = f'{settings}--{now}-预测报告' # 报告标题
brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
reportname=reportname,sqlitedb=sqlitedb),
# brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
# reportname=reportname,sqlitedb=sqlitedb),
logger.info('制作报告end')
logger.info('模型训练完成')
# logger.info('制作报告end')
# logger.info('模型训练完成')
# # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
@ -234,12 +234,12 @@ def predict_main(end_time):
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
ssl=ssl,
)
m.send_mail()
# m.send_mail()
if __name__ == '__main__':
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
for i_time in pd.date_range('2024-11-22', '2024-12-16', freq='B'):
for i_time in pd.date_range('2024-12-02', '2024-12-16', freq='B'):
end_time = i_time.strftime('%Y-%m-%d')
# print(e_time)
predict_main(end_time)

View File

@ -6,7 +6,7 @@ import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime
from lib.tools import Graphs,mse,rmse,mae
from lib.tools import Graphs,mse,rmse,mae,exception_logger
from lib.dataread import *
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer
@ -36,7 +36,7 @@ from reportlab.lib.units import cm # 单位cm
pdfmetrics.registerFont(TTFont('SimSun', 'SimSun.ttf'))
@exception_logger
def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patience_steps,
is_debug,dataset,is_train,is_fivemodels,val_size,test_size,settings,now,
etadata,modelsindex,data,is_eta):
@ -222,6 +222,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
# 原油计算预测评估指数
@exception_logger
def model_losss(sqlitedb,end_time):
global dataset
global rote
@ -411,7 +412,15 @@ def model_losss(sqlitedb,end_time):
df_predict2['id'] = range(1, 1 + len(df_predict2))
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
df_predict2['CREAT_DATE'] = end_time
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
def get_common_columns(df1, df2):
# 获取两个DataFrame的公共列名
return list(set(df1.columns).intersection(df2.columns))
common_columns = get_common_columns(df_predict2, existing_data)
try:
df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
except:
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
# 更新accuracy表中的y值
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
@ -425,15 +434,21 @@ def model_losss(sqlitedb,end_time):
logger.error(f'更新accuracy表中的y值失败row={row}')
# 上周准确率计算
predict_y = sqlitedb.select_data(table_name = "accuracy")
ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
# 模型评估前五最大最小
# ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
ids = predict_y['id'].tolist()
# 准确率基准与绘图上下界逻辑一致
# predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
# 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)
# 模型评估前五均值
# predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
# predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
# 模型评估前十均值
predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1)
predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1)
# 模型评估前十均值
# predict_y['min_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) -1
# predict_y['max_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) +1
# 模型评估前十最大最小
# allmodelnames 和 predict_y 列 重复的
allmodelnames = [col for col in allmodelnames if col in predict_y.columns]
predict_y['min_price'] = predict_y[allmodelnames[0:10]].min(axis=1)
predict_y['max_price'] = predict_y[allmodelnames[0:10]].max(axis=1)
for id in ids:
row = predict_y[predict_y['id'] == id]
try:
@ -448,25 +463,37 @@ def model_losss(sqlitedb,end_time):
df = pd.merge(predict_y,df2,on=['ds'],how='left')
df['ds'] = pd.to_datetime(df['ds'])
df = df.reindex()
# 判断预测值在不在布伦特最高最低价范围内准确率为1否则为0
def is_within_range(row):
for model in allmodelnames:
if row['LOW_PRICE'] <= row[col] <= row['HIGH_PRICE']:
return 1
else:
return 0
# 比较真实最高最低,和预测最高最低 计算准确率
def calculate_accuracy(row):
# 子集情况:
if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \
(row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']):
# 子集情况:
if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \
(row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):
return 1
# 无交集情况:
if row['max_price'] < row['LOW_PRICE'] or \
row['min_price'] > row['HIGH_PRICE']:
return 0
# 有交集情况:
if row['HIGH_PRICE'] > row['min_price'] or \
row['max_price'] > row['LOW_PRICE']:
else:
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
middle_diff = sorted_prices[2] - sorted_prices[1]
price_range = row['HIGH_PRICE'] - row['LOW_PRICE']
accuracy = middle_diff / price_range
return accuracy
# 无交集情况:
else:
return 0
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
df[columns] = df[columns].astype(float)
df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)
# df['ACCURACY'] = df.apply(is_within_range, axis=1)
# 取结束日期上一周的日期
endtime = end_time
endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')
@ -477,12 +504,15 @@ def model_losss(sqlitedb,end_time):
df3 = df.copy()
df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)]
df3 = df3[df3['ds'].isin(up_week_dates)]
# df3.to_csv(os.path.join(dataset,f'accuracy_{endtime}.csv'),index=False)
total = len(df3)
accuracy_rote = 0
# 设置权重字典
weight_dict = [0.4,0.15,0.1,0.1,0.25]
for i,group in df3.groupby('ds'):
print('权重:',round(len(group)/total,2))
print('准确率:',group['ACCURACY'].sum()/(len(group)/total))
accuracy_rote += group['ACCURACY'].sum()/(len(group)/total)
print('权重:',weight_dict[len(group)-1])
print('准确率:',(group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1])
accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]
df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])
df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote}
print(df4)
@ -598,6 +628,7 @@ def model_losss(sqlitedb,end_time):
# 聚烯烃计算预测评估指数
@exception_logger
def model_losss_juxiting(sqlitedb):
global dataset
global rote
@ -857,7 +888,7 @@ def model_losss_juxiting(sqlitedb):
import matplotlib.dates as mdates
@exception_logger
def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'):
global y
# 创建内容对应的空列表
@ -1138,6 +1169,7 @@ def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inpu
except TimeoutError as e:
print(f"请求超时: {e}")
@exception_logger
def pp_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'):
global y
# 创建内容对应的空列表
@ -1748,7 +1780,8 @@ def pp_export_pdf_v1(num_indicators=475,num_models=21, num_dayindicator=202,inpu
upload_report_data(token, upload_data)
except TimeoutError as e:
print(f"请求超时: {e}")
@exception_logger
def tansuanli_export_pdf(num_indicators=475,num_models=22, num_dayindicator=202,inputsize=5,dataset='dataset',y='电碳价格',end_time='2024-07-30',reportname='tansuanli.pdf'):
# 创建内容对应的空列表
content = list()