准确率计算

This commit is contained in:
workpc 2024-12-09 13:59:33 +08:00
parent 11630bfcce
commit 363da754b2
2 changed files with 156 additions and 0 deletions

View File

@ -199,6 +199,19 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
# 保存预测值 # 保存预测值
df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False) df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False)
df_predict2 = df_predict.copy()
# 保存到数据库
if not sqlitedb.check_table_exists('accuracy'):
sqlitedb.create_table('accuracy', columns="id int,PREDICT_DATE datetime,CREAT_DATE datetime, MIN_PRICE TEXT,MAX_PRICE TEXT,HIGH_PRICE TEXT,LOW_PRICE TEXT,RIGHT_ROTE ")
existing_data = sqlitedb.execute_query("SELECT * FROM accuracy")
if not existing_data.empty:
max_id = existing_data['id'].astype(int).max()
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
else:
df_predict2['id'] = range(1, 1 + len(df_predict2))
df_predict2.to_sql(table_name, con=sqlitedb.connect, if_exists='append', index=False)
# 把预测值上传到eta # 把预测值上传到eta
if is_update_eta: if is_update_eta:

View File

@ -0,0 +1,143 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 43,
"id": "9daadf20-caa6-4b25-901c-6cc3ef563f58",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(135, 32)\n",
"(20, 4)\n",
"(135, 35)\n",
" ds NHITS Informer LSTM iTransformer TSMixer \\\n",
"0 2024-05-21 83.607803 85.574699 82.902397 86.689003 83.845596 \n",
"1 2024-05-22 83.306198 84.577400 82.204903 86.366203 83.341599 \n",
"2 2024-05-23 82.561203 83.894600 81.447899 85.681198 82.593300 \n",
"3 2024-05-24 82.007896 81.672203 81.100899 84.773598 82.038101 \n",
"4 2024-05-27 82.463600 81.498398 81.599297 85.471298 82.187202 \n",
"\n",
" TSMixerx PatchTST RNN GRU ... quantile_90 \\\n",
"0 85.564499 83.900002 60.229801 82.869698 ... 0.0149 \n",
"1 85.500900 82.637299 60.229698 82.206802 ... 0.0149 \n",
"2 85.219299 82.346603 60.229698 81.471298 ... 0.0149 \n",
"3 83.707901 81.840401 60.229599 81.120201 ... 0.0149 \n",
"4 83.175301 82.142998 60.229698 81.602699 ... 0.0145 \n",
"\n",
" quantile_10_price quantile_90_price min_within_quantile \\\n",
"0 81.993181 84.114909 82.757301 \n",
"1 80.966342 83.120312 82.173500 \n",
"2 80.432497 82.572265 81.447899 \n",
"3 80.907020 83.059412 81.100899 \n",
"4 81.935165 84.081757 82.010399 \n",
"\n",
" max_within_quantile min_model max_model 序号 lowprice highprice \n",
"0 83.954498 MLPMultivariate DLinear NaN NaN NaN \n",
"1 83.068497 MLPMultivariate NLinear NaN NaN NaN \n",
"2 82.452904 LSTM TiDE NaN NaN NaN \n",
"3 82.933502 LSTM DeepNPTS NaN NaN NaN \n",
"4 83.266098 DLinear TFT NaN NaN NaN \n",
"\n",
"[5 rows x 35 columns]\n",
" ds NHITS Informer LSTM iTransformer TSMixer \\\n",
"130 2024-12-03 72.612854 73.548740 73.456590 73.261700 72.531740 \n",
"131 2024-12-04 73.003940 73.408714 73.738144 73.359000 72.759056 \n",
"132 2024-12-05 72.435104 73.837450 73.900570 73.456290 72.775955 \n",
"133 2024-12-06 71.915250 73.683400 73.801330 73.549644 72.517870 \n",
"134 2024-12-09 72.446610 73.003040 74.081820 73.970030 72.687225 \n",
"\n",
" TSMixerx PatchTST RNN GRU ... quantile_90 \\\n",
"130 72.424950 73.314370 62.063490 73.14000 ... NaN \n",
"131 71.986694 74.242960 63.026990 73.08031 ... NaN \n",
"132 72.056160 73.494240 62.319950 73.15901 ... NaN \n",
"133 71.949066 73.083786 62.659557 73.37916 ... NaN \n",
"134 71.930230 72.724205 63.025295 73.40915 ... NaN \n",
"\n",
" quantile_10_price quantile_90_price min_within_quantile \\\n",
"130 NaN NaN 72.386850 \n",
"131 NaN NaN 72.940490 \n",
"132 NaN NaN 73.387985 \n",
"133 NaN NaN 73.056526 \n",
"134 NaN NaN 73.581210 \n",
"\n",
" max_within_quantile min_model max_model 序号 lowprice highprice \n",
"130 73.48625 TCN FEDformer 4.0 71.68 73.93 \n",
"131 73.58921 TCN FEDformer 3.0 72.25 74.28 \n",
"132 73.36786 TCN FEDformer 2.0 71.80 72.92 \n",
"133 73.30574 TCN FEDformer 1.0 70.85 72.19 \n",
"134 73.06101 TCN FEDformer NaN NaN NaN \n",
"\n",
"[5 rows x 35 columns]\n"
]
}
],
"source": [
"import sqlite3\n",
"import os\n",
"import pandas as pd\n",
"\n",
"# 预测价格数据\n",
"dbfilename = os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','jbsh_yuanyou.db')\n",
"conn = sqlite3.connect(dbfilename)\n",
"query = 'SELECT * FROM testandpredict_groupby'\n",
"df1 = pd.read_sql_query(query, conn)\n",
"df1['ds'] = pd.to_datetime(df1['ds'])\n",
"conn.close()\n",
"print(df1.shape)\n",
"# 最高最低价\n",
"xlsfilename = os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','数据项下载.xls')\n",
"df2 = pd.read_excel(xlsfilename)[5:]\n",
"df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'lowprice','布伦特最高价':'highprice'})\n",
"df2['ds'] = pd.to_datetime(df2['ds'])\n",
"\n",
"print(df2.shape)\n",
"df = pd.merge(df1,df2,on='ds',how='outer')\n",
"\n",
"df = df.reindex()\n",
"\n",
"print(df.shape)\n",
"\n",
"df.to_csv(os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','123.csv'))\n",
"# df = df[['ds','min_within_quantile','max_within_quantile']]\n",
"\n",
"\n",
"\n",
"# 打印数据框的前几行\n",
"print(df.head())\n",
"print(df.tail())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e51c3fd0-6bff-45de-b8b6-971e7986c7a7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}