预测结果保存到数据库
This commit is contained in:
parent
1f827d8224
commit
24174fa9a0
@ -223,7 +223,7 @@ table_name = 'v_tbl_crude_oil_warning'
|
|||||||
|
|
||||||
|
|
||||||
### 开关
|
### 开关
|
||||||
is_train = True # 是否训练
|
is_train = False # 是否训练
|
||||||
is_debug = False # 是否调试
|
is_debug = False # 是否调试
|
||||||
is_eta = False # 是否使用eta接口
|
is_eta = False # 是否使用eta接口
|
||||||
is_timefurture = True # 是否使用时间特征
|
is_timefurture = True # 是否使用时间特征
|
||||||
@ -243,7 +243,7 @@ print("数据库连接成功",host,dbname,dbusername)
|
|||||||
|
|
||||||
# 数据截取日期
|
# 数据截取日期
|
||||||
start_year = 2018 # 数据开始年份
|
start_year = 2018 # 数据开始年份
|
||||||
end_time = '' # 数据截取日期
|
end_time = '2024-11-29' # 数据截取日期
|
||||||
freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日
|
freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日
|
||||||
delweekenday = True if freq == 'B' else False # 是否删除周末数据
|
delweekenday = True if freq == 'B' else False # 是否删除周末数据
|
||||||
is_corr = False # 特征是否参与滞后领先提升相关系数
|
is_corr = False # 特征是否参与滞后领先提升相关系数
|
||||||
|
@ -178,25 +178,25 @@ def predict_main():
|
|||||||
row, col = df.shape
|
row, col = df.shape
|
||||||
|
|
||||||
now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||||
# ex_Model(df,
|
ex_Model(df,
|
||||||
# horizon=horizon,
|
horizon=horizon,
|
||||||
# input_size=input_size,
|
input_size=input_size,
|
||||||
# train_steps=train_steps,
|
train_steps=train_steps,
|
||||||
# val_check_steps=val_check_steps,
|
val_check_steps=val_check_steps,
|
||||||
# early_stop_patience_steps=early_stop_patience_steps,
|
early_stop_patience_steps=early_stop_patience_steps,
|
||||||
# is_debug=is_debug,
|
is_debug=is_debug,
|
||||||
# dataset=dataset,
|
dataset=dataset,
|
||||||
# is_train=is_train,
|
is_train=is_train,
|
||||||
# is_fivemodels=is_fivemodels,
|
is_fivemodels=is_fivemodels,
|
||||||
# val_size=val_size,
|
val_size=val_size,
|
||||||
# test_size=test_size,
|
test_size=test_size,
|
||||||
# settings=settings,
|
settings=settings,
|
||||||
# now=now,
|
now=now,
|
||||||
# etadata=etadata,
|
etadata=etadata,
|
||||||
# modelsindex=modelsindex,
|
modelsindex=modelsindex,
|
||||||
# data=data,
|
data=data,
|
||||||
# is_eta=is_eta,
|
is_eta=is_eta,
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
logger.info('模型训练完成')
|
logger.info('模型训练完成')
|
||||||
@ -234,7 +234,7 @@ def predict_main():
|
|||||||
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
|
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
|
||||||
ssl=ssl,
|
ssl=ssl,
|
||||||
)
|
)
|
||||||
m.send_mail()
|
# m.send_mail()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -179,17 +179,17 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
|
|||||||
filename = f'{settings}--{now}.joblib'
|
filename = f'{settings}--{now}.joblib'
|
||||||
#文件名去掉冒号
|
#文件名去掉冒号
|
||||||
filename = filename.replace(':', '-') # 替换冒号
|
filename = filename.replace(':', '-') # 替换冒号
|
||||||
# dump(nf, os.path.join(dataset,filename))
|
dump(nf, os.path.join(dataset,filename))
|
||||||
else:
|
else:
|
||||||
# glob获取dataset下最新的joblib文件
|
# glob获取dataset下最新的joblib文件
|
||||||
import glob
|
import glob
|
||||||
filename = max(glob.glob(os.path.join(dataset,'*.joblib')), key=os.path.getctime)
|
filename = max(glob.glob(os.path.join(dataset,'*.joblib')), key=os.path.getctime)
|
||||||
# logger.info('读取模型:'+ filename)
|
logger.info('读取模型:'+ filename)
|
||||||
nf = load(filename)
|
nf = load(filename)
|
||||||
# # 测试集预测
|
# # # 测试集预测
|
||||||
nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None)
|
# nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None)
|
||||||
# 测试集预测结果保存
|
# # 测试集预测结果保存
|
||||||
nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False)
|
# nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False)
|
||||||
|
|
||||||
df_test['ds'] = pd.to_datetime(df_test['ds'], errors='coerce')
|
df_test['ds'] = pd.to_datetime(df_test['ds'], errors='coerce')
|
||||||
|
|
||||||
@ -217,7 +217,8 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien
|
|||||||
etadata.push_data(data)
|
etadata.push_data(data)
|
||||||
|
|
||||||
|
|
||||||
return nf_test_preds
|
# return nf_test_preds
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
# 原油计算预测评估指数
|
# 原油计算预测评估指数
|
||||||
@ -399,12 +400,7 @@ def model_losss(sqlitedb):
|
|||||||
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE'])
|
columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE'])
|
||||||
sqlitedb.create_table('accuracy',columns=columns)
|
sqlitedb.create_table('accuracy',columns=columns)
|
||||||
existing_data = sqlitedb.select_data(table_name = "accuracy")
|
existing_data = sqlitedb.select_data(table_name = "accuracy")
|
||||||
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
|
|
||||||
df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())]
|
|
||||||
if len(df_combined4) > 0:
|
|
||||||
for index, row in df_combined4.iterrows():
|
|
||||||
sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'")
|
|
||||||
print(df_combined4)
|
|
||||||
if not existing_data.empty:
|
if not existing_data.empty:
|
||||||
max_id = existing_data['id'].astype(int).max()
|
max_id = existing_data['id'].astype(int).max()
|
||||||
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2))
|
||||||
@ -418,7 +414,16 @@ def model_losss(sqlitedb):
|
|||||||
# df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']]
|
# df_predict2 = df_predict2[['id','PREDICT_DATE','CREAT_DATE','MIN_PRICE','MAX_PRICE']]
|
||||||
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False)
|
||||||
|
|
||||||
|
update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null')
|
||||||
|
if len(update_y) > 0:
|
||||||
|
df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())]
|
||||||
|
if len(df_combined4) > 0:
|
||||||
|
for index, row in df_combined4.iterrows():
|
||||||
|
try:
|
||||||
|
sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'")
|
||||||
|
except:
|
||||||
|
print(row)
|
||||||
|
print(df_combined4)
|
||||||
|
|
||||||
|
|
||||||
def _add_abs_error_rate():
|
def _add_abs_error_rate():
|
||||||
|
Loading…
Reference in New Issue
Block a user