计算准确率逻辑修改
This commit is contained in:
parent
1ad1553e01
commit
34c4a9e205
@ -400,7 +400,7 @@ def model_losss(sqlitedb,end_time):
|
||||
|
||||
# 保存到数据库
|
||||
if not sqlitedb.check_table_exists('accuracy'):
|
||||
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE'])
|
||||
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE','min_price','max_price'])
|
||||
sqlitedb.create_table('accuracy',columns=columns)
|
||||
existing_data = sqlitedb.select_data(table_name = "accuracy")
|
||||
|
||||
@ -429,9 +429,11 @@ def model_losss(sqlitedb,end_time):
|
||||
# 模型评估前五最大最小
|
||||
# predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']]
|
||||
# 模型评估前五均值 df_combined3['mean'] = df_combined3[modelnames].mean(axis=1)
|
||||
|
||||
predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
|
||||
predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
|
||||
# predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1
|
||||
# predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1
|
||||
# 模型评估前十均值
|
||||
predict_y['min_price'] = predict_y[allmodelnames[0:5]].min(axis=1)
|
||||
predict_y['max_price'] = predict_y[allmodelnames[0:5]].max(axis=1)
|
||||
for id in ids:
|
||||
row = predict_y[predict_y['id'] == id]
|
||||
try:
|
||||
@ -447,12 +449,19 @@ def model_losss(sqlitedb,end_time):
|
||||
df['ds'] = pd.to_datetime(df['ds'])
|
||||
df = df.reindex()
|
||||
def calculate_accuracy(row):
|
||||
if row['HIGH_PRICE'] > row['min_price'] or row['max_price'] > row['LOW_PRICE']:
|
||||
# 子集情况:
|
||||
if (row['HIGH_PRICE'] >= row['max_price'] and row['min_price'] >= row['LOW_PRICE']) or \
|
||||
(row['LOW_PRICE'] >= row['min_price'] and row['max_price'] >= row['HIGH_PRICE']):
|
||||
return 1
|
||||
# 有交集情况:
|
||||
if row['HIGH_PRICE'] > row['min_price'] or \
|
||||
row['max_price'] > row['LOW_PRICE']:
|
||||
sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']])
|
||||
middle_diff = sorted_prices[2] - sorted_prices[1]
|
||||
price_range = row['HIGH_PRICE'] - row['LOW_PRICE']
|
||||
accuracy = middle_diff / price_range
|
||||
return accuracy
|
||||
# 无交集情况:
|
||||
else:
|
||||
return 0
|
||||
columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price']
|
||||
|
Loading…
Reference in New Issue
Block a user