原油月度基础数据调试

This commit is contained in:
jingboyitiji 2025-03-07 13:28:10 +08:00
parent 337ec33ed1
commit 29a65d7c70
3 changed files with 27 additions and 20 deletions

View File

@ -158,7 +158,7 @@ table_name = 'v_tbl_crude_oil_warning'
# 开关 # 开关
is_train = False # 是否训练 is_train = False # 是否训练
is_debug = True # 是否调试 is_debug = False # 是否调试
is_eta = False # 是否使用eta接口 is_eta = False # 是否使用eta接口
is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效
is_timefurture = True # 是否使用时间特征 is_timefurture = True # 是否使用时间特征
@ -180,8 +180,8 @@ print("数据库连接成功", host, dbname, dbusername)
# 数据截取日期 # 数据截取日期
start_year = 2005 # 数据开始年份 start_year = 2000 # 数据开始年份
end_time = '' # 数据截取日期 end_time = '2023-3-1' # 数据截取日期
freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 "WW" 自定义周 freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 "WW" 自定义周
delweekenday = True if freq == 'B' else False # 是否删除周末数据 delweekenday = True if freq == 'B' else False # 是否删除周末数据
is_corr = False # 特征是否参与滞后领先提升相关系数 is_corr = False # 特征是否参与滞后领先提升相关系数

View File

@ -102,7 +102,7 @@ def predict_main():
返回: 返回:
None None
""" """
global end_time end_time = global_config['end_time']
signature = BinanceAPI(APPID, SECRET) signature = BinanceAPI(APPID, SECRET)
etadata = EtaReader(signature=signature, etadata = EtaReader(signature=signature,
classifylisturl=classifylisturl, classifylisturl=classifylisturl,
@ -332,22 +332,22 @@ def predict_main():
end_time=global_config['end_time'], end_time=global_config['end_time'],
) )
# logger.info('模型训练完成') logger.info('模型训练完成')
logger.info('训练数据绘图ing') logger.info('训练数据绘图ing')
model_results3 = model_losss(sqlitedb, end_time=end_time) model_results3 = model_losss(sqlitedb, end_time=end_time)
logger.info('训练数据绘图end') logger.info('训练数据绘图end')
# # 模型报告 # # 模型报告
logger.info('制作报告ing') # logger.info('制作报告ing')
title = f'{settings}--{end_time}-预测报告' # 报告标题 # title = f'{settings}--{end_time}-预测报告' # 报告标题
reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名 # reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名
reportname = reportname.replace(':', '-') # 替换冒号 # reportname = reportname.replace(':', '-') # 替换冒号
brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, # brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time,
reportname=reportname, sqlitedb=sqlitedb), # reportname=reportname, sqlitedb=sqlitedb),
logger.info('制作报告end') # logger.info('制作报告end')
logger.info('模型训练完成') # logger.info('模型训练完成')
# # LSTM 单变量模型 # # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
@ -373,9 +373,14 @@ def predict_main():
if __name__ == '__main__': if __name__ == '__main__':
# global end_time # global end_time
# # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 # 遍历2024-11-25 到 2024-12-3 之间的工作日日期
# for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'): for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'):
# end_time = i_time.strftime('%Y-%m-%d') try:
# predict_main() global_config['end_time'] = i_time.strftime('%Y-%m-%d')
predict_main()
except Exception as e:
logger.info(f'预测失败:{e}')
continue
predict_main()
# predict_main()

View File

@ -935,10 +935,10 @@ def model_losss(sqlitedb, end_time):
try: try:
df_combined = sqlitedb.select_data( df_combined = sqlitedb.select_data(
'accuracy', where_condition=f"created_dt <= '{end_time}'") 'accuracy', where_condition=f"created_dt <= '{end_time}'")
if len(df_combined) < 10: if len(df_combined) < 100:
len(df_combined) + '' len(df_combined) + ''
except: except:
df_combined = loadcsv(os.path.join(dataset, "cross_validation.csv")) df_combined = loadcsv(os.path.join(config.dataset, "cross_validation.csv"))
df_combined = dateConvert(df_combined) df_combined = dateConvert(df_combined)
df_combined['CREAT_DATE'] = df_combined['cutoff'] df_combined['CREAT_DATE'] = df_combined['cutoff']
df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要 df_combined4 = df_combined.copy() # 备份df_combined,后面画图需要
@ -968,6 +968,8 @@ def model_losss(sqlitedb, end_time):
modelnames.remove('y') modelnames.remove('y')
if 'cutoff' in modelnames: if 'cutoff' in modelnames:
modelnames.remove('cutoff') modelnames.remove('cutoff')
if 'ds' in modelnames:
modelnames.remove('ds')
df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要 df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要
# 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE