原油月度调试通过

This commit is contained in:
workpc 2025-03-06 14:59:18 +08:00
parent f1fe4ec943
commit 765ca10b5b
4 changed files with 531 additions and 448 deletions

View File

@ -2,7 +2,7 @@ import logging
import os import os
import logging.handlers import logging.handlers
import datetime import datetime
from lib.tools import MySQLDB,SQLiteHandler from lib.tools import MySQLDB, SQLiteHandler
# eta 接口token # eta 接口token
@ -10,66 +10,65 @@ APPID = "XNLDvxZHHugj7wJ7"
SECRET = "iSeU4s6cKKBVbt94htVY1p0sqUMqb2xa" SECRET = "iSeU4s6cKKBVbt94htVY1p0sqUMqb2xa"
# eta 接口url # eta 接口url
sourcelisturl = 'http://10.189.2.78:8108/v1/edb/source/list' sourcelisturl = 'http://10.189.2.78:8108/v1/edb/source/list'
classifylisturl = 'http://10.189.2.78:8108/v1/edb/classify/list?ClassifyType=' classifylisturl = 'http://10.189.2.78:8108/v1/edb/classify/list?ClassifyType='
uniquecodedataurl = 'http://10.189.2.78:8108/v1/edb/data?UniqueCode=4991c37becba464609b409909fe4d992&StartDate=2024-02-01' uniquecodedataurl = 'http://10.189.2.78:8108/v1/edb/data?UniqueCode=4991c37becba464609b409909fe4d992&StartDate=2024-02-01'
classifyidlisturl = 'http://10.189.2.78:8108/v1/edb/list?ClassifyId=' classifyidlisturl = 'http://10.189.2.78:8108/v1/edb/list?ClassifyId='
edbcodedataurl = 'http://10.189.2.78:8108/v1/edb/data?EdbCode=' edbcodedataurl = 'http://10.189.2.78:8108/v1/edb/data?EdbCode='
edbdatapushurl = 'http://10.189.2.78:8108/v1/edb/push' edbdatapushurl = 'http://10.189.2.78:8108/v1/edb/push'
edbdeleteurl = 'http://10.189.2.78:8108/v1/edb/business/edb/del' edbdeleteurl = 'http://10.189.2.78:8108/v1/edb/business/edb/del'
edbbusinessurl = 'http://10.189.2.78:8108/v1/edb/business/data/del' edbbusinessurl = 'http://10.189.2.78:8108/v1/edb/business/data/del'
edbcodelist = ['CO1 Comdty', 'ovx index', 'C2404194834', 'C2404199738', 'dxy curncy', 'C2403128043', 'C2403150124', edbcodelist = ['CO1 Comdty', 'ovx index', 'C2404194834', 'C2404199738', 'dxy curncy', 'C2403128043', 'C2403150124',
'DOESCRUD Index', 'WTRBM1 EEGC Index', 'FVHCM1 INDEX', 'doedtprd index', 'CFFDQMMN INDEX', 'DOESCRUD Index', 'WTRBM1 EEGC Index', 'FVHCM1 INDEX', 'doedtprd index', 'CFFDQMMN INDEX',
'C2403083739', 'C2404167878', 'C2403250571', 'lmcads03 lme comdty', 'GC1 COMB Comdty', 'C2403083739', 'C2404167878', 'C2403250571', 'lmcads03 lme comdty', 'GC1 COMB Comdty',
'C2404171822','C2404167855', 'C2404171822', 'C2404167855',
# 'W000825','W000826','G.IPE', # 美国汽柴油 # 'W000825','W000826','G.IPE', # 美国汽柴油
# 'S5131019','ID00135604','FSGAM1 Index','S5120408','ID00136724', # 新加坡汽柴油 # 'S5131019','ID00135604','FSGAM1 Index','S5120408','ID00136724', # 新加坡汽柴油
] ]
# 临时写死用指定的列,与上面的edbcode对应后面更改 # 临时写死用指定的列,与上面的edbcode对应后面更改
edbnamelist = [ edbnamelist = [
'ds','y', 'ds', 'y',
'Brent c1-c6','Brent c1-c3','Brent-WTI','美国商业原油库存', 'Brent c1-c6', 'Brent c1-c3', 'Brent-WTI', '美国商业原油库存',
'DFL','美国汽油裂解价差','ovx index','dxy curncy','lmcads03 lme comdty', 'DFL', '美国汽油裂解价差', 'ovx index', 'dxy curncy', 'lmcads03 lme comdty',
'C2403128043','C2403150124','FVHCM1 INDEX','doedtprd index','CFFDQMMN INDEX', 'C2403128043', 'C2403150124', 'FVHCM1 INDEX', 'doedtprd index', 'CFFDQMMN INDEX',
'C2403083739','C2404167878', 'C2403083739', 'C2404167878',
'GC1 COMB Comdty','C2404167855', 'GC1 COMB Comdty', 'C2404167855',
# 'A汽油价格','W000826','ICE柴油价格', # 'A汽油价格','W000826','ICE柴油价格',
# '新加坡(含硫0.05%) 柴油现货价','柴油10ppm国际市场FOB中间价新加坡','Bloomberg Commodity Fair Value Singapore Mogas 92 Swap Month 1','97#汽油FOB新加坡现货价','无铅汽油97#国际市场FOB中间价新加坡' # '新加坡(含硫0.05%) 柴油现货价','柴油10ppm国际市场FOB中间价新加坡','Bloomberg Commodity Fair Value Singapore Mogas 92 Swap Month 1','97#汽油FOB新加坡现货价','无铅汽油97#国际市场FOB中间价新加坡'
] ]
# eta自有数据指标编码 # eta自有数据指标编码
modelsindex = { modelsindex = {
'NHITS': 'SELF0000001', 'NHITS': 'SELF0000001',
'Informer':'SELF0000057', 'Informer': 'SELF0000057',
'LSTM':'SELF0000058', 'LSTM': 'SELF0000058',
'iTransformer':'SELF0000059', 'iTransformer': 'SELF0000059',
'TSMixer':'SELF0000060', 'TSMixer': 'SELF0000060',
'TSMixerx':'SELF0000061', 'TSMixerx': 'SELF0000061',
'PatchTST':'SELF0000062', 'PatchTST': 'SELF0000062',
'RNN':'SELF0000063', 'RNN': 'SELF0000063',
'GRU':'SELF0000064', 'GRU': 'SELF0000064',
'TCN':'SELF0000065', 'TCN': 'SELF0000065',
'BiTCN':'SELF0000066', 'BiTCN': 'SELF0000066',
'DilatedRNN':'SELF0000067', 'DilatedRNN': 'SELF0000067',
'MLP':'SELF0000068', 'MLP': 'SELF0000068',
'DLinear':'SELF0000069', 'DLinear': 'SELF0000069',
'NLinear':'SELF0000070', 'NLinear': 'SELF0000070',
'TFT':'SELF0000071', 'TFT': 'SELF0000071',
'FEDformer':'SELF0000072', 'FEDformer': 'SELF0000072',
'StemGNN':'SELF0000073', 'StemGNN': 'SELF0000073',
'MLPMultivariate':'SELF0000074', 'MLPMultivariate': 'SELF0000074',
'TiDE':'SELF0000075', 'TiDE': 'SELF0000075',
'DeepNPTS':'SELF0000076' 'DeepNPTS': 'SELF0000076'
} }
# eta 上传预测结果的请求体,后面发起请求的时候更改 model datalist 数据 # eta 上传预测结果的请求体,后面发起请求的时候更改 model datalist 数据
data = { data = {
"IndexCode": "", "IndexCode": "",
"IndexName": "价格预测模型", "IndexName": "价格预测模型",
"Unit": "", "Unit": "",
"Frequency": "日度", "Frequency": "日度",
"SourceName": f"价格预测", "SourceName": f"价格预测",
"Remark": 'ddd', "Remark": 'ddd',
@ -79,19 +78,18 @@ data = {
"Value": 333444 "Value": 333444
} }
] ]
} }
# eta 分类 # eta 分类
# level3才可以获取到数据所以需要人工把能源化工下所有的level3级都找到 # level3才可以获取到数据所以需要人工把能源化工下所有的level3级都找到
# url = 'http://10.189.2.78:8108/v1/edb/list?ClassifyId=1214' # url = 'http://10.189.2.78:8108/v1/edb/list?ClassifyId=1214'
#ParentId ":1160, 能源化工 # ParentId ":1160, 能源化工
# ClassifyId ":1214,原油 # ClassifyId ":1214,原油
#ParentId ":1214,",就是原油下所有的数据。 # ParentId ":1214,",就是原油下所有的数据。
ClassifyId = 1214 ClassifyId = 1214
# 变量定义--测试环境
############################################################################################################### 变量定义--测试环境
server_host = '192.168.100.53' server_host = '192.168.100.53'
login_pushreport_url = f"http://{server_host}:8080/jingbo-dev/api/server/login" login_pushreport_url = f"http://{server_host}:8080/jingbo-dev/api/server/login"
@ -103,7 +101,7 @@ login_data = {
"data": { "data": {
"account": "api_test", "account": "api_test",
# "password": "MmVmNzNlOWI0MmY0ZDdjZGUwNzE3ZjFiMDJiZDZjZWU=", # Shihua@123456 # "password": "MmVmNzNlOWI0MmY0ZDdjZGUwNzE3ZjFiMDJiZDZjZWU=", # Shihua@123456
"password": "ZTEwYWRjMzk0OWJhNTlhYmJlNTZlMDU3ZjIwZjg4M2U=", # 123456 "password": "ZTEwYWRjMzk0OWJhNTlhYmJlNTZlMDU3ZjIwZjg4M2U=", # 123456
"tenantHashCode": "8a4577dbd919675758d57999a1e891fe", "tenantHashCode": "8a4577dbd919675758d57999a1e891fe",
"terminal": "API" "terminal": "API"
}, },
@ -112,39 +110,39 @@ login_data = {
} }
upload_data = { upload_data = {
"funcModule":'研究报告信息', "funcModule": '研究报告信息',
"funcOperation":'上传原油价格预测报告', "funcOperation": '上传原油价格预测报告',
"data":{ "data": {
"ownerAccount":'arui', #报告所属用户账号 "ownerAccount": 'arui', # 报告所属用户账号
"reportType":'OIL_PRICE_FORECAST', # 报告类型固定为OIL_PRICE_FORECAST "reportType": 'OIL_PRICE_FORECAST', # 报告类型固定为OIL_PRICE_FORECAST
"fileName": '2000-40-5-50--100-原油指标数据.xlsx-Brent活跃合约--2024-09-06-15-01-29-预测报告.pdf', #文件名称 "fileName": '2000-40-5-50--100-原油指标数据.xlsx-Brent活跃合约--2024-09-06-15-01-29-预测报告.pdf', # 文件名称
"fileBase64": '' ,#文件内容base64 "fileBase64": '', # 文件内容base64
"categoryNo":'yyjgycbg', # 研究报告分类编码 "categoryNo": 'yyjgycbg', # 研究报告分类编码
"smartBusinessClassCode":'YCJGYCBG', #分析报告分类编码 "smartBusinessClassCode": 'YCJGYCBG', # 分析报告分类编码
"reportEmployeeCode":"E40116", # 报告人 "reportEmployeeCode": "E40116", # 报告人
"reportDeptCode" :"D0044" ,# 报告部门 "reportDeptCode": "D0044", # 报告部门
"productGroupCode":"RAW_MATERIAL" # 商品分类 "productGroupCode": "RAW_MATERIAL" # 商品分类
} }
} }
warning_data = { warning_data = {
"funcModule":'原油特征停更预警', "funcModule": '原油特征停更预警',
"funcOperation":'原油特征停更预警', "funcOperation": '原油特征停更预警',
"data":{ "data": {
'WARNING_TYPE_NAME':'特征数据停更预警', 'WARNING_TYPE_NAME': '特征数据停更预警',
'WARNING_CONTENT':'', 'WARNING_CONTENT': '',
'WARNING_DATE':'' 'WARNING_DATE': ''
} }
} }
query_data_list_item_nos_data = { query_data_list_item_nos_data = {
"funcModule": "数据项", "funcModule": "数据项",
"funcOperation": "查询", "funcOperation": "查询",
"data": { "data": {
"dateStart":"20200101", "dateStart": "20200101",
"dateEnd":"20241231", "dateEnd": "20241231",
"dataItemNoList":["Brentzdj","Brentzgj"] # 数据项编码,代表 brent最低价和最高价 "dataItemNoList": ["Brentzdj", "Brentzgj"] # 数据项编码,代表 brent最低价和最高价
} }
} }
@ -152,96 +150,96 @@ query_data_list_item_nos_data = {
# 北京环境数据库 # 北京环境数据库
host = '192.168.101.27' host = '192.168.101.27'
port = 3306 port = 3306
dbusername ='root' dbusername = 'root'
password = '123456' password = '123456'
dbname = 'jingbo_test' dbname = 'jingbo_test'
table_name = 'v_tbl_crude_oil_warning' table_name = 'v_tbl_crude_oil_warning'
### 开关 # 开关
is_train = False # 是否训练 is_train = False # 是否训练
is_debug = False # 是否调试 is_debug = True # 是否调试
is_eta = False # 是否使用eta接口 is_eta = False # 是否使用eta接口
is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 is_market = True # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效
is_timefurture = True # 是否使用时间特征 is_timefurture = True # 是否使用时间特征
is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_fivemodels = False # 是否使用之前保存的最佳的5个模型
is_edbcode = False # 特征使用edbcoding列表中的 is_edbcode = False # 特征使用edbcoding列表中的
is_edbnamelist = False # 自定义特征对应上面的edbnamelist is_edbnamelist = False # 自定义特征对应上面的edbnamelist
is_update_eta = False # 预测结果上传到eta is_update_eta = False # 预测结果上传到eta
is_update_report = False # 是否上传报告 is_update_report = False # 是否上传报告
is_update_warning_data = False # 是否上传预警数据 is_update_warning_data = False # 是否上传预警数据
is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 0 为不删除0.6 表示删除相关性小于0.6的特征 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 0 为不删除0.6 表示删除相关性小于0.6的特征
is_del_tow_month = True # 是否删除两个月不更新的特征 is_del_tow_month = True # 是否删除两个月不更新的特征
# 连接到数据库 # 连接到数据库
db_mysql = MySQLDB(host=host, user=dbusername, password=password, database=dbname) db_mysql = MySQLDB(host=host, user=dbusername,
password=password, database=dbname)
db_mysql.connect() db_mysql.connect()
print("数据库连接成功",host,dbname,dbusername) print("数据库连接成功", host, dbname, dbusername)
# 数据截取日期 # 数据截取日期
start_year = 1993 # 数据开始年份 start_year = 2005 # 数据开始年份
end_time = '' # 数据截取日期 end_time = '' # 数据截取日期
freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 freq = 'M' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 "WW" 自定义周
delweekenday = True if freq == 'B' else False # 是否删除周末数据 delweekenday = True if freq == 'B' else False # 是否删除周末数据
is_corr = False # 特征是否参与滞后领先提升相关系数 is_corr = False # 特征是否参与滞后领先提升相关系数
add_kdj = False # 是否添加kdj指标 add_kdj = False # 是否添加kdj指标
if add_kdj and is_edbnamelist: if add_kdj and is_edbnamelist:
edbnamelist = edbnamelist+['K','D','J'] edbnamelist = edbnamelist+['K', 'D', 'J']
### 模型参数 # 模型参数
y = 'Brent连1合约价格' # 原油指标数据的目标变量 Brent连1合约价格 Brent活跃合约 y = 'Brent连1合约价格' # 原油指标数据的目标变量 Brent连1合约价格 Brent活跃合约
horizon =3 # 预测的步长 horizon = 4 # 预测的步长
input_size = 9 # 输入序列长度 input_size = 16 # 输入序列长度
train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数
val_check_steps = 30 # 评估频率 val_check_steps = 30 # 评估频率
early_stop_patience_steps = 5 # 早停的耐心步数 early_stop_patience_steps = 5 # 早停的耐心步数
# --- 交叉验证用的参数 # --- 交叉验证用的参数
test_size = 100 # 测试集大小定义100后面使用的时候重新赋值 test_size = 100 # 测试集大小定义100后面使用的时候重新赋值
val_size = test_size # 验证集大小,同测试集大小 val_size = test_size # 验证集大小,同测试集大小
### 特征筛选用到的参数 # 特征筛选用到的参数
k = 100 # 特征筛选数量如果是0或者值比特征数量大代表全部特征 k = 100 # 特征筛选数量如果是0或者值比特征数量大代表全部特征
corr_threshold = 0.6 # 相关性大于0.6的特征 corr_threshold = 0.6 # 相关性大于0.6的特征
rote = 0.06 # 绘图上下界阈值 rote = 0.06 # 绘图上下界阈值
### 计算准确率 # 计算准确率
weight_dict = [0.4,0.15,0.1,0.1,0.25] # 权重 weight_dict = [0.4, 0.15, 0.1, 0.1, 0.25] # 权重
### 文件 # 文件
data_set = '原油指标数据.xlsx' # 数据集文件 data_set = '原油指标数据.xlsx' # 数据集文件
dataset = 'yuanyouyuedudataset' # 数据集文件夹 dataset = 'yuanyouyuedudataset' # 数据集文件夹
# 数据库名称 # 数据库名称
db_name = os.path.join(dataset,'jbsh_yuanyou_yuedu.db') db_name = os.path.join(dataset, 'jbsh_yuanyou_yuedu.db')
sqlitedb = SQLiteHandler(db_name) sqlitedb = SQLiteHandler(db_name)
sqlitedb.connect() sqlitedb.connect()
settings = f'{input_size}-{horizon}-{train_steps}--{k}-{data_set}-{y}' settings = f'{input_size}-{horizon}-{train_steps}--{k}-{data_set}-{y}'
# 获取日期时间 # 获取日期时间
# now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') # 获取当前日期时间 # now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') # 获取当前日期时间
now = datetime.datetime.now().strftime('%Y-%m-%d') # 获取当前日期时间 now = datetime.datetime.now().strftime('%Y-%m-%d') # 获取当前日期时间
reportname = f'Brent原油大模型预测--{end_time}.pdf' # 报告文件名 reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名
reportname = reportname.replace(':', '-') # 替换冒号 reportname = reportname.replace(':', '-') # 替换冒号
if end_time == '': if end_time == '':
end_time = now end_time = now
### 邮件配置 # 邮件配置
username='1321340118@qq.com' username = '1321340118@qq.com'
passwd='wgczgyhtyyyyjghi' passwd = 'wgczgyhtyyyyjghi'
# recv=['liurui_test@163.com','52585119@qq.com'] # recv=['liurui_test@163.com','52585119@qq.com']
recv=['liurui_test@163.com','jin.wang@chambroad.com'] recv = ['liurui_test@163.com', 'jin.wang@chambroad.com']
# recv=['liurui_test@163.com'] # recv=['liurui_test@163.com']
title='reportname' title = 'reportname'
content='brent价格预测报告请看附件' content = 'brent价格预测报告请看附件'
file=os.path.join(dataset,'reportname') file = os.path.join(dataset, 'reportname')
# file=os.path.join(dataset,'14-7-50--100-原油指标数据.xlsx-Brent连1合约价格--20240731175936-预测报告.pdf') # file=os.path.join(dataset,'14-7-50--100-原油指标数据.xlsx-Brent连1合约价格--20240731175936-预测报告.pdf')
ssl=True ssl = True
### 日志配置 # 日志配置
# 创建日志目录(如果不存在) # 创建日志目录(如果不存在)
log_dir = 'logs' log_dir = 'logs'
@ -253,8 +251,10 @@ logger = logging.getLogger('my_logger')
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# 配置文件处理器,将日志记录到文件 # 配置文件处理器,将日志记录到文件
file_handler = logging.handlers.RotatingFileHandler(os.path.join(log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5) file_handler = logging.handlers.RotatingFileHandler(os.path.join(
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 配置控制台处理器,将日志打印到控制台 # 配置控制台处理器,将日志打印到控制台
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
@ -265,4 +265,3 @@ logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# logger.info('当前配置:'+settings) # logger.info('当前配置:'+settings)

View File

@ -838,7 +838,9 @@ def datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_time='', y
df = df.resample('W', on='ds').mean().reset_index() df = df.resample('W', on='ds').mean().reset_index()
elif config.freq == 'M': elif config.freq == 'M':
# 按月取样 # 按月取样
df = df.resample('M', on='ds').mean().reset_index() if 'yearmonthweeks' in df.columns:
df.drop('yearmonthweeks', axis=1, inplace=True)
df = df.resample('ME', on='ds').mean().reset_index()
# 删除预测列空值的行 # 删除预测列空值的行
''' 工作日缺失,如果删除,会影响预测结果,导致统计准确率出错 ''' ''' 工作日缺失,如果删除,会影响预测结果,导致统计准确率出错 '''
# df = df.dropna(subset=['y']) # df = df.dropna(subset=['y'])

View File

@ -1,12 +1,66 @@
# 读取配置 # 读取配置
from lib.dataread import *
from lib.tools import SendMail,exception_logger
from models.nerulforcastmodels import ex_Model,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting
import glob from lib.dataread import *
from config_jingbo_yuedu import *
from lib.tools import SendMail, exception_logger
from models.nerulforcastmodels import ex_Model, model_losss, model_losss_juxiting, brent_export_pdf, tansuanli_export_pdf, pp_export_pdf, model_losss_juxiting
import datetime
import torch import torch
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
global_config.update({
# 核心参数
'logger': logger,
'dataset': dataset,
'y': y,
'is_debug': is_debug,
'is_train': is_train,
'is_fivemodels': is_fivemodels,
'settings': settings,
# 模型参数
'data_set': data_set,
'input_size': input_size,
'horizon': horizon,
'train_steps': train_steps,
'val_check_steps': val_check_steps,
'val_size': val_size,
'test_size': test_size,
'modelsindex': modelsindex,
'rote': rote,
# 特征工程开关
'is_del_corr': is_del_corr,
'is_del_tow_month': is_del_tow_month,
'is_eta': is_eta,
'is_update_eta': is_update_eta,
'early_stop_patience_steps': early_stop_patience_steps,
# 时间参数
'start_year': start_year,
'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"),
'freq': freq, # 保持列表结构
# 接口配置
'login_pushreport_url': login_pushreport_url,
'login_data': login_data,
'upload_url': upload_url,
'upload_warning_url': upload_warning_url,
'warning_data': warning_data,
# 查询接口
'query_data_list_item_nos_url': query_data_list_item_nos_url,
'query_data_list_item_nos_data': query_data_list_item_nos_data,
# eta 配置
'APPID': APPID,
'SECRET': SECRET,
'etadata': data,
# 数据库配置
'sqlitedb': sqlitedb,
})
def predict_main(): def predict_main():
@ -72,7 +126,8 @@ def predict_main():
edbdeleteurl=edbdeleteurl, edbdeleteurl=edbdeleteurl,
edbbusinessurl=edbbusinessurl, edbbusinessurl=edbbusinessurl,
) )
df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(data_set=data_set, dataset=dataset) # 原始数据,未处理 df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(
data_set=data_set, dataset=dataset) # 原始数据,未处理
if is_market: if is_market:
logger.info('从市场信息平台获取数据...') logger.info('从市场信息平台获取数据...')
@ -83,26 +138,26 @@ def predict_main():
df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju)
else: else:
logger.info('从市场信息平台获取数据') logger.info('从市场信息平台获取数据')
df_zhibiaoshuju = get_market_data(end_time,df_zhibiaoshuju) df_zhibiaoshuju = get_market_data(
end_time, df_zhibiaoshuju)
except :
except:
logger.info('最高最低价拼接失败') logger.info('最高最低价拼接失败')
# 保存到xlsx文件的sheet表 # 保存到xlsx文件的sheet表
with pd.ExcelWriter(os.path.join(dataset,data_set)) as file: with pd.ExcelWriter(os.path.join(dataset, data_set)) as file:
df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False)
df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False)
# 数据处理 # 数据处理
df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture,
end_time=end_time) end_time=end_time)
else: else:
# 读取数据 # 读取数据
logger.info('读取本地数据:' + os.path.join(dataset, data_set)) logger.info('读取本地数据:' + os.path.join(dataset, data_set))
df,df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj,
is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理
# 更改预测列名称 # 更改预测列名称
df.rename(columns={y: 'y'}, inplace=True) df.rename(columns={y: 'y'}, inplace=True)
@ -124,48 +179,65 @@ def predict_main():
else: else:
for row in first_row.itertuples(index=False): for row in first_row.itertuples(index=False):
row_dict = row._asdict() row_dict = row._asdict()
row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') config.logger.info(f'要保存的真实值:{row_dict}')
check_query = sqlitedb.select_data('trueandpredict', where_condition=f"ds = '{row.ds}'") # 判断ds是否为字符串类型,如果不是则转换为字符串类型
if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)):
row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
elif not isinstance(row_dict['ds'], str):
try:
row_dict['ds'] = pd.to_datetime(
row_dict['ds']).strftime('%Y-%m-%d')
except:
logger.warning(f"无法解析的时间格式: {row_dict['ds']}")
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
check_query = sqlitedb.select_data(
'trueandpredict', where_condition=f"ds = '{row.ds}'")
if len(check_query) > 0: if len(check_query) > 0:
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) set_clause = ", ".join(
sqlitedb.update_data('trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") [f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data(
'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'")
continue continue
sqlitedb.insert_data('trueandpredict', tuple(row_dict.values()), columns=row_dict.keys()) sqlitedb.insert_data('trueandpredict', tuple(
row_dict.values()), columns=row_dict.keys())
# 更新accuracy表的y值 # 更新accuracy表的y值
if not sqlitedb.check_table_exists('accuracy'): if not sqlitedb.check_table_exists('accuracy'):
pass pass
else: else:
update_y = sqlitedb.select_data('accuracy',where_condition="y is null") update_y = sqlitedb.select_data(
'accuracy', where_condition="y is null")
if len(update_y) > 0: if len(update_y) > 0:
logger.info('更新accuracy表的y值') logger.info('更新accuracy表的y值')
# 找到update_y 中ds且df中的y的行 # 找到update_y 中ds且df中的y的行
update_y = update_y[update_y['ds']<=end_time] update_y = update_y[update_y['ds'] <= end_time]
logger.info(f'要更新y的信息{update_y}') logger.info(f'要更新y的信息{update_y}')
# try: # try:
for row in update_y.itertuples(index=False): for row in update_y.itertuples(index=False):
try: try:
row_dict = row._asdict() row_dict = row._asdict()
yy = df[df['ds']==row_dict['ds']]['y'].values[0] yy = df[df['ds'] == row_dict['ds']]['y'].values[0]
LOW = df[df['ds']==row_dict['ds']]['Brentzdj'].values[0] LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0]
HIGH = df[df['ds']==row_dict['ds']]['Brentzgj'].values[0] HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0]
sqlitedb.update_data('accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") sqlitedb.update_data(
'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'")
except: except:
logger.info(f'更新accuracy表的y值失败{row_dict}') logger.info(f'更新accuracy表的y值失败{row_dict}')
# except Exception as e: # except Exception as e:
# logger.info(f'更新accuracy表的y值失败{e}') # logger.info(f'更新accuracy表的y值失败{e}')
import datetime
# 判断当前日期是不是周一 # 判断当前日期是不是周一
is_weekday = datetime.datetime.now().weekday() == 0 is_weekday = datetime.datetime.now().weekday() == 0
if is_weekday: if is_weekday:
logger.info('今天是周一,更新预测模型') logger.info('今天是周一,更新预测模型')
# 计算最近60天预测残差最低的模型名称 # 计算最近60天预测残差最低的模型名称
model_results = sqlitedb.select_data('trueandpredict', order_by="ds DESC", limit="60") model_results = sqlitedb.select_data(
'trueandpredict', order_by="ds DESC", limit="60")
# 删除空值率为90%以上的列 # 删除空值率为90%以上的列
if len(model_results) > 10: if len(model_results) > 10:
model_results = model_results.dropna(thresh=len(model_results)*0.1,axis=1) model_results = model_results.dropna(
thresh=len(model_results)*0.1, axis=1)
# 删除空行 # 删除空行
model_results = model_results.dropna() model_results = model_results.dropna()
modelnames = model_results.columns.to_list()[2:-1] modelnames = model_results.columns.to_list()[2:-1]
@ -173,47 +245,59 @@ def predict_main():
model_results[col] = model_results[col].astype(np.float32) model_results[col] = model_results[col].astype(np.float32)
# 计算每个预测值与真实值之间的偏差率 # 计算每个预测值与真实值之间的偏差率
for model in modelnames: for model in modelnames:
model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y'] model_results[f'{model}_abs_error_rate'] = abs(
model_results['y'] - model_results[model]) / model_results['y']
# 获取每行对应的最小偏差率值 # 获取每行对应的最小偏差率值
min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) min_abs_error_rate_values = model_results.apply(
lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
# 获取每行对应的最小偏差率值对应的列名 # 获取每行对应的最小偏差率值对应的列名
min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) min_abs_error_rate_column_name = model_results.apply(
lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
# 将列名索引转换为列名 # 将列名索引转换为列名
min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0]) min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(
lambda x: x.split('_')[0])
# 取出现次数最多的模型名称 # 取出现次数最多的模型名称
most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()
logger.info(f"最近60天预测残差最低的模型名称{most_common_model}") logger.info(f"最近60天预测残差最低的模型名称{most_common_model}")
# 保存结果到数据库 # 保存结果到数据库
if not sqlitedb.check_table_exists('most_model'): if not sqlitedb.check_table_exists('most_model'):
sqlitedb.create_table('most_model', columns="ds datetime, most_common_model TEXT") sqlitedb.create_table(
sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) 'most_model', columns="ds datetime, most_common_model TEXT")
sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime(
'%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',))
try: try:
if is_weekday: if is_weekday:
# if True: # if True:
logger.info('今天是周一,发送特征预警') logger.info('今天是周一,发送特征预警')
# 上传预警信息到数据库 # 上传预警信息到数据库
warning_data_df = df_zhibiaoliebiao.copy() warning_data_df = df_zhibiaoliebiao.copy()
warning_data_df = warning_data_df[warning_data_df['停更周期']> 3 ][['指标名称', '指标id', '频度','更新周期','指标来源','最后更新时间','停更周期']] warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[
'指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']]
# 重命名列名 # 重命名列名
warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY',
'更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'})
from sqlalchemy import create_engine from sqlalchemy import create_engine
import urllib import urllib
global password global password
if '@' in password: if '@' in password:
password = urllib.parse.quote_plus(password) password = urllib.parse.quote_plus(password)
engine = create_engine(f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') engine = create_engine(
warning_data_df['WARNING_DATE'] = datetime.date.today().strftime("%Y-%m-%d %H:%M:%S") f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}')
warning_data_df['TENANT_CODE'] = 'T0004' warning_data_df['WARNING_DATE'] = datetime.date.today().strftime(
"%Y-%m-%d %H:%M:%S")
warning_data_df['TENANT_CODE'] = 'T0004'
# 插入数据之前查询表数据然后新增id列 # 插入数据之前查询表数据然后新增id列
existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine)
if not existing_data.empty: if not existing_data.empty:
max_id = existing_data['ID'].astype(int).max() max_id = existing_data['ID'].astype(int).max()
warning_data_df['ID'] = range(max_id + 1, max_id + 1 + len(warning_data_df)) warning_data_df['ID'] = range(
max_id + 1, max_id + 1 + len(warning_data_df))
else: else:
warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) warning_data_df['ID'] = range(1, 1 + len(warning_data_df))
warning_data_df.to_sql(table_name, con=engine, if_exists='append', index=False) warning_data_df.to_sql(
table_name, con=engine, if_exists='append', index=False)
if is_update_warning_data: if is_update_warning_data:
upload_warning_info(len(warning_data_df)) upload_warning_info(len(warning_data_df))
except: except:
@ -228,72 +312,70 @@ def predict_main():
now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
ex_Model(df, ex_Model(df,
horizon=horizon, horizon=global_config['horizon'],
input_size=input_size, input_size=global_config['input_size'],
train_steps=train_steps, train_steps=global_config['train_steps'],
val_check_steps=val_check_steps, val_check_steps=global_config['val_check_steps'],
early_stop_patience_steps=early_stop_patience_steps, early_stop_patience_steps=global_config['early_stop_patience_steps'],
is_debug=is_debug, is_debug=global_config['is_debug'],
dataset=dataset, dataset=global_config['dataset'],
is_train=is_train, is_train=global_config['is_train'],
is_fivemodels=is_fivemodels, is_fivemodels=global_config['is_fivemodels'],
val_size=val_size, val_size=global_config['val_size'],
test_size=test_size, test_size=global_config['test_size'],
settings=settings, settings=global_config['settings'],
now=now, now=now,
etadata=etadata, etadata=global_config['etadata'],
modelsindex=modelsindex, modelsindex=global_config['modelsindex'],
data=data, data=data,
is_eta=is_eta, is_eta=global_config['is_eta'],
end_time=end_time, end_time=global_config['end_time'],
) )
logger.info('模型训练完成')
logger.info('训练数据绘图ing')
model_results3 = model_losss(sqlitedb,end_time=end_time)
logger.info('训练数据绘图end')
# 模型报告
# logger.info('制作报告ing')
# title = f'{settings}--{end_time}-预测报告' # 报告标题
# reportname = f'Brent原油大模型月度预测--{end_time}.pdf' # 报告文件名
# reportname = reportname.replace(':', '-') # 替换冒号
# brent_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time,
# reportname=reportname,sqlitedb=sqlitedb),
# logger.info('制作报告end')
# logger.info('模型训练完成') # logger.info('模型训练完成')
logger.info('训练数据绘图ing')
model_results3 = model_losss(sqlitedb, end_time=end_time)
logger.info('训练数据绘图end')
# # 模型报告
logger.info('制作报告ing')
title = f'{settings}--{end_time}-预测报告' # 报告标题
reportname = f'Brent原油大模型周度预测--{end_time}.pdf' # 报告文件名
reportname = reportname.replace(':', '-') # 替换冒号
brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time,
reportname=reportname, sqlitedb=sqlitedb),
logger.info('制作报告end')
logger.info('模型训练完成')
# # LSTM 单变量模型 # # LSTM 单变量模型
# ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset)
# # lstm 多变量模型 # # lstm 多变量模型
# ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset)
# # GRU 模型 # # GRU 模型
# # ex_GRU(df) # # ex_GRU(df)
# 发送邮件 # 发送邮件
m = SendMail( # m = SendMail(
username=username, # username=username,
passwd=passwd, # passwd=passwd,
recv=recv, # recv=recv,
title=title, # title=title,
content=content, # content=content,
file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), # file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime),
ssl=ssl, # ssl=ssl,
) # )
# m.send_mail() # m.send_mail()
if __name__ == '__main__': if __name__ == '__main__':
global end_time # global end_time
is_on = True # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期
# 遍历2024-11-25 到 2024-12-3 之间的工作日日期 # for i_time in pd.date_range('2024-12-1', '2025-2-26', freq='W'):
for i_time in pd.date_range('2022-6-1', '2025-3-1', freq='ME'): # end_time = i_time.strftime('%Y-%m-%d')
end_time = i_time.strftime('%Y-%m-%d') # predict_main()
predict_main()
# predict_main() predict_main()

View File

@ -102,235 +102,235 @@ def predict_main():
返回: 返回:
None None
""" """
# global end_time global end_time
# signature = BinanceAPI(APPID, SECRET) signature = BinanceAPI(APPID, SECRET)
# etadata = EtaReader(signature=signature, etadata = EtaReader(signature=signature,
# classifylisturl=classifylisturl, classifylisturl=classifylisturl,
# classifyidlisturl=classifyidlisturl, classifyidlisturl=classifyidlisturl,
# edbcodedataurl=edbcodedataurl, edbcodedataurl=edbcodedataurl,
# edbcodelist=edbcodelist, edbcodelist=edbcodelist,
# edbdatapushurl=edbdatapushurl, edbdatapushurl=edbdatapushurl,
# edbdeleteurl=edbdeleteurl, edbdeleteurl=edbdeleteurl,
# edbbusinessurl=edbbusinessurl edbbusinessurl=edbbusinessurl
# ) )
# # 获取数据 # 获取数据
# if is_eta: if is_eta:
# logger.info('从eta获取数据...') logger.info('从eta获取数据...')
# signature = BinanceAPI(APPID, SECRET) signature = BinanceAPI(APPID, SECRET)
# etadata = EtaReader(signature=signature, etadata = EtaReader(signature=signature,
# classifylisturl=classifylisturl, classifylisturl=classifylisturl,
# classifyidlisturl=classifyidlisturl, classifyidlisturl=classifyidlisturl,
# edbcodedataurl=edbcodedataurl, edbcodedataurl=edbcodedataurl,
# edbcodelist=edbcodelist, edbcodelist=edbcodelist,
# edbdatapushurl=edbdatapushurl, edbdatapushurl=edbdatapushurl,
# edbdeleteurl=edbdeleteurl, edbdeleteurl=edbdeleteurl,
# edbbusinessurl=edbbusinessurl, edbbusinessurl=edbbusinessurl,
# ) )
# df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(
# data_set=data_set, dataset=dataset) # 原始数据,未处理 data_set=data_set, dataset=dataset) # 原始数据,未处理
# if is_market: if is_market:
# logger.info('从市场信息平台获取数据...') logger.info('从市场信息平台获取数据...')
# try: try:
# # 如果是测试环境最高价最低价取excel文档 # 如果是测试环境最高价最低价取excel文档
# if server_host == '192.168.100.53': if server_host == '192.168.100.53':
# logger.info('从excel文档获取最高价最低价') logger.info('从excel文档获取最高价最低价')
# df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju)
# else: else:
# logger.info('从市场信息平台获取数据') logger.info('从市场信息平台获取数据')
# df_zhibiaoshuju = get_market_data( df_zhibiaoshuju = get_market_data(
# end_time, df_zhibiaoshuju) end_time, df_zhibiaoshuju)
# except: except:
# logger.info('最高最低价拼接失败') logger.info('最高最低价拼接失败')
# # 保存到xlsx文件的sheet表 # 保存到xlsx文件的sheet表
# with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: with pd.ExcelWriter(os.path.join(dataset, data_set)) as file:
# df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False)
# df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False)
# # 数据处理 # 数据处理
# df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture,
# end_time=end_time) end_time=end_time)
# else: else:
# # 读取数据 # 读取数据
# logger.info('读取本地数据:' + os.path.join(dataset, data_set)) logger.info('读取本地数据:' + os.path.join(dataset, data_set))
# df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj,
# is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理 is_timefurture=is_timefurture, end_time=end_time) # 原始数据,未处理
# # 更改预测列名称 # 更改预测列名称
# df.rename(columns={y: 'y'}, inplace=True) df.rename(columns={y: 'y'}, inplace=True)
# if is_edbnamelist: if is_edbnamelist:
# df = df[edbnamelist] df = df[edbnamelist]
# df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False)
# # 保存最新日期的y值到数据库 # 保存最新日期的y值到数据库
# # 取第一行数据存储到数据库中 # 取第一行数据存储到数据库中
# first_row = df[['ds', 'y']].tail(1) first_row = df[['ds', 'y']].tail(1)
# # 判断y的类型是否为float # 判断y的类型是否为float
# if not isinstance(first_row['y'].values[0], float): if not isinstance(first_row['y'].values[0], float):
# logger.info(f'{end_time}预测目标数据为空,跳过') logger.info(f'{end_time}预测目标数据为空,跳过')
# return None return None
# # 将最新真实值保存到数据库 # 将最新真实值保存到数据库
# if not sqlitedb.check_table_exists('trueandpredict'): if not sqlitedb.check_table_exists('trueandpredict'):
# first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) first_row.to_sql('trueandpredict', sqlitedb.connection, index=False)
# else: else:
# for row in first_row.itertuples(index=False): for row in first_row.itertuples(index=False):
# row_dict = row._asdict() row_dict = row._asdict()
# config.logger.info(f'要保存的真实值:{row_dict}') config.logger.info(f'要保存的真实值:{row_dict}')
# # 判断ds是否为字符串类型,如果不是则转换为字符串类型 # 判断ds是否为字符串类型,如果不是则转换为字符串类型
# if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)):
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# elif not isinstance(row_dict['ds'], str): elif not isinstance(row_dict['ds'], str):
# try: try:
# row_dict['ds'] = pd.to_datetime( row_dict['ds'] = pd.to_datetime(
# row_dict['ds']).strftime('%Y-%m-%d') row_dict['ds']).strftime('%Y-%m-%d')
# except: except:
# logger.warning(f"无法解析的时间格式: {row_dict['ds']}") logger.warning(f"无法解析的时间格式: {row_dict['ds']}")
# # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
# check_query = sqlitedb.select_data( check_query = sqlitedb.select_data(
# 'trueandpredict', where_condition=f"ds = '{row.ds}'") 'trueandpredict', where_condition=f"ds = '{row.ds}'")
# if len(check_query) > 0: if len(check_query) > 0:
# set_clause = ", ".join( set_clause = ", ".join(
# [f"{key} = '{value}'" for key, value in row_dict.items()]) [f"{key} = '{value}'" for key, value in row_dict.items()])
# sqlitedb.update_data( sqlitedb.update_data(
# 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'")
# continue continue
# sqlitedb.insert_data('trueandpredict', tuple( sqlitedb.insert_data('trueandpredict', tuple(
# row_dict.values()), columns=row_dict.keys()) row_dict.values()), columns=row_dict.keys())
# # 更新accuracy表的y值 # 更新accuracy表的y值
# if not sqlitedb.check_table_exists('accuracy'): if not sqlitedb.check_table_exists('accuracy'):
# pass pass
# else: else:
# update_y = sqlitedb.select_data( update_y = sqlitedb.select_data(
# 'accuracy', where_condition="y is null") 'accuracy', where_condition="y is null")
# if len(update_y) > 0: if len(update_y) > 0:
# logger.info('更新accuracy表的y值') logger.info('更新accuracy表的y值')
# # 找到update_y 中ds且df中的y的行 # 找到update_y 中ds且df中的y的行
# update_y = update_y[update_y['ds'] <= end_time] update_y = update_y[update_y['ds'] <= end_time]
# logger.info(f'要更新y的信息{update_y}') logger.info(f'要更新y的信息{update_y}')
# # try: # try:
# for row in update_y.itertuples(index=False): for row in update_y.itertuples(index=False):
# try: try:
# row_dict = row._asdict() row_dict = row._asdict()
# yy = df[df['ds'] == row_dict['ds']]['y'].values[0] yy = df[df['ds'] == row_dict['ds']]['y'].values[0]
# LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0]
# HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0]
# sqlitedb.update_data( sqlitedb.update_data(
# 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") 'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'")
# except: except:
# logger.info(f'更新accuracy表的y值失败{row_dict}') logger.info(f'更新accuracy表的y值失败{row_dict}')
# # except Exception as e: # except Exception as e:
# # logger.info(f'更新accuracy表的y值失败{e}') # logger.info(f'更新accuracy表的y值失败{e}')
# # 判断当前日期是不是周一 # 判断当前日期是不是周一
# is_weekday = datetime.datetime.now().weekday() == 0 is_weekday = datetime.datetime.now().weekday() == 0
# if is_weekday: if is_weekday:
# logger.info('今天是周一,更新预测模型') logger.info('今天是周一,更新预测模型')
# # 计算最近60天预测残差最低的模型名称 # 计算最近60天预测残差最低的模型名称
# model_results = sqlitedb.select_data( model_results = sqlitedb.select_data(
# 'trueandpredict', order_by="ds DESC", limit="60") 'trueandpredict', order_by="ds DESC", limit="60")
# # 删除空值率为90%以上的列 # 删除空值率为90%以上的列
# if len(model_results) > 10: if len(model_results) > 10:
# model_results = model_results.dropna( model_results = model_results.dropna(
# thresh=len(model_results)*0.1, axis=1) thresh=len(model_results)*0.1, axis=1)
# # 删除空行 # 删除空行
# model_results = model_results.dropna() model_results = model_results.dropna()
# modelnames = model_results.columns.to_list()[2:-1] modelnames = model_results.columns.to_list()[2:-1]
# for col in model_results[modelnames].select_dtypes(include=['object']).columns: for col in model_results[modelnames].select_dtypes(include=['object']).columns:
# model_results[col] = model_results[col].astype(np.float32) model_results[col] = model_results[col].astype(np.float32)
# # 计算每个预测值与真实值之间的偏差率 # 计算每个预测值与真实值之间的偏差率
# for model in modelnames: for model in modelnames:
# model_results[f'{model}_abs_error_rate'] = abs( model_results[f'{model}_abs_error_rate'] = abs(
# model_results['y'] - model_results[model]) / model_results['y'] model_results['y'] - model_results[model]) / model_results['y']
# # 获取每行对应的最小偏差率值 # 获取每行对应的最小偏差率值
# min_abs_error_rate_values = model_results.apply( min_abs_error_rate_values = model_results.apply(
# lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
# # 获取每行对应的最小偏差率值对应的列名 # 获取每行对应的最小偏差率值对应的列名
# min_abs_error_rate_column_name = model_results.apply( min_abs_error_rate_column_name = model_results.apply(
# lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
# # 将列名索引转换为列名 # 将列名索引转换为列名
# min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(
# lambda x: x.split('_')[0]) lambda x: x.split('_')[0])
# # 取出现次数最多的模型名称 # 取出现次数最多的模型名称
# most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()
# logger.info(f"最近60天预测残差最低的模型名称{most_common_model}") logger.info(f"最近60天预测残差最低的模型名称{most_common_model}")
# # 保存结果到数据库 # 保存结果到数据库
# if not sqlitedb.check_table_exists('most_model'): if not sqlitedb.check_table_exists('most_model'):
# sqlitedb.create_table( sqlitedb.create_table(
# 'most_model', columns="ds datetime, most_common_model TEXT") 'most_model', columns="ds datetime, most_common_model TEXT")
# sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime(
# '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',))
# try: try:
# if is_weekday: if is_weekday:
# # if True: # if True:
# logger.info('今天是周一,发送特征预警') logger.info('今天是周一,发送特征预警')
# # 上传预警信息到数据库 # 上传预警信息到数据库
# warning_data_df = df_zhibiaoliebiao.copy() warning_data_df = df_zhibiaoliebiao.copy()
# warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[
# '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']]
# # 重命名列名 # 重命名列名
# warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY',
# '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'})
# from sqlalchemy import create_engine from sqlalchemy import create_engine
# import urllib import urllib
# global password global password
# if '@' in password: if '@' in password:
# password = urllib.parse.quote_plus(password) password = urllib.parse.quote_plus(password)
# engine = create_engine( engine = create_engine(
# f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}')
# warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( warning_data_df['WARNING_DATE'] = datetime.date.today().strftime(
# "%Y-%m-%d %H:%M:%S") "%Y-%m-%d %H:%M:%S")
# warning_data_df['TENANT_CODE'] = 'T0004' warning_data_df['TENANT_CODE'] = 'T0004'
# # 插入数据之前查询表数据然后新增id列 # 插入数据之前查询表数据然后新增id列
# existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine)
# if not existing_data.empty: if not existing_data.empty:
# max_id = existing_data['ID'].astype(int).max() max_id = existing_data['ID'].astype(int).max()
# warning_data_df['ID'] = range( warning_data_df['ID'] = range(
# max_id + 1, max_id + 1 + len(warning_data_df)) max_id + 1, max_id + 1 + len(warning_data_df))
# else: else:
# warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) warning_data_df['ID'] = range(1, 1 + len(warning_data_df))
# warning_data_df.to_sql( warning_data_df.to_sql(
# table_name, con=engine, if_exists='append', index=False) table_name, con=engine, if_exists='append', index=False)
# if is_update_warning_data: if is_update_warning_data:
# upload_warning_info(len(warning_data_df)) upload_warning_info(len(warning_data_df))
# except: except:
# logger.info('上传预警信息到数据库失败') logger.info('上传预警信息到数据库失败')
# if is_corr: if is_corr:
# df = corr_feature(df=df) df = corr_feature(df=df)
# df1 = df.copy() # 备份一下后面特征筛选完之后加入ds y 列用 df1 = df.copy() # 备份一下后面特征筛选完之后加入ds y 列用
# logger.info(f"开始训练模型...") logger.info(f"开始训练模型...")
# row, col = df.shape row, col = df.shape
# now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
# ex_Model(df, ex_Model(df,
# horizon=global_config['horizon'], horizon=global_config['horizon'],
# input_size=global_config['input_size'], input_size=global_config['input_size'],
# train_steps=global_config['train_steps'], train_steps=global_config['train_steps'],
# val_check_steps=global_config['val_check_steps'], val_check_steps=global_config['val_check_steps'],
# early_stop_patience_steps=global_config['early_stop_patience_steps'], early_stop_patience_steps=global_config['early_stop_patience_steps'],
# is_debug=global_config['is_debug'], is_debug=global_config['is_debug'],
# dataset=global_config['dataset'], dataset=global_config['dataset'],
# is_train=global_config['is_train'], is_train=global_config['is_train'],
# is_fivemodels=global_config['is_fivemodels'], is_fivemodels=global_config['is_fivemodels'],
# val_size=global_config['val_size'], val_size=global_config['val_size'],
# test_size=global_config['test_size'], test_size=global_config['test_size'],
# settings=global_config['settings'], settings=global_config['settings'],
# now=now, now=now,
# etadata=global_config['etadata'], etadata=global_config['etadata'],
# modelsindex=global_config['modelsindex'], modelsindex=global_config['modelsindex'],
# data=data, data=data,
# is_eta=global_config['is_eta'], is_eta=global_config['is_eta'],
# end_time=global_config['end_time'], end_time=global_config['end_time'],
# ) )
# logger.info('模型训练完成') # logger.info('模型训练完成')