@ -29,6 +29,7 @@ global_config.update({
' test_size ' : test_size ,
' modelsindex ' : modelsindex ,
' rote ' : rote ,
' bdwd_items ' : bdwd_items ,
# 特征工程开关
' is_del_corr ' : is_del_corr ,
@ -36,6 +37,7 @@ global_config.update({
' is_eta ' : is_eta ,
' is_update_eta ' : is_update_eta ,
' is_fivemodels ' : is_fivemodels ,
' is_update_predict_value ' : is_update_predict_value ,
' early_stop_patience_steps ' : early_stop_patience_steps ,
# 时间参数
@ -112,243 +114,296 @@ def predict_main():
返回 :
None
"""
end_time = global_config [ ' end_time ' ]
# 获取数据
if is_eta :
logger . info ( ' 从eta获取数据... ' )
signature = BinanceAPI ( APPID , SECRET )
etadata = EtaReader ( signature = signature ,
classifylisturl = global_config [ ' classifylisturl ' ] ,
classifyidlisturl = global_config [ ' classifyidlisturl ' ] ,
edbcodedataurl = global_config [ ' edbcodedataurl ' ] ,
edbcodelist = global_config [ ' edbcodelist ' ] ,
edbdatapushurl = global_config [ ' edbdatapushurl ' ] ,
edbdeleteurl = global_config [ ' edbdeleteurl ' ] ,
edbbusinessurl = global_config [ ' edbbusinessurl ' ] ,
classifyId = global_config [ ' ClassifyId ' ] ,
)
df_zhibiaoshuju , df_zhibiaoliebiao = etadata . get_eta_api_yuanyou_data (
data_set = data_set , dataset = dataset ) # 原始数据,未处理
# end_time = global_config['end_time']
# # 获取数据
# if is_eta :
# logger.info('从eta获取数据...')
# signature = BinanceAPI(APPID, SECRET)
# etadata = EtaReader(signature=signature,
# classifylisturl=global_config['classifylisturl'],
# classifyidlisturl=global_config['classifyidlisturl'],
# edbcodedataurl=global_config['edbcodedataurl'],
# edbcodelist=global_config['edbcodelist'],
# edbdatapushurl=global_config['edbdatapushurl'],
# edbdeleteurl=global_config['edbdeleteurl'],
# edbbusinessurl=global_config['edbbusinessurl'],
# classifyId=global_config['ClassifyId'],
# )
# df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(
# data_set=data_set, dataset=dataset) # 原始数据,未处理
if is_market :
logger . info ( ' 从市场信息平台获取数据... ' )
try :
# 如果是测试环境, 最高价最低价取excel文档
if server_host == ' 192.168.100.53 ' :
logger . info ( ' 从excel文档获取最高价最低价 ' )
df_zhibiaoshuju = get_high_low_data ( df_zhibiaoshuju )
else :
logger . info ( ' 从市场信息平台获取数据 ' )
df_zhibiaoshuju = get_market_data (
end_time , df_zhibiaoshuju )
except :
logger . info ( ' 最高最低价拼接失败 ' )
# 保存到xlsx文件的sheet表
with pd . ExcelWriter ( os . path . join ( dataset , data_set ) ) as file :
df_zhibiaoshuju . to_excel ( file , sheet_name = ' 指标数据 ' , index = False )
df_zhibiaoliebiao . to_excel ( file , sheet_name = ' 指标列表 ' , index = False )
# 数据处理
df = datachuli ( df_zhibiaoshuju , df_zhibiaoliebiao , y = y , dataset = dataset , add_kdj = add_kdj , is_timefurture = is_timefurture ,
end_time = end_time )
else :
# 读取数据
logger . info ( ' 读取本地数据: ' + os . path . join ( dataset , data_set ) )
df , df_zhibiaoliebiao = getdata ( filename = os . path . join ( dataset , data_set ) , y = y , dataset = dataset , add_kdj = add_kdj ,
is_timefurture = is_timefurture , end_time = end_time ) # 原始数据,未处理
# 更改预测列名称
df . rename ( columns = { y : ' y ' } , inplace = True )
if is_edbnamelist :
df = df [ edbnamelist ]
df . to_csv ( os . path . join ( dataset , ' 指标数据.csv ' ) , index = False )
# 保存最新日期的y值到数据库
# 取第一行数据存储到数据库中
first_row = df [ [ ' ds ' , ' y ' ] ] . tail ( 1 )
# 判断y的类型是否为float
if not isinstance ( first_row [ ' y ' ] . values [ 0 ] , float ) :
logger . info ( f ' { end_time } 预测目标数据为空,跳过 ' )
return None
# 将最新真实值保存到数据库
if not sqlitedb . check_table_exists ( ' trueandpredict ' ) :
first_row . to_sql ( ' trueandpredict ' , sqlitedb . connection , index = False )
else :
for row in first_row . itertuples ( index = False ) :
row_dict = row . _asdict ( )
config . logger . info ( f ' 要保存的真实值: { row_dict } ' )
# 判断ds是否为字符串类型,如果不是则转换为字符串类型
if isinstance ( row_dict [ ' ds ' ] , ( pd . Timestamp , datetime . datetime ) ) :
row_dict [ ' ds ' ] = row_dict [ ' ds ' ] . strftime ( ' % Y- % m- %d ' )
elif not isinstance ( row_dict [ ' ds ' ] , str ) :
try :
row_dict [ ' ds ' ] = pd . to_datetime (
row_dict [ ' ds ' ] ) . strftime ( ' % Y- % m- %d ' )
except :
logger . warning ( f " 无法解析的时间格式: { row_dict [ ' ds ' ] } " )
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
check_query = sqlitedb . select_data (
' trueandpredict ' , where_condition = f " ds = ' { row . ds } ' " )
if len ( check_query ) > 0 :
set_clause = " , " . join (
[ f " { key } = ' { value } ' " for key , value in row_dict . items ( ) ] )
sqlitedb . update_data (
' trueandpredict ' , set_clause , where_condition = f " ds = ' { row . ds } ' " )
continue
sqlitedb . insert_data ( ' trueandpredict ' , tuple (
row_dict . values ( ) ) , columns = row_dict . keys ( ) )
# 更新accuracy表的y值
if not sqlitedb . check_table_exists ( ' accuracy ' ) :
pass
else :
update_y = sqlitedb . select_data (
' accuracy ' , where_condition = " y is null " )
if len ( update_y ) > 0 :
logger . info ( ' 更新accuracy表的y值 ' )
# 找到update_y 中ds且df中的y的行
update_y = update_y [ update_y [ ' ds ' ] < = end_time ]
logger . info ( f ' 要更新y的信息: { update_y } ' )
# if is_market:
# logger.info('从市场信息平台获取数据...')
# try:
for row in update_y . itertuples ( index = False ) :
try :
row_dict = row . _asdict ( )
yy = df [ df [ ' ds ' ] == row_dict [ ' ds ' ] ] [ ' y ' ] . values [ 0 ]
LOW = df [ df [ ' ds ' ] == row_dict [ ' ds ' ] ] [ ' Brentzdj ' ] . values [ 0 ]
HIGH = df [ df [ ' ds ' ] == row_dict [ ' ds ' ] ] [ ' Brentzgj ' ] . values [ 0 ]
sqlitedb . update_data (
' accuracy ' , f " y = { yy } ,LOW_PRICE = { LOW } ,HIGH_PRICE = { HIGH } " , where_condition = f " ds = ' { row_dict [ ' ds ' ] } ' " )
except :
logger . info ( f ' 更新accuracy表的y值失败: { row_dict } ' )
# except Exception as e:
# logger.info(f'更新accuracy表的y值失败: {e}')
# # 如果是测试环境, 最高价最低价取excel文档
# if server_host == '192.168.100.53':
# logger.info('从excel文档获取最高价最低价')
# df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju)
# else:
# logger.info('从市场信息平台获取数据')
# df_zhibiaoshuju = get_market_data(
# end_time, df_zhibiaoshuju)
# 判断当前日期是不是周一
is_weekday = datetime . datetime . now ( ) . weekday ( ) == 0
if is_weekday :
logger . info ( ' 今天是周一,更新预测模型 ' )
# 计算最近60天预测残差最低的模型名称
model_results = sqlitedb . select_data (
' trueandpredict ' , order_by = " ds DESC " , limit = " 60 " )
# 删除空值率为90%以上的列
if len ( model_results ) > 10 :
model_results = model_results . dropna (
thresh = len ( model_results ) * 0.1 , axis = 1 )
# 删除空行
model_results = model_results . dropna ( )
modelnames = model_results . columns . to_list ( ) [ 2 : - 1 ]
for col in model_results [ modelnames ] . select_dtypes ( include = [ ' object ' ] ) . columns :
model_results [ col ] = model_results [ col ] . astype ( np . float32 )
# 计算每个预测值与真实值之间的偏差率
for model in modelnames :
model_results [ f ' { model } _abs_error_rate ' ] = abs (
model_results [ ' y ' ] - model_results [ model ] ) / model_results [ ' y ' ]
# 获取每行对应的最小偏差率值
min_abs_error_rate_values = model_results . apply (
lambda row : row [ [ f ' { model } _abs_error_rate ' for model in modelnames ] ] . min ( ) , axis = 1 )
# 获取每行对应的最小偏差率值对应的列名
min_abs_error_rate_column_name = model_results . apply (
lambda row : row [ [ f ' { model } _abs_error_rate ' for model in modelnames ] ] . idxmin ( ) , axis = 1 )
# 将列名索引转换为列名
min_abs_error_rate_column_name = min_abs_error_rate_column_name . map (
lambda x : x . split ( ' _ ' ) [ 0 ] )
# 取出现次数最多的模型名称
most_common_model = min_abs_error_rate_column_name . value_counts ( ) . idxmax ( )
logger . info ( f " 最近60天预测残差最低的模型名称: { most_common_model } " )
# 保存结果到数据库
if not sqlitedb . check_table_exists ( ' most_model ' ) :
sqlitedb . create_table (
' most_model ' , columns = " ds datetime, most_common_model TEXT " )
sqlitedb . insert_data ( ' most_model ' , ( datetime . datetime . now ( ) . strftime (
' % Y- % m- %d % H: % M: % S ' ) , most_common_model , ) , columns = ( ' ds ' , ' most_common_model ' , ) )
# except:
# logger.info('最高最低价拼接失败')
try :
if is_weekday :
# if True:
logger . info ( ' 今天是周一,发送特征预警 ' )
# 上传预警信息到数据库
warning_data_df = df_zhibiaoliebiao . copy ( )
warning_data_df = warning_data_df [ warning_data_df [ ' 停更周期 ' ] > 3 ] [ [
' 指标名称 ' , ' 指标id ' , ' 频度 ' , ' 更新周期 ' , ' 指标来源 ' , ' 最后更新时间 ' , ' 停更周期 ' ] ]
# 重命名列名
warning_data_df = warning_data_df . rename ( columns = { ' 指标名称 ' : ' INDICATOR_NAME ' , ' 指标id ' : ' INDICATOR_ID ' , ' 频度 ' : ' FREQUENCY ' ,
' 更新周期 ' : ' UPDATE_FREQUENCY ' , ' 指标来源 ' : ' DATA_SOURCE ' , ' 最后更新时间 ' : ' LAST_UPDATE_DATE ' , ' 停更周期 ' : ' UPDATE_SUSPENSION_CYCLE ' } )
from sqlalchemy import create_engine
import urllib
global password
if ' @ ' in password :
password = urllib . parse . quote_plus ( password )
# # 保存到xlsx文件的sheet表
# with pd.ExcelWriter(os.path.join(dataset, data_set)) as file:
# df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False)
# df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False)
engine = create_engine (
f ' mysql+pymysql:// { dbusername } : { password } @ { host } : { port } / { dbname } ' )
warning_data_df [ ' WARNING_DATE ' ] = datetime . date . today ( ) . strftime (
" % Y- % m- %d % H: % M: % S " )
warning_data_df [ ' TENANT_CODE ' ] = ' T0004 '
# 插入数据之前查询表数据然后新增id列
existing_data = pd . read_sql ( f " SELECT * FROM { table_name } " , engine )
if not existing_data . empty :
max_id = existing_data [ ' ID ' ] . astype ( int ) . max ( )
warning_data_df [ ' ID ' ] = range (
max_id + 1 , max_id + 1 + len ( warning_data_df ) )
else :
warning_data_df [ ' ID ' ] = range ( 1 , 1 + len ( warning_data_df ) )
warning_data_df . to_sql (
table_name , con = engine , if_exists = ' append ' , index = False )
if is_update_warning_data :
upload_warning_info ( len ( warning_data_df ) )
except :
logger . info ( ' 上传预警信息到数据库失败 ' )
# # 数据处理
# df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture,
# end_time=end_time)
if is_corr :
df = corr_feature ( df = df )
# else:
# # 读取数据
# logger.info('读取本地数据:' + os.path.join(dataset, data_set))
# df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj,
# is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理
df1 = df . copy ( ) # 备份一下, 后面特征筛选完之后加入ds y 列用
logger . info ( f " 开始训练模型... " )
row , col = df . shape
# # 更改预测列名称
# df.rename(columns={y: 'y'}, inplace=True)
now = datetime . datetime . now ( ) . strftime ( ' % Y % m %d % H % M % S ' )
ex_Model ( df ,
horizon = global_config [ ' horizon ' ] ,
input_size = global_config [ ' input_size ' ] ,
train_steps = global_config [ ' train_steps ' ] ,
val_check_steps = global_config [ ' val_check_steps ' ] ,
early_stop_patience_steps = global_config [ ' early_stop_patience_steps ' ] ,
is_debug = global_config [ ' is_debug ' ] ,
dataset = global_config [ ' dataset ' ] ,
is_train = global_config [ ' is_train ' ] ,
is_fivemodels = global_config [ ' is_fivemodels ' ] ,
val_size = global_config [ ' val_size ' ] ,
test_size = global_config [ ' test_size ' ] ,
settings = global_config [ ' settings ' ] ,
now = now ,
etadata = global_config [ ' etadata ' ] ,
modelsindex = global_config [ ' modelsindex ' ] ,
data = data ,
is_eta = global_config [ ' is_eta ' ] ,
end_time = global_config [ ' end_time ' ] ,
)
# if is_edbnamelist:
# df = df[edbnamelist]
# df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False)
# # 保存最新日期的y值到数据库
# # 取第一行数据存储到数据库中
# first_row = df[['ds', 'y']].tail(1)
# # 判断y的类型是否为float
# if not isinstance(first_row['y'].values[0], float):
# logger.info(f'{end_time}预测目标数据为空,跳过')
# return None
# # 将最新真实值保存到数据库
# if not sqlitedb.check_table_exists('trueandpredict'):
# first_row.to_sql('trueandpredict', sqlitedb.connection, index=False)
# else:
# for row in first_row.itertuples(index=False):
# row_dict = row._asdict()
# config.logger.info(f'要保存的真实值:{row_dict}')
# # 判断ds是否为字符串类型,如果不是则转换为字符串类型
# if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)):
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# elif not isinstance(row_dict['ds'], str):
# try:
# row_dict['ds'] = pd.to_datetime(
# row_dict['ds']).strftime('%Y-%m-%d')
# except:
# logger.warning(f"无法解析的时间格式: {row_dict['ds']}")
# # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
# check_query = sqlitedb.select_data(
# 'trueandpredict', where_condition=f"ds = '{row.ds}'")
# if len(check_query) > 0:
# set_clause = ", ".join(
# [f"{key} = '{value}'" for key, value in row_dict.items()])
# sqlitedb.update_data(
# 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'")
# continue
# sqlitedb.insert_data('trueandpredict', tuple(
# row_dict.values()), columns=row_dict.keys())
# # 更新accuracy表的y值
# if not sqlitedb.check_table_exists('accuracy'):
# pass
# else:
# update_y = sqlitedb.select_data(
# 'accuracy', where_condition="y is null")
# if len(update_y) > 0:
# logger.info('更新accuracy表的y值')
# # 找到update_y 中ds且df中的y的行
# update_y = update_y[update_y['ds'] <= end_time]
# logger.info(f'要更新y的信息: {update_y}')
# # try:
# for row in update_y.itertuples(index=False):
# try:
# row_dict = row._asdict()
# yy = df[df['ds'] == row_dict['ds']]['y'].values[0]
# LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0]
# HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0]
# sqlitedb.update_data(
# 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'")
# except:
# logger.info(f'更新accuracy表的y值失败: {row_dict}')
# # except Exception as e:
# # logger.info(f'更新accuracy表的y值失败: {e}')
# # 判断当前日期是不是周一
# is_weekday = datetime.datetime.now().weekday() == 0
# if is_weekday:
# logger.info('今天是周一,更新预测模型')
# # 计算最近60天预测残差最低的模型名称
# model_results = sqlitedb.select_data(
# 'trueandpredict', order_by="ds DESC", limit="60")
# # 删除空值率为90%以上的列
# if len(model_results) > 10:
# model_results = model_results.dropna(
# thresh=len(model_results)*0.1, axis=1)
# # 删除空行
# model_results = model_results.dropna()
# modelnames = model_results.columns.to_list()[2:-1]
# for col in model_results[modelnames].select_dtypes(include=['object']).columns:
# model_results[col] = model_results[col].astype(np.float32)
# # 计算每个预测值与真实值之间的偏差率
# for model in modelnames:
# model_results[f'{model}_abs_error_rate'] = abs(
# model_results['y'] - model_results[model]) / model_results['y']
# # 获取每行对应的最小偏差率值
# min_abs_error_rate_values = model_results.apply(
# lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
# # 获取每行对应的最小偏差率值对应的列名
# min_abs_error_rate_column_name = model_results.apply(
# lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
# # 将列名索引转换为列名
# min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(
# lambda x: x.split('_')[0])
# # 取出现次数最多的模型名称
# most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()
# logger.info(f"最近60天预测残差最低的模型名称: {most_common_model}")
# # 保存结果到数据库
# if not sqlitedb.check_table_exists('most_model'):
# sqlitedb.create_table(
# 'most_model', columns="ds datetime, most_common_model TEXT")
# sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime(
# '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',))
# try:
# if is_weekday:
# # if True:
# logger.info('今天是周一,发送特征预警')
# # 上传预警信息到数据库
# warning_data_df = df_zhibiaoliebiao.copy()
# warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[
# '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']]
# # 重命名列名
# warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY',
# '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'})
# from sqlalchemy import create_engine
# import urllib
# global password
# if '@' in password:
# password = urllib.parse.quote_plus(password)
# engine = create_engine(
# f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}')
# warning_data_df['WARNING_DATE'] = datetime.date.today().strftime(
# "%Y-%m-%d %H:%M:%S")
# warning_data_df['TENANT_CODE'] = 'T0004'
# # 插入数据之前查询表数据然后新增id列
# existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine)
# if not existing_data.empty:
# max_id = existing_data['ID'].astype(int).max()
# warning_data_df['ID'] = range(
# max_id + 1, max_id + 1 + len(warning_data_df))
# else:
# warning_data_df['ID'] = range(1, 1 + len(warning_data_df))
# warning_data_df.to_sql(
# table_name, con=engine, if_exists='append', index=False)
# if is_update_warning_data:
# upload_warning_info(len(warning_data_df))
# except:
# logger.info('上传预警信息到数据库失败')
# if is_corr:
# df = corr_feature(df=df)
# df1 = df.copy() # 备份一下, 后面特征筛选完之后加入ds y 列用
# logger.info(f"开始训练模型...")
# row, col = df.shape
# now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
# ex_Model(df,
# horizon=global_config['horizon'],
# input_size=global_config['input_size'],
# train_steps=global_config['train_steps'],
# val_check_steps=global_config['val_check_steps'],
# early_stop_patience_steps=global_config['early_stop_patience_steps'],
# is_debug=global_config['is_debug'],
# dataset=global_config['dataset'],
# is_train=global_config['is_train'],
# is_fivemodels=global_config['is_fivemodels'],
# val_size=global_config['val_size'],
# test_size=global_config['test_size'],
# settings=global_config['settings'],
# now=now,
# etadata=global_config['etadata'],
# modelsindex=global_config['modelsindex'],
# data=data,
# is_eta=global_config['is_eta'],
# end_time=global_config['end_time'],
# )
# logger.info('模型训练完成')
logger . info ( ' 训练数据绘图ing ' )
model_results3 = model_losss ( sqlitedb , end_time = end_time )
logger . info ( ' 训练数据绘图end ' )
# logger.info('训练数据绘图ing')
# model_results3 = model_losss(sqlitedb, end_time=end_time)
# logger.info('训练数据绘图end')
# # 模型报告
logger . info ( ' 制作报告ing ' )
title = f ' { settings } -- { end_time } -预测报告 ' # 报告标题
reportname = f ' Brent原油大模型日度预测-- { end_time } .pdf ' # 报告文件名
reportname = reportname . replace ( ' : ' , ' - ' ) # 替换冒号
brent_export_pdf ( dataset = dataset , num_models = 5 if is_fivemodels else 22 , time = end_time ,
reportname = reportname , sqlitedb = sqlitedb ) ,
# logger.info('制作报告ing' )
# title = f'{settings}--{end_time}-预测报告' # 报告标题
# reportname = f'Brent原油大模型日度预测--{end_time}.pdf' # 报告文件名
# reportname = reportname.replace(':', '-') # 替换冒号
# brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time ,
# reportname=reportname, sqlitedb=sqlitedb),
logger . info ( ' 制作报告end ' )
logger . info ( ' 模型训练完成 ' )
# logger.info('制作报告end')
# logger.info('模型训练完成')
logger . info ( ' 发送预测结果到市场信息平台 ' )
# 读取预测数据和模型评估数据
predict_file_path = os . path . join ( config . dataset , ' predict.csv ' )
model_eval_file_path = os . path . join ( config . dataset , ' model_evaluation.csv ' )
try :
predictdata_df = pd . read_csv ( predict_file_path )
top_models_df = pd . read_csv ( model_eval_file_path )
except FileNotFoundError as e :
logger . error ( f " 文件未找到: { e } " )
return
predictdata = predictdata_df . copy ( )
# 取模型前十
top_models = top_models_df [ ' 模型(Model) ' ] . head ( 10 ) . tolist ( )
# 计算前十模型的均值
predictdata_df [ ' top_models_mean ' ] = predictdata_df [ top_models ] . mean ( axis = 1 )
# 打印日期和前十模型均值
print ( predictdata_df [ [ ' ds ' , ' top_models_mean ' ] ] )
# 准备要推送的数据
first_date = predictdata_df [ ' ds ' ] . iloc [ 0 ] . replace ( ' - ' , ' ' )
last_date = predictdata_df [ ' ds ' ] . iloc [ - 1 ] . replace ( ' - ' , ' ' )
first_mean = predictdata_df [ ' top_models_mean ' ] . iloc [ 0 ]
last_mean = predictdata_df [ ' top_models_mean ' ] . iloc [ - 1 ]
predictdata = [
{
" dataItemNo " : global_config [ ' bdwd_items ' ] [ ' ciri ' ] ,
" dataDate " : first_date ,
" dataStatus " : " add " ,
" dataValue " : first_mean
} ,
{
" dataItemNo " : global_config [ ' bdwd_items ' ] [ ' benzhou ' ] ,
" dataDate " : last_date ,
" dataStatus " : " add " ,
" dataValue " : last_mean
}
]
print ( predictdata )
# 推送数据到市场信息平台
try :
push_market_data ( predictdata )
except Exception as e :
logger . error ( f " 推送数据失败: { e } " )
# # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)