原油月度基础数据调试

This commit is contained in:
jingboyitiji 2025-03-07 13:28:10 +08:00
parent 337ec33ed1
commit 29a65d7c70
3 changed files with 27 additions and 20 deletions

View File

@ -158,7 +158,7 @@ table_name = 'v_tbl_crude_oil_warning'
# 开关
is_train = False # 是否训练
is_debug = True # 是否调试
is_debug = False # 是否调试
is_eta = False # 是否使用eta接口
is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效
is_timefurture = True # 是否使用时间特征
@ -180,8 +180,8 @@ print("数据库连接成功", host, dbname, dbusername)
# 数据截取日期
start_year = 2005 # 数据开始年份
end_time = '' # 数据截取日期
start_year = 2000 # 数据开始年份
end_time = '2023-3-1' # 数据截取日期
freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 "WW" 自定义周
delweekenday = True if freq == 'B' else False # 是否删除周末数据
is_corr = False # 特征是否参与滞后领先提升相关系数

View File

@ -102,7 +102,7 @@ def predict_main():
返回:
None
"""
global end_time
end_time = global_config['end_time']
signature = BinanceAPI(APPID, SECRET)
etadata = EtaReader(signature=signature,
classifylisturl=classifylisturl,
@ -332,22 +332,22 @@ def predict_main():
end_time=global_config['end_time'],
)
# logger.info('模型训练完成')
logger.info('模型训练完成')
logger.info('训练数据绘图ing')
model_results3 = model_losss(sqlitedb, end_time=end_time)
logger.info('训练数据绘图end')
# # 模型报告
logger.info('制作报告ing')
title = f'{settings}--{end_time}-预测报告' # 报告标题
reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名
reportname = reportname.replace(':', '-') # 替换冒号
brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time,
reportname=reportname, sqlitedb=sqlitedb),
# logger.info('制作报告ing')
# title = f'{settings}--{end_time}-预测报告' # 报告标题
# reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名
# reportname = reportname.replace(':', '-') # 替换冒号
# brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time,
# reportname=reportname, sqlitedb=sqlitedb),
logger.info('制作报告end')
logger.info('模型训练完成')
# logger.info('制作报告end')
# logger.info('模型训练完成')
# # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
@ -373,9 +373,14 @@ def predict_main():
if __name__ == '__main__':
# global end_time
# # 遍历2024-11-25 到 2024-12-3 之间的工作日日期
# for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'):
# end_time = i_time.strftime('%Y-%m-%d')
# predict_main()
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期
for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'):
try:
global_config['end_time'] = i_time.strftime('%Y-%m-%d')
predict_main()
except Exception as e:
logger.info(f'预测失败:{e}')
continue
# predict_main()

View File

@ -935,10 +935,10 @@ def model_losss(sqlitedb, end_time):
try:
df_combined = sqlitedb.select_data(
'accuracy', where_condition=f"created_dt <= '{end_time}'")
if len(df_combined) < 10:
if len(df_combined) < 100:
len(df_combined) + ''
except:
df_combined = loadcsv(os.path.join(dataset, "cross_validation.csv"))
df_combined = loadcsv(os.path.join(config.dataset, "cross_validation.csv"))
df_combined = dateConvert(df_combined)
df_combined['CREAT_DATE'] = df_combined['cutoff']
df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要
@ -968,6 +968,8 @@ def model_losss(sqlitedb, end_time):
modelnames.remove('y')
if 'cutoff' in modelnames:
modelnames.remove('cutoff')
if 'ds' in modelnames:
modelnames.remove('ds')
df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要
# 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE