diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 8dbe22e..d201f7e 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -199,6 +199,19 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien # 保存预测值 df_predict.to_csv(os.path.join(dataset,"predict.csv"),index=False) + + df_predict2 = df_predict.copy() + # 保存到数据库 + if not sqlitedb.check_table_exists('accuracy'): + sqlitedb.create_table('accuracy', columns="id int,PREDICT_DATE datetime,CREAT_DATE datetime, MIN_PRICE TEXT,MAX_PRICE TEXT,HIGH_PRICE TEXT,LOW_PRICE TEXT,RIGHT_ROTE ") + existing_data = sqlitedb.execute_query("SELECT * FROM accuracy") + if not existing_data.empty: + max_id = existing_data['id'].astype(int).max() + df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2)) + else: + df_predict2['id'] = range(1, 1 + len(df_predict2)) + df_predict2.to_sql(table_name, con=sqlitedb.connect, if_exists='append', index=False) + # 把预测值上传到eta if is_update_eta: diff --git a/原油价格预测准确率计算.ipynb b/原油价格预测准确率计算.ipynb new file mode 100644 index 0000000..e2e9058 --- /dev/null +++ b/原油价格预测准确率计算.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 43, + "id": "9daadf20-caa6-4b25-901c-6cc3ef563f58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(135, 32)\n", + "(20, 4)\n", + "(135, 35)\n", + " ds NHITS Informer LSTM iTransformer TSMixer \\\n", + "0 2024-05-21 83.607803 85.574699 82.902397 86.689003 83.845596 \n", + "1 2024-05-22 83.306198 84.577400 82.204903 86.366203 83.341599 \n", + "2 2024-05-23 82.561203 83.894600 81.447899 85.681198 82.593300 \n", + "3 2024-05-24 82.007896 81.672203 81.100899 84.773598 82.038101 \n", + "4 2024-05-27 82.463600 81.498398 81.599297 85.471298 82.187202 \n", + "\n", + " TSMixerx PatchTST RNN GRU ... quantile_90 \\\n", + "0 85.564499 83.900002 60.229801 82.869698 ... 0.0149 \n", + "1 85.500900 82.637299 60.229698 82.206802 ... 0.0149 \n", + "2 85.219299 82.346603 60.229698 81.471298 ... 0.0149 \n", + "3 83.707901 81.840401 60.229599 81.120201 ... 0.0149 \n", + "4 83.175301 82.142998 60.229698 81.602699 ... 0.0145 \n", + "\n", + " quantile_10_price quantile_90_price min_within_quantile \\\n", + "0 81.993181 84.114909 82.757301 \n", + "1 80.966342 83.120312 82.173500 \n", + "2 80.432497 82.572265 81.447899 \n", + "3 80.907020 83.059412 81.100899 \n", + "4 81.935165 84.081757 82.010399 \n", + "\n", + " max_within_quantile min_model max_model 序号 lowprice highprice \n", + "0 83.954498 MLPMultivariate DLinear NaN NaN NaN \n", + "1 83.068497 MLPMultivariate NLinear NaN NaN NaN \n", + "2 82.452904 LSTM TiDE NaN NaN NaN \n", + "3 82.933502 LSTM DeepNPTS NaN NaN NaN \n", + "4 83.266098 DLinear TFT NaN NaN NaN \n", + "\n", + "[5 rows x 35 columns]\n", + " ds NHITS Informer LSTM iTransformer TSMixer \\\n", + "130 2024-12-03 72.612854 73.548740 73.456590 73.261700 72.531740 \n", + "131 2024-12-04 73.003940 73.408714 73.738144 73.359000 72.759056 \n", + "132 2024-12-05 72.435104 73.837450 73.900570 73.456290 72.775955 \n", + "133 2024-12-06 71.915250 73.683400 73.801330 73.549644 72.517870 \n", + "134 2024-12-09 72.446610 73.003040 74.081820 73.970030 72.687225 \n", + "\n", + " TSMixerx PatchTST RNN GRU ... quantile_90 \\\n", + "130 72.424950 73.314370 62.063490 73.14000 ... NaN \n", + "131 71.986694 74.242960 63.026990 73.08031 ... NaN \n", + "132 72.056160 73.494240 62.319950 73.15901 ... NaN \n", + "133 71.949066 73.083786 62.659557 73.37916 ... NaN \n", + "134 71.930230 72.724205 63.025295 73.40915 ... NaN \n", + "\n", + " quantile_10_price quantile_90_price min_within_quantile \\\n", + "130 NaN NaN 72.386850 \n", + "131 NaN NaN 72.940490 \n", + "132 NaN NaN 73.387985 \n", + "133 NaN NaN 73.056526 \n", + "134 NaN NaN 73.581210 \n", + "\n", + " max_within_quantile min_model max_model 序号 lowprice highprice \n", + "130 73.48625 TCN FEDformer 4.0 71.68 73.93 \n", + "131 73.58921 TCN FEDformer 3.0 72.25 74.28 \n", + "132 73.36786 TCN FEDformer 2.0 71.80 72.92 \n", + "133 73.30574 TCN FEDformer 1.0 70.85 72.19 \n", + "134 73.06101 TCN FEDformer NaN NaN NaN \n", + "\n", + "[5 rows x 35 columns]\n" + ] + } + ], + "source": [ + "import sqlite3\n", + "import os\n", + "import pandas as pd\n", + "\n", + "# 预测价格数据\n", + "dbfilename = os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','jbsh_yuanyou.db')\n", + "conn = sqlite3.connect(dbfilename)\n", + "query = 'SELECT * FROM testandpredict_groupby'\n", + "df1 = pd.read_sql_query(query, conn)\n", + "df1['ds'] = pd.to_datetime(df1['ds'])\n", + "conn.close()\n", + "print(df1.shape)\n", + "# 最高最低价\n", + "xlsfilename = os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','数据项下载.xls')\n", + "df2 = pd.read_excel(xlsfilename)[5:]\n", + "df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'lowprice','布伦特最高价':'highprice'})\n", + "df2['ds'] = pd.to_datetime(df2['ds'])\n", + "\n", + "print(df2.shape)\n", + "df = pd.merge(df1,df2,on='ds',how='outer')\n", + "\n", + "df = df.reindex()\n", + "\n", + "print(df.shape)\n", + "\n", + "df.to_csv(os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','123.csv'))\n", + "# df = df[['ds','min_within_quantile','max_within_quantile']]\n", + "\n", + "\n", + "\n", + "# 打印数据框的前几行\n", + "print(df.head())\n", + "print(df.tail())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e51c3fd0-6bff-45de-b8b6-971e7986c7a7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}