496 lines
22 KiB
Plaintext
496 lines
22 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "7fadc60c-d710-4b8c-89cd-1d889ece1eaf",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"从eta获取数据...\n",
|
||
"跳过指标 美国:东海岸地区:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 美国:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 美国:洛基山地区:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 美国:墨西哥湾沿岸:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 美国:西海岸地区:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 美国:中西部地区:炼油厂的投入与使用情况:开工率:四周均值\n",
|
||
"跳过指标 中国航班执行数/7DMA\n",
|
||
"跳过指标 美国汽油表需(周度)1周环差负值\n",
|
||
"跳过指标 美国汽油表需(周度)1周环差\n",
|
||
"跳过指标 美国汽油产量(周度)1周环差\n",
|
||
"跳过指标 美国原油周度表需/4WMA\n",
|
||
"跳过指标 美国油品表需/4WMA\n",
|
||
"跳过指标 美国油品表需4周环差\n",
|
||
"跳过指标 道琼斯旅游与休闲/标普500\n",
|
||
"跳过指标 DOE-柴油产量1周环差\n",
|
||
"跳过指标 DOE-美国汽油产量1周环差\n",
|
||
"跳过指标 美国航煤表需一周环差负值\n",
|
||
"跳过指标 美国馏分油表需一周环差负值\n",
|
||
"跳过指标 美国汽油表需一周环差负值\n",
|
||
"跳过指标 欧洲汽油裂差\n",
|
||
"跳过指标 欧洲柴油裂差\n",
|
||
"跳过指标 中国主营炼厂产能利用率1周环差\n",
|
||
"跳过指标 美国零售拥堵指数(周环比)/3WMA\n",
|
||
"跳过指标 美国炼厂原油输入量1周环差\n",
|
||
"跳过指标 美国炼厂原油输入量4周均值\n",
|
||
"跳过指标 美国柴油期货裂差\n",
|
||
"跳过指标 美国墨西哥湾柴油裂差\n",
|
||
"跳过指标 美国纽约港柴油裂差\n",
|
||
"跳过指标 欧洲柴油期货裂差\n",
|
||
"跳过指标 西北欧轻柴油裂差\n",
|
||
"跳过指标 新加坡柴油裂差\n",
|
||
"跳过指标 美国汽油期货裂差\n",
|
||
"跳过指标 美国墨西哥湾汽油裂差\n",
|
||
"跳过指标 美国纽约港汽油裂差\n",
|
||
"跳过指标 欧洲鹿特丹汽油裂差\n",
|
||
"跳过指标 新加坡汽油裂差\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 读取配置\n",
|
||
"from lib.dataread import *\n",
|
||
"from lib.tools import *\n",
|
||
"from models.nerulforcastmodels import ex_Model,model_losss,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting\n",
|
||
"\n",
|
||
"import glob\n",
|
||
"import torch\n",
|
||
"torch.set_float32_matmul_precision(\"high\")\n",
|
||
"\n",
|
||
"sqlitedb = SQLiteHandler(db_name) \n",
|
||
"sqlitedb.connect()\n",
|
||
"\n",
|
||
"signature = BinanceAPI(APPID, SECRET)\n",
|
||
"etadata = EtaReader(signature=signature,\n",
|
||
" classifylisturl = classifylisturl,\n",
|
||
" classifyidlisturl=classifyidlisturl,\n",
|
||
" edbcodedataurl=edbcodedataurl,\n",
|
||
" edbcodelist=edbcodelist,\n",
|
||
" edbdatapushurl=edbdatapushurl,\n",
|
||
" edbdeleteurl=edbdeleteurl,\n",
|
||
" edbbusinessurl=edbbusinessurl\n",
|
||
" )\n",
|
||
"# 获取数据\n",
|
||
"if is_eta:\n",
|
||
" # eta数据\n",
|
||
" logger.info('从eta获取数据...')\n",
|
||
" signature = BinanceAPI(APPID, SECRET)\n",
|
||
" etadata = EtaReader(signature=signature,\n",
|
||
" classifylisturl = classifylisturl,\n",
|
||
" classifyidlisturl=classifyidlisturl,\n",
|
||
" edbcodedataurl=edbcodedataurl,\n",
|
||
" edbcodelist=edbcodelist,\n",
|
||
" edbdatapushurl=edbdatapushurl,\n",
|
||
" edbdeleteurl=edbdeleteurl,\n",
|
||
" edbbusinessurl=edbbusinessurl,\n",
|
||
" )\n",
|
||
" df_zhibiaoshuju,df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(data_set=data_set,dataset=dataset) # 原始数据,未处理\n",
|
||
"\n",
|
||
" # 数据处理\n",
|
||
" df = datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) \n",
|
||
"\n",
|
||
"else:\n",
|
||
" logger.info('读取本地数据:'+os.path.join(dataset,data_set))\n",
|
||
" df = getdata(filename=os.path.join(dataset,data_set),y=y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) # 原始数据,未处理\n",
|
||
"\n",
|
||
"# 更改预测列名称\n",
|
||
"df.rename(columns={y:'y'},inplace=True)\n",
|
||
" \n",
|
||
"if is_edbnamelist:\n",
|
||
" df = df[edbnamelist] \n",
|
||
"df.to_csv(os.path.join(dataset,'指标数据.csv'), index=False)\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ae059224-976c-4839-b455-f81da7f25179",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 保存最新日期的y值到数据库\n",
|
||
"# 取第一行数据存储到数据库中\n",
|
||
"first_row = df[['ds','y']].tail(1)\n",
|
||
"# 将最新真实值保存到数据库\n",
|
||
"if not sqlitedb.check_table_exists('trueandpredict'):\n",
|
||
" first_row.to_sql('trueandpredict',sqlitedb.connection,index=False)\n",
|
||
"else:\n",
|
||
" for row in first_row.itertuples(index=False):\n",
|
||
" row_dict = row._asdict()\n",
|
||
" row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')\n",
|
||
" check_query = sqlitedb.select_data('trueandpredict',where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" if len(check_query) > 0:\n",
|
||
" set_clause = \", \".join([f\"{key} = '{value}'\" for key, value in row_dict.items()])\n",
|
||
" sqlitedb.update_data('trueandpredict',set_clause,where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" continue\n",
|
||
" sqlitedb.insert_data('trueandpredict',tuple(row_dict.values()),columns=row_dict.keys())\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "abb597fc-c5f3-4d76-8099-5eff358cb634",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import datetime\n",
|
||
"# 判断当前日期是不是周一\n",
|
||
"is_weekday = datetime.datetime.now().weekday() == 1\n",
|
||
"if is_weekday:\n",
|
||
" logger.info('今天是周一,更新预测模型')\n",
|
||
" # 计算最近20天预测残差最低的模型名称\n",
|
||
"\n",
|
||
" model_results = sqlitedb.select_data('trueandpredict',order_by = \"ds DESC\",limit = \"20\")\n",
|
||
" # 删除空值率为40%以上的列\n",
|
||
" print(model_results.shape)\n",
|
||
" model_results = model_results.dropna(thresh=len(model_results)*0.6,axis=1)\n",
|
||
" model_results = model_results.dropna()\n",
|
||
" print(model_results.shape)\n",
|
||
" modelnames = model_results.columns.to_list()[2:] \n",
|
||
" for col in model_results[modelnames].select_dtypes(include=['object']).columns:\n",
|
||
" model_results[col] = model_results[col].astype(np.float32)\n",
|
||
" # 计算每个预测值与真实值之间的偏差率\n",
|
||
" for model in modelnames:\n",
|
||
" model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y']\n",
|
||
"\n",
|
||
" # 获取每行对应的最小偏差率值\n",
|
||
" min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)\n",
|
||
" # 获取每行对应的最小偏差率值对应的列名\n",
|
||
" min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)\n",
|
||
" print(min_abs_error_rate_column_name)\n",
|
||
" # 将列名索引转换为列名\n",
|
||
" min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0])\n",
|
||
" # 取出现次数最多的模型名称\n",
|
||
" most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()\n",
|
||
" logger.info(f\"最近20天预测残差最低的模型名称:{most_common_model}\")\n",
|
||
"\n",
|
||
" # 保存结果到数据库\n",
|
||
" \n",
|
||
" if not sqlitedb.check_table_exists('most_model'):\n",
|
||
" sqlitedb.create_table('most_model',columns=\"ds datetime, most_common_model TEXT\")\n",
|
||
" sqlitedb.insert_data('most_model',(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),most_common_model,),columns=('ds','most_common_model',))\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ade7026e-8cf2-405f-a2da-9e90f364adab",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"if is_corr:\n",
|
||
" df = corr_feature(df=df)\n",
|
||
"\n",
|
||
"df1 = df.copy() # 备份一下,后面特征筛选完之后加入ds y 列用\n",
|
||
"logger.info(f\"开始训练模型...\")\n",
|
||
"row,col = df.shape\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dfef57d8-36da-423b-bbe7-05a13e15f71b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')\n",
|
||
"ex_Model(df,\n",
|
||
" horizon=horizon,\n",
|
||
" input_size=input_size,\n",
|
||
" train_steps=train_steps,\n",
|
||
" val_check_steps=val_check_steps,\n",
|
||
" early_stop_patience_steps=early_stop_patience_steps,\n",
|
||
" is_debug=is_debug,\n",
|
||
" dataset=dataset,\n",
|
||
" is_train=is_train,\n",
|
||
" is_fivemodels=is_fivemodels,\n",
|
||
" val_size=val_size,\n",
|
||
" test_size=test_size,\n",
|
||
" settings=settings,\n",
|
||
" now=now,\n",
|
||
" etadata = etadata,\n",
|
||
" modelsindex = modelsindex,\n",
|
||
" data = data,\n",
|
||
" is_eta=is_eta,\n",
|
||
" )\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0e5b6f30-b7ca-4718-97a3-48b54156e07f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"logger.info('模型训练完成')\n",
|
||
"# # 模型评估\n",
|
||
"\n",
|
||
"pd.set_option('display.max_columns', 100)\n",
|
||
"# 计算预测评估指数\n",
|
||
"def model_losss_juxiting(sqlitedb):\n",
|
||
" global dataset\n",
|
||
" # 数据库查询最佳模型名称\n",
|
||
" most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]\n",
|
||
" most_model_name = most_model[0]\n",
|
||
"\n",
|
||
" # 预测数据处理 predict\n",
|
||
" df_combined = loadcsv(os.path.join(dataset,\"cross_validation.csv\")) \n",
|
||
" df_combined = dateConvert(df_combined)\n",
|
||
" # 删除空列\n",
|
||
" df_combined.dropna(axis=1,inplace=True)\n",
|
||
" # 删除缺失值,预测过程不能有缺失值\n",
|
||
" df_combined.dropna(inplace=True) \n",
|
||
" # 其他列转为数值类型\n",
|
||
" df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] })\n",
|
||
" # 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值\n",
|
||
" df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max')\n",
|
||
"\n",
|
||
" # 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列\n",
|
||
" df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']]\n",
|
||
" # 删除模型生成的cutoff列\n",
|
||
" df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True)\n",
|
||
" # 获取模型名称\n",
|
||
" modelnames = df_combined.columns.to_list()[1:] \n",
|
||
" if 'y' in modelnames:\n",
|
||
" modelnames.remove('y')\n",
|
||
" df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要\n",
|
||
"\n",
|
||
"\n",
|
||
" # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE\n",
|
||
" cellText = []\n",
|
||
"\n",
|
||
" # 遍历模型名称,计算模型评估指标 \n",
|
||
" for model in modelnames:\n",
|
||
" modelmse = mse(df_combined['y'], df_combined[model])\n",
|
||
" modelrmse = rmse(df_combined['y'], df_combined[model])\n",
|
||
" modelmae = mae(df_combined['y'], df_combined[model])\n",
|
||
" # modelmape = mape(df_combined['y'], df_combined[model])\n",
|
||
" # modelsmape = smape(df_combined['y'], df_combined[model])\n",
|
||
" # modelr2 = r2_score(df_combined['y'], df_combined[model])\n",
|
||
" cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)])\n",
|
||
" \n",
|
||
" model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)'])\n",
|
||
" # 按MSE降序排列\n",
|
||
" model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True)\n",
|
||
" model_results3.to_csv(os.path.join(dataset,\"model_evaluation.csv\"),index=False)\n",
|
||
" modelnames = model_results3['模型(Model)'].tolist()\n",
|
||
" allmodelnames = modelnames.copy()\n",
|
||
" # 保存5个最佳模型的名称\n",
|
||
" if len(modelnames) > 5:\n",
|
||
" modelnames = modelnames[0:5]\n",
|
||
" with open(os.path.join(dataset,\"best_modelnames.txt\"), 'w') as f:\n",
|
||
" f.write(','.join(modelnames) + '\\n')\n",
|
||
"\n",
|
||
"\n",
|
||
" # 去掉方差最大的模型,其余模型预测最大最小值确定通道边界\n",
|
||
" best_models = pd.read_csv(os.path.join(dataset,'best_modelnames.txt'),header=None).values.flatten().tolist()\n",
|
||
" \n",
|
||
"\n",
|
||
" # 预测值与真实值对比图\n",
|
||
" plt.rcParams['font.sans-serif'] = ['SimHei']\n",
|
||
" plt.figure(figsize=(15, 10))\n",
|
||
" # 设置有5个子图的画布\n",
|
||
" for n,model in enumerate(modelnames[:5]):\n",
|
||
" plt.subplot(3, 2, n+1)\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值')\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3[model], label=model)\n",
|
||
" plt.legend()\n",
|
||
" plt.xlabel('日期')\n",
|
||
" plt.ylabel('价格')\n",
|
||
" plt.title(model+'拟合')\n",
|
||
" plt.subplots_adjust(hspace=0.5)\n",
|
||
" plt.savefig(os.path.join(dataset,'预测值与真实值对比图.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" \n",
|
||
" # 历史数据+预测数据\n",
|
||
" # 拼接未来时间预测\n",
|
||
" df_predict = loadcsv(os.path.join(dataset,'predict.csv'))\n",
|
||
" df_predict.drop('unique_id',inplace=True,axis=1)\n",
|
||
" df_predict.dropna(axis=1,inplace=True)\n",
|
||
" df_predict2 = df_predict.copy()\n",
|
||
" try:\n",
|
||
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y-%m-%d')\n",
|
||
" except ValueError :\n",
|
||
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y/%m/%d')\n",
|
||
"\n",
|
||
" # 取第一行数据存储到数据库中\n",
|
||
" first_row = df_predict.head(1)\n",
|
||
" first_row['ds'] = first_row['ds'].dt.strftime('%Y-%m-%d 00:00:00')\n",
|
||
"\n",
|
||
" # # 将预测结果保存到数据库\n",
|
||
" # # 判断表存在\n",
|
||
" if not sqlitedb.check_table_exists('testandpredict_groupby'):\n",
|
||
" df_predict2.to_sql('testandpredict_groupby',sqlitedb.connection,index=False)\n",
|
||
" else:\n",
|
||
" for row in df_predict2.itertuples(index=False):\n",
|
||
" row_dict = row._asdict()\n",
|
||
" check_query = sqlitedb.select_data('testandpredict_groupby',where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" if len(check_query) > 0:\n",
|
||
" set_clause = \", \".join([f\"{key} = '{value}'\" for key, value in row_dict.items()])\n",
|
||
" sqlitedb.update_data('testandpredict_groupby',set_clause,where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" continue\n",
|
||
" sqlitedb.insert_data('testandpredict_groupby',tuple(row_dict.values()),columns=row_dict.keys())\n",
|
||
"\n",
|
||
" df_combined3 = pd.concat([df_combined3, df_predict]).reset_index(drop=True)\n",
|
||
"\n",
|
||
" # # 判断 df 的数值列转为float\n",
|
||
" for col in df_combined3.columns:\n",
|
||
" try:\n",
|
||
" if col != 'ds':\n",
|
||
" df_combined3[col] = df_combined3[col].astype(float)\n",
|
||
" df_combined3[col] = df_combined3[col].round(2)\n",
|
||
" except ValueError:\n",
|
||
" pass\n",
|
||
" df_combined3.to_csv(os.path.join(dataset,\"df_combined3.csv\"),index=False) \n",
|
||
" df_combined3.to_sql('testandpredict_groupby', sqlitedb.connection, if_exists='replace', index=False)\n",
|
||
" df_combined3.to_csv(os.path.join(dataset,\"testandpredict_groupby.csv\"),index=False)\n",
|
||
" \n",
|
||
" \n",
|
||
" ten_models = allmodelnames\n",
|
||
" # 计算每个模型的方差\n",
|
||
" variances = df_combined3[ten_models].var()\n",
|
||
" # 找到方差最大的模型\n",
|
||
" max_variance_model = variances.idxmax()\n",
|
||
" # 打印方差最大的模型\n",
|
||
" print(\"方差最大的模型是:\", max_variance_model)\n",
|
||
" # 去掉方差最大的模型\n",
|
||
" df_combined3 = df_combined3.drop(columns=[max_variance_model])\n",
|
||
" if max_variance_model in allmodelnames:\n",
|
||
" allmodelnames.remove(max_variance_model)\n",
|
||
" df_combined3['min'] = df_combined3[allmodelnames].min(axis=1)\n",
|
||
" df_combined3['max'] = df_combined3[allmodelnames].max(axis=1)\n",
|
||
" print(df_combined3[['min','max']])\n",
|
||
" # 历史价格+预测价格\n",
|
||
" df_combined3 = df_combined3[-50:] # 取50个数据点画图\n",
|
||
" plt.figure(figsize=(20, 10))\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值',marker='o')\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3[most_model], label=most_model_name)\n",
|
||
" plt.fill_between(df_combined3['ds'], df_combined3['min'], df_combined3['max'], alpha=0.2)\n",
|
||
" plt.grid(True)\n",
|
||
" # 当前日期画竖虚线\n",
|
||
" plt.axvline(x=df_combined3['ds'].iloc[-horizon], color='r', linestyle='--')\n",
|
||
" plt.legend()\n",
|
||
" plt.xlabel('日期')\n",
|
||
" plt.ylabel('价格')\n",
|
||
"\n",
|
||
" # # 显示历史值\n",
|
||
" for i, j in zip(df_combined3['ds'][:-5], df_combined3['y'][:-5]):\n",
|
||
" plt.text(i, j, str(j), ha='center', va='bottom')\n",
|
||
" plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')\n",
|
||
" plt.show()\n",
|
||
" plt.close()\n",
|
||
" \n",
|
||
" # 预测值表格\n",
|
||
" fig, ax = plt.subplots(figsize=(20, 6))\n",
|
||
" ax.axis('off') # 关闭坐标轴\n",
|
||
" # 数值保留2位小数\n",
|
||
" df_combined3 = df_combined3.round(2)\n",
|
||
" df_combined3 = df_combined3[-horizon:]\n",
|
||
" df_combined3['Day'] = [f'Day_{i}' for i in range(1,horizon+1)]\n",
|
||
" # Day列放到最前面\n",
|
||
" df_combined3 = df_combined3[['Day'] + list(df_combined3.columns[:-1])]\n",
|
||
" table = ax.table(cellText=df_combined3.values, colLabels=df_combined3.columns, loc='center')\n",
|
||
" #加宽表格\n",
|
||
" table.auto_set_font_size(False)\n",
|
||
" table.set_fontsize(10)\n",
|
||
"\n",
|
||
" # 设置表格样式,列数据最小的用绿色标识\n",
|
||
" plt.savefig(os.path.join(dataset,'预测值表格.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" # plt.show()\n",
|
||
" \n",
|
||
" # 可视化评估结果\n",
|
||
" plt.rcParams['font.sans-serif'] = ['SimHei']\n",
|
||
" fig, ax = plt.subplots(figsize=(20, 10))\n",
|
||
" ax.axis('off') # 关闭坐标轴\n",
|
||
" table = ax.table(cellText=model_results3.values, colLabels=model_results3.columns, loc='center')\n",
|
||
" # 加宽表格\n",
|
||
" table.auto_set_font_size(False)\n",
|
||
" table.set_fontsize(10)\n",
|
||
"\n",
|
||
" # 设置表格样式,列数据最小的用绿色标识\n",
|
||
" plt.savefig(os.path.join(dataset,'模型评估.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" return model_results3\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"logger.info('训练数据绘图ing')\n",
|
||
"model_results3 = model_losss_juxiting(sqlitedb)\n",
|
||
"\n",
|
||
"logger.info('训练数据绘图end')\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "85b557de-8235-4e27-b5b8-58b36dfe6724",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 模型报告\n",
|
||
"\n",
|
||
"logger.info('制作报告ing')\n",
|
||
"title = f'{settings}--{now}-预测报告' # 报告标题\n",
|
||
"\n",
|
||
"pp_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,\n",
|
||
" reportname=reportname,sqlitedb=sqlitedb),\n",
|
||
"\n",
|
||
"logger.info('制作报告end')\n",
|
||
"logger.info('模型训练完成')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d4129e71-ee2c-4af1-81ed-fadf14efa206",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 发送邮件\n",
|
||
"m = SendMail(\n",
|
||
" username=username,\n",
|
||
" passwd=passwd,\n",
|
||
" recv=recv,\n",
|
||
" title=title,\n",
|
||
" content=content,\n",
|
||
" file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),\n",
|
||
" ssl=ssl,\n",
|
||
")\n",
|
||
"# m.send_mail() \n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|