diff --git a/config_juxiting_zhoudu.py b/config_juxiting_zhoudu.py index 8cfc725..fce7337 100644 --- a/config_juxiting_zhoudu.py +++ b/config_juxiting_zhoudu.py @@ -202,15 +202,15 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 is_train = False # 是否训练 -is_debug = False # 是否调试 -is_eta = False # 是否使用eta接口 +is_debug = True # 是否调试 +is_eta = True # 是否使用eta接口 is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 is_edbnamelist = False # 自定义特征,对应上面的edbnamelist -is_update_eta = True # 预测结果上传到eta -is_update_report = True # 是否上传报告 +is_update_eta = False # 预测结果上传到eta +is_update_report = False # 是否上传报告 is_update_warning_data = False # 是否上传预警数据 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 is_del_tow_month = True # 是否删除两个月不更新的特征 @@ -224,9 +224,9 @@ print("数据库连接成功",host,dbname,dbusername) # 数据截取日期 -start_year = 2020 # 数据开始年份 -end_time = '2025-01-27' # 数据截取日期 -freq = 'W' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 +start_year = 2015 # 数据开始年份 +end_time = '' # 数据截取日期 +freq = 'WW' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 add_kdj = False # 是否添加kdj指标 @@ -243,8 +243,8 @@ avg_cols = [ ] offsite = 80 offsite_col = ['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'] -horizon =1 # 预测的步长 -input_size = 7 # 输入序列长度 +horizon =2 # 预测的步长 +input_size = 14 # 输入序列长度 train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 val_check_steps = 30 # 评估频率 early_stop_patience_steps = 5 # 早停的耐心步数 @@ -263,10 +263,10 @@ weight_dict = [0.4,0.15,0.1,0.1,0.25] # 权重 ### 文件 data_set = 'PP指标数据.xlsx' # 数据集文件 -dataset = 'juxitingzhududataset' # 数据集文件夹 +dataset = 'juxitingzhoududataset' # 数据集文件夹 # 数据库名称 -db_name = os.path.join(dataset,'jbsh_juxiting.db') +db_name = os.path.join(dataset,'jbsh_juxiting_zhoudu.db') sqlitedb = SQLiteHandler(db_name) sqlitedb.connect() diff --git a/lib/dataread.py b/lib/dataread.py index 7321ae9..5b3c278 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -1063,6 +1063,22 @@ def getdata_juxiting(filename, datecol='date', y='y', dataset='', add_kdj=False, return df, df_zhibiaoliebiao +def getdata_zhoudu_juxiting(filename, datecol='date', y='y', dataset='', add_kdj=False, is_timefurture=False, end_time=''): + config.logger.info('getdata接收:'+filename+' '+datecol+' '+end_time) + # 判断后缀名 csv或excel + if filename.endswith('.csv'): + df = loadcsv(filename) + else: + # 读取excel 指标数据 + df_zhibiaoshuju = pd.read_excel(filename, sheet_name='指标数据') + df_zhibiaoliebiao = pd.read_excel(filename, sheet_name='指标列表') + + # 日期字符串转为datatime + df = zhoududatachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, y=y, dataset=dataset, + add_kdj=add_kdj, is_timefurture=is_timefurture, end_time=end_time) + + return df, df_zhibiaoliebiao + def sanitize_filename(filename): # 使用正则表达式替换不合规的字符 diff --git a/main_juxiting_zhoudu.py b/main_juxiting_zhoudu.py index 7b3a525..56d6d1b 100644 --- a/main_juxiting_zhoudu.py +++ b/main_juxiting_zhoudu.py @@ -1,12 +1,80 @@ # 读取配置 -from lib.dataread import * -from lib.tools import SendMail,exception_logger -from models.nerulforcastmodels import ex_Model_Juxiting,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting -import glob +from lib.dataread import * +from config_juxiting_zhoudu import * +from lib.tools import SendMail, exception_logger +from models.nerulforcastmodels import ex_Model, model_losss_juxiting, tansuanli_export_pdf, pp_export_pdf +import datetime import torch torch.set_float32_matmul_precision("high") +global_config.update({ + # 核心参数 + 'logger': logger, + 'dataset': dataset, + 'y': y, + 'offsite_col': offsite_col, + 'avg_cols': avg_cols, + 'offsite': offsite, + 'edbcodenamedict': edbcodenamedict, + 'is_debug': is_debug, + 'is_train': is_train, + 'is_fivemodels': is_fivemodels, + 'settings': settings, + + + # 模型参数 + 'data_set': data_set, + 'input_size': input_size, + 'horizon': horizon, + 'train_steps': train_steps, + 'val_check_steps': val_check_steps, + 'val_size': val_size, + 'test_size': test_size, + 'modelsindex': modelsindex, + 'rote': rote, + + # 特征工程开关 + 'is_del_corr': is_del_corr, + 'is_del_tow_month': is_del_tow_month, + 'is_eta': is_eta, + 'is_update_eta': is_update_eta, + 'is_fivemodels': is_fivemodels, + 'early_stop_patience_steps': early_stop_patience_steps, + + # 时间参数 + 'start_year': start_year, + 'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"), + 'freq': freq, # 保持列表结构 + + # 接口配置 + 'login_pushreport_url': login_pushreport_url, + 'login_data': login_data, + 'upload_url': upload_url, + 'upload_warning_url': upload_warning_url, + 'warning_data': warning_data, + + # 查询接口 + 'query_data_list_item_nos_url': query_data_list_item_nos_url, + 'query_data_list_item_nos_data': query_data_list_item_nos_data, + + # eta 配置 + 'APPID': APPID, + 'SECRET': SECRET, + 'etadata': data, + 'edbcodelist': edbcodelist, + 'ClassifyId': ClassifyId, + 'edbcodedataurl': edbcodedataurl, + 'classifyidlisturl': classifyidlisturl, + 'edbdatapushurl': edbdatapushurl, + 'edbdeleteurl': edbdeleteurl, + 'edbbusinessurl': edbbusinessurl, + 'ClassifyId': ClassifyId, + 'classifylisturl': classifylisturl, + + # 数据库配置 + 'sqlitedb': sqlitedb, +}) def predict_main(): @@ -48,31 +116,23 @@ def predict_main(): 返回: None """ - global end_time - signature = BinanceAPI(APPID, SECRET) - etadata = EtaReader(signature=signature, - classifylisturl=classifylisturl, - classifyidlisturl=classifyidlisturl, - edbcodedataurl=edbcodedataurl, - edbcodelist=edbcodelist, - edbdatapushurl=edbdatapushurl, - edbdeleteurl=edbdeleteurl, - edbbusinessurl=edbbusinessurl - ) + end_time = global_config['end_time'] # 获取数据 if is_eta: logger.info('从eta获取数据...') signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, - classifylisturl=classifylisturl, - classifyidlisturl=classifyidlisturl, - edbcodedataurl=edbcodedataurl, - edbcodelist=edbcodelist, - edbdatapushurl=edbdatapushurl, - edbdeleteurl=edbdeleteurl, - edbbusinessurl=edbbusinessurl, + classifylisturl=global_config['classifylisturl'], + classifyidlisturl=global_config['classifyidlisturl'], + edbcodedataurl=global_config['edbcodedataurl'], + edbcodelist=global_config['edbcodelist'], + edbdatapushurl=global_config['edbdatapushurl'], + edbdeleteurl=global_config['edbdeleteurl'], + edbbusinessurl=global_config['edbbusinessurl'], + classifyId=global_config['ClassifyId'], ) - df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data(data_set=data_set, dataset=dataset) # 原始数据,未处理 + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( + data_set=data_set, dataset=dataset) # 原始数据,未处理 if is_market: logger.info('从市场信息平台获取数据...') @@ -83,26 +143,26 @@ def predict_main(): df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) else: logger.info('从市场信息平台获取数据') - df_zhibiaoshuju = get_market_data(end_time,df_zhibiaoshuju) - - except : + df_zhibiaoshuju = get_market_data( + end_time, df_zhibiaoshuju) + + except: logger.info('最高最低价拼接失败') - + # 保存到xlsx文件的sheet表 - with pd.ExcelWriter(os.path.join(dataset,data_set)) as file: + with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) - - + # 数据处理 - df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - end_time=end_time) + df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) else: # 读取数据 logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - df,df_zhibiaoliebiao = getdata_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 # 更改预测列名称 df.rename(columns={y: 'y'}, inplace=True) @@ -124,47 +184,65 @@ def predict_main(): else: for row in first_row.itertuples(index=False): row_dict = row._asdict() - row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') - check_query = sqlitedb.select_data('trueandpredict', where_condition=f"ds = '{row.ds}'") + config.logger.info(f'要保存的真实值:{row_dict}') + # 判断ds是否为字符串类型,如果不是则转换为字符串类型 + if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): + row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + elif not isinstance(row_dict['ds'], str): + try: + row_dict['ds'] = pd.to_datetime( + row_dict['ds']).strftime('%Y-%m-%d') + except: + logger.warning(f"无法解析的时间格式: {row_dict['ds']}") + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') + check_query = sqlitedb.select_data( + 'trueandpredict', where_condition=f"ds = '{row.ds}'") if len(check_query) > 0: - set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) - sqlitedb.update_data('trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") + set_clause = ", ".join( + [f"{key} = '{value}'" for key, value in row_dict.items()]) + sqlitedb.update_data( + 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") continue - sqlitedb.insert_data('trueandpredict', tuple(row_dict.values()), columns=row_dict.keys()) + sqlitedb.insert_data('trueandpredict', tuple( + row_dict.values()), columns=row_dict.keys()) # 更新accuracy表的y值 if not sqlitedb.check_table_exists('accuracy'): pass else: - update_y = sqlitedb.select_data('accuracy',where_condition="y is null") + update_y = sqlitedb.select_data( + 'accuracy', where_condition="y is null") if len(update_y) > 0: logger.info('更新accuracy表的y值') # 找到update_y 中ds且df中的y的行 - update_y = update_y[update_y['ds']<=end_time] + update_y = update_y[update_y['ds'] <= end_time] logger.info(f'要更新y的信息:{update_y}') # try: for row in update_y.itertuples(index=False): try: - row_dict = row._asdict() - yy = df[df['ds']==row_dict['ds']]['y'].values[0] - LOW = df[df['ds']==row_dict['ds']]['Brentzdj'].values[0] - HIGH = df[df['ds']==row_dict['ds']]['Brentzgj'].values[0] - sqlitedb.update_data('accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") + row_dict = row._asdict() + yy = df[df['ds'] == row_dict['ds']]['y'].values[0] + LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] + HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] + sqlitedb.update_data( + 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") except: logger.info(f'更新accuracy表的y值失败:{row_dict}') # except Exception as e: # logger.info(f'更新accuracy表的y值失败:{e}') - import datetime # 判断当前日期是不是周一 is_weekday = datetime.datetime.now().weekday() == 0 if is_weekday: logger.info('今天是周一,更新预测模型') # 计算最近60天预测残差最低的模型名称 - model_results = sqlitedb.select_data('trueandpredict', order_by="ds DESC", limit="60") + model_results = sqlitedb.select_data( + 'trueandpredict', order_by="ds DESC", limit="60") # 删除空值率为90%以上的列 if len(model_results) > 10: - model_results = model_results.dropna(thresh=len(model_results)*0.1,axis=1) + model_results = model_results.dropna( + thresh=len(model_results)*0.1, axis=1) # 删除空行 model_results = model_results.dropna() modelnames = model_results.columns.to_list()[2:-1] @@ -172,47 +250,59 @@ def predict_main(): model_results[col] = model_results[col].astype(np.float32) # 计算每个预测值与真实值之间的偏差率 for model in modelnames: - model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y'] + model_results[f'{model}_abs_error_rate'] = abs( + model_results['y'] - model_results[model]) / model_results['y'] # 获取每行对应的最小偏差率值 - min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) + min_abs_error_rate_values = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) # 获取每行对应的最小偏差率值对应的列名 - min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) + min_abs_error_rate_column_name = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) # 将列名索引转换为列名 - min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0]) + min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( + lambda x: x.split('_')[0]) # 取出现次数最多的模型名称 most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") # 保存结果到数据库 if not sqlitedb.check_table_exists('most_model'): - sqlitedb.create_table('most_model', columns="ds datetime, most_common_model TEXT") - sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) + sqlitedb.create_table( + 'most_model', columns="ds datetime, most_common_model TEXT") + sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) try: if is_weekday: - # if True: + # if True: logger.info('今天是周一,发送特征预警') # 上传预警信息到数据库 warning_data_df = df_zhibiaoliebiao.copy() - warning_data_df = warning_data_df[warning_data_df['停更周期']> 3 ][['指标名称', '指标id', '频度','更新周期','指标来源','最后更新时间','停更周期']] + warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ + '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] # 重命名列名 - warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) + warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', + '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) from sqlalchemy import create_engine import urllib global password if '@' in password: password = urllib.parse.quote_plus(password) - engine = create_engine(f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') - warning_data_df['WARNING_DATE'] = datetime.date.today().strftime("%Y-%m-%d %H:%M:%S") - warning_data_df['TENANT_CODE'] = 'T0004' + engine = create_engine( + f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') + warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( + "%Y-%m-%d %H:%M:%S") + warning_data_df['TENANT_CODE'] = 'T0004' # 插入数据之前查询表数据然后新增id列 existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) if not existing_data.empty: max_id = existing_data['ID'].astype(int).max() - warning_data_df['ID'] = range(max_id + 1, max_id + 1 + len(warning_data_df)) + warning_data_df['ID'] = range( + max_id + 1, max_id + 1 + len(warning_data_df)) else: warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) - warning_data_df.to_sql(table_name, con=engine, if_exists='append', index=False) + warning_data_df.to_sql( + table_name, con=engine, if_exists='append', index=False) if is_update_warning_data: upload_warning_info(len(warning_data_df)) except: @@ -226,76 +316,75 @@ def predict_main(): row, col = df.shape now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - ex_Model_Juxiting(df, - horizon=horizon, - input_size=input_size, - train_steps=train_steps, - val_check_steps=val_check_steps, - early_stop_patience_steps=early_stop_patience_steps, - is_debug=is_debug, - dataset=dataset, - is_train=is_train, - is_fivemodels=is_fivemodels, - val_size=val_size, - test_size=test_size, - settings=settings, + ex_Model(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], now=now, - etadata=etadata, - modelsindex=modelsindex, + etadata=global_config['etadata'], + modelsindex=global_config['modelsindex'], data=data, - is_eta=is_eta, - end_time=end_time, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], ) - logger.info('模型训练完成') - + logger.info('训练数据绘图ing') - model_results3 = model_losss_juxiting(sqlitedb) + model_results3 = model_losss_juxiting(sqlitedb, end_time=global_config['end_time'],is_fivemodels=global_config['is_fivemodels']) logger.info('训练数据绘图end') - - # 模型报告 + + # # 模型报告 logger.info('制作报告ing') - title = f'{settings}--{end_time}-预测报告' # 报告标题 - reportname = f'PP大模型预测报告--{end_time}.pdf' # 报告文件名 - reportname = reportname.replace(':', '-') # 替换冒号 - pp_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time, - reportname=reportname,sqlitedb=sqlitedb), + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = f'Brent原油大模型周度预测--{end_time}.pdf' # 报告文件名 + reportname = reportname.replace(':', '-') # 替换冒号 + pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), logger.info('制作报告end') logger.info('模型训练完成') # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) - + # # lstm 多变量模型 # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) - + # # GRU 模型 # # ex_GRU(df) # 发送邮件 - m = SendMail( - username=username, - passwd=passwd, - recv=recv, - title=title, - content=content, - file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), - ssl=ssl, - ) - # m.send_mail() + # m = SendMail( + # username=username, + # passwd=passwd, + # recv=recv, + # title=title, + # content=content, + # file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), + # ssl=ssl, + # ) + # m.send_mail() if __name__ == '__main__': # global end_time - # is_on = True - # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - # for i_time in pd.date_range('2025-1-20', '2025-2-6', freq='B'): - # end_time = i_time.strftime('%Y-%m-%d') + # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 + # for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'): # try: + # global_config['end_time'] = i_time.strftime('%Y-%m-%d') # predict_main() - # except: - # pass + # except Exception as e: + # logger.info(f'预测失败:{e}') + # continue - predict_main() \ No newline at end of file + predict_main()