绘图五个模型均值+上周准确率在报告中显示
This commit is contained in:
parent
620db6f65a
commit
6e7f425382
@ -239,7 +239,7 @@ def predict_main(end_time):
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
|
||||
for i_time in pd.date_range('2024-11-22', '2024-12-17', freq='B'):
|
||||
for i_time in pd.date_range('2024-12-16', '2024-12-17', freq='B'):
|
||||
end_time = i_time.strftime('%Y-%m-%d')
|
||||
# print(e_time)
|
||||
predict_main(end_time)
|
@ -344,14 +344,18 @@ def model_losss(sqlitedb,end_time):
|
||||
names_df['columns'] = names_df.apply(add_rote_column, axis=1)
|
||||
|
||||
def add_upper_lower_bound(row):
|
||||
print(row['columns'])
|
||||
print(type(row['columns']))
|
||||
|
||||
# 计算上边界值
|
||||
upper_bound = df_combined3.loc[row.name,row['columns']].max()
|
||||
upper_bound = row.max()
|
||||
# 计算下边界值
|
||||
lower_bound = df_combined3.loc[row.name,row['columns']].min()
|
||||
lower_bound = row.min()
|
||||
return pd.Series([lower_bound, upper_bound], index=['min_within_quantile', 'max_within_quantile'])
|
||||
df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1)
|
||||
|
||||
# df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1)
|
||||
|
||||
# 取前五最佳模型的最大最小值作为上下边界值
|
||||
df_combined3[['min_within_quantile','max_within_quantile']]= df_combined3[modelnames].apply(add_upper_lower_bound, axis=1)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -407,12 +411,6 @@ def model_losss(sqlitedb,end_time):
|
||||
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
||||
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
|
||||
df_predict2['CREAT_DATE'] = end_time
|
||||
|
||||
# df_predict2['PREDICT_DATE'] = df_predict2['ds']
|
||||
# df_predict2['MIN_PRICE'] = df_predict2['min_within_quantile']
|
||||
# df_predict2['MAX_PRICE'] = df_predict2['max_within_quantile']
|
||||
# df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']]
|
||||
|
||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||
|
||||
# 更新accuracy表中的y值
|
||||
@ -424,37 +422,25 @@ def model_losss(sqlitedb,end_time):
|
||||
try:
|
||||
sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'")
|
||||
except:
|
||||
print(row)
|
||||
# 准确率计算
|
||||
logger.error(f'更新accuracy表中的y值失败,row={row}')
|
||||
# 上周准确率计算
|
||||
predict_y = sqlitedb.select_data(table_name = "accuracy")
|
||||
# 找到min_price为空的行的id
|
||||
ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
|
||||
# 计算min_price和max_price
|
||||
# predict_y[['min_price','max_price']] = predict_y[['y']+allmodelnames].apply(find_closest_values, axis=1)
|
||||
predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
|
||||
# min_price为空的行更新数据到数据库中
|
||||
for id in ids:
|
||||
row = predict_y[predict_y['id'] == id]
|
||||
print(row)
|
||||
try:
|
||||
sqlitedb.update_data('accuracy',f"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}",f"id = {id}")
|
||||
except:
|
||||
print(row)
|
||||
|
||||
# 最高最低价
|
||||
logger.error(f'更新accuracy表中的min_price,max_price值失败,row={row}')
|
||||
# 拼接市场最高最低价
|
||||
xlsfilename = os.path.join(dataset,'数据项下载.xls')
|
||||
df2 = pd.read_excel(xlsfilename)[5:]
|
||||
df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'})
|
||||
print(df2.shape)
|
||||
df = pd.merge(predict_y,df2,on=['ds'],how='left')
|
||||
df['ds'] = pd.to_datetime(df['ds'])
|
||||
# df['PREDICT_DATE'] = pd.to_datetime(df['PREDICT_DATE'])
|
||||
df = df.reindex()
|
||||
print(df.shape)
|
||||
|
||||
df.to_csv(os.path.join(dataset,'123.csv'))
|
||||
|
||||
|
||||
def calculate_accuracy(row):
|
||||
if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']:
|
||||
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
|
||||
@ -464,49 +450,30 @@ def model_losss(sqlitedb,end_time):
|
||||
return accuracy
|
||||
else:
|
||||
return 0
|
||||
|
||||
# 使用 apply 函数来应用计算准确率的函数
|
||||
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
|
||||
df[columns] = df[columns].astype(float)
|
||||
df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)
|
||||
|
||||
|
||||
# 取结束日期上一周的日期
|
||||
endtime = end_time
|
||||
endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')
|
||||
up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14)
|
||||
up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2]
|
||||
up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates]
|
||||
print(up_week_dates)
|
||||
|
||||
|
||||
#
|
||||
df3 = df.copy()
|
||||
df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)]
|
||||
df3 = df3[df3['ds'].isin(up_week_dates)]
|
||||
print(df3.shape)
|
||||
df3.to_csv('up_week_dates.csv',index=False)
|
||||
|
||||
total = len(df3)
|
||||
accuracy_rote = 0
|
||||
# for i,group in df3.groupby('CREAT_DATE'):
|
||||
for i,group in df3.groupby('ds'):
|
||||
print(i)
|
||||
print('权重:',round(len(group)/total,2))
|
||||
print('准确率:',group['ACCURACY'].sum()/(len(group)/total))
|
||||
accuracy_rote += group['ACCURACY'].sum()/(len(group)/total)
|
||||
|
||||
print(accuracy_rote)
|
||||
# 列 开始日期,结束日期,准确率
|
||||
df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])
|
||||
# 追加一行数据
|
||||
df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote}
|
||||
print(df4)
|
||||
df4.to_sql("accuracy_rote", con=sqlitedb.connection, if_exists='append', index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _add_abs_error_rate():
|
||||
# 计算每个预测值与真实值之间的偏差率
|
||||
for model in allmodelnames:
|
||||
@ -540,6 +507,8 @@ def model_losss(sqlitedb,end_time):
|
||||
# 历史价格+预测价格
|
||||
sqlitedb.drop_table('testandpredict_groupby')
|
||||
df_combined3.to_sql('testandpredict_groupby',sqlitedb.connection,index=False)
|
||||
# 新增均值列
|
||||
df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)
|
||||
|
||||
def _plt_predict_ture(df):
|
||||
lens = df.shape[0] if df.shape[0] < 180 else 90
|
||||
@ -547,13 +516,15 @@ def model_losss(sqlitedb,end_time):
|
||||
# 历史价格
|
||||
plt.figure(figsize=(20, 10))
|
||||
plt.plot(df['ds'], df['y'], label='真实值')
|
||||
# 均值线
|
||||
plt.plot(df['ds'], df['mean'], color='r', linestyle='--', label='前五模型预测均值')
|
||||
# 颜色填充
|
||||
plt.fill_between(df['ds'], df['max_within_quantile'], df['min_within_quantile'], alpha=0.2)
|
||||
# markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']
|
||||
# random_marker = random.choice(markers)
|
||||
# for model in allmodelnames:
|
||||
markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd']
|
||||
random_marker = random.choice(markers)
|
||||
for model in modelnames:
|
||||
# for model in ['BiTCN','RNN']:
|
||||
# plt.plot(df['ds'], df[model], label=model,marker=random_marker)
|
||||
plt.plot(df['ds'][-horizon:], df[model][-horizon:], label=model,marker=random_marker)
|
||||
# plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange')
|
||||
# 网格
|
||||
plt.grid(True)
|
||||
@ -561,8 +532,8 @@ def model_losss(sqlitedb,end_time):
|
||||
for i, j in zip(df['ds'], df['y']):
|
||||
plt.text(i, j, str(j), ha='center', va='bottom')
|
||||
|
||||
for model in most_model:
|
||||
plt.plot(df['ds'], df[model], label=model,marker='o')
|
||||
# for model in most_model:
|
||||
# plt.plot(df['ds'], df[model], label=model,marker='o')
|
||||
# 当前日期画竖虚线
|
||||
plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--')
|
||||
plt.legend()
|
||||
@ -1007,6 +978,14 @@ def brent_export_pdf(num_indicators=475,num_models=21, num_dayindicator=202,inpu
|
||||
content.append(Graphs.draw_table(col_width,*data))
|
||||
|
||||
|
||||
content.append(Graphs.draw_little_title('上一周预测准确率:'))
|
||||
df4 = sqlitedb.select_data('accuracy_rote',order_by='结束日期 desc',limit=1)
|
||||
df4 = df4.T
|
||||
df4 = df4.reset_index()
|
||||
data = df4.values.tolist()
|
||||
col_width = 500/len(df4.columns)
|
||||
content.append(Graphs.draw_table(col_width,*data))
|
||||
|
||||
content.append(Graphs.draw_little_title('三、预测过程解析:'))
|
||||
### 特征、模型、参数配置
|
||||
content.append(Graphs.draw_little_title('模型选择:'))
|
||||
|
@ -1 +1,17 @@
|
||||
ds,NHITS,Informer,LSTM,iTransformer,TSMixer,TSMixerx,PatchTST,RNN,GRU,TCN,BiTCN,DilatedRNN,MLP,DLinear,NLinear,TFT,StemGNN,MLPMultivariate,TiDE,DeepNPTS,y,min_within_quantile,max_within_quantile,min_price,max_price,id,CREAT_DATE,序号,LOW_PRICE,HIGH_PRICE,ACCURACY
|
||||
2024-12-09,71.63069,72.14122,70.19943,69.644196,71.80898,70.52471,71.53064,67.38501,70.75815,73.546684,71.92033,70.327484,71.95594,71.18323,71.882935,71.90951,73.94965,72.28691,71.25138,73.34018,72.13999938964844,67.38501,73.94965,67.38501,73.94965,51,2024-12-06,5.0,70.92,72.65,1.0
|
||||
2024-12-10,71.14343,71.9462,70.405106,69.48242,71.70601,70.66241,71.57484,67.23587,70.46323,73.37324,71.720894,70.45846,71.88132,72.34705,71.75997,72.41326,73.74943,72.31887,71.47953,72.78831,72.19000244140625,67.23587,73.74943,67.23587,73.74943,52,2024-12-06,4.0,70.73,71.77,1.0
|
||||
2024-12-11,71.71588,72.31544,70.175125,69.58213,71.59609,70.91783,71.54794,67.433334,70.518196,73.76477,71.84062,70.746284,72.27111,71.85789,70.77939,72.912704,73.91716,72.42111,71.47695,72.61624,73.5199966430664,67.433334,73.91716,67.433334,73.91716,53,2024-12-06,3.0,72.15,73.75,1.0
|
||||
2024-12-12,72.46348,71.87648,70.26041,69.922165,71.65103,70.689384,71.72716,67.54506,70.99872,73.52567,71.78495,70.777115,72.34328,72.756325,70.9607,73.391495,73.944244,72.465836,71.445244,71.69109,73.88999938964844,67.54506,73.944244,67.54506,73.944244,54,2024-12-06,2.0,72.42,74.0,0.9647113924050618
|
||||
2024-12-13,72.85073,72.36679,70.489136,69.759766,71.78641,70.69935,71.60861,67.47295,70.81146,73.85618,71.966835,70.923485,72.63866,72.29209,71.11011,73.777534,73.82516,72.58803,71.807915,72.6083,,67.47295,73.85618,67.47295,73.85618,55,2024-12-06,1.0,73.3,74.59,0.43114728682170156
|
||||
2024-12-10,71.97169,73.11586,71.41545,70.59598,71.81998,71.541016,72.397606,68.78458,71.30409,73.65467,71.78349,71.03456,71.8188,71.85696,72.014534,72.33824,73.77812,72.2061,71.36985,72.46154,72.19000244140625,68.78458,73.77812,68.78458,73.77812,56,2024-12-09,4.0,70.73,71.77,1.0
|
||||
2024-12-11,72.35509,73.087166,71.63927,70.430176,71.82658,71.76624,72.57018,68.59579,71.03884,73.45442,71.75791,71.15297,72.066956,72.28318,72.36303,72.75233,73.594864,72.24322,71.16944,72.39418,73.5199966430664,68.59579,73.594864,68.59579,73.594864,57,2024-12-09,3.0,72.15,73.75,0.9030400000000004
|
||||
2024-12-12,73.268654,73.31714,71.39498,70.65145,71.76794,71.6176,73.038315,68.85988,71.12751,73.84763,71.82094,71.41621,72.99199,72.42496,71.474464,73.17588,73.74838,72.31574,71.692444,72.48112,73.88999938964844,68.85988,73.84763,68.85988,73.84763,58,2024-12-09,2.0,72.42,74.0,0.9035632911392374
|
||||
2024-12-13,73.11824,73.70929,71.442184,70.72518,71.878975,71.82435,73.13541,68.760376,71.51967,73.62109,71.86448,71.44768,72.952484,72.107254,72.57947,73.61362,73.77317,72.3857,71.873566,72.40313,,68.760376,73.77317,68.760376,73.77317,59,2024-12-09,1.0,73.3,74.59,0.36679844961239827
|
||||
2024-12-11,72.372795,72.6546,71.37516,71.41169,71.977135,72.13211,72.774284,69.25488,71.62906,73.70365,72.27147,71.489716,72.77891,71.61932,72.090294,72.166916,73.72794,72.78662,71.960884,72.473854,73.5199966430664,69.25488,73.72794,69.25488,73.72794,61,2024-12-10,3.0,72.15,73.75,0.9862125000000024
|
||||
2024-12-12,72.91286,72.80857,71.612946,71.193184,72.05397,72.48625,72.85768,69.06833,71.35641,73.5444,72.19988,71.585815,73.47307,72.22734,72.219765,72.48997,73.544106,72.96339,71.66363,72.122116,73.88999938964844,69.06833,73.5444,69.06833,73.5444,62,2024-12-10,2.0,72.42,74.0,0.7116455696202503
|
||||
2024-12-13,73.01643,72.78796,71.3701,71.516174,72.21711,72.4036,73.66599,69.338615,71.467476,73.9247,72.2582,71.84913,73.643265,71.70761,72.41268,72.85244,73.69811,72.91225,71.75731,71.76489,,69.338615,73.9247,69.338615,73.9247,63,2024-12-10,1.0,73.3,74.59,0.4842635658914738
|
||||
2024-12-12,73.82699,73.58274,72.81162,73.18572,72.8827,72.958374,74.36039,71.09452,72.673904,74.025635,73.17438,72.72577,73.668365,72.48243,73.224655,74.286606,73.70824,73.668076,72.934685,73.16352,73.88999938964844,71.09452,74.36039,71.09452,74.36039,66,2024-12-11,2.0,72.42,74.0,1.0
|
||||
2024-12-13,74.025696,73.275696,73.0588,72.85848,72.719444,73.238945,74.05083,71.060295,72.455666,73.90173,72.966385,72.81345,73.77382,72.488365,73.78919,74.5482,73.52537,73.58226,72.8307,73.26384,,71.060295,74.5482,71.060295,74.5482,67,2024-12-11,1.0,73.3,74.59,0.9675968992247993
|
||||
2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,71,2024-12-12,1.0,73.3,74.59,0.877170542635658
|
||||
2024-12-13,73.624176,73.15132,73.069374,73.517944,73.56706,74.08185,73.880775,71.83479,73.23418,74.28763,73.372154,73.431366,73.51813,73.90266,73.87494,74.43155,73.691505,73.46715,73.56533,73.50562,,71.83479,74.43155,71.83479,74.43155,76,2024-12-13,1.0,73.3,74.59,0.877170542635658
|
||||
|
|
Loading…
Reference in New Issue
Block a user