diff --git a/config_jingbo_zhoudu.py b/config_jingbo_zhoudu.py index a737a6c..189bf30 100644 --- a/config_jingbo_zhoudu.py +++ b/config_jingbo_zhoudu.py @@ -159,9 +159,9 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 -is_train = True # 是否训练 +is_train = False # 是否训练 is_debug = False # 是否调试 -is_eta = True # 是否使用eta接口 +is_eta = False # 是否使用eta接口 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 604b815..1336783 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -188,9 +188,9 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien logger.info('读取模型:'+ filename) nf = load(filename) # # 测试集预测 - # nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None) - # # 测试集预测结果保存 - # nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False) + nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None) + # 测试集预测结果保存 + nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False) df_test['ds'] = pd.to_datetime(df_test['ds'], errors='coerce')