PriceForecast/lib/tools.py
2025-07-28 18:08:52 +08:00

1137 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
from decimal import Decimal
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredURLLoader
from langchain_core.prompts import PromptTemplate
from tkinter import messagebox
import tkinter as tk
import pymysql
import sqlite3
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import smtplib
from reportlab.lib.units import cm # 单位cm
from reportlab.graphics.shapes import Drawing # 绘图工具
from reportlab.graphics.charts.legends import Legend # 图例类
from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类
from reportlab.lib import colors # 颜色模块
from reportlab.lib.styles import getSampleStyleSheet # 文本样式
from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch)
from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类
from reportlab.pdfbase.ttfonts import TTFont # 字体类
from reportlab.pdfbase import pdfmetrics # 注册字体
import hashlib
import hmac
import base64
import string
import random
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import time
import logging
from dotenv import load_dotenv
from lib.pydantic_models import PredictionResult, PpPredictionResult
load_dotenv()
global logger
def timeit(func):
'''计时装饰器'''
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} 函数的执行时间为: {execution_time}")
return result
return wrapper
class BinanceAPI:
'''
获取 Binance API 请求头签名
'''
def __init__(self, APPID, SECRET):
self.APPID = APPID
self.SECRET = SECRET
self.get_signature()
# 生成随机字符串作为 nonce
def generate_nonce(self, length=32):
self.nonce = ''.join(random.choices(
string.ascii_letters + string.digits, k=length))
return self.nonce
# 获取当前时间戳(秒)
def get_timestamp(self):
return int(time.time())
# 构建待签名字符串
def build_sign_str(self):
return f'appid={self.APPID}&nonce={self.nonce}&timestamp={self.timestamp}'
# 使用 HMAC SHA-256 计算签名
def calculate_signature(self, secret, message):
return base64.urlsafe_b64encode(hmac.new(secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()).decode('utf-8')
def get_signature(self):
# 调用上述方法生成签名
self.nonce = self.generate_nonce()
self.timestamp = self.get_timestamp()
self.sign_str = self.build_sign_str()
self.signature = self.calculate_signature(self.SECRET, self.sign_str)
# return self.signature
class Graphs:
'''
pdf生成类
'''
# 绘制标题
@staticmethod
def draw_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Heading1']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 18 # 字体大小
ct.leading = 50 # 行间距
ct.textColor = colors.green # 字体颜色
ct.alignment = 1 # 居中
ct.bold = True
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制小标题
@staticmethod
def draw_little_title(title: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 拿到标题样式
ct = style['Normal']
# 单独设置样式相关属性
ct.fontName = 'SimSun' # 字体名
ct.fontSize = 15 # 字体大小
ct.leading = 30 # 行间距
ct.textColor = colors.red # 字体颜色
# 创建标题对应的段落,并且返回
return Paragraph(title, ct)
# 绘制普通段落内容
@staticmethod
def draw_text(text: str):
# 获取所有样式表
style = getSampleStyleSheet()
# 获取普通样式
ct = style['Normal']
ct.fontName = 'SimSun'
ct.fontSize = 12
ct.wordWrap = 'CJK' # 设置自动换行
ct.alignment = 0 # 左对齐
ct.firstLineIndent = 32 # 第一行开头空格
ct.leading = 25
return Paragraph(text, ct)
# 绘制表格
@staticmethod
def draw_table(col_width, *args):
# 列宽度
col_width = col_width
style = [
('FONTNAME', (0, 0), (-1, -1), 'SimSun'), # 字体
('FONTSIZE', (0, 0), (-1, 0), 10), # 第一行的字体大小
('FONTSIZE', (0, 1), (-1, -1), 8), # 第二行到最后一行的字体大小
('BACKGROUND', (0, 0), (-1, 0), '#d5dae6'), # 设置第一行背景颜色
('ALIGN', (0, 0), (-1, -1), 'CENTER'), # 第一行水平居中
('ALIGN', (0, 1), (-1, -1), 'LEFT'), # 第二行到最后一行左右左对齐
('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), # 所有表格上下居中对齐
('TEXTCOLOR', (0, 0), (-1, -1), colors.darkslategray), # 设置表格内文字颜色
('GRID', (0, 0), (-1, -1), 0.5, colors.grey), # 设置表格框线为grey色线宽为0.5
# ('SPAN', (0, 1), (0, 2)), # 合并第一列二三行
# ('SPAN', (0, 3), (0, 4)), # 合并第一列三四行
# ('SPAN', (0, 5), (0, 6)), # 合并第一列五六行
# ('SPAN', (0, 7), (0, 8)), # 合并第一列五六行
]
table = Table(args, colWidths=col_width, style=style)
return table
# 创建图表
@staticmethod
def draw_bar(bar_data: list, ax: list, items: list):
drawing = Drawing(500, 250)
bc = VerticalBarChart()
bc.x = 45 # 整个图表的x坐标
bc.y = 45 # 整个图表的y坐标
bc.height = 200 # 图表的高度
bc.width = 350 # 图表的宽度
bc.data = bar_data
bc.strokeColor = colors.black # 顶部和右边轴线的颜色
bc.valueAxis.valueMin = 5000 # 设置y坐标的最小值
bc.valueAxis.valueMax = 26000 # 设置y坐标的最大值
bc.valueAxis.valueStep = 2000 # 设置y坐标的步长
bc.categoryAxis.labels.dx = 2
bc.categoryAxis.labels.dy = -8
bc.categoryAxis.labels.angle = 20
bc.categoryAxis.categoryNames = ax
# 图示
leg = Legend()
leg.fontName = 'SimSun'
leg.alignment = 'right'
leg.boxAnchor = 'ne'
leg.x = 475 # 图例的x坐标
leg.y = 240
leg.dxTextSpace = 10
leg.columnMaximum = 3
leg.colorNamePairs = items
drawing.add(leg)
drawing.add(bc)
return drawing
# 绘制图片
@staticmethod
def draw_img(path):
img = Image(path) # 读取指定路径下的图片
img.drawWidth = 20*cm # 设置图片的宽度
img.drawHeight = 10*cm # 设置图片的高度
return img
# 评估指标不在一个库,这里列出所有用到的指标的公式
# MSE
def mse(y_true, y_pred):
res_mse = metrics.mean_squared_error(y_true, y_pred)
return res_mse
# RMSE
def rmse(y_true, y_pred):
res_rmse = np.sqrt(metrics.mean_squared_error(y_true, y_pred))
return res_rmse
# MAE
def mae(y_true, y_pred):
res_mae = metrics.mean_absolute_error(y_true, y_pred)
return res_mae
# sklearn的库中没有MAPE和SMAPE下面根据公式给出算法实现
# MAPE
def mape(y_true, y_pred):
res_mape = np.mean(np.abs((y_pred - y_true) / y_true)) * 100
return res_mape
# SMAPE
def smape(y_true, y_pred):
res_smape = 2.0 * np.mean(np.abs(y_pred - y_true) /
(np.abs(y_pred) + np.abs(y_true))) * 100
return res_smape
# 相关系数绘制
def plot_corr(data, size=11):
# 去掉ds列
data.drop(columns=['ds'], inplace=True)
# 创建一个空的 DataFrame 来保存相关系数
correlation_df = pd.DataFrame(columns=['Feature', 'Correlation'])
# 计算各特征与目标列的皮尔逊相关系数,并保存到新的 DataFrame 中
for col in data.columns:
if col != 'y':
pearson_correlation = np.corrcoef(data[col], data['y'])[0, 1]
spearman_correlation, _ = spearmanr(data[col], data['y'])
new_row = {'Feature': col, 'Pearson_Correlation': round(
pearson_correlation, 3), 'Spearman_Correlation': round(spearman_correlation, 2)}
correlation_df = correlation_df._append(new_row, ignore_index=True)
# 删除空列
correlation_df.drop('Correlation', axis=1, inplace=True)
correlation_df.dropna(inplace=True)
correlation_df.to_csv('指标相关性分析.csv', index=False)
data = correlation_df['Pearson_Correlation'].values.tolist()
# 生成 -1 到 1 的 20 个区间
bins = np.linspace(-1, 1, 21)
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1]))
for i in range(len(bins) - 1)]
# 设置画布大小
plt.figure(figsize=(10, 6))
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('皮尔逊相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('皮尔逊相关性系数.png')
plt.close()
# 设置画布大小
plt.figure(figsize=(10, 6))
data = correlation_df['Spearman_Correlation'].values.tolist()
# 计算每个区间的统计数(这里是区间内数据的数量)
hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1]))
for i in range(len(bins) - 1)]
# 绘制直方图
plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0]))
# 添加标题和坐标轴标签
plt.title('斯皮尔曼相关系数分布图')
plt.xlabel('区间')
plt.ylabel('统计数')
plt.savefig('斯皮尔曼相关性系数.png')
plt.close()
# 邮件封装
class SendMail(object):
def __init__(self, username, passwd, recv, title, content,
file=None, ssl=False,
email_host='smtp.qq.com', port=25, ssl_port=465):
'''
:param username: 用户名
:param passwd: 密码
:param recv: 收件人多个要传list ['a@qq.com','b@qq.com]
:param title: 邮件标题
:param content: 邮件正文
:param file: 附件路径,如果不在当前目录下,要写绝对路径,默认没有附件
:param ssl: 是否安全链接,默认为普通
:param email_host: smtp服务器地址默认为163服务器
:param port: 非安全链接端口默认为25
:param ssl_port: 安全链接端口默认为465
'''
self.username = username # 用户名
self.passwd = passwd # 密码
self.recv = recv # 收件人多个要传list ['a@qq.com','b@qq.com]
self.title = title # 邮件标题
self.content = content # 邮件正文
self.file = file # 附件路径,如果不在当前目录下,要写绝对路径
self.email_host = email_host # smtp服务器地址
self.port = port # 普通端口
self.ssl = ssl # 是否安全链接
self.ssl_port = ssl_port # 安全链接端口
def send_mail(self):
msg = MIMEMultipart()
# 发送内容的对象
if self.file: # 处理附件的
file_name = os.path.split(self.file)[-1] # 只取文件名,不取路径
try:
f = open(self.file, 'rb').read()
except Exception as e:
raise Exception('附件打不开!!!!')
else:
att = MIMEText(f, "base64", "utf-8")
att["Content-Type"] = 'application/octet-stream'
# base64.b64encode(file_name.encode()).decode()
new_file_name = '=?utf-8?b?' + \
base64.b64encode(file_name.encode()).decode() + '?='
# 这里是处理文件名为中文名的,必须这么写
att["Content-Disposition"] = 'attachment; filename="%s"' % (
new_file_name)
msg.attach(att)
msg.attach(MIMEText(self.content)) # 邮件正文的内容
msg['Subject'] = self.title # 邮件主题
msg['From'] = self.username # 发送者账号
msg['To'] = ','.join(self.recv) # 接收者账号列表
if self.ssl:
self.smtp = smtplib.SMTP_SSL(self.email_host, port=self.ssl_port)
else:
self.smtp = smtplib.SMTP(self.email_host, port=self.port)
# 发送邮件服务器的对象
self.smtp.login(self.username, self.passwd)
try:
self.smtp.sendmail(self.username, self.recv, msg.as_string())
pass
except Exception as e:
print('出错了。。', e)
logger.info('邮件服务出错了。。', e)
else:
print('发送成功!')
self.smtp.quit()
def dateConvert(df, datecol='ds'):
# 将date列转换为datetime类型
try:
df[datecol] = pd.to_datetime(df[datecol], format=r'%Y-%m-%d')
except:
df[datecol] = pd.to_datetime(df[datecol], format=r'%Y/%m/%d')
return df
def save_to_database(sqlitedb, df, dbname, end_time):
'''
create_dt ,ds 判断数据是否存在,不存在则插入,存在则更新
'''
# 判断格式是否为日期时间类型
if pd.api.types.is_datetime64_any_dtype(df['ds']):
df['ds'] = df['ds'].dt.strftime('%Y-%m-%d')
if not sqlitedb.check_table_exists(dbname):
df.to_sql(dbname, sqlitedb.connection, index=False)
else:
for col in df.columns:
sqlitedb.add_column_if_not_exists(dbname, col, 'TEXT')
for row in df.itertuples(index=False):
row_dict = row._asdict()
columns = row_dict.keys()
check_query = sqlitedb.select_data(
dbname, where_condition=f"ds = '{row.ds}' and created_dt = '{end_time}'")
if len(check_query) > 0:
set_clause = ", ".join(
[f"{key} = '{value}'" for key, value in row_dict.items()])
sqlitedb.update_data(
dbname, set_clause, where_condition=f"ds = '{row.ds} and created_dt = {end_time}'")
continue
else:
sqlitedb.insert_data(dbname, tuple(
row_dict.values()), columns=columns)
class SQLiteHandler:
def __init__(self, db_name):
self.db_name = db_name
self.connection = None
self.cursor = None
def connect(self):
self.connection = sqlite3.connect(self.db_name)
self.cursor = self.connection.cursor()
def close(self):
if self.connection:
self.connection.close()
self.connection = None
self.cursor = None
def execute_query(self, query, params=None):
if params:
return self.cursor.execute(query, params)
else:
return self.cursor.execute(query)
def commit(self):
self.connection.commit()
def create_table(self, table_name, columns):
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})"
self.execute_query(query)
self.commit()
def insert_data(self, table_name, values, columns=None):
if columns:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
else:
placeholders = ', '.join(['?'] * len(values))
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.execute_query(query, values)
self.commit()
def select_data(self, table_name, columns=None, where_condition=None, order_by=None, limit=None):
query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table_name}"
if where_condition:
query += f" WHERE {where_condition}"
if order_by:
query += f" ORDER BY {order_by}"
if limit:
query += f" LIMIT {limit}"
results = self.execute_query(query).fetchall()
if results:
headers = [description[0]
for description in self.execute_query(query).description]
return pd.DataFrame(results, columns=headers)
else:
return pd.DataFrame()
def update_data(self, table_name, set_values, where_condition):
query = f"UPDATE {table_name} SET {set_values} WHERE {where_condition}"
self.execute_query(query)
self.commit()
def delete_data(self, table_name, where_condition):
query = f"DELETE FROM {table_name} WHERE {where_condition}"
self.execute_query(query)
self.commit()
def check_table_exists(self, table_name):
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
result = self.execute_query(query).fetchone()
return result is not None
def drop_table(self, table_name):
query = f"DROP TABLE IF EXISTS {table_name}"
self.execute_query(query)
self.commit()
def add_column_if_not_exists(self, table_name, column_name, column_type):
# 查询表结构
query = f"PRAGMA table_info({table_name})"
self.execute_query(query)
columns = [column[1] for column in self.cursor.fetchall()]
# 判断列是否存在
if column_name not in columns:
# 如果列不存在,则添加列
query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.execute_query(query)
self.commit()
print(
f"Column '{column_name}' added to table '{table_name}' successfully.")
else:
print(
f"Column '{column_name}' already exists in table '{table_name}'.")
class MySQLDB:
def __init__(self, host, user, password, database):
self.host = host
self.user = user
self.password = password
self.database = database
self.connection = None
self.cursor = None
def connect(self):
try:
self.connection = pymysql.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.database,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
self.cursor = self.connection.cursor()
logging.info("Connected to the database successfully.")
except pymysql.Error as e:
logging.error(f"Error connecting to the database: {e}")
def is_connected(self):
try:
self.connection.ping(reconnect=True)
return True
except pymysql.Error:
return False
def execute_query(self, query):
try:
self.cursor.execute(query)
result = self.cursor.fetchall()
return result
except pymysql.Error as e:
logging.error(f"Error executing query: {e}")
return None
def execute_insert(self, query, values):
try:
self.cursor.execute(query, values)
self.connection.commit()
logging.info("Insert operation successful.")
except pymysql.Error as e:
logging.error(f"Error executing insert: {e}")
self.connection.rollback()
def execute_batch_insert(self, query, params_list):
"""
Batch insert data
:param query: SQL insert statement
:param params_list: Parameter list, each element is a parameter for a record
:return: Number of affected rows
"""
if not self.is_connected():
print("Database is not connected, please call the connect method first")
return 0
try:
cursor = self.connection.cursor()
logging.info(f"Executing batch insert SQL: {query}")
logging.info(f"Batch insert parameters: {params_list}")
cursor.executemany(query, params_list)
self.connection.commit()
affected_rows = cursor.rowcount
cursor.close()
return affected_rows
except pymysql.Error as err:
print(f"Batch insert failed: {err}")
if self.is_connected():
self.connection.rollback()
return 0
def execute_update(self, query, values):
try:
self.cursor.execute(query, values)
self.connection.commit()
logging.info("Update operation successful.")
except pymysql.Error as e:
logging.error(f"Error executing update: {e}")
self.connection.rollback()
def execute_delete(self, query, values):
try:
self.cursor.execute(query, values)
self.connection.commit()
logging.info("Delete operation successful.")
except pymysql.Error as e:
logging.error(f"Error executing delete: {e}")
self.connection.rollback()
def close(self):
if self.cursor:
self.cursor.close()
if self.connection:
self.connection.close()
logging.info("Database connection closed.")
def exception_logger(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
# 记录异常日志
logging.error(
f"An error occurred in function {func.__name__}: {str(e)}")
# 可以选择重新抛出异常,或者在这里处理异常
raise e # 重新抛出异常
return wrapper
def get_week_date(end_time):
'''
获取上上周五,上周周一周二周三周四周五的日期
'''
import datetime
endtime = end_time
endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d')
up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14)
up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)]
create_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[4:-3]]
ds_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[-7:-2]]
return create_dates, ds_dates
def get_bdwd_date(date=''):
'''
计算当前日期date对应的明天五班后下周日下下周日下月最后一天下两月最后一天下三月最后一天下四月最后一天的日期
参数:
date (str): 输入的日期,格式为 '%Y-%m-%d',默认为空字符串,表示当前日期
返回:
dict: 包含所需日期的字典,键分别为 'tomorrow', 'five_working_days_later', 'next_sunday',
'next_next_sunday', 'next_month_last_day', 'next_two_months_last_day',
'next_three_months_last_day', 'next_four_months_last_day'
'''
import datetime
if not date:
current_date = datetime.date.today()
else:
current_date = datetime.datetime.strptime(date, '%Y-%m-%d').date()
# 计算明天的日期
tomorrow = current_date + datetime.timedelta(days=1)
# 计算五班后的日期
five_working_days_later = current_date
working_days_count = 0
while working_days_count < 5:
five_working_days_later += datetime.timedelta(days=1)
if five_working_days_later.weekday() < 5: # 周一到周五是工作日
working_days_count += 1
# 计算下周日的日期
days_to_next_sunday = (6 - current_date.weekday()) % 7
if days_to_next_sunday == 0:
days_to_next_sunday = 7
next_sunday = current_date + datetime.timedelta(days=days_to_next_sunday)
# 计算下下周周日的日期
next_next_sunday = next_sunday + datetime.timedelta(days=7)
# 计算下下下周周日的日期
next_next_next_sunday = next_next_sunday + datetime.timedelta(days=7)
# 计算下月最后一天的日期
next_month = current_date.replace(day=28) + datetime.timedelta(days=4)
next_month_last_day = next_month.replace(
day=1) - datetime.timedelta(days=1)
# 计算下两月最后一天的日期
next_two_months = next_month.replace(day=28) + datetime.timedelta(days=4)
next_two_months_last_day = next_two_months.replace(
day=1) - datetime.timedelta(days=1)
# 计算下三月最后一天的日期
next_three_months = next_two_months.replace(
day=28) + datetime.timedelta(days=4)
next_three_months_last_day = next_three_months.replace(
day=1) - datetime.timedelta(days=1)
# 计算下四月最后一天的日期
next_four_months = next_three_months.replace(
day=28) + datetime.timedelta(days=4)
next_four_months_last_day = next_four_months.replace(
day=1) - datetime.timedelta(days=1)
# 计算下五月最后一天的日期
next_five_months = next_four_months.replace(
day=28) + datetime.timedelta(days=4)
next_five_months_last_day = next_five_months.replace(
day=1) - datetime.timedelta(days=1)
return {
'tomorrow': tomorrow.strftime('%Y-%m-%d'),
'five_working_days_later': five_working_days_later.strftime('%Y-%m-%d'),
# 'next_sunday': next_sunday.strftime('%Y-%m-%d'),
'next_next_sunday': next_next_sunday.strftime('%Y-%m-%d'),
'next_next_next_sunday': next_next_next_sunday.strftime('%Y-%m-%d'),
# 'next_month_last_day': next_month_last_day.strftime('%Y-%m-%d'),
'next_two_months_last_day': next_two_months_last_day.strftime('%Y-%m-%d'),
'next_three_months_last_day': next_three_months_last_day.strftime('%Y-%m-%d'),
'next_four_months_last_day': next_four_months_last_day.strftime('%Y-%m-%d'),
'next_five_months_last_day': next_five_months_last_day.strftime('%Y-%m-%d'),
}
def get_bdwd_price(date, true_price, global_config):
'''
计算当前日期date对应的明天五班后下周日下下周日下月最后一天下两月最后一天下三月最后一天下四月最后一天的价格
'''
bdwd_price = {}
for wd in global_config['price_columns']:
if wd == 'day_price':
bdwd_price[wd] = true_price[true_price['ds'] == date][wd].values[0]
if wd == 'week_price':
bdwd_price[wd] = true_price[true_price['ds'] == date][wd].values[0]
true_price[wd] = pd.to_numeric(true_price[wd])
return
class DeepSeek():
def __init__(self):
pass
def summary(self, text):
prompt_template = '''请根据以下ARIMA预测结果分析未来的趋势
"{text}"
请用专业且结构清晰的中文撰写,重点数据用**加粗**显示
'''
chinese_prompt = PromptTemplate(
template=prompt_template, input_variables=['text'])
docs = [Document(page_content=text, metadata={
"source": "arima_forecast"})]
apikey = os.environ.get('OPENAI_API_KEY')
llm = ChatOpenAI(
model="deepseek-chat",
temperature=0,
base_url="https://api.deepseek.com/v1",
api_key=os.environ.get('OPENAI_API_KEY')
)
chain = load_summarize_chain(llm, prompt=chinese_prompt)
print('大语言模型分析预测结果')
summary = chain.invoke({"input_documents": docs})['output_text']
print('大语言模型分析结果:')
print(summary)
return summary
def get_model_id_name_dict(global_config):
'''
预测结果和模型表求子集得到模型名称
'''
tb = 'v_tbl_predict_models'
sql = f'select model_name,id from {tb} '
modelsname = global_config['db_mysql'].execute_query(sql)
model_id_name_dict = {row['id']: row['model_name'] for row in modelsname}
return model_id_name_dict
def get_modelsname(df, global_config):
'''
预测结果和模型表求子集得到模型名称
'''
columns = df.columns.tolist()
tb = 'v_tbl_predict_models'
sql = f'select model_name,id from {tb} '
modelsname = global_config['db_mysql'].execute_query(sql)
model_id_name_dict = get_model_id_name_dict(global_config=global_config)
model_name_list = [row['model_name'] for row in modelsname]
model_name_list = set(columns) & set(model_name_list)
model_name_list = list(model_name_list)
return model_name_list, model_id_name_dict
def convert_df_to_pydantic(df_predict, model_id_name_dict, global_config):
reverse_model_id_name_dict = {
value: key for key, value in model_id_name_dict.items()}
results = []
data = global_config['DEFAULT_CONFIG'].copy()
data['data_date'] = df_predict['created_dt'].values[0]
if isinstance(data['data_date'], np.datetime64):
data['data_date'] = pd.Timestamp(
data['data_date']).to_pydatetime()
for c in df_predict.columns:
if c not in ['ds', 'created_dt']:
data['model_id'] = reverse_model_id_name_dict[c]
data['predicted_price'] = Decimal(
round(df_predict[c].values[0], 2))
result = PredictionResult(**data)
results.append(result)
return results
def convert_df_to_pydantic_pp(df_predict, model_id_name_dict, global_config):
reverse_model_id_name_dict = {
value: key for key, value in model_id_name_dict.items()}
results = []
data = global_config['DEFAULT_CONFIG'].copy()
data['data_date'] = df_predict['created_dt'].values[0]
if isinstance(data['data_date'], np.datetime64):
data['data_date'] = pd.Timestamp(
data['data_date']).to_pydatetime()
for c in df_predict.columns:
if c not in ['ds', 'created_dt']:
data['model_id'] = reverse_model_id_name_dict[c]
data['predicted_price'] = Decimal(
round(df_predict[c].values[0], 2))
result = PpPredictionResult(**data)
results.append(result)
return results
def find_best_models(date='', global_config=None):
best_models = {}
model_id_name_dict = get_model_id_name_dict(global_config=global_config)
# 处理日期输入
if not date:
date = datetime.datetime.now().strftime('%Y-%m-%d')
else:
try:
date = datetime.datetime.strptime(
date, '%Y-%m-%d').strftime('%Y-%m-%d')
except ValueError:
global_config['logger'].error(
f"日期格式错误,期望格式为 '%Y-%m-%d',实际输入: {date}")
return best_models
current_date = datetime.datetime.strptime(date, '%Y-%m-%d')
# 计算date对应月的一日
first_day_of_month = current_date.replace(day=1)
# 计算date对应周的周一
date_monday = current_date - \
datetime.timedelta(days=current_date.weekday())
# 获取真实价格数据
try:
true_price = pd.read_csv(os.path.join(
global_config['dataset'], '指标数据.csv'))[['ds', 'y']]
except FileNotFoundError:
global_config['logger'].error(
f"未找到文件: {os.path.join(global_config['dataset'], '指标数据.csv')}")
return best_models
# 计算六月前的年月
year, month = map(int, date.split('-')[:2])
if month <= 6:
year -= 1
month = 12
else:
month -= 6
tb = 'v_tbl_predict_pp_prediction_results'
sql = f'select * from {tb} where data_date >= \'{year}-{month}-01\''
# 数据库查询对应日期的预测值
predictresult = global_config['db_mysql'].execute_query(sql)
if not predictresult:
global_config['logger'].info('没有预测结果')
return best_models
df = pd.DataFrame(predictresult)[
['data_date', 'model_id'] + global_config['price_columns']]
global_config['logger'].info(f'预测结果数量:{df.shape}')
global_config['logger'].info(
f'预测结果日期范围:{df["data_date"].min()}{df["data_date"].max()}')
def query_predict_result(date, model_id, global_config, wd):
tb = 'v_tbl_predict_pp_prediction_results'
sql = f'select {wd} from {tb} where data_date = \'{date}\' and model_id = {model_id}'
predictresult = global_config['db_mysql'].execute_query(sql)
if not predictresult:
global_config['logger'].info('没有预测结果')
return None
predictresult = float(predictresult[0][wd])
return predictresult
def calculate_best_model(price, trend, weektrueprice=None, monthtrueprice=None):
"""
计算最佳模型的辅助函数
:param price: 包含预测价格的数据框
:param trend: 价格趋势
:param weektrueprice: 周真实价格均值
:param monthtrueprice: 月真实价格均值
:return: 最佳模型的 ID 和名称
"""
price = price.copy() # Explicitly create a copy of the DataFrame
price[global_config['price_columns'][i]
] = price[global_config['price_columns'][i]].astype(float)
price = price.dropna(subset=[global_config['price_columns'][i]])
if weektrueprice is not None:
true_price_value = weektrueprice
elif monthtrueprice is not None:
true_price_value = monthtrueprice
else:
true_price_value = true_price[true_price['ds']
== date]['y'].values[0]
if not price.empty:
price.loc[:, 'trueprice'] = true_price_value
price.loc[:, 'trend'] = np.where(
price['trueprice'] - price[global_config['price_columns'][i]] > 0, 1, -1)
price.loc[:, 'abs'] = (price['trueprice'] -
price[global_config['price_columns'][i]]).abs()
if trend is not None:
price = price[price['trend'] == trend]
if not price.empty:
price = price[price['abs'] == price['abs'].min()]
best_model_id = price.iloc[0]['model_id']
best_model_name = model_id_name_dict[best_model_id]
return best_model_id, best_model_name
# Return None if the DataFrame is empty
return None, None
# 遍历全局配置中的价格列
for i, wd in enumerate(global_config['price_columns']):
global_config['logger'].info(
f'*********************************************************************************************************计算预测{date}{wd}最佳模型')
best_models[wd] = {}
if i == 0:
# 计算当前日期的前一工作日日期
ciridate = (pd.Timestamp(date) -
pd.tseries.offsets.BusinessDay(1)).strftime('%Y-%m-%d')
global_config['logger'].info(f'计算预测{date}的次日{ciridate}最佳模型')
global_config['logger'].info(
f'{date}真实价格:{true_price[true_price["ds"] == date]["y"].values[0]}')
price = df[['data_date', wd, 'model_id']]
price = price[(price['data_date'] == ciridate)
| (price['data_date'] == date)]
trend = 1 if true_price[true_price['ds'] == date]['y'].values[0] - \
true_price[true_price['ds'] == ciridate]['y'].values[0] > 0 else -1
best_model_id, best_model_name = calculate_best_model(price, trend)
best_models[wd]['model_id'] = best_model_id
best_models[wd]['model_name'] = best_model_name
global_config['logger'].info(f'{ciridate}预测最准确的模型:{best_model_id}')
global_config['logger'].info(
f'{ciridate}预测最准确的模型名称:{best_models[wd]}')
predictresult = query_predict_result(
date, best_model_id, global_config, wd)
if predictresult:
global_config['logger'].info(
f'最佳模型{best_models[wd]}{date}预测结果:{predictresult}')
best_models[wd]['predictresult'] = predictresult
# best_models 添加日期次日为date的下一个工作日
best_models[wd]['date'] = (pd.Timestamp(date) +
pd.tseries.offsets.BusinessDay(1)).strftime('%Y-%m-%d')
elif i == 1:
# 计算五个工作日之前的日期
benzhoudate = (pd.Timestamp(date) -
pd.Timedelta(days=7)).strftime('%Y-%m-%d')
global_config['logger'].info(f'计算预测{date}的五天前{benzhoudate}最佳模型')
global_config['logger'].info(
f'{date}真实价格:{true_price[true_price["ds"] == date]["y"].values[0]}')
price = df[['data_date', wd, 'model_id']]
price = price[(price['data_date'] == benzhoudate)
| (price['data_date'] == date)]
trend = 1 if true_price[true_price['ds'] == date]['y'].values[0] - \
true_price[true_price['ds'] == benzhoudate]['y'].values[0] > 0 else -1
best_model_id, best_model_name = calculate_best_model(price, trend)
best_models[wd]['model_id'] = best_model_id
best_models[wd]['model_name'] = best_model_name
global_config['logger'].info(
f'{benzhoudate}预测最准确的模型名称:{best_models[wd]}')
predictresult = query_predict_result(
date, best_model_id, global_config, wd)
if predictresult:
global_config['logger'].info(
f'最佳模型{best_models[wd]}{date}预测结果:{predictresult}')
best_models[wd]['predictresult'] = predictresult
else:
best_models[wd]['predictresult'] = None
best_models[wd]['date'] = (pd.Timestamp(date) +
pd.tseries.offsets.BusinessDay(5)).strftime('%Y-%m-%d')
elif i in [2, 3]:
weeks_ago = 1 if i == 2 else 2
ago_monday = current_date - \
datetime.timedelta(days=current_date.weekday() + 7 * weeks_ago)
ago_sunday = ago_monday + datetime.timedelta(days=6)
ago_date_str = f"{ago_monday.strftime('%Y-%m-%d')} - {ago_sunday.strftime('%Y-%m-%d')}"
global_config['logger'].info(
f'计算预测{date}的前{weeks_ago}{ago_date_str}最佳模型')
weektrueprice = true_price[(true_price['ds'] >= date_monday.strftime(
'%Y-%m-%d')) & (true_price['ds'] <= date)]['y'].mean()
global_config['logger'].info(
f'当周{date_monday.strftime("%Y-%m-%d")}---{date}真实价格的周均价:{weektrueprice}')
price = df[['data_date', wd, 'model_id']]
price = price[(price['data_date'] >= ago_monday) &
(price['data_date'] <= ago_sunday)]
price = price.groupby('model_id')[wd].mean().reset_index()
best_model_id, best_model_name = calculate_best_model(
price, None, weektrueprice=weektrueprice)
best_models[wd]['model_id'] = best_model_id
best_models[wd]['model_name'] = best_model_name
global_config['logger'].info(
f'{ago_date_str}预测最准确的模型名称:{best_models[wd]}')
predictresult = query_predict_result(
date, best_model_id, global_config, wd)
if predictresult:
global_config['logger'].info(
f'最佳模型{best_models[wd]}{date}预测结果:{predictresult}')
best_models[wd]['predictresult'] = predictresult
else:
best_models[wd]['predictresult'] = None
# best_models 添加日期,本周日下个周日
best_models[wd]['date'] = (pd.Timestamp(ago_sunday) +
pd.tseries.offsets.Week(weeks_ago*2)).strftime('%Y-%m-%d')
elif i in [4, 5, 6, 7]:
months_ago = i - 3
current_date_ts = pd.Timestamp(date)
last_month_first_day = (
current_date_ts - pd.offsets.MonthBegin(months_ago)).strftime('%Y-%m-%d')
last_month_last_day = (pd.Timestamp(
last_month_first_day) + pd.offsets.MonthEnd(0)).strftime('%Y-%m-%d')
global_config['logger'].info(
f'计算预测{date}{months_ago}月前{last_month_first_day}-{last_month_last_day}最佳模型')
monthtrueprice = true_price[(true_price['ds'] >= first_day_of_month.strftime(
'%Y-%m-%d')) & (true_price['ds'] <= date)]['y'].mean()
global_config['logger'].info(
f'当月{first_day_of_month.strftime("%Y-%m-%d")}-{date}真实价格的月均价:{monthtrueprice}')
price = df[['data_date', wd, 'model_id']]
price = price[(price['data_date'] >= last_month_first_day) & (
price['data_date'] <= last_month_last_day)]
price = price.groupby('model_id')[wd].mean().reset_index()
best_model_id, best_model_name = calculate_best_model(
price, None, monthtrueprice=monthtrueprice)
best_models[wd]['model_id'] = best_model_id
best_models[wd]['model_name'] = best_model_name
global_config['logger'].info(
f'{last_month_first_day}-{last_month_last_day}预测最准确的模型名称:{best_models[wd]}')
predictresult = query_predict_result(
date, best_model_id, global_config, wd)
if predictresult:
global_config['logger'].info(
f'最佳模型{best_models[wd]}{date}预测结果:{predictresult}')
best_models[wd]['predictresult'] = predictresult
else:
best_models[wd]['predictresult'] = None
best_models[wd]['date'] = (pd.Timestamp(date) +
pd.tseries.offsets.MonthEnd(months_ago+1)).strftime('%Y-%m-%d')
return best_models
def plot_pp_predict_result(y_hat, global_config):
"""
绘制PP期货预测结果的图表
"""
import matplotlib.pyplot as plt
import seaborn as sns
# 获取y的真实值
y = pd.read_csv(os.path.join(
global_config['dataset'], '指标数据.csv'))[['ds', 'y']]
y['ds'] = pd.to_datetime(y['ds'])
y = y[y['ds'] < y_hat['ds'].iloc[0]][-30:]
# 创建图表和子图布局,为表格预留空间
fig, ax = plt.subplots(figsize=(16, 9))
# 对日期列进行排序,确保日期大的在右边
y_hat = y_hat.sort_values(by='ds')
y = y.sort_values(by='ds')
# 绘制 y_hat 的折线图,颜色为橙色
sns.lineplot(x=y_hat['ds'], y=y_hat['predictresult'],
color='orange', label='y_hat', ax=ax, linestyle='--')
# 绘制 y 的折线图,颜色为蓝色
sns.lineplot(x=y['ds'], y=y['y'], color='blue', label='y', ax=ax)
# date_str = pd.Timestamp(y_hat["ds"].iloc[0]).strftime('%Y-%m-%d')
ax.set_title(f'{global_config["end_time"]} PP期货八大维度 预测价格走势')
ax.set_xlabel('日期')
ax.set_ylabel('预测结果')
ax.tick_params(axis='x', rotation=45)
# 准备表格数据
y_hat = y_hat[['predictresult']].T
print(y_hat)
y_hat.rename(columns={'day_price': '次日', 'week_price': '本周',
'second_week_price': '次周', 'next_week_price': '隔周',
'next_month_price': '次月', 'next_february_price': '次二月',
'next_march_price': '次三月', 'next_april_price': '次四月',
}, inplace=True)
columns = y_hat.columns.tolist()
data = y_hat.values.tolist()
# 将日期转换为字符串格式
for row in data:
if isinstance(row[0], pd.Timestamp):
row[0] = row[0].strftime('%Y-%m-%d')
# 在图表下方添加表格
table = ax.table(cellText=data, colLabels=columns,
loc='bottom', bbox=[0, -0.6, 1, 0.2])
table.auto_set_font_size(False)
table.set_fontsize(14)
plt.tight_layout(rect=[0, 0.1, 1, 1]) # 调整布局,为表格留出空间
plt.savefig('pp_predict_result.png')
if __name__ == '__main__':
print('This is a tool, not a script.')