From 6e7f425382213d49af80dbb578062563fd2ed239 Mon Sep 17 00:00:00 2001 From: liurui Date: Sun, 15 Dec 2024 23:23:21 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=98=E5=9B=BE=E4=BA=94=E4=B8=AA=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=9D=87=E5=80=BC+=E4=B8=8A=E5=91=A8=E5=87=86?= =?UTF-8?q?=E7=A1=AE=E7=8E=87=E5=9C=A8=E6=8A=A5=E5=91=8A=E4=B8=AD=E6=98=BE?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_yuanyou.py | 2 +- models/nerulforcastmodels.py | 87 ++++++++++++++---------------------- up_week_dates.csv | 16 +++++++ 3 files changed, 50 insertions(+), 55 deletions(-) diff --git a/main_yuanyou.py b/main_yuanyou.py index 7ce141a..87a8c4b 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -239,7 +239,7 @@ def predict_main(end_time): if __name__ == '__main__': # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2024-11-22', '2024-12-17', freq='B'): + for i_time in pd.date_range('2024-12-16', '2024-12-17', freq='B'): end_time = i_time.strftime('%Y-%m-%d') # print(e_time) predict_main(end_time) \ No newline at end of file diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 54d3231..103d121 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -344,14 +344,18 @@ def model_losss(sqlitedb,end_time): names_df['columns'] = names_df.apply(add_rote_column, axis=1) def add_upper_lower_bound(row): - print(row['columns']) - print(type(row['columns'])) + # 计算上边界值 - upper_bound = df_combined3.loc[row.name,row['columns']].max() + upper_bound = row.max() # 计算下边界值 - lower_bound = df_combined3.loc[row.name,row['columns']].min() + lower_bound = row.min() return pd.Series([lower_bound, upper_bound], index=['min_within_quantile', 'max_within_quantile']) - df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1) + + # df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1) + + # 取前五最佳模型的最大最小值作为上下边界值 + df_combined3[['min_within_quantile','max_within_quantile']]= df_combined3[modelnames].apply(add_upper_lower_bound, axis=1) + @@ -407,12 +411,6 @@ def model_losss(sqlitedb,end_time): df_predict2['id'] = range(1, 1 + len(df_predict2)) # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time df_predict2['CREAT_DATE'] = end_time - - # df_predict2['PREDICT_DATE'] = df_predict2['ds'] - # df_predict2['MIN_PRICE'] = df_predict2['min_within_quantile'] - # df_predict2['MAX_PRICE'] = df_predict2['max_within_quantile'] - # df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']] - df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) # 更新accuracy表中的y值 @@ -424,37 +422,25 @@ def model_losss(sqlitedb,end_time): try: sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'") except: - print(row) - # 准确率计算 + logger.error(f'更新accuracy表中的y值失败,row={row}') + # 上周准确率计算 predict_y = sqlitedb.select_data(table_name = "accuracy") - # 找到min_price为空的行的id ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist() - # 计算min_price和max_price - # predict_y[['min_price','max_price']] = predict_y[['y']+allmodelnames].apply(find_closest_values, axis=1) predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']] - # min_price为空的行更新数据到数据库中 for id in ids: row = predict_y[predict_y['id'] == id] - print(row) try: sqlitedb.update_data('accuracy',f"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}",f"id = {id}") except: - print(row) - - # 最高最低价 + logger.error(f'更新accuracy表中的min_price,max_price值失败,row={row}') + # 拼接市场最高最低价 xlsfilename = os.path.join(dataset,'数据项下载.xls') df2 = pd.read_excel(xlsfilename)[5:] df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'}) print(df2.shape) df = pd.merge(predict_y,df2,on=['ds'],how='left') df['ds'] = pd.to_datetime(df['ds']) - # df['PREDICT_DATE'] = pd.to_datetime(df['PREDICT_DATE']) df = df.reindex() - print(df.shape) - - df.to_csv(os.path.join(dataset,'123.csv')) - - def calculate_accuracy(row): if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']: sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']]) @@ -464,49 +450,30 @@ def model_losss(sqlitedb,end_time): return accuracy else: return 0 - - # 使用 apply 函数来应用计算准确率的函数 columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price'] df[columns] = df[columns].astype(float) df['ACCURACY'] = df.apply(calculate_accuracy, axis=1) - - # 取结束日期上一周的日期 endtime = end_time endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d') up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14) up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2] up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates] - print(up_week_dates) - - + # df3 = df.copy() df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)] df3 = df3[df3['ds'].isin(up_week_dates)] - print(df3.shape) - df3.to_csv('up_week_dates.csv',index=False) - total = len(df3) accuracy_rote = 0 - # for i,group in df3.groupby('CREAT_DATE'): for i,group in df3.groupby('ds'): - print(i) print('权重:',round(len(group)/total,2)) print('准确率:',group['ACCURACY'].sum()/(len(group)/total)) accuracy_rote += group['ACCURACY'].sum()/(len(group)/total) - - print(accuracy_rote) - # 列 开始日期,结束日期,准确率 df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率']) - # 追加一行数据 df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote} print(df4) df4.to_sql("accuracy_rote", con=sqlitedb.connection, if_exists='append', index=False) - - - - def _add_abs_error_rate(): # 计算每个预测值与真实值之间的偏差率 for model in allmodelnames: @@ -540,20 +507,24 @@ def model_losss(sqlitedb,end_time): # 历史价格+预测价格 sqlitedb.drop_table('testandpredict_groupby') df_combined3.to_sql('testandpredict_groupby',sqlitedb.connection,index=False) - + # 新增均值列 + df_combined3['mean'] = df_combined3[modelnames].mean(axis=1) + def _plt_predict_ture(df): lens = df.shape[0] if df.shape[0] < 180 else 90 df = df[-lens:] # 取180个数据点画图 # 历史价格 plt.figure(figsize=(20, 10)) plt.plot(df['ds'], df['y'], label='真实值') + # 均值线 + plt.plot(df['ds'], df['mean'], color='r', linestyle='--', label='前五模型预测均值') # 颜色填充 plt.fill_between(df['ds'], df['max_within_quantile'], df['min_within_quantile'], alpha=0.2) - # markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd'] - # random_marker = random.choice(markers) - # for model in allmodelnames: + markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd'] + random_marker = random.choice(markers) + for model in modelnames: # for model in ['BiTCN','RNN']: - # plt.plot(df['ds'], df[model], label=model,marker=random_marker) + plt.plot(df['ds'][-horizon:], df[model][-horizon:], label=model,marker=random_marker) # plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange') # 网格 plt.grid(True) @@ -561,8 +532,8 @@ def model_losss(sqlitedb,end_time): for i, j in zip(df['ds'], df['y']): plt.text(i, j, str(j), ha='center', va='bottom') - for model in most_model: - plt.plot(df['ds'], df[model], label=model,marker='o') + # for model in most_model: + # plt.plot(df['ds'], df[model], label=model,marker='o') # 当前日期画竖虚线 plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--') plt.legend() @@ -1006,6 +977,14 @@ def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inpu col_width = 500/len(df3.columns) content.append(Graphs.draw_table(col_width,*data)) + + content.append(Graphs.draw_little_title('上一周预测准确率:')) + df4 = sqlitedb.select_data('accuracy_rote',order_by='结束日期 desc',limit=1) + df4 = df4.T + df4 = df4.reset_index() + data = df4.values.tolist() + col_width = 500/len(df4.columns) + content.append(Graphs.draw_table(col_width,*data)) content.append(Graphs.draw_little_title('三、预测过程解析:')) ### 特征、模型、参数配置 diff --git a/up_week_dates.csv b/up_week_dates.csv index 40e7294..919e1c3 100644 --- a/up_week_dates.csv +++ b/up_week_dates.csv @@ -1 +1,17 @@ ds,NHITS,Informer,LSTM,iTransformer,TSMixer,TSMixerx,PatchTST,RNN,GRU,TCN,BiTCN,DilatedRNN,MLP,DLinear,NLinear,TFT,StemGNN,MLPMultivariate,TiDE,DeepNPTS,y,min_within_quantile,max_within_quantile,min_price,max_price,id,CREAT_DATE,序号,LOW_PRICE,HIGH_PRICE,ACCURACY +2024-12-09,71.63069,72.14122,70.19943,69.644196,71.80898,70.52471,71.53064,67.38501,70.75815,73.546684,71.92033,70.327484,71.95594,71.18323,71.882935,71.90951,73.94965,72.28691,71.25138,73.34018,72.13999938964844,67.38501,73.94965,67.38501,73.94965,51,2024-12-06,5.0,70.92,72.65,1.0 +2024-12-10,71.14343,71.9462,70.405106,69.48242,71.70601,70.66241,71.57484,67.23587,70.46323,73.37324,71.720894,70.45846,71.88132,72.34705,71.75997,72.41326,73.74943,72.31887,71.47953,72.78831,72.19000244140625,67.23587,73.74943,67.23587,73.74943,52,2024-12-06,4.0,70.73,71.77,1.0 +2024-12-11,71.71588,72.31544,70.175125,69.58213,71.59609,70.91783,71.54794,67.433334,70.518196,73.76477,71.84062,70.746284,72.27111,71.85789,70.77939,72.912704,73.91716,72.42111,71.47695,72.61624,73.5199966430664,67.433334,73.91716,67.433334,73.91716,53,2024-12-06,3.0,72.15,73.75,1.0 +2024-12-12,72.46348,71.87648,70.26041,69.922165,71.65103,70.689384,71.72716,67.54506,70.99872,73.52567,71.78495,70.777115,72.34328,72.756325,70.9607,73.391495,73.944244,72.465836,71.445244,71.69109,73.88999938964844,67.54506,73.944244,67.54506,73.944244,54,2024-12-06,2.0,72.42,74.0,0.9647113924050618 +2024-12-13,72.85073,72.36679,70.489136,69.759766,71.78641,70.69935,71.60861,67.47295,70.81146,73.85618,71.966835,70.923485,72.63866,72.29209,71.11011,73.777534,73.82516,72.58803,71.807915,72.6083,,67.47295,73.85618,67.47295,73.85618,55,2024-12-06,1.0,73.3,74.59,0.43114728682170156 +2024-12-10,71.97169,73.11586,71.41545,70.59598,71.81998,71.541016,72.397606,68.78458,71.30409,73.65467,71.78349,71.03456,71.8188,71.85696,72.014534,72.33824,73.77812,72.2061,71.36985,72.46154,72.19000244140625,68.78458,73.77812,68.78458,73.77812,56,2024-12-09,4.0,70.73,71.77,1.0 +2024-12-11,72.35509,73.087166,71.63927,70.430176,71.82658,71.76624,72.57018,68.59579,71.03884,73.45442,71.75791,71.15297,72.066956,72.28318,72.36303,72.75233,73.594864,72.24322,71.16944,72.39418,73.5199966430664,68.59579,73.594864,68.59579,73.594864,57,2024-12-09,3.0,72.15,73.75,0.9030400000000004 +2024-12-12,73.268654,73.31714,71.39498,70.65145,71.76794,71.6176,73.038315,68.85988,71.12751,73.84763,71.82094,71.41621,72.99199,72.42496,71.474464,73.17588,73.74838,72.31574,71.692444,72.48112,73.88999938964844,68.85988,73.84763,68.85988,73.84763,58,2024-12-09,2.0,72.42,74.0,0.9035632911392374 +2024-12-13,73.11824,73.70929,71.442184,70.72518,71.878975,71.82435,73.13541,68.760376,71.51967,73.62109,71.86448,71.44768,72.952484,72.107254,72.57947,73.61362,73.77317,72.3857,71.873566,72.40313,,68.760376,73.77317,68.760376,73.77317,59,2024-12-09,1.0,73.3,74.59,0.36679844961239827 +2024-12-11,72.372795,72.6546,71.37516,71.41169,71.977135,72.13211,72.774284,69.25488,71.62906,73.70365,72.27147,71.489716,72.77891,71.61932,72.090294,72.166916,73.72794,72.78662,71.960884,72.473854,73.5199966430664,69.25488,73.72794,69.25488,73.72794,61,2024-12-10,3.0,72.15,73.75,0.9862125000000024 +2024-12-12,72.91286,72.80857,71.612946,71.193184,72.05397,72.48625,72.85768,69.06833,71.35641,73.5444,72.19988,71.585815,73.47307,72.22734,72.219765,72.48997,73.544106,72.96339,71.66363,72.122116,73.88999938964844,69.06833,73.5444,69.06833,73.5444,62,2024-12-10,2.0,72.42,74.0,0.7116455696202503 +2024-12-13,73.01643,72.78796,71.3701,71.516174,72.21711,72.4036,73.66599,69.338615,71.467476,73.9247,72.2582,71.84913,73.643265,71.70761,72.41268,72.85244,73.69811,72.91225,71.75731,71.76489,,69.338615,73.9247,69.338615,73.9247,63,2024-12-10,1.0,73.3,74.59,0.4842635658914738 +2024-12-12,73.82699,73.58274,72.81162,73.18572,72.8827,72.958374,74.36039,71.09452,72.673904,74.025635,73.17438,72.72577,73.668365,72.48243,73.224655,74.286606,73.70824,73.668076,72.934685,73.16352,73.88999938964844,71.09452,74.36039,71.09452,74.36039,66,2024-12-11,2.0,72.42,74.0,1.0 +2024-12-13,74.025696,73.275696,73.0588,72.85848,72.719444,73.238945,74.05083,71.060295,72.455666,73.90173,72.966385,72.81345,73.77382,72.488365,73.78919,74.5482,73.52537,73.58226,72.8307,73.26384,,71.060295,74.5482,71.060295,74.5482,67,2024-12-11,1.0,73.3,74.59,0.9675968992247993 +2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,71,2024-12-12,1.0,73.3,74.59,0.877170542635658 +2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,76,2024-12-13,1.0,73.3,74.59,0.877170542635658