原油画图用最佳模型,而不是最多拟合的模型。

This commit is contained in:
liurui 2024-12-04 16:39:37 +08:00
parent 16bfe73eec
commit 03abddb68c
2 changed files with 6 additions and 6 deletions

View File

@ -198,7 +198,7 @@ warning_data = {
### 开关
is_train = True # 是否训练
is_debug = False # 是否调试
is_eta = True # 是否使用eta接口
is_eta = False # 是否使用eta接口
is_timefurture = True # 是否使用时间特征
is_fivemodels = False # 是否使用之前保存的最佳的5个模型
is_edbcode = True # 特征使用edbcoding列表中的
@ -235,13 +235,13 @@ if add_kdj and is_edbnamelist:
### 模型参数
y = 'Brent连1合约价格' # 原油指标数据的目标变量 Brent连1合约价格 Brent活跃合约
# y = '期货结算价(连续):布伦特原油:前一个观测值' # ineoil的目标变量
horizon =20 # 预测的步长
horizon =30 # 预测的步长
input_size = 120 # 输入序列长度
train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数
val_check_steps = 30 # 评估频率
early_stop_patience_steps = 5 # 早停的耐心步数
# --- 交叉验证用的参数
test_size = 100 # 测试集大小定义100后面使用的时候重新赋值
test_size = 200 # 测试集大小定义100后面使用的时候重新赋值
val_size = test_size # 验证集大小,同测试集大小
### 特征筛选用到的参数

View File

@ -415,7 +415,7 @@ def model_losss(sqlitedb):
df_combined3.to_csv(os.path.join(dataset,"df_combined3.csv"),index=False)
# 历史价格+预测价格
df_combined3 = df_combined3[-50:] # 取50个数据点画图
# df_combined3 = df_combined3[-50:] # 取50个数据点画图
# 历史价格
plt.figure(figsize=(20, 10))
plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值')
@ -429,8 +429,8 @@ def model_losss(sqlitedb):
plt.text(i, j, str(j), ha='center', va='bottom')
# 数据库查询最佳模型名称
most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]
# most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]]
most_model = modelnames[0:1]
for model in most_model:
plt.plot(df_combined3['ds'], df_combined3[model], label=model,marker='o')
# 当前日期画竖虚线