From 34c4a9e2053ea29a8269b2092be9a390ca3cec6b Mon Sep 17 00:00:00 2001 From: liurui Date: Tue, 17 Dec 2024 09:31:36 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=A1=E7=AE=97=E5=87=86=E7=A1=AE=E7=8E=87?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/nerulforcastmodels.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 64d2e8c..6fa1757 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -400,7 +400,7 @@ def model_losss(sqlitedb,end_time): # 保存到数据库 if not sqlitedb.check_table_exists('accuracy'): - columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE']) + columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE','min_price','max_price']) sqlitedb.create_table('accuracy',columns=columns) existing_data = sqlitedb.select_data(table_name = "accuracy") @@ -429,9 +429,11 @@ def model_losss(sqlitedb,end_time): # 模型评估前五最大最小 # predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']] # 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1) - - predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1 - predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1 + # predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1 + # predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1 + # 模型评估前十均值 + predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1) + predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1) for id in ids: row = predict_y[predict_y['id'] == id] try: @@ -447,12 +449,19 @@ def model_losss(sqlitedb,end_time): df['ds'] = pd.to_datetime(df['ds']) df = df.reindex() def calculate_accuracy(row): - if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']: + # 子集情况: + if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \ + (row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']): + return 1 + # 有交集情况: + if row['HIGH_PRICE'] > row['min_price'] or \ + row['max_price'] > row['LOW_PRICE']: sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']]) middle_diff = sorted_prices[2] - sorted_prices[1] price_range = row['HIGH_PRICE'] - row['LOW_PRICE'] accuracy = middle_diff / price_range return accuracy + # 无交集情况: else: return 0 columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price'] @@ -586,7 +595,7 @@ def model_losss(sqlitedb,end_time): _plt_model_results3() return model_results3 - + # 聚烯烃计算预测评估指数 def model_losss_juxiting(sqlitedb):