调整最佳模型在绘图之前

This commit is contained in:
liurui 2024-11-07 11:32:38 +08:00
parent bf831258e6
commit c8320bf849
2 changed files with 44 additions and 43 deletions

38
main.py
View File

@ -118,25 +118,25 @@ def predict_main():
row,col = df.shape row,col = df.shape
now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
ex_Model(df, # ex_Model(df,
horizon=horizon, # horizon=horizon,
input_size=input_size, # input_size=input_size,
train_steps=train_steps, # train_steps=train_steps,
val_check_steps=val_check_steps, # val_check_steps=val_check_steps,
early_stop_patience_steps=early_stop_patience_steps, # early_stop_patience_steps=early_stop_patience_steps,
is_debug=is_debug, # is_debug=is_debug,
dataset=dataset, # dataset=dataset,
is_train=is_train, # is_train=is_train,
is_fivemodels=is_fivemodels, # is_fivemodels=is_fivemodels,
val_size=val_size, # val_size=val_size,
test_size=test_size, # test_size=test_size,
settings=settings, # settings=settings,
now=now, # now=now,
etadata = etadata, # etadata = etadata,
modelsindex = modelsindex, # modelsindex = modelsindex,
data = data, # data = data,
is_eta=is_eta, # is_eta=is_eta,
) # )
logger.info('模型训练完成') logger.info('模型训练完成')

View File

@ -509,6 +509,31 @@ def model_losss_juxiting(sqlitedb):
df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要 df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要
# 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE
cellText = []
# 遍历模型名称,计算模型评估指标
for model in modelnames:
modelmse = mse(df_combined['y'], df_combined[model])
modelrmse = rmse(df_combined['y'], df_combined[model])
modelmae = mae(df_combined['y'], df_combined[model])
# modelmape = mape(df_combined['y'], df_combined[model])
# modelsmape = smape(df_combined['y'], df_combined[model])
# modelr2 = r2_score(df_combined['y'], df_combined[model])
cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)])
model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)'])
# 按MSE降序排列
model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True)
model_results3.to_csv(os.path.join(dataset,"model_evaluation.csv"),index=False)
modelnames = model_results3['模型(Model)'].tolist()
allmodelnames = modelnames.copy()
# 保存5个最佳模型的名称
if len(modelnames) > 5:
modelnames = modelnames[0:5]
with open(os.path.join(dataset,"best_modelnames.txt"), 'w') as f:
f.write(','.join(modelnames) + '\n')
# 使用最佳五个模型进行绘图 # 使用最佳五个模型进行绘图
best_models = pd.read_csv(os.path.join(dataset,'best_modelnames.txt'),header=None).values.flatten().tolist() best_models = pd.read_csv(os.path.join(dataset,'best_modelnames.txt'),header=None).values.flatten().tolist()
def find_min_max_within_quantile(row): def find_min_max_within_quantile(row):
@ -605,30 +630,6 @@ def model_losss_juxiting(sqlitedb):
df_combined3.to_csv(os.path.join(dataset,"testandpredict_groupby.csv"),index=False) df_combined3.to_csv(os.path.join(dataset,"testandpredict_groupby.csv"),index=False)
# 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE
cellText = []
# 遍历模型名称,计算模型评估指标
for model in modelnames:
modelmse = mse(df_combined['y'], df_combined[model])
modelrmse = rmse(df_combined['y'], df_combined[model])
modelmae = mae(df_combined['y'], df_combined[model])
# modelmape = mape(df_combined['y'], df_combined[model])
# modelsmape = smape(df_combined['y'], df_combined[model])
# modelr2 = r2_score(df_combined['y'], df_combined[model])
cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)])
model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)'])
# 按MSE降序排列
model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True)
model_results3.to_csv(os.path.join(dataset,"model_evaluation.csv"),index=False)
modelnames = model_results3['模型(Model)'].tolist()
allmodelnames = modelnames.copy()
# 保存5个最佳模型的名称
if len(modelnames) > 5:
modelnames = modelnames[0:5]
with open(os.path.join(dataset,"best_modelnames.txt"), 'w') as f:
f.write(','.join(modelnames) + '\n')
# 预测值与真实值对比图 # 预测值与真实值对比图
plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['font.sans-serif'] = ['SimHei']