PriceForecast/lib/tools.py
2024-11-14 10:21:25 +08:00

445 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from config_jingbo import logger
from sklearn import metrics
import random, string, base64, hmac, hashlib
from reportlab.pdfbase import pdfmetrics # 注册字体
from reportlab.pdfbase.ttfonts import TTFont # 字体类
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
from reportlab.lib import colors # 颜色模块
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
from reportlab.graphics.charts.legends import Legend # 图例类
from reportlab.graphics.shapes import Drawing # 绘图工具
from reportlab.lib.units import cm # 单位cm
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import sqlite3
import tkinter as tk
from tkinter import messagebox
def timeit(func):
'''计时装饰器'''
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} 函数的执行时间为: {execution_time}")
return result
return wrapper
class BinanceAPI:
'''
获取 Binance API 请求头签名
'''
def __init__(self, APPID, SECRET):
self.APPID = APPID
self.SECRET = SECRET
self.get_signature()
# 生成随机字符串作为 nonce
def generate_nonce(self, length=32):
self.nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
return self.nonce
# 获取当前时间戳(秒)
def get_timestamp(self):
return int(time.time())
# 构建待签名字符串
def build_sign_str(self):
return f'appid={self.APPID}&nonce={self.nonce}&timestamp={self.timestamp}'
# 使用 HMAC SHA-256 计算签名
def calculate_signature(self, secret, message):
return base64.urlsafe_b64encode(hmac.new(secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()).decode('utf-8')
def get_signature(self):
# 调用上述方法生成签名
self.nonce = self.generate_nonce()
self.timestamp = self.get_timestamp()
self.sign_str = self.build_sign_str()
self.signature = self.calculate_signature(self.SECRET, self.sign_str)
# return self.signature
class Graphs:
'''
pdf生成类
'''
# 绘制标题
@staticmethod
def draw_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Heading1']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 18 # 字体大小
ct.leading = 50 # 行间距
ct.textColor = colors.green # 字体颜色
ct.alignment = 1 # 居中
ct.bold = True
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制小标题
@staticmethod
def draw_little_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Normal']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 15 # 字体大小
ct.leading = 30 # 行间距
ct.textColor = colors.red # 字体颜色
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制普通段落内容
@staticmethod
def draw_text(text: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 获取普通样式
ct = style['Normal']
ct.fontName = 'SimSun'
ct.fontSize = 12
ct.wordWrap = 'CJK' # 设置自动换行
ct.alignment = 0 # 左对齐
ct.firstLineIndent = 32 # 第一行开头空格
ct.leading = 25
return Paragraph(text, ct)
# 绘制表格
@staticmethod
def draw_table(col_width,*args):
# 列宽度
col_width = col_width
style = [
('FONTNAME', (0, 0), (-1, -1), 'SimSun'), # 字体
('FONTSIZE', (0, 0), (-1, 0), 10), # 第一行的字体大小
('FONTSIZE', (0, 1), (-1, -1), 8), # 第二行到最后一行的字体大小
('BACKGROUND', (0, 0), (-1, 0), '#d5dae6'), # 设置第一行背景颜色
('ALIGN', (0, 0), (-1, -1), 'CENTER'), # 第一行水平居中
('ALIGN', (0, 1), (-1, -1), 'LEFT'), # 第二行到最后一行左右左对齐
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), # 所有表格上下居中对齐
('TEXTCOLOR', (0, 0), (-1, -1), colors.darkslategray), # 设置表格内文字颜色
('GRID', (0, 0), (-1, -1), 0.5, colors.grey), # 设置表格框线为grey色线宽为0.5
# ('SPAN', (0, 1), (0, 2)), # 合并第一列二三行
# ('SPAN', (0, 3), (0, 4)), # 合并第一列三四行
# ('SPAN', (0, 5), (0, 6)), # 合并第一列五六行
# ('SPAN', (0, 7), (0, 8)), # 合并第一列五六行
]
table = Table(args, colWidths=col_width, style=style)
return table
# 创建图表
@staticmethod
def draw_bar(bar_data: list, ax: list, items: list):
drawing = Drawing(500, 250)
bc = VerticalBarChart()
bc.x = 45 # 整个图表的x坐标
bc.y = 45 # 整个图表的y坐标
bc.height = 200 # 图表的高度
bc.width = 350 # 图表的宽度
bc.data = bar_data
bc.strokeColor = colors.black # 顶部和右边轴线的颜色
bc.valueAxis.valueMin = 5000 # 设置y坐标的最小值
bc.valueAxis.valueMax = 26000 # 设置y坐标的最大值
bc.valueAxis.valueStep = 2000 # 设置y坐标的步长
bc.categoryAxis.labels.dx = 2
bc.categoryAxis.labels.dy = -8
bc.categoryAxis.labels.angle = 20
bc.categoryAxis.categoryNames = ax
# 图示
leg = Legend()
leg.fontName = 'SimSun'
leg.alignment = 'right'
leg.boxAnchor = 'ne'
leg.x = 475 # 图例的x坐标
leg.y = 240
leg.dxTextSpace = 10
leg.columnMaximum = 3
leg.colorNamePairs = items
drawing.add(leg)
drawing.add(bc)
return drawing
# 绘制图片
@staticmethod
def draw_img(path):
img = Image(path) # 读取指定路径下的图片
img.drawWidth = 20*cm # 设置图片的宽度
img.drawHeight = 10*cm # 设置图片的高度
return img
# 评估指标不在一个库,这里列出所有用到的指标的公式
# MSE
def mse(y_true, y_pred):
res_mse = metrics.mean_squared_error(y_true, y_pred)
return res_mse
# RMSE
def rmse(y_true, y_pred):
res_rmse = np.sqrt(metrics.mean_squared_error(y_true, y_pred))
return res_rmse
# MAE
def mae(y_true, y_pred):
res_mae = metrics.mean_absolute_error(y_true, y_pred)
return res_mae
# sklearn的库中没有MAPE和SMAPE下面根据公式给出算法实现
# MAPE
def mape(y_true, y_pred):
res_mape = np.mean(np.abs((y_pred - y_true) / y_true)) * 100
return res_mape
# SMAPE
def smape(y_true, y_pred):
res_smape = 2.0 * np.mean(np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true))) * 100
return res_smape
# 相关系数绘制
def plot_corr(data, size=11):
# 去掉ds列
data.drop(columns=['ds'], inplace=True)
# 创建一个空的 DataFrame 来保存相关系数
correlation_df = pd.DataFrame(columns=['Feature', 'Correlation'])
# 计算各特征与目标列的皮尔逊相关系数,并保存到新的 DataFrame 中
for col in data.columns:
if col!= 'y':
pearson_correlation = np.corrcoef(data[col], data['y'])[0, 1]
spearman_correlation, _ = spearmanr(data[col], data['y'])
new_row = {'Feature': col, 'Pearson_Correlation': round(pearson_correlation,3), 'Spearman_Correlation': round(spearman_correlation,2)}
correlation_df = correlation_df._append(new_row, ignore_index=True)
# 删除空列
correlation_df.drop('Correlation', axis=1, inplace=True)
correlation_df.dropna(inplace=True)
correlation_df.to_csv('指标相关性分析.csv', index=False)
data = correlation_df['Pearson_Correlation'].values.tolist()
# 生成 -1 到 1 的 20 个区间
bins = np.linspace(-1, 1, 21)
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
#设置画布大小
plt.figure(figsize=(10, 6))
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('皮尔逊相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('皮尔逊相关性系数.png')
plt.close()
#设置画布大小
plt.figure(figsize=(10, 6))
data = correlation_df['Spearman_Correlation'].values.tolist()
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('斯皮尔曼相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('斯皮尔曼相关性系数.png')
plt.close()
# 邮件封装
class SendMail(object):
def __init__(self,username,passwd,recv,title,content,
file=None,ssl=False,
email_host='smtp.qq.com',port=25,ssl_port=465):
'''
:param username: 用户名
:param passwd: 密码
:param recv: 收件人多个要传list ['a@qq.com','b@qq.com]
:param title: 邮件标题
:param content: 邮件正文
:param file: 附件路径,如果不在当前目录下,要写绝对路径,默认没有附件
:param ssl: 是否安全链接,默认为普通
:param email_host: smtp服务器地址默认为163服务器
:param port: 非安全链接端口默认为25
:param ssl_port: 安全链接端口默认为465
'''
self.username = username #用户名
self.passwd = passwd #密码
self.recv = recv #收件人多个要传list ['a@qq.com','b@qq.com]
self.title = title #邮件标题
self.content = content #邮件正文
self.file = file #附件路径,如果不在当前目录下,要写绝对路径
self.email_host = email_host #smtp服务器地址
self.port = port #普通端口
self.ssl = ssl #是否安全链接
self.ssl_port = ssl_port #安全链接端口
def send_mail(self):
msg = MIMEMultipart()
#发送内容的对象
if self.file:#处理附件的
file_name = os.path.split(self.file)[-1]#只取文件名,不取路径
try:
f = open(self.file, 'rb').read()
except Exception as e:
raise Exception('附件打不开!!!!')
else:
att = MIMEText(f,"base64", "utf-8")
att["Content-Type"] = 'application/octet-stream'
#base64.b64encode(file_name.encode()).decode()
new_file_name='=?utf-8?b?' + base64.b64encode(file_name.encode()).decode() + '?='
#这里是处理文件名为中文名的,必须这么写
att["Content-Disposition"] = 'attachment; filename="%s"'%(new_file_name)
msg.attach(att)
msg.attach(MIMEText(self.content))#邮件正文的内容
msg['Subject'] = self.title # 邮件主题
msg['From'] = self.username # 发送者账号
msg['To'] = ','.join(self.recv) # 接收者账号列表
if self.ssl:
self.smtp = smtplib.SMTP_SSL(self.email_host,port=self.ssl_port)
else:
self.smtp = smtplib.SMTP(self.email_host,port=self.port)
#发送邮件服务器的对象
self.smtp.login(self.username,self.passwd)
try:
self.smtp.sendmail(self.username,self.recv,msg.as_string())
pass
except Exception as e:
print('出错了。。',e)
logger.info('邮件服务出错了。。',e)
else:
print('发送成功!')
logger.info('邮件发送成功!')
self.smtp.quit()
def dateConvert(df, datecol='ds'):
# 将date列转换为datetime类型
try:
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y-%m-%d')
except:
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
return df
class SQLiteHandler:
def __init__(self, db_name):
self.db_name = db_name
self.connection = None
self.cursor = None
def connect(self):
self.connection = sqlite3.connect(self.db_name)
self.cursor = self.connection.cursor()
def close(self):
if self.connection:
self.connection.close()
self.connection = None
self.cursor = None
def execute_query(self, query, params=None):
if params:
return self.cursor.execute(query, params)
else:
return self.cursor.execute(query)
def commit(self):
self.connection.commit()
def create_table(self, table_name, columns):
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})"
self.execute_query(query)
self.commit()
def insert_data(self, table_name, values, columns=None):
if columns:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
else:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.execute_query(query, values)
self.commit()
def select_data(self, table_name, columns=None, where_condition=None, order_by=None, limit=None):
query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table_name}"
if where_condition:
query += f" WHERE {where_condition}"
if order_by:
query += f" ORDER BY {order_by}"
if limit:
query += f" LIMIT {limit}"
results = self.execute_query(query).fetchall()
if results:
headers = [description[0] for description in self.execute_query(query).description]
return pd.DataFrame(results, columns=headers)
else:
return pd.DataFrame()
def update_data(self, table_name, set_values, where_condition):
query = f"UPDATE {table_name} SET {set_values} WHERE {where_condition}"
logger.info('更新数据sql'+ query)
self.execute_query(query)
self.commit()
def delete_data(self, table_name, where_condition):
query = f"DELETE FROM {table_name} WHERE {where_condition}"
self.execute_query(query)
self.commit()
def check_table_exists(self, table_name):
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
result = self.execute_query(query).fetchone()
return result is not None
def add_column_if_not_exists(self, table_name, column_name, column_type):
# 查询表结构
query = f"PRAGMA table_info({table_name})"
self.execute_query(query)
columns = [column[1] for column in self.cursor.fetchall()]
# 判断列是否存在
if column_name not in columns:
# 如果列不存在,则添加列
query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.execute_query(query)
self.commit()
print(f"Column '{column_name}' added to table '{table_name}' successfully.")
else:
print(f"Column '{column_name}' already exists in table '{table_name}'.")
if __name__ == '__main__':
print('This is a tool, not a script.')