diff --git a/config_juxiting_yuedu.py b/config_juxiting_yuedu.py index ab8b83b..9900153 100644 --- a/config_juxiting_yuedu.py +++ b/config_juxiting_yuedu.py @@ -88,6 +88,13 @@ bdwdname = [ '次三月', '次四月', ] + +# 数据库预测结果表八大维度列名 +price_columns = [ + 'day_price', 'week_price', 'second_week_price', 'next_week_price', + 'next_month_price', 'next_february_price', 'next_march_price', 'next_april_price' +] + modelsindex = [ { "NHITS": "SELF0000275", @@ -406,8 +413,8 @@ bdwd_items = { } # 报告中八大维度数据项重命名 -columnsrename={'jxtppbdwdbz': '本周', 'jxtppbdwdcey': '次二月', 'jxtppbdwdcr': '次日', 'jxtppbdwdcsiy': '次四月', - 'jxtppbdwdcsany': '次三月', 'jxtppbdwdcy': '次月', 'jxtppbdwdcz': '次周', 'jxtppbdwdgz': '隔周', } +columnsrename = {'jxtppbdwdbz': '本周', 'jxtppbdwdcey': '次二月', 'jxtppbdwdcr': '次日', 'jxtppbdwdcsiy': '次四月', + 'jxtppbdwdcsany': '次三月', 'jxtppbdwdcy': '次月', 'jxtppbdwdcz': '次周', 'jxtppbdwdgz': '隔周', } # 北京环境数据库 host = '192.168.101.27' @@ -459,7 +466,7 @@ print("数据库连接成功", host, dbname, dbusername) # 数据截取日期 start_year = 2000 # 数据开始年份 -end_time = '' # 数据截取日期 +end_time = '2025-07-22' # 数据截取日期 freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 diff --git a/lib/tools.py b/lib/tools.py index 6f75219..24f5e64 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -864,8 +864,7 @@ def find_best_models(date='', global_config=None): # 获取真实价格数据 try: - true_price = pd.read_csv(os.path.join( - global_config['dataset'], '指标数据.csv'))[['ds', 'y']] + true_price = pd.read_csv('juxitingdataset/指标数据.csv')[['ds', 'y']] except FileNotFoundError: global_config['logger'].error( f"未找到文件: {os.path.join(global_config['dataset'], '指标数据.csv')}") @@ -1082,23 +1081,31 @@ def plot_pp_predict_result(y_hat, global_config): import seaborn as sns # 获取y的真实值 - y = pd.read_csv(os.path.join( - global_config['dataset'], '指标数据.csv'))[['ds', 'y']] + # y = pd.read_csv(os.path.join( + # global_config['dataset'], '指标数据.csv'))[['ds', 'y']] + y = pd.read_csv('juxitingdataset/指标数据.csv')[['ds', 'y']] y['ds'] = pd.to_datetime(y['ds']) y = y[y['ds'] < y_hat['ds'].iloc[0]][-30:] + # 取y的最后一行数据追加到y_hat(将真实值最后一行作为预测值起点) + if not y.empty: + # 获取y的最后一行并将'y'列重命名为'predictresult'以匹配y_hat结构 + y_last_row = y.tail(1).rename(columns={'y': 'predictresult'}) + # 追加到y_hat + y_y_hat = pd.concat([y_last_row, y_hat], ignore_index=True) + # 创建图表和子图布局,为表格预留空间 fig, ax = plt.subplots(figsize=(16, 9)) # 对日期列进行排序,确保日期大的在右边 - y_hat = y_hat.sort_values(by='ds') + y_y_hat = y_y_hat.sort_values(by='ds') y = y.sort_values(by='ds') # 绘制 y_hat 的折线图,颜色为橙色 - sns.lineplot(x=y_hat['ds'], y=y_hat['predictresult'], - color='orange', label='y_hat', ax=ax, linestyle='--') + sns.lineplot(x=y_y_hat['ds'], y=y_y_hat['predictresult'], + color='orange', label='预测值', ax=ax, linestyle='--') # 绘制 y 的折线图,颜色为蓝色 - sns.lineplot(x=y['ds'], y=y['y'], color='blue', label='y', ax=ax) + sns.lineplot(x=y['ds'], y=y['y'], color='blue', label='真实值', ax=ax) # date_str = pd.Timestamp(y_hat["ds"].iloc[0]).strftime('%Y-%m-%d') ax.set_title(f'{global_config["end_time"]} PP期货八大维度 预测价格走势') @@ -1129,7 +1136,8 @@ def plot_pp_predict_result(y_hat, global_config): table.set_fontsize(14) plt.tight_layout(rect=[0, 0.1, 1, 1]) # 调整布局,为表格留出空间 - plt.savefig('pp_predict_result.png') + plt.savefig(os.path.join( + global_config['dataset'], 'pp_predict_result.png')) if __name__ == '__main__': diff --git a/main_juxiting.py b/main_juxiting.py index 9894cb7..b747c34 100644 --- a/main_juxiting.py +++ b/main_juxiting.py @@ -554,14 +554,7 @@ if __name__ == '__main__': # logger.info(f'预测失败:{e}') # continue - # predict_main() + predict_main() # push_market_value() # sql_inset_predict(global_config) - from lib.tools import find_best_models - best_bdwd_price = find_best_models( - date='2025-07-22', global_config=global_config) - y_hat = pd.DataFrame(best_bdwd_price).T[['date', 'predictresult']] - y_hat['ds'] = pd.to_datetime(y_hat['date']) - # 绘制PP期货预测结果的图表 - plot_pp_predict_result(y_hat, global_config) diff --git a/main_juxiting_yuedu.py b/main_juxiting_yuedu.py index 011ae4d..c4c2edf 100644 --- a/main_juxiting_yuedu.py +++ b/main_juxiting_yuedu.py @@ -3,7 +3,7 @@ from lib.dataread import * from config_juxiting_yuedu import * from lib.tools import SendMail, convert_df_to_pydantic_pp, exception_logger, get_modelsname -from models.nerulforcastmodels import ex_Model, model_losss_juxiting, pp_export_pdf +from models.nerulforcastmodels import ex_Model, model_losss_juxiting, pp_bdwd_png, pp_export_pdf import datetime import torch torch.set_float32_matmul_precision("high") @@ -24,6 +24,7 @@ global_config.update({ 'settings': settings, 'bdwdname': bdwdname, 'columnsrename': columnsrename, + 'price_columns': price_columns, # 模型参数 @@ -291,210 +292,218 @@ def predict_main(): 返回: None """ - # end_time = global_config['end_time'] - # signature = BinanceAPI(APPID, SECRET) - # etadata = EtaReader(signature=signature, - # classifylisturl=global_config['classifylisturl'], - # classifyidlisturl=global_config['classifyidlisturl'], - # edbcodedataurl=global_config['edbcodedataurl'], - # edbcodelist=global_config['edbcodelist'], - # edbdatapushurl=global_config['edbdatapushurl'], - # edbdeleteurl=global_config['edbdeleteurl'], - # edbbusinessurl=global_config['edbbusinessurl'], - # classifyId=global_config['ClassifyId'], - # ) - # # 获取数据 - # if is_eta: - # logger.info('从eta获取数据...') + end_time = global_config['end_time'] + signature = BinanceAPI(APPID, SECRET) + etadata = EtaReader(signature=signature, + classifylisturl=global_config['classifylisturl'], + classifyidlisturl=global_config['classifyidlisturl'], + edbcodedataurl=global_config['edbcodedataurl'], + edbcodelist=global_config['edbcodelist'], + edbdatapushurl=global_config['edbdatapushurl'], + edbdeleteurl=global_config['edbdeleteurl'], + edbbusinessurl=global_config['edbbusinessurl'], + classifyId=global_config['ClassifyId'], + ) + # 获取数据 + if is_eta: + logger.info('从eta获取数据...') - # df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( - # data_set=data_set, dataset=dataset) # 原始数据,未处理 + df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( + data_set=data_set, dataset=dataset) # 原始数据,未处理 - # if is_market: - # logger.info('从市场信息平台获取数据...') - # try: - # # 如果是测试环境,最高价最低价取excel文档 - # if server_host == '192.168.100.53': - # logger.info('从excel文档获取最高价最低价') - # df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) - # else: - # logger.info('从市场信息平台获取数据') - # df_zhibiaoshuju = get_market_data( - # end_time, df_zhibiaoshuju) + if is_market: + logger.info('从市场信息平台获取数据...') + try: + # 如果是测试环境,最高价最低价取excel文档 + if server_host == '192.168.100.53': + logger.info('从excel文档获取最高价最低价') + df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) + else: + logger.info('从市场信息平台获取数据') + df_zhibiaoshuju = get_market_data( + end_time, df_zhibiaoshuju) - # except: - # logger.info('最高最低价拼接失败') + except: + logger.info('最高最低价拼接失败') - # # 保存到xlsx文件的sheet表 - # with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: - # df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) - # df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) + # 保存到xlsx文件的sheet表 + with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: + df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) + df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) - # # 数据处理 - # df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, - # end_time=end_time) + # 数据处理 + df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, + end_time=end_time) - # else: - # # 读取数据 - # logger.info('读取本地数据:' + os.path.join(dataset, data_set)) - # df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, - # is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 + else: + # 读取数据 + logger.info('读取本地数据:' + os.path.join(dataset, data_set)) + df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, + is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 - # # 更改预测列名称 - # df.rename(columns={y: 'y'}, inplace=True) + # 更改预测列名称 + df.rename(columns={y: 'y'}, inplace=True) - # if is_edbnamelist: - # df = df[edbnamelist] - # df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) - # # 保存最新日期的y值到数据库 - # # 取第一行数据存储到数据库中 - # first_row = df[['ds', 'y']].tail(1) - # # 判断y的类型是否为float - # if not isinstance(first_row['y'].values[0], float): - # logger.info(f'{end_time}预测目标数据为空,跳过') - # return None + if is_edbnamelist: + df = df[edbnamelist] + df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) + # 保存最新日期的y值到数据库 + # 取第一行数据存储到数据库中 + first_row = df[['ds', 'y']].tail(1) + # 判断y的类型是否为float + if not isinstance(first_row['y'].values[0], float): + logger.info(f'{end_time}预测目标数据为空,跳过') + return None - # # 将最新真实值保存到数据库 - # if not sqlitedb.check_table_exists('trueandpredict'): - # first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) - # else: - # for row in first_row.itertuples(index=False): - # row_dict = row._asdict() - # config.logger.info(f'要保存的真实值:{row_dict}') - # # 判断ds是否为字符串类型,如果不是则转换为字符串类型 - # if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): - # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') - # elif not isinstance(row_dict['ds'], str): - # try: - # row_dict['ds'] = pd.to_datetime( - # row_dict['ds']).strftime('%Y-%m-%d') - # except: - # logger.warning(f"无法解析的时间格式: {row_dict['ds']}") - # # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') - # # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') - # check_query = sqlitedb.select_data( - # 'trueandpredict', where_condition=f"ds = '{row.ds}'") - # if len(check_query) > 0: - # set_clause = ", ".join( - # [f"{key} = '{value}'" for key, value in row_dict.items()]) - # sqlitedb.update_data( - # 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") - # continue - # sqlitedb.insert_data('trueandpredict', tuple( - # row_dict.values()), columns=row_dict.keys()) + # 将最新真实值保存到数据库 + if not sqlitedb.check_table_exists('trueandpredict'): + first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) + else: + for row in first_row.itertuples(index=False): + row_dict = row._asdict() + config.logger.info(f'要保存的真实值:{row_dict}') + # 判断ds是否为字符串类型,如果不是则转换为字符串类型 + if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): + row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + elif not isinstance(row_dict['ds'], str): + try: + row_dict['ds'] = pd.to_datetime( + row_dict['ds']).strftime('%Y-%m-%d') + except: + logger.warning(f"无法解析的时间格式: {row_dict['ds']}") + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') + # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') + check_query = sqlitedb.select_data( + 'trueandpredict', where_condition=f"ds = '{row.ds}'") + if len(check_query) > 0: + set_clause = ", ".join( + [f"{key} = '{value}'" for key, value in row_dict.items()]) + sqlitedb.update_data( + 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") + continue + sqlitedb.insert_data('trueandpredict', tuple( + row_dict.values()), columns=row_dict.keys()) - # # 更新accuracy表的y值 - # if not sqlitedb.check_table_exists('accuracy'): - # pass - # else: - # update_y = sqlitedb.select_data( - # 'accuracy', where_condition="y is null") - # if len(update_y) > 0: - # logger.info('更新accuracy表的y值') - # # 找到update_y 中ds且df中的y的行 - # update_y = update_y[update_y['ds'] <= end_time] - # logger.info(f'要更新y的信息:{update_y}') - # # try: - # for row in update_y.itertuples(index=False): - # try: - # row_dict = row._asdict() - # yy = df[df['ds'] == row_dict['ds']]['y'].values[0] - # LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] - # HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] - # sqlitedb.update_data( - # 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") - # except: - # logger.info(f'更新accuracy表的y值失败:{row_dict}') - # # except Exception as e: - # # logger.info(f'更新accuracy表的y值失败:{e}') + # 更新accuracy表的y值 + if not sqlitedb.check_table_exists('accuracy'): + pass + else: + update_y = sqlitedb.select_data( + 'accuracy', where_condition="y is null") + if len(update_y) > 0: + logger.info('更新accuracy表的y值') + # 找到update_y 中ds且df中的y的行 + update_y = update_y[update_y['ds'] <= end_time] + logger.info(f'要更新y的信息:{update_y}') + # try: + for row in update_y.itertuples(index=False): + try: + row_dict = row._asdict() + yy = df[df['ds'] == row_dict['ds']]['y'].values[0] + LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] + HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] + sqlitedb.update_data( + 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") + except: + logger.info(f'更新accuracy表的y值失败:{row_dict}') + # except Exception as e: + # logger.info(f'更新accuracy表的y值失败:{e}') - # # 判断当前日期是不是周一 - # is_weekday = datetime.datetime.now().weekday() == 0 - # if is_weekday: - # logger.info('今天是周一,更新预测模型') - # # 计算最近60天预测残差最低的模型名称 - # model_results = sqlitedb.select_data( - # 'trueandpredict', order_by="ds DESC", limit="60") - # # 删除空值率为90%以上的列 - # if len(model_results) > 10: - # model_results = model_results.dropna( - # thresh=len(model_results)*0.1, axis=1) - # # 删除空行 - # model_results = model_results.dropna() - # modelnames = model_results.columns.to_list()[2:-1] - # for col in model_results[modelnames].select_dtypes(include=['object']).columns: - # model_results[col] = model_results[col].astype(np.float32) - # # 计算每个预测值与真实值之间的偏差率 - # for model in modelnames: - # model_results[f'{model}_abs_error_rate'] = abs( - # model_results['y'] - model_results[model]) / model_results['y'] - # # 获取每行对应的最小偏差率值 - # min_abs_error_rate_values = model_results.apply( - # lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) - # # 获取每行对应的最小偏差率值对应的列名 - # min_abs_error_rate_column_name = model_results.apply( - # lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) - # # 将列名索引转换为列名 - # min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( - # lambda x: x.split('_')[0]) - # # 取出现次数最多的模型名称 - # most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() - # logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") - # # 保存结果到数据库 - # if not sqlitedb.check_table_exists('most_model'): - # sqlitedb.create_table( - # 'most_model', columns="ds datetime, most_common_model TEXT") - # sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( - # '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) + # 判断当前日期是不是周一 + is_weekday = datetime.datetime.now().weekday() == 0 + if is_weekday: + logger.info('今天是周一,更新预测模型') + # 计算最近60天预测残差最低的模型名称 + model_results = sqlitedb.select_data( + 'trueandpredict', order_by="ds DESC", limit="60") + # 删除空值率为90%以上的列 + if len(model_results) > 10: + model_results = model_results.dropna( + thresh=len(model_results)*0.1, axis=1) + # 删除空行 + model_results = model_results.dropna() + modelnames = model_results.columns.to_list()[2:-1] + for col in model_results[modelnames].select_dtypes(include=['object']).columns: + model_results[col] = model_results[col].astype(np.float32) + # 计算每个预测值与真实值之间的偏差率 + for model in modelnames: + model_results[f'{model}_abs_error_rate'] = abs( + model_results['y'] - model_results[model]) / model_results['y'] + # 获取每行对应的最小偏差率值 + min_abs_error_rate_values = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) + # 获取每行对应的最小偏差率值对应的列名 + min_abs_error_rate_column_name = model_results.apply( + lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) + # 将列名索引转换为列名 + min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( + lambda x: x.split('_')[0]) + # 取出现次数最多的模型名称 + most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() + logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") + # 保存结果到数据库 + if not sqlitedb.check_table_exists('most_model'): + sqlitedb.create_table( + 'most_model', columns="ds datetime, most_common_model TEXT") + sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( + '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) - # if is_corr: - # df = corr_feature(df=df) + if is_corr: + df = corr_feature(df=df) - # df1 = df.copy() # 备份一下,后面特征筛选完之后加入ds y 列用 - # logger.info(f"开始训练模型...") - # row, col = df.shape + df1 = df.copy() # 备份一下,后面特征筛选完之后加入ds y 列用 + logger.info(f"开始训练模型...") + row, col = df.shape - # now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') - # ex_Model(df, - # horizon=global_config['horizon'], - # input_size=global_config['input_size'], - # train_steps=global_config['train_steps'], - # val_check_steps=global_config['val_check_steps'], - # early_stop_patience_steps=global_config['early_stop_patience_steps'], - # is_debug=global_config['is_debug'], - # dataset=global_config['dataset'], - # is_train=global_config['is_train'], - # is_fivemodels=global_config['is_fivemodels'], - # val_size=global_config['val_size'], - # test_size=global_config['test_size'], - # settings=global_config['settings'], - # now=now, - # etadata=etadata, - # modelsindex=global_config['modelsindex'], - # data=data, - # is_eta=global_config['is_eta'], - # end_time=global_config['end_time'], - # ) + now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ex_Model(df, + horizon=global_config['horizon'], + input_size=global_config['input_size'], + train_steps=global_config['train_steps'], + val_check_steps=global_config['val_check_steps'], + early_stop_patience_steps=global_config['early_stop_patience_steps'], + is_debug=global_config['is_debug'], + dataset=global_config['dataset'], + is_train=global_config['is_train'], + is_fivemodels=global_config['is_fivemodels'], + val_size=global_config['val_size'], + test_size=global_config['test_size'], + settings=global_config['settings'], + now=now, + etadata=etadata, + modelsindex=global_config['modelsindex'], + data=data, + is_eta=global_config['is_eta'], + end_time=global_config['end_time'], + ) - # logger.info('模型训练完成') + logger.info('模型训练完成') - # logger.info('训练数据绘图ing') - # model_results3 = model_losss_juxiting(sqlitedb, end_time=global_config['end_time'],is_fivemodels=global_config['is_fivemodels']) - # logger.info('训练数据绘图end') + logger.info('训练数据绘图ing') + model_results3 = model_losss_juxiting( + sqlitedb, end_time=global_config['end_time'], is_fivemodels=global_config['is_fivemodels']) + logger.info('训练数据绘图end') - # push_market_value() - # # 模型报告 - # logger.info('制作报告ing') - # title = f'{settings}--{end_time}-预测报告' # 报告标题 - # reportname = f'聚烯烃PP大模型月度预测--{end_time}.pdf' # 报告文件名 - # reportname = reportname.replace(':', '-') # 替换冒号 - # pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, - # reportname=reportname, sqlitedb=sqlitedb), + push_market_value() - # logger.info('制作报告end') - # logger.info('模型训练完成') sql_inset_predict(global_config) + 模型报告 + logger.info('制作报告ing') + title = f'{settings}--{end_time}-预测报告' # 报告标题 + reportname = f'聚烯烃PP大模型月度预测--{end_time}.pdf' # 报告文件名 + reportname = reportname.replace(':', '-') # 替换冒号 + pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, + reportname=reportname, sqlitedb=sqlitedb), + + logger.info('制作报告end') + logger.info('模型训练完成') + + # 图片报告 + logger.info('图片报告ing') + pp_bdwd_png(global_config=global_config) + logger.info('图片报告end') + # # LSTM 单变量模型 # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index 739379a..b803cb4 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -6,7 +6,7 @@ import seaborn as sns import matplotlib.pyplot as plt import matplotlib.dates as mdates import datetime -from lib.tools import Graphs, mse, rmse, mae, exception_logger +from lib.tools import Graphs, find_best_models, mse, plot_pp_predict_result, rmse, mae, exception_logger from lib.tools import save_to_database, get_week_date from lib.dataread import * from neuralforecast import NeuralForecast @@ -165,7 +165,8 @@ def ex_Model(df, horizon, input_size, train_steps, val_check_steps, early_stop_p # VanillaTransformer(h=horizon, input_size=input_size, max_steps=train_steps, val_check_steps=val_check_steps, scaler_type='standard', ), //报错了 # Autoformer(h=horizon, input_size=input_size, max_steps=train_steps, val_check_steps=val_check_steps, scaler_type='standard', ), //报错了 - NBEATS(h=horizon, input_size=input_size, max_steps=train_steps, val_check_steps=val_check_steps, scaler_type='standard', ), + NBEATS(h=horizon, input_size=input_size, max_steps=train_steps, + val_check_steps=val_check_steps, scaler_type='standard', ), # NBEATSx(h=horizon, input_size=input_size, max_steps=train_steps, val_check_steps=val_check_steps, scaler_type='standard',activation='ReLU', ), //报错 # HINT(h=horizon), @@ -2359,7 +2360,8 @@ def brent_export_pdf(num_indicators=475, num_models=21, num_dayindicator=202, in stime = df3['ds'].iloc[0] etime = df3['ds'].iloc[-1] # 添加偏差率表格 - fivemodels = '、'.join(eval_df['模型(Model)'].values[:5]) # 字符串形式,后面写入字符串使用 + fivemodels = '、'.join( + eval_df['模型(Model)'].values[:5]) # 字符串形式,后面写入字符串使用 content.append(Graphs.draw_text( f'预测使用了{num_models}个模型进行训练,使用评估结果MAE前五的模型分别是 {fivemodels} ,模型上一预测区间 {stime} -- {etime}的偏差率(%)分别是:')) # # 添加偏差率表格 @@ -2370,7 +2372,8 @@ def brent_export_pdf(num_indicators=475, num_models=21, num_dayindicator=202, in content.append(Graphs.draw_table(col_width, *data)) content.append(Graphs.draw_little_title('上一周预测准确率:')) - df4 = sqlitedb.select_data('accuracy_rote', order_by='结束日期 desc', limit=1) + df4 = sqlitedb.select_data( + 'accuracy_rote', order_by='结束日期 desc', limit=1) df4 = df4.T df4 = df4.reset_index() df4 = df4.T @@ -3524,6 +3527,17 @@ def pp_export_pdf(num_indicators=475, num_models=21, num_dayindicator=202, input print(f"请求超时: {e}") +@exception_logger +def pp_bdwd_png(global_config): + best_bdwd_price = find_best_models( + date=global_config['end_time'], global_config=global_config) + # y_hat = pd.DataFrame(best_bdwd_price).T[['date', 'predictresult']][-4:] + y_hat = pd.DataFrame(best_bdwd_price).T[['date', 'predictresult']] + y_hat['ds'] = pd.to_datetime(y_hat['date']) + # 绘制PP期货预测结果的图表 + plot_pp_predict_result(y_hat, global_config) + + def pp_export_pdf_v1(num_indicators=475, num_models=21, num_dayindicator=202, inputsize=5, dataset='dataset', time='2024-07-30', reportname='report.pdf'): global y # 创建内容对应的空列表