diff --git a/config_jingbo.py b/config_jingbo.py index c194ec0..790127b 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -198,7 +198,7 @@ warning_data = { ### 开关 is_train = True # 是否训练 is_debug = False # 是否调试 -is_eta = True # 是否使用eta接口 +is_eta = False # 是否使用eta接口 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = True # 特征使用edbcoding列表中的 @@ -235,13 +235,13 @@ if add_kdj and is_edbnamelist: ### 模型参数 y = 'Brent连1合约价格' # 原油指标数据的目标变量 Brent连1合约价格 Brent活跃合约 # y = '期货结算价(连续):布伦特原油:前一个观测值' # ineoil的目标变量 -horizon =20 # 预测的步长 +horizon =30 # 预测的步长 input_size = 120 # 输入序列长度 train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 val_check_steps = 30 # 评估频率 early_stop_patience_steps = 5 # 早停的耐心步数 # --- 交叉验证用的参数 -test_size = 100 # 测试集大小,定义100,后面使用的时候重新赋值 +test_size = 200 # 测试集大小,定义100,后面使用的时候重新赋值 val_size = test_size # 验证集大小,同测试集大小 ### 特征筛选用到的参数 diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index c3ebc46..ba4b38b 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -415,7 +415,7 @@ def model_losss(sqlitedb): df_combined3.to_csv(os.path.join(dataset,"df_combined3.csv"),index=False) # 历史价格+预测价格 - df_combined3 = df_combined3[-50:] # 取50个数据点画图 + # df_combined3 = df_combined3[-50:] # 取50个数据点画图 # 历史价格 plt.figure(figsize=(20, 10)) plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值') @@ -429,8 +429,8 @@ def model_losss(sqlitedb): plt.text(i, j, str(j), ha='center', va='bottom') # 数据库查询最佳模型名称 - most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]] - + # most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]] + most_model = modelnames[0:1] for model in most_model: plt.plot(df_combined3['ds'], df_combined3[model], label=model,marker='o') # 当前日期画竖虚线