PriceForecast/lib/tools.py

450 lines
16 KiB
Python
Raw Normal View History

2024-11-01 16:38:21 +08:00
import time
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from config_jingbo import logger
from sklearn import metrics
import random, string, base64, hmac, hashlib
from reportlab.pdfbase import pdfmetrics # 注册字体
from reportlab.pdfbase.ttfonts import TTFont # 字体类
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
from reportlab.lib import colors # 颜色模块
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
from reportlab.graphics.charts.legends import Legend # 图例类
from reportlab.graphics.shapes import Drawing # 绘图工具
from reportlab.lib.units import cm # 单位cm
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import sqlite3
import tkinter as tk
from tkinter import messagebox
def timeit(func):
'''计时装饰器'''
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} 函数的执行时间为: {execution_time}")
return result
return wrapper
class BinanceAPI:
'''
获取 Binance API 请求头签名
'''
def __init__(self, APPID, SECRET):
self.APPID = APPID
self.SECRET = SECRET
self.get_signature()
# 生成随机字符串作为 nonce
def generate_nonce(self, length=32):
self.nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
return self.nonce
# 获取当前时间戳(秒)
def get_timestamp(self):
return int(time.time())
# 构建待签名字符串
def build_sign_str(self):
return f'appid={self.APPID}&nonce={self.nonce}&timestamp={self.timestamp}'
# 使用 HMAC SHA-256 计算签名
def calculate_signature(self, secret, message):
return base64.urlsafe_b64encode(hmac.new(secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()).decode('utf-8')
def get_signature(self):
# 调用上述方法生成签名
self.nonce = self.generate_nonce()
self.timestamp = self.get_timestamp()
self.sign_str = self.build_sign_str()
self.signature = self.calculate_signature(self.SECRET, self.sign_str)
# return self.signature
class Graphs:
'''
pdf生成类
'''
# 绘制标题
@staticmethod
def draw_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Heading1']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 18 # 字体大小
ct.leading = 50 # 行间距
ct.textColor = colors.green # 字体颜色
ct.alignment = 1 # 居中
ct.bold = True
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制小标题
@staticmethod
def draw_little_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Normal']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 15 # 字体大小
ct.leading = 30 # 行间距
ct.textColor = colors.red # 字体颜色
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制普通段落内容
@staticmethod
def draw_text(text: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 获取普通样式
ct = style['Normal']
ct.fontName = 'SimSun'
ct.fontSize = 12
ct.wordWrap = 'CJK' # 设置自动换行
ct.alignment = 0 # 左对齐
ct.firstLineIndent = 32 # 第一行开头空格
ct.leading = 25
return Paragraph(text, ct)
# 绘制表格
@staticmethod
def draw_table(col_width,*args):
# 列宽度
col_width = col_width
style = [
('FONTNAME', (0, 0), (-1, -1), 'SimSun'), # 字体
('FONTSIZE', (0, 0), (-1, 0), 10), # 第一行的字体大小
('FONTSIZE', (0, 1), (-1, -1), 8), # 第二行到最后一行的字体大小
('BACKGROUND', (0, 0), (-1, 0), '#d5dae6'), # 设置第一行背景颜色
('ALIGN', (0, 0), (-1, -1), 'CENTER'), # 第一行水平居中
('ALIGN', (0, 1), (-1, -1), 'LEFT'), # 第二行到最后一行左右左对齐
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), # 所有表格上下居中对齐
('TEXTCOLOR', (0, 0), (-1, -1), colors.darkslategray), # 设置表格内文字颜色
('GRID', (0, 0), (-1, -1), 0.5, colors.grey), # 设置表格框线为grey色线宽为0.5
# ('SPAN', (0, 1), (0, 2)), # 合并第一列二三行
# ('SPAN', (0, 3), (0, 4)), # 合并第一列三四行
# ('SPAN', (0, 5), (0, 6)), # 合并第一列五六行
# ('SPAN', (0, 7), (0, 8)), # 合并第一列五六行
]
table = Table(args, colWidths=col_width, style=style)
return table
# 创建图表
@staticmethod
def draw_bar(bar_data: list, ax: list, items: list):
drawing = Drawing(500, 250)
bc = VerticalBarChart()
bc.x = 45 # 整个图表的x坐标
bc.y = 45 # 整个图表的y坐标
bc.height = 200 # 图表的高度
bc.width = 350 # 图表的宽度
bc.data = bar_data
bc.strokeColor = colors.black # 顶部和右边轴线的颜色
bc.valueAxis.valueMin = 5000 # 设置y坐标的最小值
bc.valueAxis.valueMax = 26000 # 设置y坐标的最大值
bc.valueAxis.valueStep = 2000 # 设置y坐标的步长
bc.categoryAxis.labels.dx = 2
bc.categoryAxis.labels.dy = -8
bc.categoryAxis.labels.angle = 20
bc.categoryAxis.categoryNames = ax
# 图示
leg = Legend()
leg.fontName = 'SimSun'
leg.alignment = 'right'
leg.boxAnchor = 'ne'
leg.x = 475 # 图例的x坐标
leg.y = 240
leg.dxTextSpace = 10
leg.columnMaximum = 3
leg.colorNamePairs = items
drawing.add(leg)
drawing.add(bc)
return drawing
# 绘制图片
@staticmethod
def draw_img(path):
img = Image(path) # 读取指定路径下的图片
img.drawWidth = 20*cm # 设置图片的宽度
img.drawHeight = 10*cm # 设置图片的高度
return img
# 评估指标不在一个库,这里列出所有用到的指标的公式
# MSE
def mse(y_true, y_pred):
res_mse = metrics.mean_squared_error(y_true, y_pred)
return res_mse
# RMSE
def rmse(y_true, y_pred):
res_rmse = np.sqrt(metrics.mean_squared_error(y_true, y_pred))
return res_rmse
# MAE
def mae(y_true, y_pred):
res_mae = metrics.mean_absolute_error(y_true, y_pred)
return res_mae
# sklearn的库中没有MAPE和SMAPE下面根据公式给出算法实现
# MAPE
def mape(y_true, y_pred):
res_mape = np.mean(np.abs((y_pred - y_true) / y_true)) * 100
return res_mape
# SMAPE
def smape(y_true, y_pred):
res_smape = 2.0 * np.mean(np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true))) * 100
return res_smape
# 相关系数绘制
def plot_corr(data, size=11):
# 去掉ds列
data.drop(columns=['ds'], inplace=True)
# 创建一个空的 DataFrame 来保存相关系数
correlation_df = pd.DataFrame(columns=['Feature', 'Correlation'])
# 计算各特征与目标列的皮尔逊相关系数,并保存到新的 DataFrame 中
for col in data.columns:
if col!= 'y':
pearson_correlation = np.corrcoef(data[col], data['y'])[0, 1]
spearman_correlation, _ = spearmanr(data[col], data['y'])
new_row = {'Feature': col, 'Pearson_Correlation': round(pearson_correlation,3), 'Spearman_Correlation': round(spearman_correlation,2)}
correlation_df = correlation_df._append(new_row, ignore_index=True)
# 删除空列
correlation_df.drop('Correlation', axis=1, inplace=True)
correlation_df.dropna(inplace=True)
correlation_df.to_csv('指标相关性分析.csv', index=False)
data = correlation_df['Pearson_Correlation'].values.tolist()
# 生成 -1 到 1 的 20 个区间
bins = np.linspace(-1, 1, 21)
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
#设置画布大小
plt.figure(figsize=(10, 6))
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('皮尔逊相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('皮尔逊相关性系数.png')
plt.close()
#设置画布大小
plt.figure(figsize=(10, 6))
data = correlation_df['Spearman_Correlation'].values.tolist()
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('斯皮尔曼相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('斯皮尔曼相关性系数.png')
plt.close()
# 邮件封装
class SendMail(object):
def __init__(self,username,passwd,recv,title,content,
file=None,ssl=False,
email_host='smtp.qq.com',port=25,ssl_port=465):
'''
:param username: 用户名
:param passwd: 密码
:param recv: 收件人多个要传list ['a@qq.com','b@qq.com]
:param title: 邮件标题
:param content: 邮件正文
:param file: 附件路径如果不在当前目录下要写绝对路径默认没有附件
:param ssl: 是否安全链接默认为普通
:param email_host: smtp服务器地址默认为163服务器
:param port: 非安全链接端口默认为25
:param ssl_port: 安全链接端口默认为465
'''
self.username = username #用户名
self.passwd = passwd #密码
self.recv = recv #收件人多个要传list ['a@qq.com','b@qq.com]
self.title = title #邮件标题
self.content = content #邮件正文
self.file = file #附件路径,如果不在当前目录下,要写绝对路径
self.email_host = email_host #smtp服务器地址
self.port = port #普通端口
self.ssl = ssl #是否安全链接
self.ssl_port = ssl_port #安全链接端口
def send_mail(self):
msg = MIMEMultipart()
#发送内容的对象
if self.file:#处理附件的
file_name = os.path.split(self.file)[-1]#只取文件名,不取路径
try:
f = open(self.file, 'rb').read()
except Exception as e:
raise Exception('附件打不开!!!!')
else:
att = MIMEText(f,"base64", "utf-8")
att["Content-Type"] = 'application/octet-stream'
#base64.b64encode(file_name.encode()).decode()
new_file_name='=?utf-8?b?' + base64.b64encode(file_name.encode()).decode() + '?='
#这里是处理文件名为中文名的,必须这么写
att["Content-Disposition"] = 'attachment; filename="%s"'%(new_file_name)
msg.attach(att)
msg.attach(MIMEText(self.content))#邮件正文的内容
msg['Subject'] = self.title # 邮件主题
msg['From'] = self.username # 发送者账号
msg['To'] = ','.join(self.recv) # 接收者账号列表
if self.ssl:
self.smtp = smtplib.SMTP_SSL(self.email_host,port=self.ssl_port)
else:
self.smtp = smtplib.SMTP(self.email_host,port=self.port)
#发送邮件服务器的对象
self.smtp.login(self.username,self.passwd)
try:
self.smtp.sendmail(self.username,self.recv,msg.as_string())
pass
except Exception as e:
print('出错了。。',e)
logger.info('邮件服务出错了。。',e)
else:
print('发送成功!')
logger.info('邮件发送成功!')
self.smtp.quit()
def dateConvert(df, datecol='ds'):
# 将date列转换为datetime类型
try:
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y-%m-%d')
except:
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
return df
class SQLiteHandler:
def __init__(self, db_name):
self.db_name = db_name
self.connection = None
self.cursor = None
def connect(self):
self.connection = sqlite3.connect(self.db_name)
self.cursor = self.connection.cursor()
def close(self):
if self.connection:
self.connection.close()
self.connection = None
self.cursor = None
def execute_query(self, query, params=None):
if params:
return self.cursor.execute(query, params)
else:
return self.cursor.execute(query)
def commit(self):
self.connection.commit()
def create_table(self, table_name, columns):
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})"
self.execute_query(query)
self.commit()
def insert_data(self, table_name, values, columns=None):
if columns:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
else:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.execute_query(query, values)
self.commit()
def select_data(self, table_name, columns=None, where_condition=None, order_by=None, limit=None):
query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table_name}"
if where_condition:
query += f" WHERE {where_condition}"
if order_by:
query += f" ORDER BY {order_by}"
if limit:
query += f" LIMIT {limit}"
results = self.execute_query(query).fetchall()
if results:
headers = [description[0] for description in self.execute_query(query).description]
return pd.DataFrame(results, columns=headers)
else:
return pd.DataFrame()
def update_data(self, table_name, set_values, where_condition):
query = f"UPDATE {table_name} SET {set_values} WHERE {where_condition}"
logger.info('更新数据sql'+ query)
self.execute_query(query)
self.commit()
def delete_data(self, table_name, where_condition):
query = f"DELETE FROM {table_name} WHERE {where_condition}"
self.execute_query(query)
self.commit()
def check_table_exists(self, table_name):
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
result = self.execute_query(query).fetchone()
return result is not None
def drop_table(self, table_name):
query = f"DROP TABLE IF EXISTS {table_name}"
self.execute_query(query)
self.commit()
2024-11-01 16:38:21 +08:00
def add_column_if_not_exists(self, table_name, column_name, column_type):
# 查询表结构
query = f"PRAGMA table_info({table_name})"
self.execute_query(query)
columns = [column[1] for column in self.cursor.fetchall()]
# 判断列是否存在
if column_name not in columns:
# 如果列不存在,则添加列
query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.execute_query(query)
self.commit()
print(f"Column '{column_name}' added to table '{table_name}' successfully.")
else:
print(f"Column '{column_name}' already exists in table '{table_name}'.")
if __name__ == '__main__':
print('This is a tool, not a script.')