PriceForecast/原油预测绘图调试.ipynb
2024-12-25 16:13:22 +08:00

711 lines
32 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "7fadc60c-d710-4b8c-89cd-1d889ece1eaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"数据库连接成功 192.168.101.27 jingbo_test root\n"
]
}
],
"source": [
"# 读取配置\n",
"# 父目录下的lib\n",
"from lib.dataread import *\n",
"from lib.tools import Graphs,mse,rmse,mae,exception_logger\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0e5b6f30-b7ca-4718-97a3-48b54156e07f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(51, 30)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>模型(Model)</th>\n",
" <th>平均平方误差(MSE)</th>\n",
" <th>均方根误差(RMSE)</th>\n",
" <th>平均绝对误差(MAE)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>DilatedRNN</td>\n",
" <td>1.567000</td>\n",
" <td>1.252</td>\n",
" <td>0.978</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>NLinear</td>\n",
" <td>1.905000</td>\n",
" <td>1.380</td>\n",
" <td>1.104</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>BiTCN</td>\n",
" <td>1.906000</td>\n",
" <td>1.380</td>\n",
" <td>1.042</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>PatchTST</td>\n",
" <td>1.939000</td>\n",
" <td>1.393</td>\n",
" <td>1.129</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>TiDE</td>\n",
" <td>1.967000</td>\n",
" <td>1.402</td>\n",
" <td>1.090</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>TSMixer</td>\n",
" <td>2.056000</td>\n",
" <td>1.434</td>\n",
" <td>1.111</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>RNN</td>\n",
" <td>2.101000</td>\n",
" <td>1.449</td>\n",
" <td>1.144</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>DLinear</td>\n",
" <td>2.162000</td>\n",
" <td>1.470</td>\n",
" <td>1.178</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>TFT</td>\n",
" <td>2.196000</td>\n",
" <td>1.482</td>\n",
" <td>1.137</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>FEDformer</td>\n",
" <td>2.211000</td>\n",
" <td>1.487</td>\n",
" <td>1.239</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>TCN</td>\n",
" <td>2.397000</td>\n",
" <td>1.548</td>\n",
" <td>1.276</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NHITS</td>\n",
" <td>2.454000</td>\n",
" <td>1.567</td>\n",
" <td>1.190</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>MLP</td>\n",
" <td>2.468000</td>\n",
" <td>1.571</td>\n",
" <td>1.224</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>TSMixerx</td>\n",
" <td>2.490000</td>\n",
" <td>1.578</td>\n",
" <td>1.231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Informer</td>\n",
" <td>3.095000</td>\n",
" <td>1.759</td>\n",
" <td>1.352</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>DeepNPTS</td>\n",
" <td>3.267000</td>\n",
" <td>1.808</td>\n",
" <td>1.357</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>GRU</td>\n",
" <td>5.172000</td>\n",
" <td>2.274</td>\n",
" <td>1.909</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>LSTM</td>\n",
" <td>6.844000</td>\n",
" <td>2.616</td>\n",
" <td>2.386</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>MLPMultivariate</td>\n",
" <td>8.163000</td>\n",
" <td>2.857</td>\n",
" <td>2.221</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>StemGNN</td>\n",
" <td>17.216000</td>\n",
" <td>4.149</td>\n",
" <td>3.359</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>iTransformer</td>\n",
" <td>21.568001</td>\n",
" <td>4.644</td>\n",
" <td>3.487</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 模型(Model) 平均平方误差(MSE) 均方根误差(RMSE) 平均绝对误差(MAE)\n",
"11 DilatedRNN 1.567000 1.252 0.978\n",
"14 NLinear 1.905000 1.380 1.104\n",
"10 BiTCN 1.906000 1.380 1.042\n",
"6 PatchTST 1.939000 1.393 1.129\n",
"19 TiDE 1.967000 1.402 1.090\n",
"4 TSMixer 2.056000 1.434 1.111\n",
"7 RNN 2.101000 1.449 1.144\n",
"13 DLinear 2.162000 1.470 1.178\n",
"15 TFT 2.196000 1.482 1.137\n",
"16 FEDformer 2.211000 1.487 1.239\n",
"9 TCN 2.397000 1.548 1.276\n",
"0 NHITS 2.454000 1.567 1.190\n",
"12 MLP 2.468000 1.571 1.224\n",
"5 TSMixerx 2.490000 1.578 1.231\n",
"1 Informer 3.095000 1.759 1.352\n",
"20 DeepNPTS 3.267000 1.808 1.357\n",
"8 GRU 5.172000 2.274 1.909\n",
"2 LSTM 6.844000 2.616 2.386\n",
"18 MLPMultivariate 8.163000 2.857 2.221\n",
"17 StemGNN 17.216000 4.149 3.359\n",
"3 iTransformer 21.568001 4.644 3.487"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# 原油计算预测评估指数\n",
"@exception_logger\n",
"def model_losss(sqlitedb,end_time):\n",
" global dataset\n",
" global rote\n",
" most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]\n",
" most_model_name = most_model[0]\n",
"\n",
" # 预测数据处理 predict\n",
" df_combined = loadcsv(os.path.join(dataset,\"cross_validation.csv\")) \n",
" df_combined = dateConvert(df_combined)\n",
" # 删除空列\n",
" df_combined.dropna(axis=1,inplace=True)\n",
" # 删除缺失值,预测过程不能有缺失值\n",
" df_combined.dropna(inplace=True) \n",
" # 其他列转为数值类型\n",
" df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] })\n",
" # 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值\n",
" df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max')\n",
"\n",
" # 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列\n",
" df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']]\n",
" # 删除模型生成的cutoff列\n",
" df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True)\n",
" # 获取模型名称\n",
" modelnames = df_combined.columns.to_list()[1:] \n",
" if 'y' in modelnames:\n",
" modelnames.remove('y')\n",
" # df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要\n",
" df_combined3 = sqlitedb.select_data('accuracy')\n",
"\n",
"\n",
" # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE\n",
" cellText = []\n",
"\n",
" # 遍历模型名称,计算模型评估指标 \n",
" for model in modelnames:\n",
" modelmse = mse(df_combined['y'], df_combined[model])\n",
" modelrmse = rmse(df_combined['y'], df_combined[model])\n",
" modelmae = mae(df_combined['y'], df_combined[model])\n",
" # modelmape = mape(df_combined['y'], df_combined[model])\n",
" # modelsmape = smape(df_combined['y'], df_combined[model])\n",
" # modelr2 = r2_score(df_combined['y'], df_combined[model])\n",
" cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)])\n",
" \n",
" model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)'])\n",
" # 按MSE降序排列\n",
" model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True)\n",
" model_results3.to_csv(os.path.join(dataset,\"model_evaluation.csv\"),index=False)\n",
" modelnames = model_results3['模型(Model)'].tolist()\n",
" allmodelnames = modelnames.copy()\n",
" # # 保存5个最佳模型的名称\n",
" # if len(modelnames) > 5:\n",
" # modelnames = modelnames[0:5]\n",
" # if is_fivemodels:\n",
" # pass\n",
" # else:\n",
" # with open(os.path.join(dataset,\"best_modelnames.txt\"), 'w') as f:\n",
" # f.write(','.join(modelnames) + '\\n')\n",
"\n",
" # # 预测值与真实值对比图\n",
" # plt.rcParams['font.sans-serif'] = ['SimHei']\n",
" # plt.figure(figsize=(15, 10))\n",
" # for n,model in enumerate(modelnames[:5]):\n",
" # plt.subplot(3, 2, n+1)\n",
" # plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值')\n",
" # plt.plot(df_combined3['ds'], df_combined3[model], label=model)\n",
" # plt.legend()\n",
" # plt.xlabel('日期')\n",
" # plt.ylabel('价格')\n",
" # plt.title(model+'拟合')\n",
" # plt.subplots_adjust(hspace=0.5)\n",
" # plt.savefig(os.path.join(dataset,'预测值与真实值对比图.png'), bbox_inches='tight')\n",
" # plt.close()\n",
" \n",
" \n",
" # # 历史数据+预测数据\n",
" # # 拼接未来时间预测\n",
" df_predict = pd.read_csv(os.path.join(dataset,'predict.csv'))\n",
" df_predict.drop('unique_id',inplace=True,axis=1)\n",
" df_predict.dropna(axis=1,inplace=True)\n",
"\n",
" try:\n",
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y-%m-%d')\n",
" except ValueError :\n",
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y/%m/%d')\n",
"\n",
" \n",
" df_combined3 = pd.concat([df_combined3, df_predict]).reset_index(drop=True)\n",
"\n",
" # 计算每个模型与最佳模型的绝对误差比例根据设置的阈值rote筛选预测值显示最大最小值\n",
" # names = []\n",
" # names_df = df_combined3.copy()\n",
" # for col in allmodelnames:\n",
" # names_df[f'{col}-{most_model_name}-误差比例'] = abs(names_df[col] - names_df[most_model_name]) / names_df[most_model_name]\n",
" # names.append(f'{col}-{most_model_name}-误差比例')\n",
"\n",
" # names_df = names_df[names]\n",
" # def add_rote_column(row):\n",
" # columns = []\n",
" # for r in names_df.columns:\n",
" # if row[r] <= rote:\n",
" # columns.append(r.split('-')[0])\n",
" # return pd.Series([columns], index=['columns'])\n",
" # names_df['columns'] = names_df.apply(add_rote_column, axis=1)\n",
" \n",
" def add_upper_lower_bound(row):\n",
"\n",
" # 计算上边界值\n",
" upper_bound = row.max()\n",
" # 计算下边界值\n",
" lower_bound = row.min()\n",
" return pd.Series([lower_bound, upper_bound], index=['min_within_quantile', 'max_within_quantile'])\n",
"\n",
" # df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1)\n",
"\n",
" # 取前五最佳模型的最大最小值作为上下边界值\n",
" # df_combined3[['min_within_quantile','max_within_quantile']]= df_combined3[modelnames].apply(add_upper_lower_bound, axis=1)\n",
" \n",
" def find_closest_values(row):\n",
" x = row.y\n",
" if x is None or np.isnan(x):\n",
" return pd.Series([None, None], index=['min_price','max_price'])\n",
" # row = row.drop('ds')\n",
" row = row.values.tolist()\n",
" row.sort()\n",
" print(row)\n",
" # x 在row中的索引\n",
" index = row.index(x)\n",
" if index == 0:\n",
" return pd.Series([row[index+1], row[index+2]], index=['min_price','max_price'])\n",
" elif index == len(row)-1:\n",
" return pd.Series([row[index-2], row[index-1]], index=['min_price','max_price'])\n",
" else:\n",
" return pd.Series([row[index-1], row[index+1]], index=['min_price','max_price'])\n",
"\n",
"\n",
" \n",
" def find_most_common_model():\n",
" # 最多频率的模型名称\n",
" min_model_max_frequency_model = df_combined3['min_model'].tail(60).value_counts().idxmax()\n",
" max_model_max_frequency_model = df_combined3['max_model'].tail(60).value_counts().idxmax()\n",
" if min_model_max_frequency_model == max_model_max_frequency_model:\n",
" # 取60天第二多的模型\n",
" max_model_max_frequency_model = df_combined3['max_model'].tail(60).value_counts().nlargest(2).index[1]\n",
"\n",
" df_predict['min_model'] = min_model_max_frequency_model\n",
" df_predict['max_model'] = max_model_max_frequency_model\n",
" df_predict['min_within_quantile'] = df_predict[min_model_max_frequency_model]\n",
" df_predict['max_within_quantile'] = df_predict[max_model_max_frequency_model]\n",
"\n",
"\n",
" # find_most_common_model()\n",
"\n",
" df_combined3['ds'] = pd.to_datetime(df_combined3['ds'])\n",
" df_combined3['ds'] = df_combined3['ds'].dt.strftime('%Y-%m-%d')\n",
" df_predict2 = df_combined3.tail(horizon)\n",
"\n",
" # 保存到数据库\n",
" # if not sqlitedb.check_table_exists('accuracy'):\n",
" # columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE','min_price','max_price'])\n",
" # sqlitedb.create_table('accuracy',columns=columns)\n",
" # existing_data = sqlitedb.select_data(table_name = \"accuracy\")\n",
"\n",
" # if not existing_data.empty:\n",
" # max_id = existing_data['id'].astype(int).max()\n",
" # df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))\n",
" # else:\n",
" # df_predict2['id'] = range(1, 1 + len(df_predict2))\n",
" # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time\n",
" # df_predict2['CREAT_DATE'] = end_time\n",
" # def get_common_columns(df1, df2):\n",
" # # 获取两个DataFrame的公共列名\n",
" # return list(set(df1.columns).intersection(df2.columns))\n",
"\n",
" # common_columns = get_common_columns(df_predict2, existing_data)\n",
" # try:\n",
" # df_predict2[common_columns].to_sql(\"accuracy\", con=sqlitedb.connection, if_exists='append', index=False)\n",
" # except:\n",
" # df_predict2.to_sql(\"accuracy\", con=sqlitedb.connection, if_exists='append', index=False)\n",
" \n",
" # 更新accuracy表中的y值\n",
" # update_y = sqlitedb.select_data(table_name = \"accuracy\",where_condition='y is null')\n",
" # if len(update_y) > 0:\n",
" # df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())]\n",
" # if len(df_combined4) > 0: \n",
" # for index, row in df_combined4.iterrows():\n",
" # try:\n",
" # sqlitedb.update_data('accuracy',f\"y = {row['y']}\",f\"ds = '{row['ds']}'\")\n",
" # except:\n",
" # logger.error(f'更新accuracy表中的y值失败row={row}')\n",
" # 上周准确率计算\n",
" # predict_y = sqlitedb.select_data(table_name = \"accuracy\") \n",
" # ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()\n",
" # ids = predict_y['id'].tolist()\n",
" # 准确率基准与绘图上下界逻辑一致\n",
" # predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]\n",
" # 模型评估前五均值 \n",
" # predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1\n",
" # predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1\n",
" # # 模型评估前十均值 \n",
" # predict_y['min_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) -1.5\n",
" # predict_y['max_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) +1.5\n",
" # 模型评估前十最大最小\n",
" # allmodelnames 和 predict_y 列 重复的\n",
" # allmodelnames = [col for col in allmodelnames if col in predict_y.columns]\n",
" # predict_y['min_price'] = predict_y[allmodelnames[0:10]].min(axis=1) \n",
" # predict_y['max_price'] = predict_y[allmodelnames[0:10]].max(axis=1)\n",
" # for id in ids:\n",
" # row = predict_y[predict_y['id'] == id]\n",
" # try:\n",
" # sqlitedb.update_data('accuracy',f\"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}\",f\"id = {id}\")\n",
" # except:\n",
" # logger.error(f'更新accuracy表中的min_price,max_price值失败row={row}')\n",
" # 拼接市场最高最低价\n",
" # xlsfilename = os.path.join(dataset,'数据项下载.xls')\n",
" # df2 = pd.read_excel(xlsfilename)[5:]\n",
" # df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'})\n",
" # print(df2.shape)\n",
" # df = pd.merge(predict_y,df2,on=['ds'],how='left')\n",
" # df['ds'] = pd.to_datetime(df['ds'])\n",
" # df = df.reindex()\n",
"\n",
" # 判断预测值在不在布伦特最高最低价范围内准确率为1否则为0\n",
" # def is_within_range(row):\n",
" # for model in allmodelnames:\n",
" # if row['LOW_PRICE'] <= row[col] <= row['HIGH_PRICE']:\n",
" # return 1\n",
" # else:\n",
" # return 0\n",
"\n",
" # 比较真实最高最低,和预测最高最低 计算准确率\n",
" # def calculate_accuracy(row):\n",
" # # 全子集情况:\n",
" # if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \\\n",
" # (row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):\n",
" # return 1 \n",
" # # 无交集情况:\n",
" # if row['max_price'] < row['LOW_PRICE'] or \\\n",
" # row['min_price'] > row['HIGH_PRICE']:\n",
" # return 0\n",
" # # 有交集情况:\n",
" # else:\n",
" # sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])\n",
" # middle_diff = sorted_prices[2] - sorted_prices[1]\n",
" # price_range = row['HIGH_PRICE'] - row['LOW_PRICE']\n",
" # accuracy = middle_diff / price_range\n",
" # return accuracy\n",
"\n",
" # columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']\n",
" # df[columns] = df[columns].astype(float)\n",
" # df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)\n",
" # df['ACCURACY'] = df.apply(is_within_range, axis=1)\n",
" # 取结束日期上一周的日期\n",
" def get_week_date(end_time):\n",
" endtime = end_time\n",
" endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')\n",
" up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14)\n",
" up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)]\n",
" create_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[4:-3]]\n",
" ds_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[-7:-2]]\n",
" return create_dates,ds_dates\n",
" \n",
" create_dates,ds_dates = get_week_date(end_time)\n",
" # 计算准确率并保存结果\n",
" def _get_accuracy_rate(df,create_dates,ds_dates):\n",
" df3 = df.copy()\n",
" df3 = df3[df3['CREAT_DATE'].isin(create_dates)]\n",
" df3 = df3[df3['ds'].isin(ds_dates)]\n",
" accuracy_rote = 0\n",
" for i,group in df3.groupby('CREAT_DATE'):\n",
" accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]\n",
" df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])\n",
" df4.loc[len(df4)] = {'开始日期':ds_dates[0],'结束日期':ds_dates[-1],'准确率':accuracy_rote}\n",
" df4.to_sql(\"accuracy_rote\", con=sqlitedb.connection, if_exists='append', index=False)\n",
" # return df4\n",
" \n",
" # _get_accuracy_rate(df,create_dates,ds_dates)\n",
" \n",
" def _add_abs_error_rate():\n",
" # 计算每个预测值与真实值之间的偏差率\n",
" for model in allmodelnames:\n",
" df_combined3[f'{model}_abs_error_rate'] = abs(df_combined3['y'] - df_combined3[model]) / df_combined3['y']\n",
"\n",
" # 获取每行对应的最小偏差率值\n",
" min_abs_error_rate_values = df_combined3.apply(lambda row: row[[f'{model}_abs_error_rate' for model in allmodelnames]].min(), axis=1)\n",
" # 获取每行对应的最小偏差率值对应的列名\n",
" min_abs_error_rate_column_name = df_combined3.apply(lambda row: row[[f'{model}_abs_error_rate' for model in allmodelnames]].idxmin(), axis=1) \n",
" # 将列名索引转换为列名\n",
" min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0])\n",
" # 获取最小偏差率对应的模型的预测值\n",
" min_abs_error_rate_predictions = df_combined3.apply(lambda row: row[min_abs_error_rate_column_name[row.name]], axis=1)\n",
" # 将最小偏差率对应的模型的预测值添加到DataFrame中\n",
" df_combined3['min_abs_error_rate_prediction'] = min_abs_error_rate_predictions\n",
" df_combined3['min_abs_error_rate_column_name'] = min_abs_error_rate_column_name\n",
" _add_abs_error_rate()\n",
"\n",
" # 判断 df 的数值列转为float\n",
" for col in df_combined3.columns:\n",
" try:\n",
" if col != 'ds':\n",
" df_combined3[col] = df_combined3[col].astype(float)\n",
" df_combined3[col] = df_combined3[col].round(2)\n",
" except ValueError:\n",
" pass\n",
" df_combined3.to_csv(os.path.join(dataset,\"testandpredict_groupby.csv\"),index=False) \n",
" \n",
" \n",
" # 历史价格+预测价格\n",
" sqlitedb.drop_table('testandpredict_groupby')\n",
" df_combined3.to_sql('testandpredict_groupby',sqlitedb.connection,index=False)\n",
" # 新增均值列\n",
" df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)\n",
"\n",
" # def _plt_predict_ture(df):\n",
" # lens = df.shape[0] if df.shape[0] < 180 else 90\n",
" # df = df[-lens:] # 取180个数据点画图\n",
" # # 历史价格\n",
" # plt.figure(figsize=(20, 10))\n",
" # plt.plot(df['ds'], df['y'], label='真实值')\n",
" # # 均值线\n",
" # plt.plot(df['ds'], df['mean'], color='r', linestyle='--', label='前五模型预测均值')\n",
" # # 颜色填充\n",
" # plt.fill_between(df['ds'], df['max_within_quantile'], df['min_within_quantile'], alpha=0.2)\n",
" # markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']\n",
" # random_marker = random.choice(markers)\n",
" # for model in modelnames:\n",
" # # for model in ['BiTCN','RNN']:\n",
" # plt.plot(df['ds'][-horizon:], df[model][-horizon:], label=model,marker=random_marker)\n",
" # # plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange')\n",
" # # 网格\n",
" # plt.grid(True)\n",
" # # 显示历史值\n",
" # for i, j in zip(df['ds'], df['y']):\n",
" # plt.text(i, j, str(j), ha='center', va='bottom')\n",
"\n",
" # # for model in most_model:\n",
" # # plt.plot(df['ds'], df[model], label=model,marker='o')\n",
" # # 当前日期画竖虚线\n",
" # plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--')\n",
" # plt.legend()\n",
" # plt.xlabel('日期')\n",
" # plt.ylabel('价格')\n",
" \n",
" # plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')\n",
" # plt.close()\n",
" \n",
" def _plt_top10_predict_ture():\n",
" # 模型评估前十均值画图\n",
" df = sqlitedb.select_data(table_name = \"accuracy\")\n",
" # CREAT_DATE 去重取id最大的数据\n",
" df = df.sort_values(by=['CREAT_DATE','id'],ascending=[False,False]).drop_duplicates(subset=['CREAT_DATE'],keep='last')\n",
" print(df.shape)\n",
" lens = df.shape[0] if df.shape[0] < 180 else 180 \n",
" df = df[-lens:] # 取180个数据点画图\n",
" # 历史价格\n",
" plt.figure(figsize=(20, 10))\n",
" plt.plot(df['ds'], df['y'], label='真实值')\n",
" # 均值线\n",
" df['mean'] = df[allmodelnames[:10]].mean(axis=1)\n",
" plt.plot(df['ds'], df['mean'], color='g', linestyle='--', label='前十模型预测均值')\n",
" plt.plot(df['ds'], df['min_price'], color='r', linestyle='--', label='min_price')\n",
" plt.plot(df['ds'], df['max_price'], color='r', linestyle='--', label='max_price')\n",
" # 颜色填充\n",
" plt.fill_between(df['ds'], df['max_price'], df['min_price'], alpha=0.2)\n",
" markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']\n",
" random_marker = random.choice(markers)\n",
" # for model in modelnames[:5]:\n",
" # for model in ['BiTCN','RNN']:\n",
" # plt.plot(df['ds'][-horizon:], df[model][-horizon:], label=model,marker=random_marker)\n",
" # plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange')\n",
" # 网格\n",
" plt.grid(True)\n",
" # 显示历史值\n",
" # for i, j in zip(df['ds'], df['y']):\n",
" # plt.text(i, j, str(j), ha='center', va='bottom')\n",
"\n",
" # for model in most_model:\n",
" # plt.plot(df['ds'], df[model], label=model,marker='o')\n",
" # 当前日期画竖虚线\n",
" plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--')\n",
" plt.legend()\n",
" plt.xlabel('日期')\n",
" plt.ylabel('价格')\n",
" \n",
" plt.savefig(os.path.join(dataset,'历史价格-预测值1.png'), bbox_inches='tight')\n",
" plt.close()\n",
"\n",
" def _plt_predict_table(df): \n",
" # 预测值表格\n",
" fig, ax = plt.subplots(figsize=(20, 6))\n",
" ax.axis('off') # 关闭坐标轴\n",
" # 数值保留2位小数\n",
" df = df.round(2)\n",
" df = df[-horizon:]\n",
" df['Day'] = [f'Day_{i}' for i in range(1,horizon+1)]\n",
" # Day列放到最前面\n",
" df = df[['Day'] + list(df.columns[:-1])]\n",
" table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')\n",
" #加宽表格\n",
" table.auto_set_font_size(False)\n",
" table.set_fontsize(10)\n",
"\n",
" # 设置表格样式,列数据最小的用绿色标识\n",
" plt.savefig(os.path.join(dataset,'预测值表格.png'), bbox_inches='tight')\n",
" plt.close()\n",
" \n",
" def _plt_model_results3():\n",
" # 可视化评估结果\n",
" plt.rcParams['font.sans-serif'] = ['SimHei']\n",
" fig, ax = plt.subplots(figsize=(20, 10))\n",
" ax.axis('off') # 关闭坐标轴\n",
" table = ax.table(cellText=model_results3.values, colLabels=model_results3.columns, loc='center')\n",
" # 加宽表格\n",
" table.auto_set_font_size(False)\n",
" table.set_fontsize(10)\n",
"\n",
" # 设置表格样式,列数据最小的用绿色标识\n",
" plt.savefig(os.path.join(dataset,'模型评估.png'), bbox_inches='tight')\n",
" plt.close()\n",
"\n",
" # _plt_predict_ture(df_combined3)\n",
" _plt_top10_predict_ture()\n",
" _plt_predict_table(df_combined3)\n",
" _plt_model_results3()\n",
"\n",
" return model_results3\n",
" \n",
"model_losss(sqlitedb=sqlitedb,end_time='2024-12-16')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ce1967f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}