PriceForecast/main_juxiting.py
2025-01-03 09:32:07 +08:00

177 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 读取配置
from lib.dataread import *
from lib.tools import *
from models.nerulforcastmodels import ex_Model,model_losss,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting
import glob
import torch
torch.set_float32_matmul_precision("high")
sqlitedb = SQLiteHandler(db_name)
sqlitedb.connect()
def predict_main():
signature = BinanceAPI(APPID, SECRET)
etadata = EtaReader(signature=signature,
classifylisturl = classifylisturl,
classifyidlisturl=classifyidlisturl,
edbcodedataurl=edbcodedataurl,
edbcodelist=edbcodelist,
edbdatapushurl=edbdatapushurl,
edbdeleteurl=edbdeleteurl,
edbbusinessurl=edbbusinessurl
)
# 获取数据
if is_eta:
# eta数据
logger.info('从eta获取数据...')
signature = BinanceAPI(APPID, SECRET)
etadata = EtaReader(signature=signature,
classifylisturl = classifylisturl,
classifyidlisturl=classifyidlisturl,
edbcodedataurl=edbcodedataurl,
edbcodelist=edbcodelist,
edbdatapushurl=edbdatapushurl,
edbdeleteurl=edbdeleteurl,
edbbusinessurl=edbbusinessurl,
)
df_zhibiaoshuju,df_zhibiaoliebiao = etadata.get_eta_api_pp_data(data_set=data_set,dataset=dataset) # 原始数据,未处理
# 数据处理
df = datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time)
else:
logger.info('读取本地数据:'+os.path.join(dataset,data_set))
df = getdata_juxiting(filename=os.path.join(dataset,data_set),y=y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) # 原始数据,未处理
# 更改预测列名称
df.rename(columns={y:'y'},inplace=True)
if is_edbnamelist:
df = df[edbnamelist]
df.to_csv(os.path.join(dataset,'指标数据.csv'), index=False)
# 保存最新日期的y值到数据库
# 取第一行数据存储到数据库中
first_row = df[['ds','y']].tail(1)
# 将最新真实值保存到数据库
if not sqlitedb.check_table_exists('trueandpredict'):
first_row.to_sql('trueandpredict',sqlitedb.connection,index=False)
else:
for row in first_row.itertuples(index=False):
row_dict = row._asdict()
row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
check_query = sqlitedb.select_data('trueandpredict',where_condition = f"ds = '{row.ds}'")
if len(check_query) > 0:
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data('trueandpredict',set_clause,where_condition = f"ds = '{row.ds}'")
continue
sqlitedb.insert_data('trueandpredict',tuple(row_dict.values()),columns=row_dict.keys())
import datetime
# 判断当前日期是不是周一
is_weekday = datetime.datetime.now().weekday() == 0
if is_weekday:
logger.info('今天是周一,更新预测模型')
# 计算最近20天预测残差最低的模型名称
model_results = sqlitedb.select_data('trueandpredict',order_by = "ds DESC",limit = "20")
# 删除空值率为40%以上的列,删除空行
model_results = model_results.dropna(thresh=len(model_results)*0.6,axis=1)
model_results = model_results.dropna()
modelnames = model_results.columns.to_list()[2:]
for col in model_results[modelnames].select_dtypes(include=['object']).columns:
model_results[col] = model_results[col].astype(np.float32)
# 计算每个预测值与真实值之间的偏差率
for model in modelnames:
model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y']
# 获取每行对应的最小偏差率值
min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
# 获取每行对应的最小偏差率值对应的列名
min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
# 将列名索引转换为列名
min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0])
# 取出现次数最多的模型名称
most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()
logger.info(f"最近20天预测残差最低的模型名称{most_common_model}")
# 保存结果到数据库
if not sqlitedb.check_table_exists('most_model'):
sqlitedb.create_table('most_model',columns="ds datetime, most_common_model TEXT")
sqlitedb.insert_data('most_model',(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),most_common_model,),columns=('ds','most_common_model',))
if is_corr:
df = corr_feature(df=df)
df1 = df.copy() # 备份一下后面特征筛选完之后加入ds y 列用
logger.info(f"开始训练模型...")
row,col = df.shape
now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
ex_Model(df,
horizon=horizon,
input_size=input_size,
train_steps=train_steps,
val_check_steps=val_check_steps,
early_stop_patience_steps=early_stop_patience_steps,
is_debug=is_debug,
dataset=dataset,
is_train=is_train,
is_fivemodels=is_fivemodels,
val_size=val_size,
test_size=test_size,
settings=settings,
now=now,
etadata = etadata,
modelsindex = modelsindex,
data = data,
is_eta=is_eta,
end_time=end_time,
)
logger.info('模型训练完成')
# # 模型评估
logger.info('训练数据绘图ing')
model_results3 = model_losss_juxiting(sqlitedb)
logger.info('训练数据绘图end')
# 模型报告
logger.info('制作报告ing')
title = f'{settings}--{now}-预测报告' # 报告标题
pp_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
reportname=reportname,sqlitedb=sqlitedb),
logger.info('制作报告end')
logger.info('模型训练完成')
# tansuanli_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,end_time=end_time,reportname=reportname)
# # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
# # lstm 多变量模型
# ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset)
# # GRU 模型
# # ex_GRU(df)
# 发送邮件
m = SendMail(
username=username,
passwd=passwd,
recv=recv,
title=title,
content=content,
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
ssl=ssl,
)
m.send_mail()
if __name__ == '__main__':
predict_main()