From 620db6f65af15609b1101ab7b51e215cc75172a2 Mon Sep 17 00:00:00 2001 From: liurui Date: Sun, 15 Dec 2024 19:43:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=91=A8=E5=BA=A6=E5=87=86?= =?UTF-8?q?=E7=A1=AE=E7=8E=87=E5=88=B0=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_jingbo.py | 2 +- main_yuanyou.py | 10 +++- models/nerulforcastmodels.py | 109 ++++++++++++++++++++++++++++++----- up_week_dates.csv | 17 +----- 4 files changed, 104 insertions(+), 34 deletions(-) diff --git a/config_jingbo.py b/config_jingbo.py index d2c6cbd..0fa4074 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -243,7 +243,7 @@ print("数据库连接成功",host,dbname,dbusername) # 数据截取日期 start_year = 2018 # 数据开始年份 -end_time = '2024-11-29' # 数据截取日期 +end_time = '2024-12-04' # 数据截取日期 freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 diff --git a/main_yuanyou.py b/main_yuanyou.py index 4ce188b..7ce141a 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -9,7 +9,7 @@ torch.set_float32_matmul_precision("high") -def predict_main(): +def predict_main(end_time): """ 主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。 @@ -202,7 +202,7 @@ def predict_main(): logger.info('模型训练完成') logger.info('训练数据绘图ing') - model_results3 = model_losss(sqlitedb) + model_results3 = model_losss(sqlitedb,end_time=end_time) logger.info('训练数据绘图end') # 模型报告 @@ -238,4 +238,8 @@ def predict_main(): if __name__ == '__main__': - predict_main() \ No newline at end of file + # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 + for i_time in pd.date_range('2024-11-22', '2024-12-17', freq='B'): + end_time = i_time.strftime('%Y-%m-%d') + # print(e_time) + predict_main(end_time) \ No newline at end of file diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index e00e4ad..54d3231 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -222,7 +222,7 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien # 原油计算预测评估指数 -def model_losss(sqlitedb): +def model_losss(sqlitedb,end_time): global dataset global rote most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]] @@ -357,23 +357,22 @@ def model_losss(sqlitedb): def find_closest_values(row): x = row.y - if x is None: - return pd.Series([None, None], index=['min_within_quantile','max_within_quantile']) - row = row.drop('ds') - row = row[:,0].values.tolist() - row = row.sort() + if x is None or np.isnan(x): + return pd.Series([None, None], index=['min_price','max_price']) + # row = row.drop('ds') + row = row.values.tolist() + row.sort() print(row) # x 在row中的索引 index = row.index(x) if index == 0: - return pd.Series([row[index+1], row[index+2]], index=['min_within_quantile','max_within_quantile']) + return pd.Series([row[index+1], row[index+2]], index=['min_price','max_price']) elif index == len(row)-1: - return pd.Series([row[index-2], row[index-1]], index=['min_within_quantile','max_within_quantile']) + return pd.Series([row[index-2], row[index-1]], index=['min_price','max_price']) else: - return pd.Series([row[index-1], row[index+1]], index=['min_within_quantile','max_within_quantile']) + return pd.Series([row[index-1], row[index+1]], index=['min_price','max_price']) + - # df_combined3[['min_within_quantile','max_within_quantile']] = df_combined3.apply(find_closest_values, axis=1) - def find_most_common_model(): # 最多频率的模型名称 @@ -406,24 +405,106 @@ def model_losss(sqlitedb): df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2)) else: df_predict2['id'] = range(1, 1 + len(df_predict2)) + # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time + df_predict2['CREAT_DATE'] = end_time - df_predict2['CREAT_DATE'] = now if end_time == '' else end_time # df_predict2['PREDICT_DATE'] = df_predict2['ds'] # df_predict2['MIN_PRICE'] = df_predict2['min_within_quantile'] # df_predict2['MAX_PRICE'] = df_predict2['max_within_quantile'] # df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']] + df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + # 更新accuracy表中的y值 update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null') if len(update_y) > 0: df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())] - if len(df_combined4) > 0: + if len(df_combined4) > 0: for index, row in df_combined4.iterrows(): try: sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'") except: print(row) - print(df_combined4) + # 准确率计算 + predict_y = sqlitedb.select_data(table_name = "accuracy") + # 找到min_price为空的行的id + ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist() + # 计算min_price和max_price + # predict_y[['min_price','max_price']] = predict_y[['y']+allmodelnames].apply(find_closest_values, axis=1) + predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']] + # min_price为空的行更新数据到数据库中 + for id in ids: + row = predict_y[predict_y['id'] == id] + print(row) + try: + sqlitedb.update_data('accuracy',f"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}",f"id = {id}") + except: + print(row) + + # 最高最低价 + xlsfilename = os.path.join(dataset,'数据项下载.xls') + df2 = pd.read_excel(xlsfilename)[5:] + df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'}) + print(df2.shape) + df = pd.merge(predict_y,df2,on=['ds'],how='left') + df['ds'] = pd.to_datetime(df['ds']) + # df['PREDICT_DATE'] = pd.to_datetime(df['PREDICT_DATE']) + df = df.reindex() + print(df.shape) + + df.to_csv(os.path.join(dataset,'123.csv')) + + + def calculate_accuracy(row): + if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']: + sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']]) + middle_diff = sorted_prices[2] - sorted_prices[1] + price_range = row['HIGH_PRICE'] - row['LOW_PRICE'] + accuracy = middle_diff / price_range + return accuracy + else: + return 0 + + # 使用 apply 函数来应用计算准确率的函数 + columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price'] + df[columns] = df[columns].astype(float) + df['ACCURACY'] = df.apply(calculate_accuracy, axis=1) + + + # 取结束日期上一周的日期 + endtime = end_time + endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d') + up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14) + up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2] + up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates] + print(up_week_dates) + + + df3 = df.copy() + df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)] + df3 = df3[df3['ds'].isin(up_week_dates)] + print(df3.shape) + df3.to_csv('up_week_dates.csv',index=False) + + total = len(df3) + accuracy_rote = 0 + # for i,group in df3.groupby('CREAT_DATE'): + for i,group in df3.groupby('ds'): + print(i) + print('权重:',round(len(group)/total,2)) + print('准确率:',group['ACCURACY'].sum()/(len(group)/total)) + accuracy_rote += group['ACCURACY'].sum()/(len(group)/total) + + print(accuracy_rote) + # 列 开始日期,结束日期,准确率 + df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率']) + # 追加一行数据 + df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote} + print(df4) + df4.to_sql("accuracy_rote", con=sqlitedb.connection, if_exists='append', index=False) + + + def _add_abs_error_rate(): diff --git a/up_week_dates.csv b/up_week_dates.csv index 2bb7d4e..40e7294 100644 --- a/up_week_dates.csv +++ b/up_week_dates.csv @@ -1,16 +1 @@ -ds,ACCURACY,PREDICT_DATE,CREAT_DATE,HIGH_PRICE_y,LOW_PRICE_y,MIN_PRICE,MAX_PRICE,Ds_Week,Pre_Week -2024-11-26,1.0,2024-11-26,2024-11-25,73.8,71.63,71.071556,76.0069,47,47 -2024-11-27,1.0,2024-11-27,2024-11-25,72.85,71.71,71.003624,75.58056,47,47 -2024-11-28,0.7893243243243208,2024-11-28,2024-11-25,72.96,71.85,72.08385,76.20426,47,47 -2024-11-29,1.0,2024-11-29,2024-11-25,73.34,71.75,71.32973,75.70395,47,47 -2024-11-27,0.6068728070175414,2024-11-27,2024-11-26,72.85,71.71,72.158165,76.17365,47,47 -2024-11-28,0.8021441441441385,2024-11-28,2024-11-26,72.96,71.85,72.06962,76.447,47,47 -2024-11-29,0.6082389937106918,2024-11-29,2024-11-26,73.34,71.75,72.3729,76.08291,47,47 -2024-11-28,1.0,2024-11-28,2024-11-27,72.96,71.85,70.70975,75.20803,47,47 -2024-11-29,1.0,2024-11-29,2024-11-27,73.34,71.75,69.92311,75.423775,47,47 -2024-11-29,1.0,2024-11-29,2024-11-28,73.34,71.75,70.43276,75.48062,47,47 -2024-11-25,0.11832806324110515,2024-11-25,2024-11-22,74.83,72.3,74.53063,76.67314,47,47 -2024-11-26,0.0,2024-11-26,2024-11-22,73.8,71.63,74.44043,76.874565,47,47 -2024-11-27,0.0,2024-11-27,2024-11-22,72.85,71.71,74.66318,76.73413,47,47 -2024-11-28,0.0,2024-11-28,2024-11-22,72.96,71.85,74.70841,77.14105,47,47 -2024-11-29,0.0,2024-11-29,2024-11-22,73.34,71.75,74.70321,77.74617,47,47 +ds,NHITS,Informer,LSTM,iTransformer,TSMixer,TSMixerx,PatchTST,RNN,GRU,TCN,BiTCN,DilatedRNN,MLP,DLinear,NLinear,TFT,StemGNN,MLPMultivariate,TiDE,DeepNPTS,y,min_within_quantile,max_within_quantile,min_price,max_price,id,CREAT_DATE,序号,LOW_PRICE,HIGH_PRICE,ACCURACY