格式化代码

This commit is contained in:
workpc 2025-03-05 09:47:02 +08:00
parent e49ab6dd02
commit fe1e99b075
4 changed files with 2006 additions and 1641 deletions

View File

@ -39,7 +39,6 @@ edbnamelist = [
]
# eta自有数据指标编码
modelsindex = {
'NHITS': 'SELF0000001',
@ -90,8 +89,7 @@ data = {
ClassifyId = 1214
############################################################################################################### 变量定义--测试环境
# 变量定义--测试环境
server_host = '192.168.100.53'
login_pushreport_url = f"http://{server_host}:8080/jingbo-dev/api/server/login"
@ -158,7 +156,7 @@ dbname = 'jingbo_test'
table_name = 'v_tbl_crude_oil_warning'
### 开关
# 开关
is_train = False # 是否训练
is_debug = False # 是否调试
is_eta = False # 是否使用eta接口
@ -174,9 +172,9 @@ is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 0 为不
is_del_tow_month = True # 是否删除两个月不更新的特征
# 连接到数据库
db_mysql = MySQLDB(host=host, user=dbusername, password=password, database=dbname)
db_mysql = MySQLDB(host=host, user=dbusername,
password=password, database=dbname)
db_mysql.connect()
print("数据库连接成功", host, dbname, dbusername)
@ -191,7 +189,7 @@ add_kdj = False # 是否添加kdj指标
if add_kdj and is_edbnamelist:
edbnamelist = edbnamelist+['K', 'D', 'J']
### 模型参数
# 模型参数
y = 'Brent连1合约价格' # 原油指标数据的目标变量 Brent连1合约价格 Brent活跃合约
horizon = 2 # 预测的步长
input_size = 12 # 输入序列长度
@ -202,16 +200,16 @@ early_stop_patience_steps = 5 # 早停的耐心步数
test_size = 100 # 测试集大小定义100后面使用的时候重新赋值
val_size = test_size # 验证集大小,同测试集大小
### 特征筛选用到的参数
# 特征筛选用到的参数
k = 100 # 特征筛选数量如果是0或者值比特征数量大代表全部特征
corr_threshold = 0.6 # 相关性大于0.6的特征
rote = 0.06 # 绘图上下界阈值
### 计算准确率
# 计算准确率
weight_dict = [0.4, 0.15, 0.1, 0.1, 0.25] # 权重
### 文件
# 文件
data_set = '原油指标数据.xlsx' # 数据集文件
dataset = 'yuanyouzhoududataset' # 数据集文件夹
@ -228,7 +226,7 @@ reportname = f'Brent原油大模型周度预测--{end_time}.pdf' # 报告文件
reportname = reportname.replace(':', '-') # 替换冒号
if end_time == '':
end_time = now
### 邮件配置
# 邮件配置
username = '1321340118@qq.com'
passwd = 'wgczgyhtyyyyjghi'
# recv=['liurui_test@163.com','52585119@qq.com']
@ -241,7 +239,7 @@ file=os.path.join(dataset,'reportname')
ssl = True
### 日志配置
# 日志配置
# 创建日志目录(如果不存在)
log_dir = 'logs'
@ -253,8 +251,10 @@ logger = logging.getLogger('my_logger')
logger.setLevel(logging.INFO)
# 配置文件处理器,将日志记录到文件
file_handler = logging.handlers.RotatingFileHandler(os.path.join(log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
file_handler = logging.handlers.RotatingFileHandler(os.path.join(
log_dir, 'pricepredict.log'), maxBytes=1024 * 1024, backupCount=5)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
# 配置控制台处理器,将日志打印到控制台
console_handler = logging.StreamHandler()
@ -265,4 +265,3 @@ logger.addHandler(file_handler)
logger.addHandler(console_handler)
# logger.info('当前配置:'+settings)

View File

@ -1,5 +1,19 @@
# 导入模块
from config_jingbo_zhoudu import *
from reportlab.lib.units import cm # 单位cm
from reportlab.graphics.shapes import Drawing # 绘图工具
from reportlab.graphics.charts.legends import Legend # 图例类
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
from reportlab.lib import colors # 颜色模块
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
from reportlab.pdfbase.ttfonts import TTFont # 字体类
from reportlab.pdfbase import pdfmetrics # 注册字体
from sklearn import metrics
from datetime import timedelta
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import datetime
@ -16,23 +30,10 @@ import json
import math
import torch
torch.set_float32_matmul_precision("high")
import matplotlib.pyplot as plt
# 设置plt显示中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
from datetime import timedelta
from sklearn import metrics
from reportlab.pdfbase import pdfmetrics # 注册字体
from reportlab.pdfbase.ttfonts import TTFont # 字体类
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
from reportlab.lib import colors # 颜色模块
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
from reportlab.graphics.charts.legends import Legend # 图例类
from reportlab.graphics.shapes import Drawing # 绘图工具
from reportlab.lib.units import cm # 单位cm
# 注册字体(提前准备好字体文件, 如果同一个文件需要多种字体可以注册多个)
pdfmetrics.registerFont(TTFont('SimSun', 'SimSun.ttf'))
@ -42,13 +43,13 @@ plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# from config_jingbo_pro import *
# from config_jingbo import *
from config_jingbo_zhoudu import *
# from config_jingbo_yuedu import *
# from config_yongan import *
# from config_juxiting import *
# from config_juxiting_zhoudu import *
# from config_juxiting_pro import *
# from config_jingbo import logger
# 定义函数
@ -147,7 +148,8 @@ def get_head_auth_report():
logger.info("获取token中...")
logger.info(f'url:{login_pushreport_url},login_data:{login_data}')
# 发送 POST 请求到登录 URL携带登录数据
login_res = requests.post(url=login_pushreport_url, json=login_data, timeout=(3, 30))
login_res = requests.post(url=login_pushreport_url,
json=login_data, timeout=(3, 30))
# 将响应内容转换为 JSON 格式
text = json.loads(login_res.text)
@ -187,7 +189,8 @@ def upload_report_data(token, upload_data):
logger.info(f"upload_data:{upload_data}")
# 发送POST请求上传报告数据
upload_res = requests.post(url=upload_url, headers=headers, json=upload_data, timeout=(3, 15))
upload_res = requests.post(
url=upload_url, headers=headers, json=upload_data, timeout=(3, 15))
# 将响应内容转换为 JSON 格式
upload_res = json.loads(upload_res.text)
@ -233,7 +236,8 @@ def upload_warning_data(warning_data):
logger.info(f"warning_data:{warning_data}")
# 发送POST请求上传预警数据
upload_res = requests.post(url=upload_warning_url, headers=headers, json=warning_data, timeout=(3, 15))
upload_res = requests.post(
url=upload_warning_url, headers=headers, json=warning_data, timeout=(3, 15))
# 如果上传成功,返回响应对象
if upload_res:
@ -244,7 +248,6 @@ def upload_warning_data(warning_data):
return None
def upload_warning_info(df_count):
"""
上传预警信息到指定的URL
@ -280,7 +283,6 @@ def upload_warning_info(df_count):
logger.error(f'上传预警信息失败:{e}')
def create_feature_last_update_time(df):
"""
计算特征停更信息用
@ -293,10 +295,12 @@ def create_feature_last_update_time(df):
df1 = df.copy()
# 找到每列的最后更新时间
df1.set_index('ds', inplace=True)
last_update_times = df1.apply(lambda x: x.dropna().index.max().strftime('%Y-%m-%d') if not x.dropna().empty else None)
last_update_times = df1.apply(lambda x: x.dropna().index.max().strftime(
'%Y-%m-%d') if not x.dropna().empty else None)
# 保存每列的最后更新时间到文件
last_update_times_df = pd.DataFrame(columns = ['feature', 'last_update_time','is_value','update_period','warning_date','stop_update_period'])
last_update_times_df = pd.DataFrame(columns=[
'feature', 'last_update_time', 'is_value', 'update_period', 'warning_date', 'stop_update_period'])
# 打印每列的最后更新时间
for column, last_update_time in last_update_times.items():
@ -309,12 +313,17 @@ def create_feature_last_update_time(df):
# 计算特征数据值的时间差
try:
# 计算预警日期
time_diff = (df1[column].dropna().index.to_series().diff().mode()[0]).total_seconds() / 3600 / 24
last_update_time_datetime = datetime.datetime.strptime(last_update_time, '%Y-%m-%d')
time_diff = (df1[column].dropna().index.to_series().diff().mode()[
0]).total_seconds() / 3600 / 24
last_update_time_datetime = datetime.datetime.strptime(
last_update_time, '%Y-%m-%d')
last_update_date = end_time if end_time != '' else datetime.datetime.now().strftime('%Y-%m-%d')
end_time_datetime = datetime.datetime.strptime(last_update_date, '%Y-%m-%d')
early_warning_date = last_update_time_datetime + timedelta(days=time_diff)*2 + timedelta(days=1)
stop_update_period = int(math.ceil((end_time_datetime-last_update_time_datetime).days / time_diff))
end_time_datetime = datetime.datetime.strptime(
last_update_date, '%Y-%m-%d')
early_warning_date = last_update_time_datetime + \
timedelta(days=time_diff)*2 + timedelta(days=1)
stop_update_period = int(
math.ceil((end_time_datetime-last_update_time_datetime).days / time_diff))
early_warning_date = early_warning_date.strftime('%Y-%m-%d')
except KeyError:
time_diff = 0
@ -325,13 +334,14 @@ def create_feature_last_update_time(df):
last_update_times_df.loc[len(last_update_times_df)] = values
logger.info(f"Column {column} was last updated at {last_update_time}")
y_last_update_time = last_update_times_df[last_update_times_df['feature']=='y']['warning_date'].values[0]
last_update_times_df.to_csv(os.path.join(dataset,'last_update_times.csv'), index=False)
y_last_update_time = last_update_times_df[last_update_times_df['feature']
== 'y']['warning_date'].values[0]
last_update_times_df.to_csv(os.path.join(
dataset, 'last_update_times.csv'), index=False)
logger.info('特征停更信息保存到文件last_update_times.csv')
return last_update_times_df, y_last_update_time
# 统计特征频度
def featurePindu(dataset):
# 读取文件
@ -377,7 +387,8 @@ def featurePindu(dataset):
pindu_dfs = pd.DataFrame()
# 根据count分组
# 输出特征频度统计
pindudict = {'1':'日度','3':'日度','7':'周度','30':'月度','90':'季度','180':'半年度','365':'年度'}
pindudict = {'1': '日度', '3': '日度', '7': '周度',
'30': '月度', '90': '季度', '180': '半年度', '365': '年度'}
for i in df.groupby('count'):
# 获取 i[1] 的索引值
index = i[1].index
@ -454,7 +465,6 @@ def featureAnalysis(df,dataset,y):
# plt.close()
def corr_feature(df):
# 重新命名列名,列名排序,y在第一个
df.reindex(['y'] + sorted(df.columns.difference(['y'])))
@ -479,7 +489,8 @@ def corr_feature(df):
# 读取滞后周期文件,更改特征
characteristic_period = pd.read_csv('dataset/特征滞后周期.csv', encoding='utf-8')
# 去掉周期为0的行
characteristic_period = characteristic_period.drop(characteristic_period[characteristic_period['滞后周期'] == 0].index)
characteristic_period = characteristic_period.drop(
characteristic_period[characteristic_period['滞后周期'] == 0].index)
for col in df.columns:
# 跳过y列
if col in ['y']:
@ -487,12 +498,12 @@ def corr_feature(df):
# 特征滞后n个周期计算与y的相关性
if col in characteristic_period['特征'].values:
# 获取特征对应的周期
period = characteristic_period[characteristic_period['特征'] == col]['滞后周期'].values[0]
period = characteristic_period[characteristic_period['特征']
== col]['滞后周期'].values[0]
# 滞后处理
df[col] = df[col].shift(period)
df.to_csv(os.path.join(dataset, '滞后处理后的数据集.csv'))
# corr_feture_noscaler = {} # 保存相关性最大的周期
# 遍历df_test的每一列计算相关性
# for col in df_noscaler.columns:
@ -582,7 +593,6 @@ def corr_feature(df):
# logger.info(max(corr_dict, key=corr_dict.get), corr_dict[max(corr_dict, key=corr_dict.get)])
# corr_feture[col] = max(corr_dict, key=corr_dict.get).split('_')[-1]
# # 结果保存到txt文件
# with open('dataset/标准化的特征滞后相关性.txt', 'w') as f:
# for key, value in corr_feture.items():
@ -649,6 +659,7 @@ def calculate_kdj(data, n=9):
# data = data.dropna()
return data
def check_column(df, col_name, two_months_ago):
'''
检查列是否需要删除
@ -681,6 +692,7 @@ def check_column(df,col_name,two_months_ago):
corresponding_date = df_check_column.iloc[-1]['ds']
return corresponding_date < two_months_ago
def datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_time='', y='y', dataset='dataset', delweekenday=False, add_kdj=False, is_timefurture=False):
'''
原油特征数据处理函数
@ -744,8 +756,10 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
logger.info(f'删除全为空值的列后数据量:{df.shape}')
df.to_csv(os.path.join(dataset, '未填充的特征数据.csv'), index=False)
# 去掉指标列表中的columns_to_drop的行
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(dataset,'特征处理后的指标名称及分类.csv'),index=False)
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(
df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(
dataset, '特征处理后的指标名称及分类.csv'), index=False)
# 数据频度分析
featurePindu(dataset=dataset)
# 向上填充
@ -765,6 +779,7 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
featureAnalysis(df, dataset=dataset, y=y)
return df
def zhoududatachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_time='', y='y', dataset='dataset', delweekenday=False, add_kdj=False, is_timefurture=False):
'''
原油特征周度数据处理函数
@ -818,8 +833,10 @@ def zhoududatachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time=''
logger.info(f'删除全为空值的列后数据量:{df.shape}')
df.to_csv(os.path.join(dataset, '未填充的特征数据.csv'), index=False)
# 去掉指标列表中的columns_to_drop的行
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(dataset,'特征处理后的指标名称及分类.csv'),index=False)
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(
df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(
dataset, '特征处理后的指标名称及分类.csv'), index=False)
# 数据频度分析
featurePindu(dataset=dataset)
# 向上填充
@ -842,8 +859,6 @@ def zhoududatachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time=''
return df
def datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_time='', y='y', dataset='dataset', delweekenday=False, add_kdj=False, is_timefurture=False):
'''
聚烯烃特征数据处理函数
@ -877,6 +892,7 @@ def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time
current_date = datetime.datetime.now()
two_months_ago = current_date - timedelta(days=40)
# 检查两月不更新的特征
def check_column(col_name):
if 'ds' in col_name or 'y' in col_name:
return False
@ -900,8 +916,10 @@ def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time
logger.info(f'删除全为空值的列后数据量:{df.shape}')
df.to_csv(os.path.join(dataset, '未填充的特征数据.csv'), index=False)
# 去掉指标列表中的columns_to_drop的行
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(dataset,'特征处理后的指标名称及分类.csv'),index=False)
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(
df.columns.tolist())]
df_zhibiaoliebiao.to_csv(os.path.join(
dataset, '特征处理后的指标名称及分类.csv'), index=False)
# 频度分析
featurePindu(dataset=dataset)
# 向上填充
@ -922,6 +940,7 @@ def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time
featureAnalysis(df, dataset=dataset, y=y)
return df
def getdata(filename, datecol='date', y='y', dataset='', add_kdj=False, is_timefurture=False, end_time=''):
logger.info('getdata接收'+filename+' '+datecol+' '+end_time)
# 判断后缀名 csv或excel
@ -932,12 +951,13 @@ def getdata(filename, datecol='date',y='y',dataset='',add_kdj=False,is_timefurtu
df_zhibiaoshuju = pd.read_excel(filename, sheet_name='指标数据')
df_zhibiaoliebiao = pd.read_excel(filename, sheet_name='指标列表')
# 日期字符串转为datatime
df = datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time)
df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, y=y, dataset=dataset,
add_kdj=add_kdj, is_timefurture=is_timefurture, end_time=end_time)
return df, df_zhibiaoliebiao
def getzhoududata(filename, datecol='date', y='y', dataset='', add_kdj=False, is_timefurture=False, end_time=''):
logger.info('getdata接收'+filename+' '+datecol+' '+end_time)
# 判断后缀名 csv或excel
@ -948,15 +968,13 @@ def getzhoududata(filename, datecol='date',y='y',dataset='',add_kdj=False,is_tim
df_zhibiaoshuju = pd.read_excel(filename, sheet_name='指标数据')
df_zhibiaoliebiao = pd.read_excel(filename, sheet_name='指标列表')
# 日期字符串转为datatime
df = zhoududatachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time)
df = zhoududatachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, y=y, dataset=dataset,
add_kdj=add_kdj, is_timefurture=is_timefurture, end_time=end_time)
return df, df_zhibiaoliebiao
def getdata_juxiting(filename, datecol='date', y='y', dataset='', add_kdj=False, is_timefurture=False, end_time=''):
logger.info('getdata接收'+filename+' '+datecol+' '+end_time)
# 判断后缀名 csv或excel
@ -968,7 +986,8 @@ def getdata_juxiting(filename, datecol='date',y='y',dataset='',add_kdj=False,is_
df_zhibiaoliebiao = pd.read_excel(filename, sheet_name='指标列表')
# 日期字符串转为datatime
df = datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time)
df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, y=y, dataset=dataset,
add_kdj=add_kdj, is_timefurture=is_timefurture, end_time=end_time)
return df, df_zhibiaoliebiao
@ -982,10 +1001,12 @@ def sanitize_filename(filename):
# 如果需要,可以添加更多替换规则
return sanitized
class BinanceAPI:
'''
获取 Binance API 请求头签名
'''
def __init__(self, APPID, SECRET):
self.APPID = APPID
self.SECRET = SECRET
@ -993,7 +1014,8 @@ class BinanceAPI:
# 生成随机字符串作为 nonce
def generate_nonce(self, length=32):
self.nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
self.nonce = ''.join(random.choices(
string.ascii_letters + string.digits, k=length))
return self.nonce
# 获取当前时间戳(秒)
@ -1016,6 +1038,7 @@ class BinanceAPI:
self.signature = self.calculate_signature(self.SECRET, self.sign_str)
# return self.signature
class Graphs:
# 绘制标题
@staticmethod
@ -1129,6 +1152,8 @@ class Graphs:
return img
# 定义样式函数
def style_row(row):
if '' in row['频度']:
return ['background-color: yellow'] * len(row)
@ -1136,7 +1161,6 @@ def style_row(row):
return ['background-color: gray'] * len(row)
class EtaReader():
def __init__(self, signature, classifylisturl, classifyidlisturl, edbcodedataurl, edbcodelist, edbdatapushurl, edbdeleteurl, edbbusinessurl):
'''
@ -1164,7 +1188,6 @@ class EtaReader():
self.edbdeleteurl = edbdeleteurl
self.edbbusinessurl = edbbusinessurl
def filter_yuanyou_data(self, ClassifyName, data):
'''
指标名称保留规则
@ -1306,8 +1329,6 @@ class EtaReader():
if any(keyword in data for keyword in ['拉丝']):
return True
# 检查需要的特征
# 去掉 期货市场 分类下的数据
if ClassifyName == '期货市场':
@ -1339,14 +1360,12 @@ class EtaReader():
else:
pass
# 保留 需求 下所有指标
if ClassifyName == '需求':
return True
else:
pass
return True
# 通过edbcode 获取指标数据
@ -1361,7 +1380,8 @@ class EtaReader():
data = response.json() # 假设接口返回的是JSON数据
all_data_items = data.get('Data')
# 列表转换为DataFrame
df3 = pd.DataFrame(all_data_items, columns=['DataTime', 'Value', 'UpdateTime'])
df3 = pd.DataFrame(all_data_items, columns=[
'DataTime', 'Value', 'UpdateTime'])
# df3 = pd.read_json(all_data_items, orient='records')
# 去掉UpdateTime 列
@ -1369,7 +1389,8 @@ class EtaReader():
# df3.set_index('DataTime')
df3.rename(columns={'Value': EdbName}, inplace=True)
# 将数据存储df1
df = pd.merge(df, df3, how='outer',on='DataTime',suffixes= ('', '_y'))
df = pd.merge(df, df3, how='outer',
on='DataTime', suffixes=('', '_y'))
# 按时间排序
df = df.sort_values(by='DataTime', ascending=True)
return df
@ -1396,7 +1417,8 @@ class EtaReader():
# 定义你的headers这里可以包含多个参数
self.headers = {
'nonce': self.signature.nonce, # 例如,一个认证令牌
'timestamp': str(self.signature.timestamp), # 自定义的header参数
# 自定义的header参数
'timestamp': str(self.signature.timestamp),
'appid': self.signature.APPID, # 另一个自定义的header参数
'signature': self.signature.signature
}
@ -1410,10 +1432,10 @@ class EtaReader():
'''
# 构建新的DataFrame df df1
df = pd.DataFrame(columns=['指标分类', '指标名称', '指标id', '频度','指标来源','来源id','最后更新时间','更新周期','预警日期','停更周期'])
df = pd.DataFrame(columns=[
'指标分类', '指标名称', '指标id', '频度', '指标来源', '来源id', '最后更新时间', '更新周期', '预警日期', '停更周期'])
df1 = pd.DataFrame(columns=['DataTime'])
# 外网环境无法访问,请确认是否为内网环境
try:
# 发送GET请求 获取指标分类列表
@ -1432,7 +1454,8 @@ class EtaReader():
fixed_value = 1214
# 遍历列表,只保留那些'category' key的值为固定值的数据项
filtered_data = [item for item in data.get('Data') if item.get('ParentId') == fixed_value]
filtered_data = [item for item in data.get(
'Data') if item.get('ParentId') == fixed_value]
# 然后循环filtered_data去获取list数据才能获取到想要获取的ClassifyId
n = 0
@ -1452,7 +1475,8 @@ class EtaReader():
for i in Data:
# s+= 1
EdbCode = i.get('EdbCode')
EdbName = i.get('EdbName') # 指标名称要保存到df2的指标名称列,df的指标名称列
# 指标名称要保存到df2的指标名称列,df的指标名称列
EdbName = i.get('EdbName')
Frequency = i.get('Frequency') # 频度要保存到df的频度列
SourceName = i.get('SourceName') # 来源名称要保存到df的频度列
Source = i.get('Source') # 来源ID要保存到df的频度列
@ -1469,9 +1493,9 @@ class EtaReader():
if Source == 2:
continue
# 判断名称是否需要保存
isSave = self.filter_yuanyou_data(ClassifyName,EdbName)
isSave = self.filter_yuanyou_data(
ClassifyName, EdbName)
if isSave:
# 保存到df
df1 = self.edbcodegetdata(df1, EdbCode, EdbName)
@ -1483,27 +1507,35 @@ class EtaReader():
logger.info(f'指标名称:{EdbName} 没有数据')
continue
try:
time_sequence = edbname_df['DataTime'].values.tolist()[-10:]
time_sequence = edbname_df['DataTime'].values.tolist(
)[-10:]
except IndexError:
time_sequence = edbname_df['DataTime'].values.tolist()
time_sequence = edbname_df['DataTime'].values.tolist(
)
# 使用Counter来统计每个星期几的出现次数
from collections import Counter
weekday_counter = Counter(datetime.datetime.strptime(time_str, "%Y-%m-%d").strftime('%A') for time_str in time_sequence)
weekday_counter = Counter(datetime.datetime.strptime(
time_str, "%Y-%m-%d").strftime('%A') for time_str in time_sequence)
# 打印出现次数最多的星期几
try:
most_common_weekday = weekday_counter.most_common(1)[0][0]
most_common_weekday = weekday_counter.most_common(1)[
0][0]
# 计算两周后的日期
warning_date = (datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d") + datetime.timedelta(weeks=2)).strftime("%Y-%m-%d")
stop_update_period = (datetime.datetime.strptime(today, "%Y-%m-%d") - datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d")).days // 7
warning_date = (datetime.datetime.strptime(
time_sequence[-1], "%Y-%m-%d") + datetime.timedelta(weeks=2)).strftime("%Y-%m-%d")
stop_update_period = (datetime.datetime.strptime(
today, "%Y-%m-%d") - datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d")).days // 7
except IndexError:
most_common_weekday = '其他'
stop_update_period = 0
if '' in Frequency:
most_common_weekday = '每天'
warning_date = (datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d") + datetime.timedelta(days=3)).strftime("%Y-%m-%d")
stop_update_period = (datetime.datetime.strptime(today, "%Y-%m-%d") - datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d")).days
warning_date = (datetime.datetime.strptime(
time_sequence[-1], "%Y-%m-%d") + datetime.timedelta(days=3)).strftime("%Y-%m-%d")
stop_update_period = (datetime.datetime.strptime(
today, "%Y-%m-%d") - datetime.datetime.strptime(time_sequence[-1], "%Y-%m-%d")).days
# 保存频度 指标名称 分类 指标id 到 df
df2 = pd.DataFrame({'指标分类': ClassifyName,
@ -1525,7 +1557,8 @@ class EtaReader():
logger.info(f'跳过指标 {EdbName}')
# 找到列表中不在指标列中的指标id保存成新的list
new_list = [item for item in self.edbcodelist if item not in df['指标id'].tolist()]
new_list = [
item for item in self.edbcodelist if item not in df['指标id'].tolist()]
logger.info(new_list)
# 遍历new_list获取指标数据保存到df1
for item in new_list:
@ -1537,7 +1570,8 @@ class EtaReader():
itemname = item
df1 = self.edbcodegetdata(df1, item, itemname)
df = pd.concat([df, pd.DataFrame({'指标分类': '其他', '指标名称': itemname, '指标id': item, '频度': '其他','指标来源':'其他','来源id':'其他'},index=[0])])
df = pd.concat([df, pd.DataFrame(
{'指标分类': '其他', '指标名称': itemname, '指标id': item, '频度': '其他', '指标来源': '其他', '来源id': '其他'}, index=[0])])
# 按时间排序
df1.sort_values('DataTime', inplace=True, ascending=False)
@ -1559,7 +1593,8 @@ class EtaReader():
# 定义你的headers这里可以包含多个参数
self.headers = {
'nonce': self.signature.nonce, # 例如,一个认证令牌
'timestamp': str(self.signature.timestamp), # 自定义的header参数
# 自定义的header参数
'timestamp': str(self.signature.timestamp),
'appid': self.signature.APPID, # 另一个自定义的header参数
'signature': self.signature.signature
}
@ -1576,7 +1611,6 @@ class EtaReader():
df = pd.DataFrame(columns=['指标分类', '指标名称', '指标id', '频度'])
df1 = pd.DataFrame(columns=['DataTime'])
# 外网环境无法访问,请确认是否为内网环境
try:
# 发送GET请求 获取指标分类列表
@ -1595,7 +1629,8 @@ class EtaReader():
fixed_value = ClassifyId
# 遍历列表,只保留那些'category' key的值为固定值的数据项
filtered_data = [item for item in data.get('Data') if item.get('ParentId') == fixed_value]
filtered_data = [item for item in data.get(
'Data') if item.get('ParentId') == fixed_value]
# 然后循环filtered_data去获取list数据才能获取到想要获取的ClassifyId
n = 0
@ -1615,7 +1650,8 @@ class EtaReader():
for i in Data:
# s+= 1
EdbCode = i.get('EdbCode')
EdbName = i.get('EdbName') # 指标名称要保存到df2的指标名称列,df的指标名称列
# 指标名称要保存到df2的指标名称列,df的指标名称列
EdbName = i.get('EdbName')
Frequency = i.get('Frequency') # 频度要保存到df的频度列
# 频度不是 日 或者 周的 跳过
if Frequency not in ['日度', '周度', '', '']:
@ -1626,7 +1662,8 @@ class EtaReader():
if isSave:
# 保存到df
# 保存频度 指标名称 分类 指标id 到 df
df2 = pd.DataFrame({'指标分类': ClassifyName, '指标名称': EdbName, '指标id': EdbCode, '频度': Frequency},index=[0])
df2 = pd.DataFrame(
{'指标分类': ClassifyName, '指标名称': EdbName, '指标id': EdbCode, '频度': Frequency}, index=[0])
# df = pd.merge(df, df2, how='outer')
df = pd.concat([df, df2])
@ -1635,7 +1672,8 @@ class EtaReader():
logger.info(f'跳过指标 {EdbName}')
# 找到列表中不在指标列中的指标id保存成新的list
new_list = [item for item in self.edbcodelist if item not in df['指标id'].tolist()]
new_list = [
item for item in self.edbcodelist if item not in df['指标id'].tolist()]
logger.info(new_list)
# 遍历new_list获取指标数据保存到df1
for item in new_list:
@ -1647,7 +1685,8 @@ class EtaReader():
itemname = item
df1 = self.edbcodegetdata(df1, item, itemname)
df = pd.concat([df, pd.DataFrame({'指标分类': '其他', '指标名称': itemname, '指标id': item, '频度': '其他'},index=[0])])
df = pd.concat([df, pd.DataFrame(
{'指标分类': '其他', '指标名称': itemname, '指标id': item, '频度': '其他'}, index=[0])])
# 按时间排序
df1.sort_values('DataTime', inplace=True, ascending=False)
@ -1673,14 +1712,16 @@ class EtaReader():
# 定义你的headers这里可以包含多个参数
self.headers = {
'nonce': self.signature.nonce, # 例如,一个认证令牌
'timestamp': str(self.signature.timestamp), # 自定义的header参数
# 自定义的header参数
'timestamp': str(self.signature.timestamp),
'appid': self.signature.APPID, # 另一个自定义的header参数
'signature': self.signature.signature
}
# 发送post请求 上传数据
logger.info(f'请求参数:{data}')
response = requests.post(self.edbdatapushurl, headers=self.headers,data=json.dumps(data))
response = requests.post(
self.edbdatapushurl, headers=self.headers, data=json.dumps(data))
# 检查响应状态码
if response.status_code == 200:
@ -1700,7 +1741,8 @@ class EtaReader():
# 定义你的headers这里可以包含多个参数
self.headers = {
'nonce': self.signature.nonce, # 例如,一个认证令牌
'timestamp': str(self.signature.timestamp), # 自定义的header参数
# 自定义的header参数
'timestamp': str(self.signature.timestamp),
'appid': self.signature.APPID, # 另一个自定义的header参数
'signature': self.signature.signature
}
@ -1709,8 +1751,8 @@ class EtaReader():
"IndexCodeList": IndexCodeList # 指标编码列表
}
# 发送post请求 上传数据
response = requests.post(self.edbdeleteurl, headers=self.headers,data=json.dumps(data))
response = requests.post(
self.edbdeleteurl, headers=self.headers, data=json.dumps(data))
# 检查响应状态码
if response.status_code == 200:
@ -1741,15 +1783,15 @@ class EtaReader():
# 定义你的headers这里可以包含多个参数
self.headers = {
'nonce': self.signature.nonce, # 例如,一个认证令牌
'timestamp': str(self.signature.timestamp), # 自定义的header参数
# 自定义的header参数
'timestamp': str(self.signature.timestamp),
'appid': self.signature.APPID, # 另一个自定义的header参数
'signature': self.signature.signature
}
# 发送post请求 上传数据
response = requests.post(self.edbbusinessurl, headers=self.headers,data=json.dumps(data))
response = requests.post(
self.edbbusinessurl, headers=self.headers, data=json.dumps(data))
# 检查响应状态码
if response.status_code == 200:
@ -1771,11 +1813,13 @@ def get_market_data(end_time,df):
# 获取token
token = get_head_auth_report()
# 定义请求参数
query_data_list_item_nos_data['data']['dateEnd'] = end_time.replace('-','')
query_data_list_item_nos_data['data']['dateEnd'] = end_time.replace(
'-', '')
# 发送请求
headers = {"Authorization": token}
logger.info('获取数据中...')
items_res = requests.post(url=query_data_list_item_nos_url, headers=headers, json=query_data_list_item_nos_data, timeout=(3, 35))
items_res = requests.post(url=query_data_list_item_nos_url, headers=headers,
json=query_data_list_item_nos_data, timeout=(3, 35))
json_data = json.loads(items_res.text)
logger.info(f"获取到的数据:{json_data}")
df3 = pd.DataFrame(json_data['data'])
@ -1799,7 +1843,8 @@ def get_market_data(end_time,df):
def get_high_low_data(df):
# 读取excel 从第五行开始
df1 = pd.read_excel(os.path.join(dataset,'数据项下载.xls'),header=5, names=['numid','date', 'Brentzdj', 'Brentzgj'])
df1 = pd.read_excel(os.path.join(dataset, '数据项下载.xls'), header=5, names=[
'numid', 'date', 'Brentzdj', 'Brentzgj'])
# 合并数据
df = pd.merge(df, df1, how='left', on='date')
return df
@ -1843,12 +1888,16 @@ def addtimecharacteristics(df,dataset):
df['is_year_end'] = df['ds'].dt.is_year_end.astype(int)
# 添加月度第几周周一到周日为一周每月1日所在的周为第一周
# 计算当前日期所在周的周一
df['current_monday'] = df['ds'] - pd.to_timedelta(df['ds'].dt.dayofweek, unit='D')
df['current_monday'] = df['ds'] - \
pd.to_timedelta(df['ds'].dt.dayofweek, unit='D')
# 计算当月1日所在周的周一
df['first_monday'] = df['ds'].dt.to_period('M').dt.start_time - pd.to_timedelta(df['ds'].dt.to_period('M').dt.start_time.dt.dayofweek, unit='D')
df['first_monday'] = df['ds'].dt.to_period('M').dt.start_time - pd.to_timedelta(
df['ds'].dt.to_period('M').dt.start_time.dt.dayofweek, unit='D')
# 计算周数差并+1得到周数
df['weekofmonth'] = ((df['current_monday'] - df['first_monday']).dt.days // 7) + 1
df['yearmonthweeks'] = df['year'].astype(str) + df['month'].astype(str) + df['weekofmonth'].astype(str)
df['weekofmonth'] = (
(df['current_monday'] - df['first_monday']).dt.days // 7) + 1
df['yearmonthweeks'] = df['year'].astype(
str) + df['month'].astype(str) + df['weekofmonth'].astype(str)
df.drop(columns=['current_monday', 'first_monday'], inplace=True)
# 去掉 quarter_start quarter
df.drop(columns=['quarter_start', 'quarter'], inplace=True)

View File

@ -1,5 +1,7 @@
# 读取配置
from lib.dataread import *
# from config_jingbo_zhoudu import *
from lib.tools import SendMail, exception_logger
from models.nerulforcastmodels import ex_Model, model_losss, model_losss_juxiting, brent_export_pdf, tansuanli_export_pdf, pp_export_pdf, model_losss_juxiting
@ -8,7 +10,6 @@ import torch
torch.set_float32_matmul_precision("high")
def predict_main():
"""
主预测函数用于从 ETA 获取数据处理数据训练模型并进行预测
@ -72,7 +73,8 @@ def predict_main():
edbdeleteurl=edbdeleteurl,
edbbusinessurl=edbbusinessurl,
)
df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(data_set=data_set, dataset=dataset) # 原始数据,未处理
df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data(
data_set=data_set, dataset=dataset) # 原始数据,未处理
if is_market:
logger.info('从市场信息平台获取数据...')
@ -83,7 +85,8 @@ def predict_main():
df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju)
else:
logger.info('从市场信息平台获取数据')
df_zhibiaoshuju = get_market_data(end_time,df_zhibiaoshuju)
df_zhibiaoshuju = get_market_data(
end_time, df_zhibiaoshuju)
except:
logger.info('最高最低价拼接失败')
@ -93,7 +96,6 @@ def predict_main():
df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False)
df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False)
# 数据处理
df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture,
end_time=end_time)
@ -126,18 +128,23 @@ def predict_main():
row_dict = row._asdict()
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d')
# row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S')
check_query = sqlitedb.select_data('trueandpredict', where_condition=f"ds = '{row.ds}'")
check_query = sqlitedb.select_data(
'trueandpredict', where_condition=f"ds = '{row.ds}'")
if len(check_query) > 0:
set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data('trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'")
set_clause = ", ".join(
[f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data(
'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'")
continue
sqlitedb.insert_data('trueandpredict', tuple(row_dict.values()), columns=row_dict.keys())
sqlitedb.insert_data('trueandpredict', tuple(
row_dict.values()), columns=row_dict.keys())
# 更新accuracy表的y值
if not sqlitedb.check_table_exists('accuracy'):
pass
else:
update_y = sqlitedb.select_data('accuracy',where_condition="y is null")
update_y = sqlitedb.select_data(
'accuracy', where_condition="y is null")
if len(update_y) > 0:
logger.info('更新accuracy表的y值')
# 找到update_y 中ds且df中的y的行
@ -150,7 +157,8 @@ def predict_main():
yy = df[df['ds'] == row_dict['ds']]['y'].values[0]
LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0]
HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0]
sqlitedb.update_data('accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'")
sqlitedb.update_data(
'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'")
except:
logger.info(f'更新accuracy表的y值失败{row_dict}')
# except Exception as e:
@ -162,10 +170,12 @@ def predict_main():
if is_weekday:
logger.info('今天是周一,更新预测模型')
# 计算最近60天预测残差最低的模型名称
model_results = sqlitedb.select_data('trueandpredict', order_by="ds DESC", limit="60")
model_results = sqlitedb.select_data(
'trueandpredict', order_by="ds DESC", limit="60")
# 删除空值率为90%以上的列
if len(model_results) > 10:
model_results = model_results.dropna(thresh=len(model_results)*0.1,axis=1)
model_results = model_results.dropna(
thresh=len(model_results)*0.1, axis=1)
# 删除空行
model_results = model_results.dropna()
modelnames = model_results.columns.to_list()[2:-1]
@ -173,20 +183,26 @@ def predict_main():
model_results[col] = model_results[col].astype(np.float32)
# 计算每个预测值与真实值之间的偏差率
for model in modelnames:
model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y']
model_results[f'{model}_abs_error_rate'] = abs(
model_results['y'] - model_results[model]) / model_results['y']
# 获取每行对应的最小偏差率值
min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
min_abs_error_rate_values = model_results.apply(
lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1)
# 获取每行对应的最小偏差率值对应的列名
min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
min_abs_error_rate_column_name = model_results.apply(
lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1)
# 将列名索引转换为列名
min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0])
min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(
lambda x: x.split('_')[0])
# 取出现次数最多的模型名称
most_common_model = min_abs_error_rate_column_name.value_counts().idxmax()
logger.info(f"最近60天预测残差最低的模型名称{most_common_model}")
# 保存结果到数据库
if not sqlitedb.check_table_exists('most_model'):
sqlitedb.create_table('most_model', columns="ds datetime, most_common_model TEXT")
sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',))
sqlitedb.create_table(
'most_model', columns="ds datetime, most_common_model TEXT")
sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime(
'%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',))
try:
if is_weekday:
@ -194,26 +210,32 @@ def predict_main():
logger.info('今天是周一,发送特征预警')
# 上传预警信息到数据库
warning_data_df = df_zhibiaoliebiao.copy()
warning_data_df = warning_data_df[warning_data_df['停更周期']> 3 ][['指标名称', '指标id', '频度','更新周期','指标来源','最后更新时间','停更周期']]
warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[
'指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']]
# 重命名列名
warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'})
warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY',
'更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'})
from sqlalchemy import create_engine
import urllib
global password
if '@' in password:
password = urllib.parse.quote_plus(password)
engine = create_engine(f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}')
warning_data_df['WARNING_DATE'] = datetime.date.today().strftime("%Y-%m-%d %H:%M:%S")
engine = create_engine(
f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}')
warning_data_df['WARNING_DATE'] = datetime.date.today().strftime(
"%Y-%m-%d %H:%M:%S")
warning_data_df['TENANT_CODE'] = 'T0004'
# 插入数据之前查询表数据然后新增id列
existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine)
if not existing_data.empty:
max_id = existing_data['ID'].astype(int).max()
warning_data_df['ID'] = range(max_id + 1, max_id + 1 + len(warning_data_df))
warning_data_df['ID'] = range(
max_id + 1, max_id + 1 + len(warning_data_df))
else:
warning_data_df['ID'] = range(1, 1 + len(warning_data_df))
warning_data_df.to_sql(table_name, con=engine, if_exists='append', index=False)
warning_data_df.to_sql(
table_name, con=engine, if_exists='append', index=False)
if is_update_warning_data:
upload_warning_info(len(warning_data_df))
except:
@ -248,7 +270,6 @@ def predict_main():
end_time=end_time,
)
logger.info('模型训练完成')
logger.info('训练数据绘图ing')

File diff suppressed because it is too large Load Diff