添加周度准确率到数据库
This commit is contained in:
parent
24174fa9a0
commit
620db6f65a
@ -243,7 +243,7 @@ print("数据库连接成功",host,dbname,dbusername)
|
|||||||
|
|
||||||
# 数据截取日期
|
# 数据截取日期
|
||||||
start_year = 2018 # 数据开始年份
|
start_year = 2018 # 数据开始年份
|
||||||
end_time = '2024-11-29' # 数据截取日期
|
end_time = '2024-12-04' # 数据截取日期
|
||||||
freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日
|
freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日
|
||||||
delweekenday = True if freq == 'B' else False # 是否删除周末数据
|
delweekenday = True if freq == 'B' else False # 是否删除周末数据
|
||||||
is_corr = False # 特征是否参与滞后领先提升相关系数
|
is_corr = False # 特征是否参与滞后领先提升相关系数
|
||||||
|
@ -9,7 +9,7 @@ torch.set_float32_matmul_precision("high")
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def predict_main():
|
def predict_main(end_time):
|
||||||
"""
|
"""
|
||||||
主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。
|
主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ def predict_main():
|
|||||||
logger.info('模型训练完成')
|
logger.info('模型训练完成')
|
||||||
|
|
||||||
logger.info('训练数据绘图ing')
|
logger.info('训练数据绘图ing')
|
||||||
model_results3 = model_losss(sqlitedb)
|
model_results3 = model_losss(sqlitedb,end_time=end_time)
|
||||||
logger.info('训练数据绘图end')
|
logger.info('训练数据绘图end')
|
||||||
|
|
||||||
# 模型报告
|
# 模型报告
|
||||||
@ -238,4 +238,8 @@ def predict_main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
predict_main()
|
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
|
||||||
|
for i_time in pd.date_range('2024-11-22', '2024-12-17', freq='B'):
|
||||||
|
end_time = i_time.strftime('%Y-%m-%d')
|
||||||
|
# print(e_time)
|
||||||
|
predict_main(end_time)
|
@ -222,7 +222,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
|
|||||||
|
|
||||||
|
|
||||||
# 原油计算预测评估指数
|
# 原油计算预测评估指数
|
||||||
def model_losss(sqlitedb):
|
def model_losss(sqlitedb,end_time):
|
||||||
global dataset
|
global dataset
|
||||||
global rote
|
global rote
|
||||||
most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]
|
most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]
|
||||||
@ -357,22 +357,21 @@ def model_losss(sqlitedb):
|
|||||||
|
|
||||||
def find_closest_values(row):
|
def find_closest_values(row):
|
||||||
x = row.y
|
x = row.y
|
||||||
if x is None:
|
if x is None or np.isnan(x):
|
||||||
return pd.Series([None, None], index=['min_within_quantile','max_within_quantile'])
|
return pd.Series([None, None], index=['min_price','max_price'])
|
||||||
row = row.drop('ds')
|
# row = row.drop('ds')
|
||||||
row = row[:,0].values.tolist()
|
row = row.values.tolist()
|
||||||
row = row.sort()
|
row.sort()
|
||||||
print(row)
|
print(row)
|
||||||
# x 在row中的索引
|
# x 在row中的索引
|
||||||
index = row.index(x)
|
index = row.index(x)
|
||||||
if index == 0:
|
if index == 0:
|
||||||
return pd.Series([row[index+1], row[index+2]], index=['min_within_quantile','max_within_quantile'])
|
return pd.Series([row[index+1], row[index+2]], index=['min_price','max_price'])
|
||||||
elif index == len(row)-1:
|
elif index == len(row)-1:
|
||||||
return pd.Series([row[index-2], row[index-1]], index=['min_within_quantile','max_within_quantile'])
|
return pd.Series([row[index-2], row[index-1]], index=['min_price','max_price'])
|
||||||
else:
|
else:
|
||||||
return pd.Series([row[index-1], row[index+1]], index=['min_within_quantile','max_within_quantile'])
|
return pd.Series([row[index-1], row[index+1]], index=['min_price','max_price'])
|
||||||
|
|
||||||
# df_combined3[['min_within_quantile','max_within_quantile']] = df_combined3.apply(find_closest_values, axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
def find_most_common_model():
|
def find_most_common_model():
|
||||||
@ -406,14 +405,17 @@ def model_losss(sqlitedb):
|
|||||||
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
||||||
else:
|
else:
|
||||||
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
df_predict2['id'] = range(1, 1 + len(df_predict2))
|
||||||
|
# df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
|
||||||
|
df_predict2['CREAT_DATE'] = end_time
|
||||||
|
|
||||||
df_predict2['CREAT_DATE'] = now if end_time == '' else end_time
|
|
||||||
# df_predict2['PREDICT_DATE'] = df_predict2['ds']
|
# df_predict2['PREDICT_DATE'] = df_predict2['ds']
|
||||||
# df_predict2['MIN_PRICE'] = df_predict2['min_within_quantile']
|
# df_predict2['MIN_PRICE'] = df_predict2['min_within_quantile']
|
||||||
# df_predict2['MAX_PRICE'] = df_predict2['max_within_quantile']
|
# df_predict2['MAX_PRICE'] = df_predict2['max_within_quantile']
|
||||||
# df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']]
|
# df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']]
|
||||||
|
|
||||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||||
|
|
||||||
|
# 更新accuracy表中的y值
|
||||||
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
|
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
|
||||||
if len(update_y) > 0:
|
if len(update_y) > 0:
|
||||||
df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())]
|
df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())]
|
||||||
@ -423,7 +425,86 @@ def model_losss(sqlitedb):
|
|||||||
sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'")
|
sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'")
|
||||||
except:
|
except:
|
||||||
print(row)
|
print(row)
|
||||||
print(df_combined4)
|
# 准确率计算
|
||||||
|
predict_y = sqlitedb.select_data(table_name = "accuracy")
|
||||||
|
# 找到min_price为空的行的id
|
||||||
|
ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist()
|
||||||
|
# 计算min_price和max_price
|
||||||
|
# predict_y[['min_price','max_price']] = predict_y[['y']+allmodelnames].apply(find_closest_values, axis=1)
|
||||||
|
predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
|
||||||
|
# min_price为空的行更新数据到数据库中
|
||||||
|
for id in ids:
|
||||||
|
row = predict_y[predict_y['id'] == id]
|
||||||
|
print(row)
|
||||||
|
try:
|
||||||
|
sqlitedb.update_data('accuracy',f"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}",f"id = {id}")
|
||||||
|
except:
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
# 最高最低价
|
||||||
|
xlsfilename = os.path.join(dataset,'数据项下载.xls')
|
||||||
|
df2 = pd.read_excel(xlsfilename)[5:]
|
||||||
|
df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'})
|
||||||
|
print(df2.shape)
|
||||||
|
df = pd.merge(predict_y,df2,on=['ds'],how='left')
|
||||||
|
df['ds'] = pd.to_datetime(df['ds'])
|
||||||
|
# df['PREDICT_DATE'] = pd.to_datetime(df['PREDICT_DATE'])
|
||||||
|
df = df.reindex()
|
||||||
|
print(df.shape)
|
||||||
|
|
||||||
|
df.to_csv(os.path.join(dataset,'123.csv'))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_accuracy(row):
|
||||||
|
if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']:
|
||||||
|
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
|
||||||
|
middle_diff = sorted_prices[2] - sorted_prices[1]
|
||||||
|
price_range = row['HIGH_PRICE'] - row['LOW_PRICE']
|
||||||
|
accuracy = middle_diff / price_range
|
||||||
|
return accuracy
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 使用 apply 函数来应用计算准确率的函数
|
||||||
|
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
|
||||||
|
df[columns] = df[columns].astype(float)
|
||||||
|
df['ACCURACY'] = df.apply(calculate_accuracy, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
# 取结束日期上一周的日期
|
||||||
|
endtime = end_time
|
||||||
|
endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')
|
||||||
|
up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14)
|
||||||
|
up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2]
|
||||||
|
up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates]
|
||||||
|
print(up_week_dates)
|
||||||
|
|
||||||
|
|
||||||
|
df3 = df.copy()
|
||||||
|
df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)]
|
||||||
|
df3 = df3[df3['ds'].isin(up_week_dates)]
|
||||||
|
print(df3.shape)
|
||||||
|
df3.to_csv('up_week_dates.csv',index=False)
|
||||||
|
|
||||||
|
total = len(df3)
|
||||||
|
accuracy_rote = 0
|
||||||
|
# for i,group in df3.groupby('CREAT_DATE'):
|
||||||
|
for i,group in df3.groupby('ds'):
|
||||||
|
print(i)
|
||||||
|
print('权重:',round(len(group)/total,2))
|
||||||
|
print('准确率:',group['ACCURACY'].sum()/(len(group)/total))
|
||||||
|
accuracy_rote += group['ACCURACY'].sum()/(len(group)/total)
|
||||||
|
|
||||||
|
print(accuracy_rote)
|
||||||
|
# 列 开始日期,结束日期,准确率
|
||||||
|
df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率'])
|
||||||
|
# 追加一行数据
|
||||||
|
df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote}
|
||||||
|
print(df4)
|
||||||
|
df4.to_sql("accuracy_rote", con=sqlitedb.connection, if_exists='append', index=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _add_abs_error_rate():
|
def _add_abs_error_rate():
|
||||||
|
@ -1,16 +1 @@
|
|||||||
ds,ACCURACY,PREDICT_DATE,CREAT_DATE,HIGH_PRICE_y,LOW_PRICE_y,MIN_PRICE,MAX_PRICE,Ds_Week,Pre_Week
|
ds,NHITS,Informer,LSTM,iTransformer,TSMixer,TSMixerx,PatchTST,RNN,GRU,TCN,BiTCN,DilatedRNN,MLP,DLinear,NLinear,TFT,StemGNN,MLPMultivariate,TiDE,DeepNPTS,y,min_within_quantile,max_within_quantile,min_price,max_price,id,CREAT_DATE,序号,LOW_PRICE,HIGH_PRICE,ACCURACY
|
||||||
2024-11-26,1.0,2024-11-26,2024-11-25,73.8,71.63,71.071556,76.0069,47,47
|
|
||||||
2024-11-27,1.0,2024-11-27,2024-11-25,72.85,71.71,71.003624,75.58056,47,47
|
|
||||||
2024-11-28,0.7893243243243208,2024-11-28,2024-11-25,72.96,71.85,72.08385,76.20426,47,47
|
|
||||||
2024-11-29,1.0,2024-11-29,2024-11-25,73.34,71.75,71.32973,75.70395,47,47
|
|
||||||
2024-11-27,0.6068728070175414,2024-11-27,2024-11-26,72.85,71.71,72.158165,76.17365,47,47
|
|
||||||
2024-11-28,0.8021441441441385,2024-11-28,2024-11-26,72.96,71.85,72.06962,76.447,47,47
|
|
||||||
2024-11-29,0.6082389937106918,2024-11-29,2024-11-26,73.34,71.75,72.3729,76.08291,47,47
|
|
||||||
2024-11-28,1.0,2024-11-28,2024-11-27,72.96,71.85,70.70975,75.20803,47,47
|
|
||||||
2024-11-29,1.0,2024-11-29,2024-11-27,73.34,71.75,69.92311,75.423775,47,47
|
|
||||||
2024-11-29,1.0,2024-11-29,2024-11-28,73.34,71.75,70.43276,75.48062,47,47
|
|
||||||
2024-11-25,0.11832806324110515,2024-11-25,2024-11-22,74.83,72.3,74.53063,76.67314,47,47
|
|
||||||
2024-11-26,0.0,2024-11-26,2024-11-22,73.8,71.63,74.44043,76.874565,47,47
|
|
||||||
2024-11-27,0.0,2024-11-27,2024-11-22,72.85,71.71,74.66318,76.73413,47,47
|
|
||||||
2024-11-28,0.0,2024-11-28,2024-11-22,72.96,71.85,74.70841,77.14105,47,47
|
|
||||||
2024-11-29,0.0,2024-11-29,2024-11-22,73.34,71.75,74.70321,77.74617,47,47
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user