552 lines
30 KiB
Plaintext
552 lines
30 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "7fadc60c-d710-4b8c-89cd-1d889ece1eaf",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"从eta获取数据...\n",
|
||
"从eta获取数据...\n",
|
||
"['ID01385938', 'lmcads03 lme comdty', 'GC1 COMB Comdty', 'C2404171822', 'dxy curncy', 'S5443199 ', 'S5479800', 'S5443108', 'H7358586', 'LC3FM1 INDEX', 'CNY REGN Curncy', 's0105897', 'M0067419', 'M0066351', 'S0266372', 'S0266438', 'S0266506']\n",
|
||
"['ID01385938', 'lmcads03 lme comdty', 'GC1 COMB Comdty', 'C2404171822', 'dxy curncy', 'S5443199 ', 'S5479800', 'S5443108', 'H7358586', 'LC3FM1 INDEX', 'CNY REGN Curncy', 's0105897', 'M0067419', 'M0066351', 'S0266372', 'S0266438', 'S0266506']\n",
|
||
"ID01385938\n",
|
||
"ID01385938\n",
|
||
"lmcads03 lme comdty\n",
|
||
"lmcads03 lme comdty\n",
|
||
"GC1 COMB Comdty\n",
|
||
"GC1 COMB Comdty\n",
|
||
"C2404171822\n",
|
||
"C2404171822\n",
|
||
"dxy curncy\n",
|
||
"dxy curncy\n",
|
||
"S5443199 \n",
|
||
"S5443199 \n",
|
||
"S5479800\n",
|
||
"S5479800\n",
|
||
"S5443108\n",
|
||
"S5443108\n",
|
||
"H7358586\n",
|
||
"H7358586\n",
|
||
"LC3FM1 INDEX\n",
|
||
"LC3FM1 INDEX\n",
|
||
"CNY REGN Curncy\n",
|
||
"CNY REGN Curncy\n",
|
||
"s0105897\n",
|
||
"s0105897\n",
|
||
"M0067419\n",
|
||
"M0067419\n",
|
||
"M0066351\n",
|
||
"M0066351\n",
|
||
"S0266372\n",
|
||
"S0266372\n",
|
||
"S0266438\n",
|
||
"S0266438\n",
|
||
"S0266506\n",
|
||
"S0266506\n",
|
||
" date PP:拉丝:1102K:市场价:青州:国家能源宁煤(日) LME铜价 黄金连1合约 Brent-WTI \\\n",
|
||
"0 2024-11-19 7390.0 NaN 2634.7 4.12 \n",
|
||
"1 2024-11-18 7380.0 9072.5 2614.6 4.14 \n",
|
||
"2 2024-11-15 7380.0 9002.5 2570.1 4.02 \n",
|
||
"3 2024-11-14 7380.0 8990.0 2572.9 3.86 \n",
|
||
"4 2024-11-13 7380.0 9047.0 2586.5 3.85 \n",
|
||
"\n",
|
||
" 美元指数 甲醇鲁南价格 甲醇太仓港口价格 山东丙烯主流价 丙烷(山东) FEI丙烷 M1 在岸人民币汇率 南华工业品指数 \\\n",
|
||
"0 106.206 NaN NaN NaN NaN 615.08 7.2387 NaN \n",
|
||
"1 106.275 2340.0 NaN 6850.0 NaN 606.60 7.2320 3727.05 \n",
|
||
"2 106.687 2310.0 NaN 6930.0 NaN 611.69 7.2294 3708.14 \n",
|
||
"3 106.673 2310.0 2472.0 6945.0 NaN 618.61 7.2271 3739.76 \n",
|
||
"4 106.481 2310.0 2480.0 6800.0 NaN 621.88 7.2340 3772.43 \n",
|
||
"\n",
|
||
" PVC期货主力 PE期货收盘价 PP连续-1月 PP连续-5月 PP连续-9月 \n",
|
||
"0 NaN NaN NaN NaN NaN \n",
|
||
"1 5284.0 NaN NaN NaN NaN \n",
|
||
"2 5282.0 NaN NaN NaN NaN \n",
|
||
"3 5310.0 NaN NaN NaN NaN \n",
|
||
"4 5347.0 NaN NaN NaN NaN \n",
|
||
" date PP:拉丝:1102K:市场价:青州:国家能源宁煤(日) LME铜价 黄金连1合约 Brent-WTI \\\n",
|
||
"0 2024-11-19 7390.0 NaN 2634.7 4.12 \n",
|
||
"1 2024-11-18 7380.0 9072.5 2614.6 4.14 \n",
|
||
"2 2024-11-15 7380.0 9002.5 2570.1 4.02 \n",
|
||
"3 2024-11-14 7380.0 8990.0 2572.9 3.86 \n",
|
||
"4 2024-11-13 7380.0 9047.0 2586.5 3.85 \n",
|
||
"\n",
|
||
" 美元指数 甲醇鲁南价格 甲醇太仓港口价格 山东丙烯主流价 丙烷(山东) FEI丙烷 M1 在岸人民币汇率 南华工业品指数 \\\n",
|
||
"0 106.206 NaN NaN NaN NaN 615.08 7.2387 NaN \n",
|
||
"1 106.275 2340.0 NaN 6850.0 NaN 606.60 7.2320 3727.05 \n",
|
||
"2 106.687 2310.0 NaN 6930.0 NaN 611.69 7.2294 3708.14 \n",
|
||
"3 106.673 2310.0 2472.0 6945.0 NaN 618.61 7.2271 3739.76 \n",
|
||
"4 106.481 2310.0 2480.0 6800.0 NaN 621.88 7.2340 3772.43 \n",
|
||
"\n",
|
||
" PVC期货主力 PE期货收盘价 PP连续-1月 PP连续-5月 PP连续-9月 \n",
|
||
"0 NaN NaN NaN NaN NaN \n",
|
||
"1 5284.0 NaN NaN NaN NaN \n",
|
||
"2 5282.0 NaN NaN NaN NaN \n",
|
||
"3 5310.0 NaN NaN NaN NaN \n",
|
||
"4 5347.0 NaN NaN NaN NaN \n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "KeyError",
|
||
"evalue": "\"None of [Index(['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'], dtype='object')] are in the [columns]\"",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[10], line 41\u001b[0m\n\u001b[0;32m 38\u001b[0m df_zhibiaoshuju,df_zhibiaoliebiao \u001b[38;5;241m=\u001b[39m etadata\u001b[38;5;241m.\u001b[39mget_eta_api_pp_data(data_set\u001b[38;5;241m=\u001b[39mdata_set,dataset\u001b[38;5;241m=\u001b[39mdataset) \u001b[38;5;66;03m# 原始数据,未处理\u001b[39;00m\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# 数据处理\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m df \u001b[38;5;241m=\u001b[39m datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,y \u001b[38;5;241m=\u001b[39m y,dataset\u001b[38;5;241m=\u001b[39mdataset,add_kdj\u001b[38;5;241m=\u001b[39madd_kdj,is_timefurture\u001b[38;5;241m=\u001b[39mis_timefurture,end_time\u001b[38;5;241m=\u001b[39mend_time) \n\u001b[0;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 44\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m读取本地数据:\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39mos\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(dataset,data_set))\n",
|
||
"File \u001b[1;32md:\\liurui\\dev\\code\\PriceForecast\\lib\\dataread.py:635\u001b[0m, in \u001b[0;36mdatachuli_juxiting\u001b[1;34m(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, end_time, y, dataset, delweekenday, add_kdj, is_timefurture)\u001b[0m\n\u001b[0;32m 632\u001b[0m df\u001b[38;5;241m.\u001b[39mrename(columns\u001b[38;5;241m=\u001b[39m{datecol:\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mds\u001b[39m\u001b[38;5;124m'\u001b[39m},inplace\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m 634\u001b[0m \u001b[38;5;66;03m# 指定列统一减少数值\u001b[39;00m\n\u001b[1;32m--> 635\u001b[0m df[offsite_col] \u001b[38;5;241m=\u001b[39m df[offsite_col]\u001b[38;5;241m-\u001b[39moffsite\n\u001b[0;32m 636\u001b[0m \u001b[38;5;66;03m# 预测列为avg_cols的均值\u001b[39;00m\n\u001b[0;32m 637\u001b[0m df[y] \u001b[38;5;241m=\u001b[39m df[avg_cols]\u001b[38;5;241m.\u001b[39mmean(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\Lib\\site-packages\\pandas\\core\\frame.py:3899\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[1;34m(self, key)\u001b[0m\n\u001b[0;32m 3897\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_iterator(key):\n\u001b[0;32m 3898\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(key)\n\u001b[1;32m-> 3899\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39m_get_indexer_strict(key, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m 3901\u001b[0m \u001b[38;5;66;03m# take() does not accept boolean indexers\u001b[39;00m\n\u001b[0;32m 3902\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(indexer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n",
|
||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\Lib\\site-packages\\pandas\\core\\indexes\\base.py:6115\u001b[0m, in \u001b[0;36mIndex._get_indexer_strict\u001b[1;34m(self, key, axis_name)\u001b[0m\n\u001b[0;32m 6112\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 6113\u001b[0m keyarr, indexer, new_indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reindex_non_unique(keyarr)\n\u001b[1;32m-> 6115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_raise_if_missing(keyarr, indexer, axis_name)\n\u001b[0;32m 6117\u001b[0m keyarr \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtake(indexer)\n\u001b[0;32m 6118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, Index):\n\u001b[0;32m 6119\u001b[0m \u001b[38;5;66;03m# GH 42790 - Preserve name from an Index\u001b[39;00m\n",
|
||
"File \u001b[1;32md:\\ProgramData\\anaconda3\\Lib\\site-packages\\pandas\\core\\indexes\\base.py:6176\u001b[0m, in \u001b[0;36mIndex._raise_if_missing\u001b[1;34m(self, key, indexer, axis_name)\u001b[0m\n\u001b[0;32m 6174\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_interval_msg:\n\u001b[0;32m 6175\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(key)\n\u001b[1;32m-> 6176\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNone of [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m] are in the [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00maxis_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 6178\u001b[0m not_found \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(ensure_index(key)[missing_mask\u001b[38;5;241m.\u001b[39mnonzero()[\u001b[38;5;241m0\u001b[39m]]\u001b[38;5;241m.\u001b[39munique())\n\u001b[0;32m 6179\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnot_found\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in index\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||
"\u001b[1;31mKeyError\u001b[0m: \"None of [Index(['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'], dtype='object')] are in the [columns]\""
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 读取配置\n",
|
||
"from lib.dataread import *\n",
|
||
"from lib.tools import *\n",
|
||
"from models.nerulforcastmodels import ex_Model,model_losss,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting\n",
|
||
"\n",
|
||
"import glob\n",
|
||
"import torch\n",
|
||
"torch.set_float32_matmul_precision(\"high\")\n",
|
||
"\n",
|
||
"sqlitedb = SQLiteHandler(db_name) \n",
|
||
"sqlitedb.connect()\n",
|
||
"\n",
|
||
"signature = BinanceAPI(APPID, SECRET)\n",
|
||
"etadata = EtaReader(signature=signature,\n",
|
||
" classifylisturl = classifylisturl,\n",
|
||
" classifyidlisturl=classifyidlisturl,\n",
|
||
" edbcodedataurl=edbcodedataurl,\n",
|
||
" edbcodelist=edbcodelist,\n",
|
||
" edbdatapushurl=edbdatapushurl,\n",
|
||
" edbdeleteurl=edbdeleteurl,\n",
|
||
" edbbusinessurl=edbbusinessurl\n",
|
||
" )\n",
|
||
"# 获取数据\n",
|
||
"if is_eta:\n",
|
||
" # eta数据\n",
|
||
" logger.info('从eta获取数据...')\n",
|
||
" signature = BinanceAPI(APPID, SECRET)\n",
|
||
" etadata = EtaReader(signature=signature,\n",
|
||
" classifylisturl = classifylisturl,\n",
|
||
" classifyidlisturl=classifyidlisturl,\n",
|
||
" edbcodedataurl=edbcodedataurl,\n",
|
||
" edbcodelist=edbcodelist,\n",
|
||
" edbdatapushurl=edbdatapushurl,\n",
|
||
" edbdeleteurl=edbdeleteurl,\n",
|
||
" edbbusinessurl=edbbusinessurl,\n",
|
||
" )\n",
|
||
" df_zhibiaoshuju,df_zhibiaoliebiao = etadata.get_eta_api_pp_data(data_set=data_set,dataset=dataset) # 原始数据,未处理\n",
|
||
"\n",
|
||
" # 数据处理\n",
|
||
" df = datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) \n",
|
||
"\n",
|
||
"else:\n",
|
||
" logger.info('读取本地数据:'+os.path.join(dataset,data_set))\n",
|
||
" df = getdata_juxiting(filename=os.path.join(dataset,data_set),y=y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) # 原始数据,未处理\n",
|
||
"\n",
|
||
"# 更改预测列名称\n",
|
||
"df.rename(columns={y:'y'},inplace=True)\n",
|
||
" \n",
|
||
"if is_edbnamelist:\n",
|
||
" df = df[edbnamelist] \n",
|
||
"df.to_csv(os.path.join(dataset,'指标数据.csv'), index=False)\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ae059224-976c-4839-b455-f81da7f25179",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 保存最新日期的y值到数据库\n",
|
||
"# 取第一行数据存储到数据库中\n",
|
||
"first_row = df[['ds','y']].tail(1)\n",
|
||
"# 将最新真实值保存到数据库\n",
|
||
"if not sqlitedb.check_table_exists('trueandpredict'):\n",
|
||
" first_row.to_sql('trueandpredict',sqlitedb.connection,index=False)\n",
|
||
"else:\n",
|
||
" for row in first_row.itertuples(index=False):\n",
|
||
" row_dict = row._asdict()\n",
|
||
" row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')\n",
|
||
" check_query = sqlitedb.select_data('trueandpredict',where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" if len(check_query) > 0:\n",
|
||
" set_clause = \", \".join([f\"{key} = '{value}'\" for key, value in row_dict.items()])\n",
|
||
" sqlitedb.update_data('trueandpredict',set_clause,where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" continue\n",
|
||
" sqlitedb.insert_data('trueandpredict',tuple(row_dict.values()),columns=row_dict.keys())\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "abb597fc-c5f3-4d76-8099-5eff358cb634",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import datetime\n",
|
||
"# 判断当前日期是不是周一\n",
|
||
"is_weekday = datetime.datetime.now().weekday() == 1\n",
|
||
"if is_weekday:\n",
|
||
" logger.info('今天是周一,更新预测模型')\n",
|
||
" # 计算最近20天预测残差最低的模型名称\n",
|
||
"\n",
|
||
" model_results = sqlitedb.select_data('trueandpredict',order_by = \"ds DESC\",limit = \"20\")\n",
|
||
" # 删除空值率为40%以上的列\n",
|
||
" print(model_results.shape)\n",
|
||
" model_results = model_results.dropna(thresh=len(model_results)*0.6,axis=1)\n",
|
||
" model_results = model_results.dropna()\n",
|
||
" print(model_results.shape)\n",
|
||
" modelnames = model_results.columns.to_list()[2:] \n",
|
||
" for col in model_results[modelnames].select_dtypes(include=['object']).columns:\n",
|
||
" model_results[col] = model_results[col].astype(np.float32)\n",
|
||
" # 计算每个预测值与真实值之间的偏差率\n",
|
||
" for model in modelnames:\n",
|
||
" model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y']\n",
|
||
"\n",
|
||
" # 获取每行对应的最小偏差率值\n",
|
||
" min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)\n",
|
||
" # 获取每行对应的最小偏差率值对应的列名\n",
|
||
" min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)\n",
|
||
" print(min_abs_error_rate_column_name)\n",
|
||
" # 将列名索引转换为列名\n",
|
||
" min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0])\n",
|
||
" # 取出现次数最多的模型名称\n",
|
||
" most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()\n",
|
||
" logger.info(f\"最近20天预测残差最低的模型名称:{most_common_model}\")\n",
|
||
"\n",
|
||
" # 保存结果到数据库\n",
|
||
" \n",
|
||
" if not sqlitedb.check_table_exists('most_model'):\n",
|
||
" sqlitedb.create_table('most_model',columns=\"ds datetime, most_common_model TEXT\")\n",
|
||
" sqlitedb.insert_data('most_model',(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),most_common_model,),columns=('ds','most_common_model',))\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ade7026e-8cf2-405f-a2da-9e90f364adab",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"if is_corr:\n",
|
||
" df = corr_feature(df=df)\n",
|
||
"\n",
|
||
"df1 = df.copy() # 备份一下,后面特征筛选完之后加入ds y 列用\n",
|
||
"logger.info(f\"开始训练模型...\")\n",
|
||
"row,col = df.shape\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dfef57d8-36da-423b-bbe7-05a13e15f71b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')\n",
|
||
"ex_Model(df,\n",
|
||
" horizon=horizon,\n",
|
||
" input_size=input_size,\n",
|
||
" train_steps=train_steps,\n",
|
||
" val_check_steps=val_check_steps,\n",
|
||
" early_stop_patience_steps=early_stop_patience_steps,\n",
|
||
" is_debug=is_debug,\n",
|
||
" dataset=dataset,\n",
|
||
" is_train=is_train,\n",
|
||
" is_fivemodels=is_fivemodels,\n",
|
||
" val_size=val_size,\n",
|
||
" test_size=test_size,\n",
|
||
" settings=settings,\n",
|
||
" now=now,\n",
|
||
" etadata = etadata,\n",
|
||
" modelsindex = modelsindex,\n",
|
||
" data = data,\n",
|
||
" is_eta=is_eta,\n",
|
||
" )\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0e5b6f30-b7ca-4718-97a3-48b54156e07f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"logger.info('模型训练完成')\n",
|
||
"# # 模型评估\n",
|
||
"\n",
|
||
"pd.set_option('display.max_columns', 100)\n",
|
||
"# 计算预测评估指数\n",
|
||
"def model_losss_juxiting(sqlitedb):\n",
|
||
" global dataset\n",
|
||
" # 数据库查询最佳模型名称\n",
|
||
" most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]\n",
|
||
" most_model_name = most_model[0]\n",
|
||
"\n",
|
||
" # 预测数据处理 predict\n",
|
||
" df_combined = loadcsv(os.path.join(dataset,\"cross_validation.csv\")) \n",
|
||
" df_combined = dateConvert(df_combined)\n",
|
||
" # 删除空列\n",
|
||
" df_combined.dropna(axis=1,inplace=True)\n",
|
||
" # 删除缺失值,预测过程不能有缺失值\n",
|
||
" df_combined.dropna(inplace=True) \n",
|
||
" # 其他列转为数值类型\n",
|
||
" df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] })\n",
|
||
" # 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值\n",
|
||
" df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max')\n",
|
||
"\n",
|
||
" # 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列\n",
|
||
" df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']]\n",
|
||
" # 删除模型生成的cutoff列\n",
|
||
" df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True)\n",
|
||
" # 获取模型名称\n",
|
||
" modelnames = df_combined.columns.to_list()[1:] \n",
|
||
" if 'y' in modelnames:\n",
|
||
" modelnames.remove('y')\n",
|
||
" df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要\n",
|
||
"\n",
|
||
"\n",
|
||
" # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE\n",
|
||
" cellText = []\n",
|
||
"\n",
|
||
" # 遍历模型名称,计算模型评估指标 \n",
|
||
" for model in modelnames:\n",
|
||
" modelmse = mse(df_combined['y'], df_combined[model])\n",
|
||
" modelrmse = rmse(df_combined['y'], df_combined[model])\n",
|
||
" modelmae = mae(df_combined['y'], df_combined[model])\n",
|
||
" # modelmape = mape(df_combined['y'], df_combined[model])\n",
|
||
" # modelsmape = smape(df_combined['y'], df_combined[model])\n",
|
||
" # modelr2 = r2_score(df_combined['y'], df_combined[model])\n",
|
||
" cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)])\n",
|
||
" \n",
|
||
" model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)'])\n",
|
||
" # 按MSE降序排列\n",
|
||
" model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True)\n",
|
||
" model_results3.to_csv(os.path.join(dataset,\"model_evaluation.csv\"),index=False)\n",
|
||
" modelnames = model_results3['模型(Model)'].tolist()\n",
|
||
" allmodelnames = modelnames.copy()\n",
|
||
" # 保存5个最佳模型的名称\n",
|
||
" if len(modelnames) > 5:\n",
|
||
" modelnames = modelnames[0:5]\n",
|
||
" with open(os.path.join(dataset,\"best_modelnames.txt\"), 'w') as f:\n",
|
||
" f.write(','.join(modelnames) + '\\n')\n",
|
||
"\n",
|
||
"\n",
|
||
" # 去掉方差最大的模型,其余模型预测最大最小值确定通道边界\n",
|
||
" best_models = pd.read_csv(os.path.join(dataset,'best_modelnames.txt'),header=None).values.flatten().tolist()\n",
|
||
" \n",
|
||
"\n",
|
||
" # 预测值与真实值对比图\n",
|
||
" plt.rcParams['font.sans-serif'] = ['SimHei']\n",
|
||
" plt.figure(figsize=(15, 10))\n",
|
||
" # 设置有5个子图的画布\n",
|
||
" for n,model in enumerate(modelnames[:5]):\n",
|
||
" plt.subplot(3, 2, n+1)\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值')\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3[model], label=model)\n",
|
||
" plt.legend()\n",
|
||
" plt.xlabel('日期')\n",
|
||
" plt.ylabel('价格')\n",
|
||
" plt.title(model+'拟合')\n",
|
||
" plt.subplots_adjust(hspace=0.5)\n",
|
||
" plt.savefig(os.path.join(dataset,'预测值与真实值对比图.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" \n",
|
||
" # 历史数据+预测数据\n",
|
||
" # 拼接未来时间预测\n",
|
||
" df_predict = loadcsv(os.path.join(dataset,'predict.csv'))\n",
|
||
" df_predict.drop('unique_id',inplace=True,axis=1)\n",
|
||
" df_predict.dropna(axis=1,inplace=True)\n",
|
||
" df_predict2 = df_predict.copy()\n",
|
||
" try:\n",
|
||
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y-%m-%d')\n",
|
||
" except ValueError :\n",
|
||
" df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y/%m/%d')\n",
|
||
"\n",
|
||
" # 取第一行数据存储到数据库中\n",
|
||
" first_row = df_predict.head(1)\n",
|
||
" first_row['ds'] = first_row['ds'].dt.strftime('%Y-%m-%d 00:00:00')\n",
|
||
"\n",
|
||
" # # 将预测结果保存到数据库\n",
|
||
" # # 判断表存在\n",
|
||
" if not sqlitedb.check_table_exists('testandpredict_groupby'):\n",
|
||
" df_predict2.to_sql('testandpredict_groupby',sqlitedb.connection,index=False)\n",
|
||
" else:\n",
|
||
" for row in df_predict2.itertuples(index=False):\n",
|
||
" row_dict = row._asdict()\n",
|
||
" check_query = sqlitedb.select_data('testandpredict_groupby',where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" if len(check_query) > 0:\n",
|
||
" set_clause = \", \".join([f\"{key} = '{value}'\" for key, value in row_dict.items()])\n",
|
||
" sqlitedb.update_data('testandpredict_groupby',set_clause,where_condition = f\"ds = '{row.ds}'\")\n",
|
||
" continue\n",
|
||
" sqlitedb.insert_data('testandpredict_groupby',tuple(row_dict.values()),columns=row_dict.keys())\n",
|
||
"\n",
|
||
" df_combined3 = pd.concat([df_combined3, df_predict]).reset_index(drop=True)\n",
|
||
"\n",
|
||
" # # 判断 df 的数值列转为float\n",
|
||
" for col in df_combined3.columns:\n",
|
||
" try:\n",
|
||
" if col != 'ds':\n",
|
||
" df_combined3[col] = df_combined3[col].astype(float)\n",
|
||
" df_combined3[col] = df_combined3[col].round(2)\n",
|
||
" except ValueError:\n",
|
||
" pass\n",
|
||
" df_combined3.to_csv(os.path.join(dataset,\"df_combined3.csv\"),index=False) \n",
|
||
" df_combined3.to_sql('testandpredict_groupby', sqlitedb.connection, if_exists='replace', index=False)\n",
|
||
" df_combined3.to_csv(os.path.join(dataset,\"testandpredict_groupby.csv\"),index=False)\n",
|
||
" \n",
|
||
" \n",
|
||
" ten_models = allmodelnames\n",
|
||
" # 计算每个模型的方差\n",
|
||
" variances = df_combined3[ten_models].var()\n",
|
||
" # 找到方差最大的模型\n",
|
||
" max_variance_model = variances.idxmax()\n",
|
||
" # 打印方差最大的模型\n",
|
||
" print(\"方差最大的模型是:\", max_variance_model)\n",
|
||
" # 去掉方差最大的模型\n",
|
||
" df_combined3 = df_combined3.drop(columns=[max_variance_model])\n",
|
||
" if max_variance_model in allmodelnames:\n",
|
||
" allmodelnames.remove(max_variance_model)\n",
|
||
" df_combined3['min'] = df_combined3[allmodelnames].min(axis=1)\n",
|
||
" df_combined3['max'] = df_combined3[allmodelnames].max(axis=1)\n",
|
||
" print(df_combined3[['min','max']])\n",
|
||
" # 历史价格+预测价格\n",
|
||
" df_combined3 = df_combined3[-50:] # 取50个数据点画图\n",
|
||
" plt.figure(figsize=(20, 10))\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值',marker='o')\n",
|
||
" plt.plot(df_combined3['ds'], df_combined3[most_model], label=most_model_name)\n",
|
||
" plt.fill_between(df_combined3['ds'], df_combined3['min'], df_combined3['max'], alpha=0.2)\n",
|
||
" plt.grid(True)\n",
|
||
" # 当前日期画竖虚线\n",
|
||
" plt.axvline(x=df_combined3['ds'].iloc[-horizon], color='r', linestyle='--')\n",
|
||
" plt.legend()\n",
|
||
" plt.xlabel('日期')\n",
|
||
" plt.ylabel('价格')\n",
|
||
"\n",
|
||
" # # 显示历史值\n",
|
||
" for i, j in zip(df_combined3['ds'][:-5], df_combined3['y'][:-5]):\n",
|
||
" plt.text(i, j, str(j), ha='center', va='bottom')\n",
|
||
" plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight')\n",
|
||
" plt.show()\n",
|
||
" plt.close()\n",
|
||
" \n",
|
||
" # 预测值表格\n",
|
||
" fig, ax = plt.subplots(figsize=(20, 6))\n",
|
||
" ax.axis('off') # 关闭坐标轴\n",
|
||
" # 数值保留2位小数\n",
|
||
" df_combined3 = df_combined3.round(2)\n",
|
||
" df_combined3 = df_combined3[-horizon:]\n",
|
||
" df_combined3['Day'] = [f'Day_{i}' for i in range(1,horizon+1)]\n",
|
||
" # Day列放到最前面\n",
|
||
" df_combined3 = df_combined3[['Day'] + list(df_combined3.columns[:-1])]\n",
|
||
" table = ax.table(cellText=df_combined3.values, colLabels=df_combined3.columns, loc='center')\n",
|
||
" #加宽表格\n",
|
||
" table.auto_set_font_size(False)\n",
|
||
" table.set_fontsize(10)\n",
|
||
"\n",
|
||
" # 设置表格样式,列数据最小的用绿色标识\n",
|
||
" plt.savefig(os.path.join(dataset,'预测值表格.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" # plt.show()\n",
|
||
" \n",
|
||
" # 可视化评估结果\n",
|
||
" plt.rcParams['font.sans-serif'] = ['SimHei']\n",
|
||
" fig, ax = plt.subplots(figsize=(20, 10))\n",
|
||
" ax.axis('off') # 关闭坐标轴\n",
|
||
" table = ax.table(cellText=model_results3.values, colLabels=model_results3.columns, loc='center')\n",
|
||
" # 加宽表格\n",
|
||
" table.auto_set_font_size(False)\n",
|
||
" table.set_fontsize(10)\n",
|
||
"\n",
|
||
" # 设置表格样式,列数据最小的用绿色标识\n",
|
||
" plt.savefig(os.path.join(dataset,'模型评估.png'), bbox_inches='tight')\n",
|
||
" plt.close()\n",
|
||
" return model_results3\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"logger.info('训练数据绘图ing')\n",
|
||
"model_results3 = model_losss_juxiting(sqlitedb)\n",
|
||
"\n",
|
||
"logger.info('训练数据绘图end')\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "85b557de-8235-4e27-b5b8-58b36dfe6724",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 模型报告\n",
|
||
"\n",
|
||
"logger.info('制作报告ing')\n",
|
||
"title = f'{settings}--{now}-预测报告' # 报告标题\n",
|
||
"\n",
|
||
"pp_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,\n",
|
||
" reportname=reportname,sqlitedb=sqlitedb),\n",
|
||
"\n",
|
||
"logger.info('制作报告end')\n",
|
||
"logger.info('模型训练完成')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d4129e71-ee2c-4af1-81ed-fadf14efa206",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 发送邮件\n",
|
||
"m = SendMail(\n",
|
||
" username=username,\n",
|
||
" passwd=passwd,\n",
|
||
" recv=recv,\n",
|
||
" title=title,\n",
|
||
" content=content,\n",
|
||
" file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),\n",
|
||
" ssl=ssl,\n",
|
||
")\n",
|
||
"# m.send_mail() \n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|