添加异常装饰器
This commit is contained in:
parent
34c4a9e205
commit
9de1b4a857
@ -223,7 +223,7 @@ table_name = 'v_tbl_crude_oil_warning'
|
||||
|
||||
|
||||
### 开关
|
||||
is_train = False # 是否训练
|
||||
is_train = True # 是否训练
|
||||
is_debug = False # 是否调试
|
||||
is_eta = False # 是否使用eta接口
|
||||
is_timefurture = True # 是否使用时间特征
|
||||
|
13
lib/tools.py
13
lib/tools.py
@ -512,5 +512,18 @@ class MySQLDB:
|
||||
self.connection.close()
|
||||
logging.info("Database connection closed.")
|
||||
|
||||
def exception_logger(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 记录异常日志
|
||||
logging.error(f"An error occurred in function {func.__name__}: {str(e)}")
|
||||
# 可以选择重新抛出异常,或者在这里处理异常
|
||||
raise e # 重新抛出异常
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('This is a tool, not a script.')
|
@ -1,6 +1,6 @@
|
||||
# 读取配置
|
||||
from lib.dataread import *
|
||||
from lib.tools import SendMail
|
||||
from lib.tools import SendMail,exception_logger
|
||||
from models.nerulforcastmodels import ex_Model,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting
|
||||
|
||||
import glob
|
||||
@ -206,14 +206,14 @@ def predict_main(end_time):
|
||||
logger.info('训练数据绘图end')
|
||||
|
||||
# 模型报告
|
||||
logger.info('制作报告ing')
|
||||
title = f'{settings}--{now}-预测报告' # 报告标题
|
||||
# logger.info('制作报告ing')
|
||||
# title = f'{settings}--{now}-预测报告' # 报告标题
|
||||
|
||||
brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
|
||||
reportname=reportname,sqlitedb=sqlitedb),
|
||||
# brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
|
||||
# reportname=reportname,sqlitedb=sqlitedb),
|
||||
|
||||
logger.info('制作报告end')
|
||||
logger.info('模型训练完成')
|
||||
# logger.info('制作报告end')
|
||||
# logger.info('模型训练完成')
|
||||
|
||||
# # LSTM 单变量模型
|
||||
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
|
||||
@ -234,12 +234,12 @@ def predict_main(end_time):
|
||||
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
|
||||
ssl=ssl,
|
||||
)
|
||||
m.send_mail()
|
||||
# m.send_mail()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
|
||||
for i_time in pd.date_range('2024-11-22', '2024-12-16', freq='B'):
|
||||
for i_time in pd.date_range('2024-12-02', '2024-12-16', freq='B'):
|
||||
end_time = i_time.strftime('%Y-%m-%d')
|
||||
# print(e_time)
|
||||
predict_main(end_time)
|
@ -6,7 +6,7 @@ import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
import datetime
|
||||
from lib.tools import Graphs,mse,rmse,mae
|
||||
from lib.tools import Graphs,mse,rmse,mae,exception_logger
|
||||
from lib.dataread import *
|
||||
from neuralforecast import NeuralForecast
|
||||
from neuralforecast.models import NHITS,Informer, NBEATSx,LSTM,PatchTST, iTransformer, TSMixer
|
||||
@ -36,7 +36,7 @@ from reportlab.lib.units import cm # 单位:cm
|
||||
pdfmetrics.registerFont(TTFont('SimSun', 'SimSun.ttf'))
|
||||
|
||||
|
||||
|
||||
@exception_logger
|
||||
def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patience_steps,
|
||||
is_debug,dataset,is_train,is_fivemodels,val_size,test_size,settings,now,
|
||||
etadata,modelsindex,data,is_eta):
|
||||
@ -222,6 +222,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
|
||||
|
||||
|
||||
# 原油计算预测评估指数
|
||||
@exception_logger
|
||||
def model_losss(sqlitedb,end_time):
|
||||
global dataset
|
||||
global rote
|
||||
@ -411,7 +412,15 @@ def model_losss(sqlitedb,end_time):
|
||||
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
||||
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
|
||||
df_predict2['CREAT_DATE'] = end_time
|
||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||
def get_common_columns(df1, df2):
|
||||
# 获取两个DataFrame的公共列名
|
||||
return list(set(df1.columns).intersection(df2.columns))
|
||||
|
||||
common_columns = get_common_columns(df_predict2, existing_data)
|
||||
try:
|
||||
df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||
except:
|
||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||
|
||||
# 更新accuracy表中的y值
|
||||
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
|
||||
@ -425,15 +434,21 @@ def model_losss(sqlitedb,end_time):
|
||||
logger.error(f'更新accuracy表中的y值失败,row={row}')
|
||||
# 上周准确率计算
|
||||
predict_y = sqlitedb.select_data(table_name = "accuracy")
|
||||
ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
|
||||
# 模型评估前五最大最小
|
||||
# ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
|
||||
ids = predict_y['id'].tolist()
|
||||
# 准确率基准与绘图上下界逻辑一致
|
||||
# predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
|
||||
# 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)
|
||||
# 模型评估前五均值
|
||||
# predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
|
||||
# predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
|
||||
# 模型评估前十均值
|
||||
predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1)
|
||||
predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1)
|
||||
# 模型评估前十均值
|
||||
# predict_y['min_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) -1
|
||||
# predict_y['max_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) +1
|
||||
# 模型评估前十最大最小
|
||||
# allmodelnames 和 predict_y 列 重复的
|
||||
allmodelnames = [col for col in allmodelnames if col in predict_y.columns]
|
||||
predict_y['min_price'] = predict_y[allmodelnames[0:10]].min(axis=1)
|
||||
predict_y['max_price'] = predict_y[allmodelnames[0:10]].max(axis=1)
|
||||
for id in ids:
|
||||
row = predict_y[predict_y['id'] == id]
|
||||
try:
|
||||
@ -448,25 +463,37 @@ def model_losss(sqlitedb,end_time):
|
||||
df = pd.merge(predict_y,df2,on=['ds'],how='left')
|
||||
df['ds'] = pd.to_datetime(df['ds'])
|
||||
df = df.reindex()
|
||||
|
||||
# 判断预测值在不在布伦特最高最低价范围内,准确率为1,否则为0
|
||||
def is_within_range(row):
|
||||
for model in allmodelnames:
|
||||
if row['LOW_PRICE'] <= row[col] <= row['HIGH_PRICE']:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
# 比较真实最高最低,和预测最高最低 计算准确率
|
||||
def calculate_accuracy(row):
|
||||
# 子集情况:
|
||||
if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \
|
||||
(row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']):
|
||||
# 全子集情况:
|
||||
if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \
|
||||
(row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):
|
||||
return 1
|
||||
# 无交集情况:
|
||||
if row['max_price'] < row['LOW_PRICE'] or \
|
||||
row['min_price'] > row['HIGH_PRICE']:
|
||||
return 0
|
||||
# 有交集情况:
|
||||
if row['HIGH_PRICE'] > row['min_price'] or \
|
||||
row['max_price'] > row['LOW_PRICE']:
|
||||
else:
|
||||
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
|
||||
middle_diff = sorted_prices[2] - sorted_prices[1]
|
||||
price_range = row['HIGH_PRICE'] - row['LOW_PRICE']
|
||||
accuracy = middle_diff / price_range
|
||||
return accuracy
|
||||
# 无交集情况:
|
||||
else:
|
||||
return 0
|
||||
|
||||
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
|
||||
df[columns] = df[columns].astype(float)
|
||||
df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)
|
||||
# df['ACCURACY'] = df.apply(is_within_range, axis=1)
|
||||
# 取结束日期上一周的日期
|
||||
endtime = end_time
|
||||
endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')
|
||||
@ -477,12 +504,15 @@ def model_losss(sqlitedb,end_time):
|
||||
df3 = df.copy()
|
||||
df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)]
|
||||
df3 = df3[df3['ds'].isin(up_week_dates)]
|
||||
# df3.to_csv(os.path.join(dataset,f'accuracy_{endtime}.csv'),index=False)
|
||||
total = len(df3)
|
||||
accuracy_rote = 0
|
||||
# 设置权重字典
|
||||
weight_dict = [0.4,0.15,0.1,0.1,0.25]
|
||||
for i,group in df3.groupby('ds'):
|
||||
print('权重:',round(len(group)/total,2))
|
||||
print('准确率:',group['ACCURACY'].sum()/(len(group)/total))
|
||||
accuracy_rote += group['ACCURACY'].sum()/(len(group)/total)
|
||||
print('权重:',weight_dict[len(group)-1])
|
||||
print('准确率:',(group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1])
|
||||
accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]
|
||||
df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])
|
||||
df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote}
|
||||
print(df4)
|
||||
@ -598,6 +628,7 @@ def model_losss(sqlitedb,end_time):
|
||||
|
||||
|
||||
# 聚烯烃计算预测评估指数
|
||||
@exception_logger
|
||||
def model_losss_juxiting(sqlitedb):
|
||||
global dataset
|
||||
global rote
|
||||
@ -857,7 +888,7 @@ def model_losss_juxiting(sqlitedb):
|
||||
|
||||
|
||||
import matplotlib.dates as mdates
|
||||
|
||||
@exception_logger
|
||||
def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'):
|
||||
global y
|
||||
# 创建内容对应的空列表
|
||||
@ -1138,6 +1169,7 @@ def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inpu
|
||||
except TimeoutError as e:
|
||||
print(f"请求超时: {e}")
|
||||
|
||||
@exception_logger
|
||||
def pp_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inputsize=5,dataset='dataset',time = '2024-07-30',reportname='report.pdf',sqlitedb='jbsh_yuanyou.db'):
|
||||
global y
|
||||
# 创建内容对应的空列表
|
||||
@ -1748,7 +1780,8 @@ def pp_export_pdf_v1(num_indicators=475,num_models=21, num_dayindicator=202,inpu
|
||||
upload_report_data(token, upload_data)
|
||||
except TimeoutError as e:
|
||||
print(f"请求超时: {e}")
|
||||
|
||||
|
||||
@exception_logger
|
||||
def tansuanli_export_pdf(num_indicators=475,num_models=22, num_dayindicator=202,inputsize=5,dataset='dataset',y='电碳价格',end_time='2024-07-30',reportname='tansuanli.pdf'):
|
||||
# 创建内容对应的空列表
|
||||
content = list()
|
||||
|
Loading…
Reference in New Issue
Block a user