2024-11-01 16:38:21 +08:00
|
|
|
|
import time
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import seaborn as sns
|
|
|
|
|
from sklearn import metrics
|
|
|
|
|
import random, string, base64, hmac, hashlib
|
|
|
|
|
from reportlab.pdfbase import pdfmetrics # 注册字体
|
|
|
|
|
from reportlab.pdfbase.ttfonts import TTFont # 字体类
|
|
|
|
|
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
|
|
|
|
|
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
|
|
|
|
|
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
|
|
|
|
|
from reportlab.lib import colors # 颜色模块
|
|
|
|
|
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
|
|
|
|
|
from reportlab.graphics.charts.legends import Legend # 图例类
|
|
|
|
|
from reportlab.graphics.shapes import Drawing # 绘图工具
|
|
|
|
|
from reportlab.lib.units import cm # 单位:cm
|
|
|
|
|
import smtplib
|
|
|
|
|
from email.mime.text import MIMEText
|
|
|
|
|
from email.mime.multipart import MIMEMultipart
|
|
|
|
|
import sqlite3
|
2024-12-02 13:56:46 +08:00
|
|
|
|
import pymysql
|
2024-11-01 16:38:21 +08:00
|
|
|
|
import tkinter as tk
|
|
|
|
|
from tkinter import messagebox
|
|
|
|
|
|
2024-12-02 13:56:46 +08:00
|
|
|
|
global logger
|
2024-11-01 16:38:21 +08:00
|
|
|
|
def timeit(func):
|
|
|
|
|
'''计时装饰器'''
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
|
end_time = time.time()
|
|
|
|
|
execution_time = end_time - start_time
|
|
|
|
|
logger.info(f"{func.__name__} 函数的执行时间为: {execution_time} 秒")
|
|
|
|
|
return result
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
class BinanceAPI:
|
|
|
|
|
'''
|
|
|
|
|
获取 Binance API 请求头签名
|
|
|
|
|
'''
|
|
|
|
|
def __init__(self, APPID, SECRET):
|
|
|
|
|
self.APPID = APPID
|
|
|
|
|
self.SECRET = SECRET
|
|
|
|
|
self.get_signature()
|
|
|
|
|
|
|
|
|
|
# 生成随机字符串作为 nonce
|
|
|
|
|
def generate_nonce(self, length=32):
|
|
|
|
|
self.nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
|
|
|
|
|
return self.nonce
|
|
|
|
|
|
|
|
|
|
# 获取当前时间戳(秒)
|
|
|
|
|
def get_timestamp(self):
|
|
|
|
|
return int(time.time())
|
|
|
|
|
|
|
|
|
|
# 构建待签名字符串
|
|
|
|
|
def build_sign_str(self):
|
|
|
|
|
return f'appid={self.APPID}&nonce={self.nonce}×tamp={self.timestamp}'
|
|
|
|
|
|
|
|
|
|
# 使用 HMAC SHA-256 计算签名
|
|
|
|
|
def calculate_signature(self, secret, message):
|
|
|
|
|
return base64.urlsafe_b64encode(hmac.new(secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()).decode('utf-8')
|
|
|
|
|
|
|
|
|
|
def get_signature(self):
|
|
|
|
|
# 调用上述方法生成签名
|
|
|
|
|
self.nonce = self.generate_nonce()
|
|
|
|
|
self.timestamp = self.get_timestamp()
|
|
|
|
|
self.sign_str = self.build_sign_str()
|
|
|
|
|
self.signature = self.calculate_signature(self.SECRET, self.sign_str)
|
|
|
|
|
# return self.signature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Graphs:
|
|
|
|
|
'''
|
|
|
|
|
pdf生成类
|
|
|
|
|
'''
|
|
|
|
|
# 绘制标题
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_title(title: str):
|
|
|
|
|
# 获取所有样式表
|
|
|
|
|
style = getSampleStyleSheet()
|
|
|
|
|
# 拿到标题样式
|
|
|
|
|
ct = style['Heading1']
|
|
|
|
|
# 单独设置样式相关属性
|
|
|
|
|
ct.fontName = 'SimSun' # 字体名
|
|
|
|
|
ct.fontSize = 18 # 字体大小
|
|
|
|
|
ct.leading = 50 # 行间距
|
|
|
|
|
ct.textColor = colors.green # 字体颜色
|
|
|
|
|
ct.alignment = 1 # 居中
|
|
|
|
|
ct.bold = True
|
|
|
|
|
# 创建标题对应的段落,并且返回
|
|
|
|
|
return Paragraph(title, ct)
|
|
|
|
|
|
|
|
|
|
# 绘制小标题
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_little_title(title: str):
|
|
|
|
|
# 获取所有样式表
|
|
|
|
|
style = getSampleStyleSheet()
|
|
|
|
|
# 拿到标题样式
|
|
|
|
|
ct = style['Normal']
|
|
|
|
|
# 单独设置样式相关属性
|
|
|
|
|
ct.fontName = 'SimSun' # 字体名
|
|
|
|
|
ct.fontSize = 15 # 字体大小
|
|
|
|
|
ct.leading = 30 # 行间距
|
|
|
|
|
ct.textColor = colors.red # 字体颜色
|
|
|
|
|
# 创建标题对应的段落,并且返回
|
|
|
|
|
return Paragraph(title, ct)
|
|
|
|
|
|
|
|
|
|
# 绘制普通段落内容
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_text(text: str):
|
|
|
|
|
# 获取所有样式表
|
|
|
|
|
style = getSampleStyleSheet()
|
|
|
|
|
# 获取普通样式
|
|
|
|
|
ct = style['Normal']
|
|
|
|
|
ct.fontName = 'SimSun'
|
|
|
|
|
ct.fontSize = 12
|
|
|
|
|
ct.wordWrap = 'CJK' # 设置自动换行
|
|
|
|
|
ct.alignment = 0 # 左对齐
|
|
|
|
|
ct.firstLineIndent = 32 # 第一行开头空格
|
|
|
|
|
ct.leading = 25
|
|
|
|
|
return Paragraph(text, ct)
|
|
|
|
|
|
|
|
|
|
# 绘制表格
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_table(col_width,*args):
|
|
|
|
|
# 列宽度
|
|
|
|
|
col_width = col_width
|
|
|
|
|
style = [
|
|
|
|
|
('FONTNAME', (0, 0), (-1, -1), 'SimSun'), # 字体
|
|
|
|
|
('FONTSIZE', (0, 0), (-1, 0), 10), # 第一行的字体大小
|
|
|
|
|
('FONTSIZE', (0, 1), (-1, -1), 8), # 第二行到最后一行的字体大小
|
|
|
|
|
('BACKGROUND', (0, 0), (-1, 0), '#d5dae6'), # 设置第一行背景颜色
|
|
|
|
|
('ALIGN', (0, 0), (-1, -1), 'CENTER'), # 第一行水平居中
|
|
|
|
|
('ALIGN', (0, 1), (-1, -1), 'LEFT'), # 第二行到最后一行左右左对齐
|
|
|
|
|
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), # 所有表格上下居中对齐
|
|
|
|
|
('TEXTCOLOR', (0, 0), (-1, -1), colors.darkslategray), # 设置表格内文字颜色
|
|
|
|
|
('GRID', (0, 0), (-1, -1), 0.5, colors.grey), # 设置表格框线为grey色,线宽为0.5
|
|
|
|
|
# ('SPAN', (0, 1), (0, 2)), # 合并第一列二三行
|
|
|
|
|
# ('SPAN', (0, 3), (0, 4)), # 合并第一列三四行
|
|
|
|
|
# ('SPAN', (0, 5), (0, 6)), # 合并第一列五六行
|
|
|
|
|
# ('SPAN', (0, 7), (0, 8)), # 合并第一列五六行
|
|
|
|
|
]
|
|
|
|
|
table = Table(args, colWidths=col_width, style=style)
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
|
|
# 创建图表
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_bar(bar_data: list, ax: list, items: list):
|
|
|
|
|
drawing = Drawing(500, 250)
|
|
|
|
|
bc = VerticalBarChart()
|
|
|
|
|
bc.x = 45 # 整个图表的x坐标
|
|
|
|
|
bc.y = 45 # 整个图表的y坐标
|
|
|
|
|
bc.height = 200 # 图表的高度
|
|
|
|
|
bc.width = 350 # 图表的宽度
|
|
|
|
|
bc.data = bar_data
|
|
|
|
|
bc.strokeColor = colors.black # 顶部和右边轴线的颜色
|
|
|
|
|
bc.valueAxis.valueMin = 5000 # 设置y坐标的最小值
|
|
|
|
|
bc.valueAxis.valueMax = 26000 # 设置y坐标的最大值
|
|
|
|
|
bc.valueAxis.valueStep = 2000 # 设置y坐标的步长
|
|
|
|
|
bc.categoryAxis.labels.dx = 2
|
|
|
|
|
bc.categoryAxis.labels.dy = -8
|
|
|
|
|
bc.categoryAxis.labels.angle = 20
|
|
|
|
|
bc.categoryAxis.categoryNames = ax
|
|
|
|
|
|
|
|
|
|
# 图示
|
|
|
|
|
leg = Legend()
|
|
|
|
|
leg.fontName = 'SimSun'
|
|
|
|
|
leg.alignment = 'right'
|
|
|
|
|
leg.boxAnchor = 'ne'
|
|
|
|
|
leg.x = 475 # 图例的x坐标
|
|
|
|
|
leg.y = 240
|
|
|
|
|
leg.dxTextSpace = 10
|
|
|
|
|
leg.columnMaximum = 3
|
|
|
|
|
leg.colorNamePairs = items
|
|
|
|
|
drawing.add(leg)
|
|
|
|
|
drawing.add(bc)
|
|
|
|
|
return drawing
|
|
|
|
|
|
|
|
|
|
# 绘制图片
|
|
|
|
|
@staticmethod
|
|
|
|
|
def draw_img(path):
|
|
|
|
|
img = Image(path) # 读取指定路径下的图片
|
|
|
|
|
img.drawWidth = 20*cm # 设置图片的宽度
|
|
|
|
|
img.drawHeight = 10*cm # 设置图片的高度
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 评估指标不在一个库,这里列出所有用到的指标的公式
|
|
|
|
|
|
|
|
|
|
# MSE
|
|
|
|
|
def mse(y_true, y_pred):
|
|
|
|
|
|
|
|
|
|
res_mse = metrics.mean_squared_error(y_true, y_pred)
|
|
|
|
|
|
|
|
|
|
return res_mse
|
|
|
|
|
# RMSE
|
|
|
|
|
def rmse(y_true, y_pred):
|
|
|
|
|
|
|
|
|
|
res_rmse = np.sqrt(metrics.mean_squared_error(y_true, y_pred))
|
|
|
|
|
|
|
|
|
|
return res_rmse
|
|
|
|
|
|
|
|
|
|
# MAE
|
|
|
|
|
def mae(y_true, y_pred):
|
|
|
|
|
|
|
|
|
|
res_mae = metrics.mean_absolute_error(y_true, y_pred)
|
|
|
|
|
|
|
|
|
|
return res_mae
|
|
|
|
|
|
|
|
|
|
# sklearn的库中没有MAPE和SMAPE,下面根据公式给出算法实现
|
|
|
|
|
# MAPE
|
|
|
|
|
def mape(y_true, y_pred):
|
|
|
|
|
|
|
|
|
|
res_mape = np.mean(np.abs((y_pred - y_true) / y_true)) * 100
|
|
|
|
|
|
|
|
|
|
return res_mape
|
|
|
|
|
|
|
|
|
|
# SMAPE
|
|
|
|
|
def smape(y_true, y_pred):
|
|
|
|
|
|
|
|
|
|
res_smape = 2.0 * np.mean(np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true))) * 100
|
|
|
|
|
|
|
|
|
|
return res_smape
|
|
|
|
|
|
|
|
|
|
# 相关系数绘制
|
|
|
|
|
def plot_corr(data, size=11):
|
|
|
|
|
# 去掉ds列
|
|
|
|
|
data.drop(columns=['ds'], inplace=True)
|
|
|
|
|
|
|
|
|
|
# 创建一个空的 DataFrame 来保存相关系数
|
|
|
|
|
correlation_df = pd.DataFrame(columns=['Feature', 'Correlation'])
|
|
|
|
|
|
|
|
|
|
# 计算各特征与目标列的皮尔逊相关系数,并保存到新的 DataFrame 中
|
|
|
|
|
for col in data.columns:
|
|
|
|
|
if col!= 'y':
|
|
|
|
|
pearson_correlation = np.corrcoef(data[col], data['y'])[0, 1]
|
|
|
|
|
spearman_correlation, _ = spearmanr(data[col], data['y'])
|
|
|
|
|
new_row = {'Feature': col, 'Pearson_Correlation': round(pearson_correlation,3), 'Spearman_Correlation': round(spearman_correlation,2)}
|
|
|
|
|
correlation_df = correlation_df._append(new_row, ignore_index=True)
|
|
|
|
|
# 删除空列
|
|
|
|
|
correlation_df.drop('Correlation', axis=1, inplace=True)
|
|
|
|
|
correlation_df.dropna(inplace=True)
|
|
|
|
|
correlation_df.to_csv('指标相关性分析.csv', index=False)
|
|
|
|
|
|
|
|
|
|
data = correlation_df['Pearson_Correlation'].values.tolist()
|
|
|
|
|
# 生成 -1 到 1 的 20 个区间
|
|
|
|
|
bins = np.linspace(-1, 1, 21)
|
|
|
|
|
# 计算每个区间的统计数(这里是区间内数据的数量)
|
|
|
|
|
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
|
|
|
|
|
|
|
|
|
|
#设置画布大小
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
# 绘制直方图
|
|
|
|
|
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
|
|
|
|
|
|
|
|
|
|
# 添加标题和坐标轴标签
|
|
|
|
|
plt.title('皮尔逊相关系数分布图')
|
|
|
|
|
plt.xlabel('区间')
|
|
|
|
|
plt.ylabel('统计数')
|
|
|
|
|
plt.savefig('皮尔逊相关性系数.png')
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#设置画布大小
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
data = correlation_df['Spearman_Correlation'].values.tolist()
|
|
|
|
|
# 计算每个区间的统计数(这里是区间内数据的数量)
|
|
|
|
|
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)]
|
|
|
|
|
|
|
|
|
|
# 绘制直方图
|
|
|
|
|
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
|
|
|
|
|
|
|
|
|
|
# 添加标题和坐标轴标签
|
|
|
|
|
plt.title('斯皮尔曼相关系数分布图')
|
|
|
|
|
plt.xlabel('区间')
|
|
|
|
|
plt.ylabel('统计数')
|
|
|
|
|
plt.savefig('斯皮尔曼相关性系数.png')
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 邮件封装
|
|
|
|
|
class SendMail(object):
|
|
|
|
|
def __init__(self,username,passwd,recv,title,content,
|
|
|
|
|
file=None,ssl=False,
|
|
|
|
|
email_host='smtp.qq.com',port=25,ssl_port=465):
|
|
|
|
|
'''
|
|
|
|
|
:param username: 用户名
|
|
|
|
|
:param passwd: 密码
|
|
|
|
|
:param recv: 收件人,多个要传list ['a@qq.com','b@qq.com]
|
|
|
|
|
:param title: 邮件标题
|
|
|
|
|
:param content: 邮件正文
|
|
|
|
|
:param file: 附件路径,如果不在当前目录下,要写绝对路径,默认没有附件
|
|
|
|
|
:param ssl: 是否安全链接,默认为普通
|
|
|
|
|
:param email_host: smtp服务器地址,默认为163服务器
|
|
|
|
|
:param port: 非安全链接端口,默认为25
|
|
|
|
|
:param ssl_port: 安全链接端口,默认为465
|
|
|
|
|
'''
|
|
|
|
|
self.username = username #用户名
|
|
|
|
|
self.passwd = passwd #密码
|
|
|
|
|
self.recv = recv #收件人,多个要传list ['a@qq.com','b@qq.com]
|
|
|
|
|
self.title = title #邮件标题
|
|
|
|
|
self.content = content #邮件正文
|
|
|
|
|
self.file = file #附件路径,如果不在当前目录下,要写绝对路径
|
|
|
|
|
self.email_host = email_host #smtp服务器地址
|
|
|
|
|
self.port = port #普通端口
|
|
|
|
|
self.ssl = ssl #是否安全链接
|
|
|
|
|
self.ssl_port = ssl_port #安全链接端口
|
|
|
|
|
def send_mail(self):
|
|
|
|
|
msg = MIMEMultipart()
|
|
|
|
|
#发送内容的对象
|
|
|
|
|
if self.file:#处理附件的
|
|
|
|
|
file_name = os.path.split(self.file)[-1]#只取文件名,不取路径
|
|
|
|
|
try:
|
|
|
|
|
f = open(self.file, 'rb').read()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise Exception('附件打不开!!!!')
|
|
|
|
|
else:
|
|
|
|
|
att = MIMEText(f,"base64", "utf-8")
|
|
|
|
|
att["Content-Type"] = 'application/octet-stream'
|
|
|
|
|
#base64.b64encode(file_name.encode()).decode()
|
|
|
|
|
new_file_name='=?utf-8?b?' + base64.b64encode(file_name.encode()).decode() + '?='
|
|
|
|
|
#这里是处理文件名为中文名的,必须这么写
|
|
|
|
|
att["Content-Disposition"] = 'attachment; filename="%s"'%(new_file_name)
|
|
|
|
|
msg.attach(att)
|
|
|
|
|
msg.attach(MIMEText(self.content))#邮件正文的内容
|
|
|
|
|
msg['Subject'] = self.title # 邮件主题
|
|
|
|
|
msg['From'] = self.username # 发送者账号
|
|
|
|
|
msg['To'] = ','.join(self.recv) # 接收者账号列表
|
|
|
|
|
if self.ssl:
|
|
|
|
|
self.smtp = smtplib.SMTP_SSL(self.email_host,port=self.ssl_port)
|
|
|
|
|
else:
|
|
|
|
|
self.smtp = smtplib.SMTP(self.email_host,port=self.port)
|
|
|
|
|
#发送邮件服务器的对象
|
|
|
|
|
self.smtp.login(self.username,self.passwd)
|
|
|
|
|
try:
|
|
|
|
|
self.smtp.sendmail(self.username,self.recv,msg.as_string())
|
|
|
|
|
pass
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print('出错了。。',e)
|
|
|
|
|
logger.info('邮件服务出错了。。',e)
|
|
|
|
|
else:
|
|
|
|
|
print('发送成功!')
|
|
|
|
|
self.smtp.quit()
|
|
|
|
|
|
|
|
|
|
def dateConvert(df, datecol='ds'):
|
|
|
|
|
# 将date列转换为datetime类型
|
|
|
|
|
try:
|
|
|
|
|
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y-%m-%d')
|
|
|
|
|
except:
|
|
|
|
|
df[datecol] = pd.to_datetime(df[datecol],format=r'%Y/%m/%d')
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLiteHandler:
|
|
|
|
|
def __init__(self, db_name):
|
|
|
|
|
self.db_name = db_name
|
|
|
|
|
self.connection = None
|
|
|
|
|
self.cursor = None
|
|
|
|
|
|
|
|
|
|
def connect(self):
|
|
|
|
|
self.connection = sqlite3.connect(self.db_name)
|
|
|
|
|
self.cursor = self.connection.cursor()
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
if self.connection:
|
|
|
|
|
self.connection.close()
|
|
|
|
|
self.connection = None
|
|
|
|
|
self.cursor = None
|
|
|
|
|
|
|
|
|
|
def execute_query(self, query, params=None):
|
|
|
|
|
if params:
|
|
|
|
|
return self.cursor.execute(query, params)
|
|
|
|
|
else:
|
|
|
|
|
return self.cursor.execute(query)
|
|
|
|
|
|
|
|
|
|
def commit(self):
|
|
|
|
|
self.connection.commit()
|
|
|
|
|
|
|
|
|
|
def create_table(self, table_name, columns):
|
|
|
|
|
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
self.commit()
|
|
|
|
|
|
|
|
|
|
def insert_data(self, table_name, values, columns=None):
|
|
|
|
|
if columns:
|
|
|
|
|
placeholders = ', '.join(['?'] * len(values))
|
|
|
|
|
query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
|
|
|
|
|
else:
|
|
|
|
|
placeholders = ', '.join(['?'] * len(values))
|
|
|
|
|
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
|
|
|
|
|
self.execute_query(query, values)
|
|
|
|
|
self.commit()
|
|
|
|
|
|
|
|
|
|
def select_data(self, table_name, columns=None, where_condition=None, order_by=None, limit=None):
|
|
|
|
|
query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table_name}"
|
|
|
|
|
if where_condition:
|
|
|
|
|
query += f" WHERE {where_condition}"
|
|
|
|
|
if order_by:
|
|
|
|
|
query += f" ORDER BY {order_by}"
|
|
|
|
|
if limit:
|
|
|
|
|
query += f" LIMIT {limit}"
|
|
|
|
|
results = self.execute_query(query).fetchall()
|
|
|
|
|
if results:
|
|
|
|
|
headers = [description[0] for description in self.execute_query(query).description]
|
|
|
|
|
return pd.DataFrame(results, columns=headers)
|
|
|
|
|
else:
|
|
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
def update_data(self, table_name, set_values, where_condition):
|
|
|
|
|
query = f"UPDATE {table_name} SET {set_values} WHERE {where_condition}"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
self.commit()
|
|
|
|
|
|
|
|
|
|
def delete_data(self, table_name, where_condition):
|
|
|
|
|
query = f"DELETE FROM {table_name} WHERE {where_condition}"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
self.commit()
|
|
|
|
|
|
|
|
|
|
def check_table_exists(self, table_name):
|
|
|
|
|
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
|
|
|
|
|
result = self.execute_query(query).fetchone()
|
|
|
|
|
return result is not None
|
|
|
|
|
|
2024-11-26 13:18:25 +08:00
|
|
|
|
def drop_table(self, table_name):
|
|
|
|
|
query = f"DROP TABLE IF EXISTS {table_name}"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
self.commit()
|
|
|
|
|
|
2024-11-01 16:38:21 +08:00
|
|
|
|
def add_column_if_not_exists(self, table_name, column_name, column_type):
|
|
|
|
|
# 查询表结构
|
|
|
|
|
query = f"PRAGMA table_info({table_name})"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
columns = [column[1] for column in self.cursor.fetchall()]
|
|
|
|
|
|
|
|
|
|
# 判断列是否存在
|
|
|
|
|
if column_name not in columns:
|
|
|
|
|
# 如果列不存在,则添加列
|
|
|
|
|
query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
|
|
|
|
|
self.execute_query(query)
|
|
|
|
|
self.commit()
|
|
|
|
|
print(f"Column '{column_name}' added to table '{table_name}' successfully.")
|
|
|
|
|
else:
|
|
|
|
|
print(f"Column '{column_name}' already exists in table '{table_name}'.")
|
|
|
|
|
|
2024-12-02 13:56:46 +08:00
|
|
|
|
import logging
|
|
|
|
|
class MySQLDB:
|
|
|
|
|
def __init__(self, host, user, password, database):
|
|
|
|
|
self.host = host
|
|
|
|
|
self.user = user
|
|
|
|
|
self.password = password
|
|
|
|
|
self.database = database
|
|
|
|
|
self.connection = None
|
|
|
|
|
self.cursor = None
|
|
|
|
|
|
|
|
|
|
def connect(self):
|
|
|
|
|
try:
|
|
|
|
|
self.connection = pymysql.connect(
|
|
|
|
|
host=self.host,
|
|
|
|
|
user=self.user,
|
|
|
|
|
password=self.password,
|
|
|
|
|
database=self.database,
|
|
|
|
|
charset='utf8mb4',
|
|
|
|
|
cursorclass=pymysql.cursors.DictCursor
|
|
|
|
|
)
|
|
|
|
|
self.cursor = self.connection.cursor()
|
|
|
|
|
logging.info("Connected to the database successfully.")
|
|
|
|
|
except pymysql.Error as e:
|
|
|
|
|
logging.error(f"Error connecting to the database: {e}")
|
|
|
|
|
|
|
|
|
|
def execute_query(self, query):
|
|
|
|
|
try:
|
|
|
|
|
self.cursor.execute(query)
|
|
|
|
|
result = self.cursor.fetchall()
|
|
|
|
|
return result
|
|
|
|
|
except pymysql.Error as e:
|
|
|
|
|
logging.error(f"Error executing query: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def execute_insert(self, query, values):
|
|
|
|
|
try:
|
|
|
|
|
self.cursor.execute(query, values)
|
|
|
|
|
self.connection.commit()
|
|
|
|
|
logging.info("Insert operation successful.")
|
|
|
|
|
except pymysql.Error as e:
|
|
|
|
|
logging.error(f"Error executing insert: {e}")
|
|
|
|
|
self.connection.rollback()
|
|
|
|
|
|
|
|
|
|
def execute_update(self, query, values):
|
|
|
|
|
try:
|
|
|
|
|
self.cursor.execute(query, values)
|
|
|
|
|
self.connection.commit()
|
|
|
|
|
logging.info("Update operation successful.")
|
|
|
|
|
except pymysql.Error as e:
|
|
|
|
|
logging.error(f"Error executing update: {e}")
|
|
|
|
|
self.connection.rollback()
|
|
|
|
|
|
|
|
|
|
def execute_delete(self, query, values):
|
|
|
|
|
try:
|
|
|
|
|
self.cursor.execute(query, values)
|
|
|
|
|
self.connection.commit()
|
|
|
|
|
logging.info("Delete operation successful.")
|
|
|
|
|
except pymysql.Error as e:
|
|
|
|
|
logging.error(f"Error executing delete: {e}")
|
|
|
|
|
self.connection.rollback()
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
if self.cursor:
|
|
|
|
|
self.cursor.close()
|
|
|
|
|
if self.connection:
|
|
|
|
|
self.connection.close()
|
|
|
|
|
logging.info("Database connection closed.")
|
|
|
|
|
|
2024-12-18 17:49:23 +08:00
|
|
|
|
def exception_logger(func):
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
try:
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# 记录异常日志
|
|
|
|
|
logging.error(f"An error occurred in function {func.__name__}: {str(e)}")
|
|
|
|
|
# 可以选择重新抛出异常,或者在这里处理异常
|
|
|
|
|
raise e # 重新抛出异常
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-11-01 16:38:21 +08:00
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
print('This is a tool, not a script.')
|