From 24525e227d33169061afd81db85aed941ac130f6 Mon Sep 17 00:00:00 2001 From: workpc Date: Tue, 11 Mar 2025 10:10:20 +0800 Subject: [PATCH] =?UTF-8?q?=E8=81=9A=E7=83=AF=E7=83=83=E6=97=A5=E5=BA=A6?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_juxiting.py | 10 +++++----- lib/dataread.py | 41 ++++++----------------------------------- main_juxiting.py | 32 ++++++++++++++++++-------------- 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/config_juxiting.py b/config_juxiting.py index 6054a38..2977718 100644 --- a/config_juxiting.py +++ b/config_juxiting.py @@ -201,17 +201,17 @@ table_name = 'v_tbl_crude_oil_warning' # 开关 -is_train = False # 是否训练 +is_train = True # 是否训练 is_debug = True # 是否调试 -is_eta = True # 是否使用eta接口 +is_eta = False # 是否使用eta接口 is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 is_edbnamelist = False # 自定义特征,对应上面的edbnamelist is_update_eta = False # 预测结果上传到eta -is_update_report = True # 是否上传报告 -is_update_warning_data = True # 是否上传预警数据 +is_update_report = False # 是否上传报告 +is_update_warning_data = False # 是否上传预警数据 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 is_del_tow_month = True # 是否删除两个月不更新的特征 @@ -306,7 +306,7 @@ logger.setLevel(logging.INFO) file_handler = logging.handlers.RotatingFileHandler(os.path.join( log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5) file_handler.setFormatter(logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')) # 配置控制台处理器,将日志打印到控制台 console_handler = logging.StreamHandler() diff --git a/lib/dataread.py b/lib/dataread.py index e30c6a1..7321ae9 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -115,37 +115,6 @@ global_config = { # 数据库配置 'sqlitedb': None, } - -# logger = global_config['logger'] -# dataset = global_config['dataset'] -# y = global_config['y'] -# data_set = global_config['data_set'] -# input_size = global_config['input_size'] -# horizon = global_config['horizon'] -# train_steps = global_config['train_steps'] -# val_check_steps = global_config['val_check_steps'] -# is_del_corr = global_config['is_del_corr'] -# is_del_tow_month = global_config['is_del_tow_month'] -# is_eta = global_config['is_eta'] -# is_update_eta = global_config['is_update_eta'] -# is_update_eta_data = global_config['is_update_eta_data'] -# start_year = global_config['start_year'] -# end_time = global_config['end_time'] -# freq = global_config['freq'][0] -# offsite_col = global_config['offsite_col'] -# avg_cols = global_config['avg_col'] -# offsite = global_config['offsite'] -# edbcodenamedict = global_config['edbcodenamedict'] -# query_data_list_item_nos_url = global_config['query_data_list_item_nos_url'] -# query_data_list_item_nos_data = global_config['query_data_list_item_nos_data'] -# config.login_pushreport_url = global_config['config.login_pushreport_url'] -# login_data = global_config['login_data'] -# upload_url = global_config['upload_url'] -# upload_warning_url = global_config['upload_warning_url'] -# upload_warning_data = global_config['upload_warning_data'] -# warning_data = global_config['upload_warning_data'] -# APPID = global_config['APPID'] -# SECRET = global_config['SECRET'] # 定义函数 @@ -973,14 +942,16 @@ def datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_t df.rename(columns={datecol: 'ds'}, inplace=True) # 指定列统一减少数值 - df[config.offsite_col] = df[config.offsite_col]-config.offsite + print(global_config.keys()) + df[global_config['offsite_col']] = df[global_config['offsite_col']] - \ + global_config['offsite'] # 预测列为avg_cols的均值 - df[y] = df[config.avg_cols].mean(axis=1) + df[global_config['y']] = df[global_config['avg_cols']].mean(axis=1) # 去掉多余的列avg_cols - df = df.drop(columns=config.avg_cols) + df = df.drop(columns=global_config['avg_cols']) # 重命名预测列 - df.rename(columns={y: 'y'}, inplace=True) + df.rename(columns={global_config['y']: 'y'}, inplace=True) # 按时间顺序排列 df.sort_values(by='ds', inplace=True) df['ds'] = pd.to_datetime(df['ds']) diff --git a/main_juxiting.py b/main_juxiting.py index fab405c..1e4054b 100644 --- a/main_juxiting.py +++ b/main_juxiting.py @@ -3,7 +3,7 @@ from lib.dataread import * from config_juxiting import * from lib.tools import SendMail, exception_logger -from models.nerulforcastmodels import ex_Model, model_losss, model_losss_juxiting, brent_export_pdf, tansuanli_export_pdf, pp_export_pdf, model_losss_juxiting +from models.nerulforcastmodels import ex_Model, model_losss_juxiting, tansuanli_export_pdf, pp_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -13,6 +13,10 @@ global_config.update({ 'logger': logger, 'dataset': dataset, 'y': y, + 'offsite_col': offsite_col, + 'avg_cols': avg_cols, + 'offsite': offsite, + 'edbcodenamedict': edbcodenamedict, 'is_debug': is_debug, 'is_train': is_train, 'is_fivemodels': is_fivemodels, @@ -150,14 +154,14 @@ def predict_main(): df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) # 数据处理 - df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - end_time=end_time) + df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) else: # 读取数据 logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + df, df_zhibiaoliebiao = getdata_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 # 更改预测列名称 df.rename(columns={y: 'y'}, inplace=True) @@ -335,19 +339,19 @@ def predict_main(): logger.info('模型训练完成') logger.info('训练数据绘图ing') - model_results3 = model_losss(sqlitedb, end_time=end_time) + model_results3 = model_losss_juxiting(sqlitedb, end_time=end_time) logger.info('训练数据绘图end') # # 模型报告 - # logger.info('制作报告ing') - # title = f'{settings}--{end_time}-预测报告' # 报告标题 - # reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名 - # reportname = reportname.replace(':', '-') # 替换冒号 - # brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, - # reportname=reportname, sqlitedb=sqlitedb), + logger.info('制作报告ing') + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名 + reportname = reportname.replace(':', '-') # 替换冒号 + pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), - # logger.info('制作报告end') - # logger.info('模型训练完成') + logger.info('制作报告end') + logger.info('模型训练完成') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)