diff --git a/config_jingbo.py b/config_jingbo.py index e6038ad..db41838 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -163,7 +163,7 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 is_train = False # 是否训练 is_debug = True # 是否调试 -is_eta = False # 是否使用eta接口 +is_eta = True # 是否使用eta接口 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 diff --git a/main_yuanyou.py b/main_yuanyou.py index d24d65d..ac4fbdc 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -1,12 +1,76 @@ # 读取配置 -from lib.dataread import * -from lib.tools import SendMail,exception_logger -from models.nerulforcastmodels import ex_Model,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting -import glob +from lib.dataread import * +from config_jingbo import * +from lib.tools import SendMail, exception_logger +from models.nerulforcastmodels import ex_Model, model_losss, model_losss_juxiting, brent_export_pdf, tansuanli_export_pdf, pp_export_pdf, model_losss_juxiting +import datetime import torch torch.set_float32_matmul_precision("high") +global_config.update({ + # 核心参数 + 'logger': logger, + 'dataset': dataset, + 'y': y, + 'is_debug': is_debug, + 'is_train': is_train, + 'is_fivemodels': is_fivemodels, + 'settings': settings, + + + # 模型参数 + 'data_set': data_set, + 'input_size': input_size, + 'horizon': horizon, + 'train_steps': train_steps, + 'val_check_steps': val_check_steps, + 'val_size': val_size, + 'test_size': test_size, + 'modelsindex': modelsindex, + 'rote': rote, + + # 特征工程开关 + 'is_del_corr': is_del_corr, + 'is_del_tow_month': is_del_tow_month, + 'is_eta': is_eta, + 'is_update_eta': is_update_eta, + 'is_fivemodels': is_fivemodels, + 'early_stop_patience_steps': early_stop_patience_steps, + + # 时间参数 + 'start_year': start_year, + 'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"), + 'freq': freq, # 保持列表结构 + + # 接口配置 + 'login_pushreport_url': login_pushreport_url, + 'login_data': login_data, + 'upload_url': upload_url, + 'upload_warning_url': upload_warning_url, + 'warning_data': warning_data, + + # 查询接口 + 'query_data_list_item_nos_url': query_data_list_item_nos_url, + 'query_data_list_item_nos_data': query_data_list_item_nos_data, + + # eta 配置 + 'APPID': APPID, + 'SECRET': SECRET, + 'etadata': data, + 'edbcodelist': edbcodelist, + 'ClassifyId': ClassifyId, + 'edbcodedataurl': edbcodedataurl, + 'classifyidlisturl': classifyidlisturl, + 'edbdatapushurl': edbdatapushurl, + 'edbdeleteurl': edbdeleteurl, + 'edbbusinessurl': edbbusinessurl, + 'ClassifyId': ClassifyId, + 'classifylisturl': classifylisturl, + + # 数据库配置 + 'sqlitedb': sqlitedb, +}) def predict_main(): @@ -48,31 +112,23 @@ def predict_main(): 返回: None """ - global end_time - signature = BinanceAPI(APPID, SECRET) - etadata = EtaReader(signature=signature, - classifylisturl=classifylisturl, - classifyidlisturl=classifyidlisturl, - edbcodedataurl=edbcodedataurl, - edbcodelist=edbcodelist, - edbdatapushurl=edbdatapushurl, - edbdeleteurl=edbdeleteurl, - edbbusinessurl=edbbusinessurl - ) + end_time = global_config['end_time'] # 获取数据 if is_eta: logger.info('从eta获取数据...') signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, - classifylisturl=classifylisturl, - classifyidlisturl=classifyidlisturl, - edbcodedataurl=edbcodedataurl, - edbcodelist=edbcodelist, - edbdatapushurl=edbdatapushurl, - edbdeleteurl=edbdeleteurl, - edbbusinessurl=edbbusinessurl, - ) - df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(data_set=data_set, dataset=dataset) # 原始数据,未处理 + classifylisturl=global_config['classifylisturl'], + classifyidlisturl=global_config['classifyidlisturl'], + edbcodedataurl=global_config['edbcodedataurl'], + edbcodelist=global_config['edbcodelist'], + edbdatapushurl=global_config['edbdatapushurl'], + edbdeleteurl=global_config['edbdeleteurl'], + edbbusinessurl=global_config['edbbusinessurl'], + classifyId=global_config['ClassifyId'], + ) + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( + data_set=data_set, dataset=dataset) # 原始数据,未处理 if is_market: logger.info('从市场信息平台获取数据...') @@ -83,26 +139,26 @@ def predict_main(): df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) else: logger.info('从市场信息平台获取数据') - df_zhibiaoshuju = get_market_data(end_time,df_zhibiaoshuju) - - except : + df_zhibiaoshuju = get_market_data( + end_time, df_zhibiaoshuju) + + except: logger.info('最高最低价拼接失败') - + # 保存到xlsx文件的sheet表 - with pd.ExcelWriter(os.path.join(dataset,data_set)) as file: + with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) - - + # 数据处理 df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - end_time=end_time) + end_time=end_time) else: # 读取数据 logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - df,df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 # 更改预测列名称 df.rename(columns={y: 'y'}, inplace=True) @@ -124,47 +180,65 @@ def predict_main(): else: for row in first_row.itertuples(index=False): row_dict = row._asdict() - row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') - check_query = sqlitedb.select_data('trueandpredict', where_condition=f"ds = '{row.ds}'") + config.logger.info(f'要保存的真实值:{row_dict}') + # 判断ds是否为字符串类型,如果不是则转换为字符串类型 + if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): + row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + elif not isinstance(row_dict['ds'], str): + try: + row_dict['ds'] = pd.to_datetime( + row_dict['ds']).strftime('%Y-%m-%d') + except: + logger.warning(f"无法解析的时间格式: {row_dict['ds']}") + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') + check_query = sqlitedb.select_data( + 'trueandpredict', where_condition=f"ds = '{row.ds}'") if len(check_query) > 0: - set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) - sqlitedb.update_data('trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") + set_clause = ", ".join( + [f"{key} = '{value}'" for key, value in row_dict.items()]) + sqlitedb.update_data( + 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") continue - sqlitedb.insert_data('trueandpredict', tuple(row_dict.values()), columns=row_dict.keys()) + sqlitedb.insert_data('trueandpredict', tuple( + row_dict.values()), columns=row_dict.keys()) # 更新accuracy表的y值 if not sqlitedb.check_table_exists('accuracy'): pass else: - update_y = sqlitedb.select_data('accuracy',where_condition="y is null") + update_y = sqlitedb.select_data( + 'accuracy', where_condition="y is null") if len(update_y) > 0: logger.info('更新accuracy表的y值') # 找到update_y 中ds且df中的y的行 - update_y = update_y[update_y['ds']<=end_time] + update_y = update_y[update_y['ds'] <= end_time] logger.info(f'要更新y的信息:{update_y}') # try: for row in update_y.itertuples(index=False): try: - row_dict = row._asdict() - yy = df[df['ds']==row_dict['ds']]['y'].values[0] - LOW = df[df['ds']==row_dict['ds']]['Brentzdj'].values[0] - HIGH = df[df['ds']==row_dict['ds']]['Brentzgj'].values[0] - sqlitedb.update_data('accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") + row_dict = row._asdict() + yy = df[df['ds'] == row_dict['ds']]['y'].values[0] + LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] + HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] + sqlitedb.update_data( + 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") except: logger.info(f'更新accuracy表的y值失败:{row_dict}') # except Exception as e: # logger.info(f'更新accuracy表的y值失败:{e}') - import datetime # 判断当前日期是不是周一 is_weekday = datetime.datetime.now().weekday() == 0 if is_weekday: logger.info('今天是周一,更新预测模型') # 计算最近60天预测残差最低的模型名称 - model_results = sqlitedb.select_data('trueandpredict', order_by="ds DESC", limit="60") + model_results = sqlitedb.select_data( + 'trueandpredict', order_by="ds DESC", limit="60") # 删除空值率为90%以上的列 if len(model_results) > 10: - model_results = model_results.dropna(thresh=len(model_results)*0.1,axis=1) + model_results = model_results.dropna( + thresh=len(model_results)*0.1, axis=1) # 删除空行 model_results = model_results.dropna() modelnames = model_results.columns.to_list()[2:-1] @@ -172,54 +246,61 @@ def predict_main(): model_results[col] = model_results[col].astype(np.float32) # 计算每个预测值与真实值之间的偏差率 for model in modelnames: - model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y'] + model_results[f'{model}_abs_error_rate'] = abs( + model_results['y'] - model_results[model]) / model_results['y'] # 获取每行对应的最小偏差率值 - min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) + min_abs_error_rate_values = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) # 获取每行对应的最小偏差率值对应的列名 - min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) + min_abs_error_rate_column_name = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) # 将列名索引转换为列名 - min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0]) + min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( + lambda x: x.split('_')[0]) # 取出现次数最多的模型名称 most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") # 保存结果到数据库 if not sqlitedb.check_table_exists('most_model'): - sqlitedb.create_table('most_model', columns="ds datetime, most_common_model TEXT") - sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) - if is_update_warning_data: - upload_warning_info(len(warning_data_df)) + sqlitedb.create_table( + 'most_model', columns="ds datetime, most_common_model TEXT") + sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) try: if is_weekday: - # if True: + # if True: logger.info('今天是周一,发送特征预警') # 上传预警信息到数据库 warning_data_df = df_zhibiaoliebiao.copy() - warning_data_df = warning_data_df[warning_data_df['停更周期']> 3 ][['指标名称', '指标id', '频度','更新周期','指标来源','最后更新时间','停更周期']] + warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ + '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] # 重命名列名 - warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) + warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', + '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) from sqlalchemy import create_engine import urllib global password if '@' in password: password = urllib.parse.quote_plus(password) - engine = create_engine(f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') - warning_data_df['WARNING_DATE'] = datetime.date.today().strftime("%Y-%m-%d %H:%M:%S") - warning_data_df['TENANT_CODE'] = 'T0004' + engine = create_engine( + f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') + warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( + "%Y-%m-%d %H:%M:%S") + warning_data_df['TENANT_CODE'] = 'T0004' # 插入数据之前查询表数据然后新增id列 existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) if not existing_data.empty: max_id = existing_data['ID'].astype(int).max() - warning_data_df['ID'] = range(max_id + 1, max_id + 1 + len(warning_data_df)) + warning_data_df['ID'] = range( + max_id + 1, max_id + 1 + len(warning_data_df)) else: warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) - warning_data_df.to_sql(table_name, con=engine, if_exists='append', index=False) + warning_data_df.to_sql( + table_name, con=engine, if_exists='append', index=False) if is_update_warning_data: upload_warning_info(len(warning_data_df)) - - # 发送钉钉消息 - upload_warning_info(len(warning_data_df)) except: logger.info('上传预警信息到数据库失败') @@ -232,72 +313,70 @@ def predict_main(): now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') ex_Model(df, - horizon=horizon, - input_size=input_size, - train_steps=train_steps, - val_check_steps=val_check_steps, - early_stop_patience_steps=early_stop_patience_steps, - is_debug=is_debug, - dataset=dataset, - is_train=is_train, - is_fivemodels=is_fivemodels, - val_size=val_size, - test_size=test_size, - settings=settings, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], now=now, - etadata=etadata, - modelsindex=modelsindex, + etadata=global_config['etadata'], + modelsindex=global_config['modelsindex'], data=data, - is_eta=is_eta, - end_time=end_time, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], ) + # logger.info('模型训练完成') - logger.info('模型训练完成') - logger.info('训练数据绘图ing') - model_results3 = model_losss(sqlitedb,end_time=end_time) + model_results3 = model_losss(sqlitedb, end_time=end_time) logger.info('训练数据绘图end') - - # 模型报告 + + # # 模型报告 logger.info('制作报告ing') - title = f'{settings}--{end_time}-预测报告' # 报告标题 - reportname = f'Brent原油大模型预测--{end_time}.pdf' # 报告文件名 - reportname = reportname.replace(':', '-') # 替换冒号 - brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time, - reportname=reportname,sqlitedb=sqlitedb), + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = f'Brent原油大模型日度预测--{end_time}.pdf' # 报告文件名 + reportname = reportname.replace(':', '-') # 替换冒号 + brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), logger.info('制作报告end') logger.info('模型训练完成') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) - + # # lstm 多变量模型 # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) - + # # GRU 模型 # # ex_GRU(df) # 发送邮件 - m = SendMail( - username=username, - passwd=passwd, - recv=recv, - title=title, - content=content, - file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), - ssl=ssl, - ) - # m.send_mail() + # m = SendMail( + # username=username, + # passwd=passwd, + # recv=recv, + # title=title, + # content=content, + # file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), + # ssl=ssl, + # ) + # m.send_mail() if __name__ == '__main__': # global end_time - # is_on = True # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2025-1-20', '2025-2-6', freq='B'): + # for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'): # end_time = i_time.strftime('%Y-%m-%d') # predict_main() - predict_main() \ No newline at end of file + predict_main() diff --git a/main_yuanyou_yuedu.py b/main_yuanyou_yuedu.py index c00a2b0..5a504fa 100644 --- a/main_yuanyou_yuedu.py +++ b/main_yuanyou_yuedu.py @@ -35,6 +35,7 @@ global_config.update({ 'is_del_tow_month': is_del_tow_month, 'is_eta': is_eta, 'is_update_eta': is_update_eta, + 'is_fivemodels': is_fivemodels, 'early_stop_patience_steps': early_stop_patience_steps, # 时间参数 @@ -57,6 +58,15 @@ global_config.update({ 'APPID': APPID, 'SECRET': SECRET, 'etadata': data, + 'edbcodelist': edbcodelist, + 'ClassifyId': ClassifyId, + 'edbcodedataurl': edbcodedataurl, + 'classifyidlisturl': classifyidlisturl, + 'edbdatapushurl': edbdatapushurl, + 'edbdeleteurl': edbdeleteurl, + 'edbbusinessurl': edbbusinessurl, + 'ClassifyId': ClassifyId, + 'classifylisturl': classifylisturl, # 数据库配置 'sqlitedb': sqlitedb, diff --git a/main_yuanyou_zhoudu.py b/main_yuanyou_zhoudu.py index 82df921..8718bbc 100644 --- a/main_yuanyou_zhoudu.py +++ b/main_yuanyou_zhoudu.py @@ -35,11 +35,12 @@ global_config.update({ 'is_del_tow_month': is_del_tow_month, 'is_eta': is_eta, 'is_update_eta': is_update_eta, + 'is_fivemodels': is_fivemodels, 'early_stop_patience_steps': early_stop_patience_steps, # 时间参数 'start_year': start_year, - 'end_time': end_time, + 'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"), 'freq': freq, # 保持列表结构 # 接口配置 @@ -57,12 +58,20 @@ global_config.update({ 'APPID': APPID, 'SECRET': SECRET, 'etadata': data, + 'edbcodelist': edbcodelist, + 'ClassifyId': ClassifyId, + 'edbcodedataurl': edbcodedataurl, + 'classifyidlisturl': classifyidlisturl, + 'edbdatapushurl': edbdatapushurl, + 'edbdeleteurl': edbdeleteurl, + 'edbbusinessurl': edbbusinessurl, + 'ClassifyId': ClassifyId, + 'classifylisturl': classifylisturl, # 数据库配置 'sqlitedb': sqlitedb, }) - def predict_main(): """ 主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。