计算准确率逻辑修改

This commit is contained in:
liurui 2024-12-17 09:31:36 +08:00
parent 1ad1553e01
commit 34c4a9e205

View File

@ -400,7 +400,7 @@ def model_losss(sqlitedb,end_time):
# 保存到数据库
if not sqlitedb.check_table_exists('accuracy'):
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE'])
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE','min_price','max_price'])
sqlitedb.create_table('accuracy',columns=columns)
existing_data = sqlitedb.select_data(table_name = "accuracy")
@ -429,9 +429,11 @@ def model_losss(sqlitedb,end_time):
# 模型评估前五最大最小
# predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
# 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)
predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
# predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
# predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
# 模型评估前十均值
predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1)
predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1)
for id in ids:
row = predict_y[predict_y['id'] == id]
try:
@ -447,12 +449,19 @@ def model_losss(sqlitedb,end_time):
df['ds'] = pd.to_datetime(df['ds'])
df = df.reindex()
def calculate_accuracy(row):
if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']:
# 子集情况:
if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \
(row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']):
return 1
# 有交集情况:
if row['HIGH_PRICE'] > row['min_price'] or \
row['max_price'] > row['LOW_PRICE']:
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
middle_diff = sorted_prices[2] - sorted_prices[1]
price_range = row['HIGH_PRICE'] - row['LOW_PRICE']
accuracy = middle_diff / price_range
return accuracy
# 无交集情况:
else:
return 0
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
@ -586,7 +595,7 @@ def model_losss(sqlitedb,end_time):
_plt_model_results3()
return model_results3
# 聚烯烃计算预测评估指数
def model_losss_juxiting(sqlitedb):