{ "cells": [ { "cell_type": "code", "execution_count": 13, "id": "9daadf20-caa6-4b25-901c-6cc3ef563f58", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(85, 28)\n", "(22, 4)\n", "(85, 31)\n", " ds NHITS Informer LSTM iTransformer TSMixer \\\n", "0 2024-11-25 75.714300 75.523370 73.614220 75.27068 75.03936 \n", "1 2024-11-26 76.039635 75.558270 73.692310 75.04110 74.60100 \n", "2 2024-11-27 77.375790 75.361885 73.826140 74.99121 74.37731 \n", "3 2024-11-28 78.872400 76.339920 73.883484 75.79425 74.04826 \n", "4 2024-11-29 79.576970 76.333170 73.876396 75.89008 74.07330 \n", "\n", " TSMixerx PatchTST RNN GRU ... y \\\n", "0 74.581190 75.70277 74.721280 74.512060 ... 73.010002 \n", "1 73.496025 75.97611 74.588060 74.713425 ... 72.809998 \n", "2 73.522026 76.48628 74.486400 74.946010 ... 72.830002 \n", "3 73.416306 76.38267 75.195710 74.946014 ... 73.279999 \n", "4 73.521570 76.20661 75.089966 74.935165 ... 72.940002 \n", "\n", " min_within_quantile max_within_quantile id CREAT_DATE min_price \\\n", "0 74.41491 75.29100 1 2024-11-22 74.414910 \n", "1 74.11780 74.95678 2 2024-11-22 73.496025 \n", "2 73.93820 74.50395 3 2024-11-22 73.522026 \n", "3 73.85808 74.46382 4 2024-11-22 73.416306 \n", "4 73.96690 74.81860 5 2024-11-22 73.521570 \n", "\n", " max_price 序号 LOW_PRICE HIGH_PRICE \n", "0 75.959854 15.0 72.30 74.83 \n", "1 77.182580 14.0 71.63 73.80 \n", "2 78.378624 13.0 71.71 72.85 \n", "3 79.415400 12.0 71.85 72.96 \n", "4 79.576970 11.0 71.75 73.34 \n", "\n", "[5 rows x 31 columns]\n", " ds NHITS Informer LSTM iTransformer TSMixer \\\n", "80 2024-12-16 74.53431 73.944080 71.68200 74.022340 74.295820 \n", "81 2024-12-17 74.81450 73.830450 71.95232 74.314950 74.167290 \n", "82 2024-12-18 75.55861 73.525100 72.00824 74.441380 74.212180 \n", "83 2024-12-19 75.36518 74.012215 72.20199 74.397190 74.330130 \n", "84 2024-12-20 74.78187 73.929596 72.23908 74.510895 74.208084 \n", "\n", " TSMixerx PatchTST RNN GRU ... y min_within_quantile \\\n", "80 74.41700 74.587390 73.607780 73.747700 ... NaN 74.231680 \n", "81 74.36576 74.363060 73.688736 73.833950 ... NaN 73.735420 \n", "82 74.29719 74.073555 73.456700 74.146034 ... NaN 74.073555 \n", "83 73.79145 74.529945 74.230125 74.144520 ... NaN 74.330130 \n", "84 74.59672 74.231255 74.201860 73.996100 ... NaN 74.083810 \n", "\n", " max_within_quantile id CREAT_DATE min_price max_price 序号 LOW_PRICE \\\n", "80 74.621160 81 2024-12-16 72.75007 74.62116 NaN NaN \n", "81 74.682365 82 2024-12-16 72.72196 74.81450 NaN NaN \n", "82 75.157074 83 2024-12-16 73.12483 75.55861 NaN NaN \n", "83 75.339240 84 2024-12-16 73.07359 75.36518 NaN NaN \n", "84 74.604610 85 2024-12-16 72.93583 74.78187 NaN NaN \n", "\n", " HIGH_PRICE \n", "80 NaN \n", "81 NaN \n", "82 NaN \n", "83 NaN \n", "84 NaN \n", "\n", "[5 rows x 31 columns]\n" ] } ], "source": [ "import sqlite3\n", "import os\n", "import pandas as pd\n", "\n", "# dataset = r'D:\\liurui\\dev\\code\\PriceForecast\\yuanyoudataset'\n", "\n", "dataset = r'C:\\Users\\Administrator\\Desktop' \n", "\n", "# 预测价格数据\n", "# dbfilename = os.path.join(r'D:\\code\\PriceForecast\\yuanyoudataset','jbsh_yuanyou.db')\n", "# conn = sqlite3.connect(dbfilename)\n", "# query = 'SELECT * FROM accuracy'\n", "# df1 = pd.read_sql_query(query, conn)\n", "# df1['ds'] = df1['PREDICT_DATE']\n", "# conn.close()\n", "# print(df1.shape)\n", "\n", "# 预测价格数据\n", "dfcsvfilename = os.path.join(dataset,'accuracy_ten.csv')\n", "df1 = pd.read_csv(dfcsvfilename)\n", "print(df1.shape)\n", "\n", "# 最高最低价\n", "xlsfilename = os.path.join(dataset,'数据项下载.xls')\n", "df2 = pd.read_excel(xlsfilename)[5:]\n", "df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'})\n", "print(df2.shape)\n", "\n", "\n", "\n", "df = pd.merge(df1,df2,on=['ds'],how='left')\n", "\n", "df['ds'] = pd.to_datetime(df['ds'])\n", "# df['PREDICT_DATE'] = pd.to_datetime(df['PREDICT_DATE'])\n", "df = df.reindex()\n", "\n", "print(df.shape)\n", "# from datetime import datetime\n", "import time\n", "df.to_csv(os.path.join(dataset,f'预测数据-{time.time()}.csv'))\n", "# df = df[['ds','min_within_quantile','max_within_quantile']]\n", "\n", "\n", "\n", "# 打印数据框的前几行\n", "print(df.head())\n", "print(df.tail())\n" ] }, { "cell_type": "code", "execution_count": 27, "id": "0d77ab7d", "metadata": {}, "outputs": [], "source": [ "# 模型评估前五均值 \n", "df['min_price'] = df.iloc[:,1:11].mean(axis=1) -2\n", "df['max_price'] = df.iloc[:,1:11].mean(axis=1) +2" ] }, { "cell_type": "code", "execution_count": 28, "id": "e51c3fd0-6bff-45de-b8b6-971e7986c7a7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 开始日期 结束日期 准确率\n", "0 2024-11-08 2024-11-15 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-15 2024-11-22 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-15 2024-11-22 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-15 2024-11-22 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-15 2024-11-22 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-15 2024-11-22 0\n", " 开始日期 结束日期 准确率\n", "0 2024-11-22 2024-11-29 0.808456\n", " 开始日期 结束日期 准确率\n", "0 2024-11-22 2024-11-29 0.808456\n", " 开始日期 结束日期 准确率\n", "0 2024-11-22 2024-11-29 0.808456\n", " 开始日期 结束日期 准确率\n", "0 2024-11-22 2024-11-29 0.808456\n", " 开始日期 结束日期 准确率\n", "0 2024-11-22 2024-11-29 0.808456\n", " 开始日期 结束日期 准确率\n", "0 2024-11-29 2024-12-06 0.955061\n", " 开始日期 结束日期 准确率\n", "0 2024-11-29 2024-12-06 0.955061\n", " 开始日期 结束日期 准确率\n", "0 2024-11-29 2024-12-06 0.955061\n", " 开始日期 结束日期 准确率\n", "0 2024-11-29 2024-12-06 0.955061\n", " 开始日期 结束日期 准确率\n", "0 2024-11-29 2024-12-06 0.955061\n", " 开始日期 结束日期 准确率\n", "0 2024-12-06 2024-12-13 0.905554\n" ] } ], "source": [ "# 定义一个函数来计算准确率\n", "# 比较真实最高最低,和预测最高最低 计算准确率\n", "def calculate_accuracy(row):\n", " # 全子集情况:\n", " if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \\\n", " (row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']):\n", " return 1 \n", " # 无交集情况:\n", " if row['max_price'] < row['LOW_PRICE'] or \\\n", " row['min_price'] > row['HIGH_PRICE']:\n", " return 0\n", " # 有交集情况:\n", " else:\n", " sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])\n", " middle_diff = sorted_prices[2] - sorted_prices[1]\n", " price_range = row['HIGH_PRICE'] - row['LOW_PRICE']\n", " accuracy = middle_diff / price_range\n", " return accuracy\n", "import datetime\n", "weight_dict = [0.4,0.15,0.1,0.1,0.25] # 权重\n", "\n", "columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']\n", "df[columns] = df[columns].astype(float)\n", "df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)\n", "# df['ACCURACY'] = df.apply(is_within_range, axis=1)\n", "# 取结束日期上一周的日期\n", "def get_week_date(end_time):\n", " endtime = end_time\n", " endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')\n", " up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14)\n", " up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2]\n", " up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates]\n", " return up_week_dates\n", "\n", "# 计算准确率并保存结果\n", "def _get_accuracy_rate(df,up_week_dates,endtime):\n", " df3 = df.copy()\n", " df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)]\n", " df3 = df3[df3['ds'].isin(up_week_dates)]\n", " accuracy_rote = 0\n", " for i,group in df3.groupby('ds'):\n", " # print('权重:',weight_dict[len(group)-1])\n", " # print('准确率:',(group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1])\n", " accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]\n", " df3.to_csv(os.path.join(dataset,f'accuracy_{endtime}.csv'),index=False)\n", " df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])\n", " df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote}\n", " df4.to_csv(os.path.join(dataset,f'accuracy_rote_{endtime}.csv'),index=False)\n", " print(df4)\n", " # df4.to_sql(\"accuracy_rote\", con=sqlitedb.connection, if_exists='append', index=False)\n", "\n", "\n", "end_times = df['CREAT_DATE'].unique()\n", "for endtime in end_times:\n", " up_week_dates = get_week_date(endtime)\n", " _get_accuracy_rate(df,up_week_dates,end_time)\n", "\n", "# 打印结果\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0f942c69", "metadata": {}, "outputs": [], "source": [ "import datetime\n", "# ds 按周取\n", "df['Ds_Week'] = df['ds'].apply(lambda x: x.strftime('%U'))\n", "df['Pre_Week'] = df['PREDICT_DATE'].apply(lambda x: x.strftime('%U'))" ] }, { "cell_type": "code", "execution_count": null, "id": "a7b05510", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | ds | \n", "ACCURACY | \n", "PREDICT_DATE | \n", "CREAT_DATE | \n", "HIGH_PRICE_y | \n", "LOW_PRICE_y | \n", "MIN_PRICE | \n", "MAX_PRICE | \n", "Ds_Week | \n", "Pre_Week | \n", "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2024-11-26 | \n", "1.000000 | \n", "2024-11-26 | \n", "2024-11-25 | \n", "73.80 | \n", "71.63 | \n", "71.071556 | \n", "76.006900 | \n", "47 | \n", "47 | \n", "
1 | \n", "2024-11-27 | \n", "1.000000 | \n", "2024-11-27 | \n", "2024-11-25 | \n", "72.85 | \n", "71.71 | \n", "71.003624 | \n", "75.580560 | \n", "47 | \n", "47 | \n", "
2 | \n", "2024-11-28 | \n", "0.789324 | \n", "2024-11-28 | \n", "2024-11-25 | \n", "72.96 | \n", "71.85 | \n", "72.083850 | \n", "76.204260 | \n", "47 | \n", "47 | \n", "
3 | \n", "2024-11-29 | \n", "1.000000 | \n", "2024-11-29 | \n", "2024-11-25 | \n", "73.34 | \n", "71.75 | \n", "71.329730 | \n", "75.703950 | \n", "47 | \n", "47 | \n", "
4 | \n", "2024-12-02 | \n", "0.853412 | \n", "2024-12-02 | \n", "2024-11-25 | \n", "72.89 | \n", "71.52 | \n", "71.720825 | \n", "76.264275 | \n", "48 | \n", "48 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
70 | \n", "2024-11-25 | \n", "0.118328 | \n", "2024-11-25 | \n", "2024-11-22 | \n", "74.83 | \n", "72.30 | \n", "74.530630 | \n", "76.673140 | \n", "47 | \n", "47 | \n", "
71 | \n", "2024-11-26 | \n", "0.000000 | \n", "2024-11-26 | \n", "2024-11-22 | \n", "73.80 | \n", "71.63 | \n", "74.440430 | \n", "76.874565 | \n", "47 | \n", "47 | \n", "
72 | \n", "2024-11-27 | \n", "0.000000 | \n", "2024-11-27 | \n", "2024-11-22 | \n", "72.85 | \n", "71.71 | \n", "74.663180 | \n", "76.734130 | \n", "47 | \n", "47 | \n", "
73 | \n", "2024-11-28 | \n", "0.000000 | \n", "2024-11-28 | \n", "2024-11-22 | \n", "72.96 | \n", "71.85 | \n", "74.708410 | \n", "77.141050 | \n", "47 | \n", "47 | \n", "
74 | \n", "2024-11-29 | \n", "0.000000 | \n", "2024-11-29 | \n", "2024-11-22 | \n", "73.34 | \n", "71.75 | \n", "74.703210 | \n", "77.746170 | \n", "47 | \n", "47 | \n", "
75 rows × 10 columns
\n", "