import datetime from decimal import Decimal from langchain_core.documents import Document from langchain_openai import ChatOpenAI from langchain.chains.summarize import load_summarize_chain from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.document_loaders import UnstructuredURLLoader from langchain_core.prompts import PromptTemplate from tkinter import messagebox import tkinter as tk import pymysql import sqlite3 from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText import smtplib from reportlab.lib.units import cm # 单位:cm from reportlab.graphics.shapes import Drawing # 绘图工具 from reportlab.graphics.charts.legends import Legend # 图例类 from reportlab.graphics.charts.barcharts import VerticalBarChart # 图表类 from reportlab.lib import colors # 颜色模块 from reportlab.lib.styles import getSampleStyleSheet # 文本样式 from reportlab.lib.pagesizes import letter # 页面的标志尺寸(8.5*inch, 11*inch) from reportlab.platypus import Table, SimpleDocTemplate, Paragraph, Image # 报告内容相关类 from reportlab.pdfbase.ttfonts import TTFont # 字体类 from reportlab.pdfbase import pdfmetrics # 注册字体 import hashlib import hmac import base64 import string import random from sklearn import metrics import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import numpy as np import os import time import logging from dotenv import load_dotenv from lib.pydantic_models import PredictionResult, PpPredictionResult load_dotenv() global logger def timeit(func): '''计时装饰器''' def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() execution_time = end_time - start_time logger.info(f"{func.__name__} 函数的执行时间为: {execution_time} 秒") return result return wrapper class BinanceAPI: ''' 获取 Binance API 请求头签名 ''' def __init__(self, APPID, SECRET): self.APPID = APPID self.SECRET = SECRET self.get_signature() # 生成随机字符串作为 nonce def generate_nonce(self, length=32): self.nonce = ''.join(random.choices( string.ascii_letters + string.digits, k=length)) return self.nonce # 获取当前时间戳(秒) def get_timestamp(self): return int(time.time()) # 构建待签名字符串 def build_sign_str(self): return f'appid={self.APPID}&nonce={self.nonce}×tamp={self.timestamp}' # 使用 HMAC SHA-256 计算签名 def calculate_signature(self, secret, message): return base64.urlsafe_b64encode(hmac.new(secret.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()).decode('utf-8') def get_signature(self): # 调用上述方法生成签名 self.nonce = self.generate_nonce() self.timestamp = self.get_timestamp() self.sign_str = self.build_sign_str() self.signature = self.calculate_signature(self.SECRET, self.sign_str) # return self.signature class Graphs: ''' pdf生成类 ''' # 绘制标题 @staticmethod def draw_title(title: str): # 获取所有样式表 style = getSampleStyleSheet() # 拿到标题样式 ct = style['Heading1'] # 单独设置样式相关属性 ct.fontName = 'SimSun' # 字体名 ct.fontSize = 18 # 字体大小 ct.leading = 50 # 行间距 ct.textColor = colors.green # 字体颜色 ct.alignment = 1 # 居中 ct.bold = True # 创建标题对应的段落,并且返回 return Paragraph(title, ct) # 绘制小标题 @staticmethod def draw_little_title(title: str): # 获取所有样式表 style = getSampleStyleSheet() # 拿到标题样式 ct = style['Normal'] # 单独设置样式相关属性 ct.fontName = 'SimSun' # 字体名 ct.fontSize = 15 # 字体大小 ct.leading = 30 # 行间距 ct.textColor = colors.red # 字体颜色 # 创建标题对应的段落,并且返回 return Paragraph(title, ct) # 绘制普通段落内容 @staticmethod def draw_text(text: str): # 获取所有样式表 style = getSampleStyleSheet() # 获取普通样式 ct = style['Normal'] ct.fontName = 'SimSun' ct.fontSize = 12 ct.wordWrap = 'CJK' # 设置自动换行 ct.alignment = 0 # 左对齐 ct.firstLineIndent = 32 # 第一行开头空格 ct.leading = 25 return Paragraph(text, ct) # 绘制表格 @staticmethod def draw_table(col_width, *args): # 列宽度 col_width = col_width style = [ ('FONTNAME', (0, 0), (-1, -1), 'SimSun'), # 字体 ('FONTSIZE', (0, 0), (-1, 0), 10), # 第一行的字体大小 ('FONTSIZE', (0, 1), (-1, -1), 8), # 第二行到最后一行的字体大小 ('BACKGROUND', (0, 0), (-1, 0), '#d5dae6'), # 设置第一行背景颜色 ('ALIGN', (0, 0), (-1, -1), 'CENTER'), # 第一行水平居中 ('ALIGN', (0, 1), (-1, -1), 'LEFT'), # 第二行到最后一行左右左对齐 ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'), # 所有表格上下居中对齐 ('TEXTCOLOR', (0, 0), (-1, -1), colors.darkslategray), # 设置表格内文字颜色 ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), # 设置表格框线为grey色,线宽为0.5 # ('SPAN', (0, 1), (0, 2)), # 合并第一列二三行 # ('SPAN', (0, 3), (0, 4)), # 合并第一列三四行 # ('SPAN', (0, 5), (0, 6)), # 合并第一列五六行 # ('SPAN', (0, 7), (0, 8)), # 合并第一列五六行 ] table = Table(args, colWidths=col_width, style=style) return table # 创建图表 @staticmethod def draw_bar(bar_data: list, ax: list, items: list): drawing = Drawing(500, 250) bc = VerticalBarChart() bc.x = 45 # 整个图表的x坐标 bc.y = 45 # 整个图表的y坐标 bc.height = 200 # 图表的高度 bc.width = 350 # 图表的宽度 bc.data = bar_data bc.strokeColor = colors.black # 顶部和右边轴线的颜色 bc.valueAxis.valueMin = 5000 # 设置y坐标的最小值 bc.valueAxis.valueMax = 26000 # 设置y坐标的最大值 bc.valueAxis.valueStep = 2000 # 设置y坐标的步长 bc.categoryAxis.labels.dx = 2 bc.categoryAxis.labels.dy = -8 bc.categoryAxis.labels.angle = 20 bc.categoryAxis.categoryNames = ax # 图示 leg = Legend() leg.fontName = 'SimSun' leg.alignment = 'right' leg.boxAnchor = 'ne' leg.x = 475 # 图例的x坐标 leg.y = 240 leg.dxTextSpace = 10 leg.columnMaximum = 3 leg.colorNamePairs = items drawing.add(leg) drawing.add(bc) return drawing # 绘制图片 @staticmethod def draw_img(path): img = Image(path) # 读取指定路径下的图片 img.drawWidth = 20*cm # 设置图片的宽度 img.drawHeight = 10*cm # 设置图片的高度 return img # 评估指标不在一个库,这里列出所有用到的指标的公式 # MSE def mse(y_true, y_pred): res_mse = metrics.mean_squared_error(y_true, y_pred) return res_mse # RMSE def rmse(y_true, y_pred): res_rmse = np.sqrt(metrics.mean_squared_error(y_true, y_pred)) return res_rmse # MAE def mae(y_true, y_pred): res_mae = metrics.mean_absolute_error(y_true, y_pred) return res_mae # sklearn的库中没有MAPE和SMAPE,下面根据公式给出算法实现 # MAPE def mape(y_true, y_pred): res_mape = np.mean(np.abs((y_pred - y_true) / y_true)) * 100 return res_mape # SMAPE def smape(y_true, y_pred): res_smape = 2.0 * np.mean(np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true))) * 100 return res_smape # 相关系数绘制 def plot_corr(data, size=11): # 去掉ds列 data.drop(columns=['ds'], inplace=True) # 创建一个空的 DataFrame 来保存相关系数 correlation_df = pd.DataFrame(columns=['Feature', 'Correlation']) # 计算各特征与目标列的皮尔逊相关系数,并保存到新的 DataFrame 中 for col in data.columns: if col != 'y': pearson_correlation = np.corrcoef(data[col], data['y'])[0, 1] spearman_correlation, _ = spearmanr(data[col], data['y']) new_row = {'Feature': col, 'Pearson_Correlation': round( pearson_correlation, 3), 'Spearman_Correlation': round(spearman_correlation, 2)} correlation_df = correlation_df._append(new_row, ignore_index=True) # 删除空列 correlation_df.drop('Correlation', axis=1, inplace=True) correlation_df.dropna(inplace=True) correlation_df.to_csv('指标相关性分析.csv', index=False) data = correlation_df['Pearson_Correlation'].values.tolist() # 生成 -1 到 1 的 20 个区间 bins = np.linspace(-1, 1, 21) # 计算每个区间的统计数(这里是区间内数据的数量) hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)] # 设置画布大小 plt.figure(figsize=(10, 6)) # 绘制直方图 plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0])) # 添加标题和坐标轴标签 plt.title('皮尔逊相关系数分布图') plt.xlabel('区间') plt.ylabel('统计数') plt.savefig('皮尔逊相关性系数.png') plt.close() # 设置画布大小 plt.figure(figsize=(10, 6)) data = correlation_df['Spearman_Correlation'].values.tolist() # 计算每个区间的统计数(这里是区间内数据的数量) hist_values = [np.sum((data >= bins[i]) & (data < bins[i + 1])) for i in range(len(bins) - 1)] # 绘制直方图 plt.bar(bins[:-1], hist_values, width=(bins[1] - bins[0])) # 添加标题和坐标轴标签 plt.title('斯皮尔曼相关系数分布图') plt.xlabel('区间') plt.ylabel('统计数') plt.savefig('斯皮尔曼相关性系数.png') plt.close() # 邮件封装 class SendMail(object): def __init__(self, username, passwd, recv, title, content, file=None, ssl=False, email_host='smtp.qq.com', port=25, ssl_port=465): ''' :param username: 用户名 :param passwd: 密码 :param recv: 收件人,多个要传list ['a@qq.com','b@qq.com] :param title: 邮件标题 :param content: 邮件正文 :param file: 附件路径,如果不在当前目录下,要写绝对路径,默认没有附件 :param ssl: 是否安全链接,默认为普通 :param email_host: smtp服务器地址,默认为163服务器 :param port: 非安全链接端口,默认为25 :param ssl_port: 安全链接端口,默认为465 ''' self.username = username # 用户名 self.passwd = passwd # 密码 self.recv = recv # 收件人,多个要传list ['a@qq.com','b@qq.com] self.title = title # 邮件标题 self.content = content # 邮件正文 self.file = file # 附件路径,如果不在当前目录下,要写绝对路径 self.email_host = email_host # smtp服务器地址 self.port = port # 普通端口 self.ssl = ssl # 是否安全链接 self.ssl_port = ssl_port # 安全链接端口 def send_mail(self): msg = MIMEMultipart() # 发送内容的对象 if self.file: # 处理附件的 file_name = os.path.split(self.file)[-1] # 只取文件名,不取路径 try: f = open(self.file, 'rb').read() except Exception as e: raise Exception('附件打不开!!!!') else: att = MIMEText(f, "base64", "utf-8") att["Content-Type"] = 'application/octet-stream' # base64.b64encode(file_name.encode()).decode() new_file_name = '=?utf-8?b?' + \ base64.b64encode(file_name.encode()).decode() + '?=' # 这里是处理文件名为中文名的,必须这么写 att["Content-Disposition"] = 'attachment; filename="%s"' % ( new_file_name) msg.attach(att) msg.attach(MIMEText(self.content)) # 邮件正文的内容 msg['Subject'] = self.title # 邮件主题 msg['From'] = self.username # 发送者账号 msg['To'] = ','.join(self.recv) # 接收者账号列表 if self.ssl: self.smtp = smtplib.SMTP_SSL(self.email_host, port=self.ssl_port) else: self.smtp = smtplib.SMTP(self.email_host, port=self.port) # 发送邮件服务器的对象 self.smtp.login(self.username, self.passwd) try: self.smtp.sendmail(self.username, self.recv, msg.as_string()) pass except Exception as e: print('出错了。。', e) logger.info('邮件服务出错了。。', e) else: print('发送成功!') self.smtp.quit() def dateConvert(df, datecol='ds'): # 将date列转换为datetime类型 try: df[datecol] = pd.to_datetime(df[datecol], format=r'%Y-%m-%d') except: df[datecol] = pd.to_datetime(df[datecol], format=r'%Y/%m/%d') return df def save_to_database(sqlitedb, df, dbname, end_time): ''' create_dt ,ds 判断数据是否存在,不存在则插入,存在则更新 ''' # 判断格式是否为日期时间类型 if pd.api.types.is_datetime64_any_dtype(df['ds']): df['ds'] = df['ds'].dt.strftime('%Y-%m-%d') if not sqlitedb.check_table_exists(dbname): df.to_sql(dbname, sqlitedb.connection, index=False) else: for col in df.columns: sqlitedb.add_column_if_not_exists(dbname, col, 'TEXT') for row in df.itertuples(index=False): row_dict = row._asdict() columns = row_dict.keys() check_query = sqlitedb.select_data( dbname, where_condition=f"ds = '{row.ds}' and created_dt = '{end_time}'") if len(check_query) > 0: set_clause = ", ".join( [f"{key} = '{value}'" for key, value in row_dict.items()]) sqlitedb.update_data( dbname, set_clause, where_condition=f"ds = '{row.ds} and created_dt = {end_time}'") continue else: sqlitedb.insert_data(dbname, tuple( row_dict.values()), columns=columns) class SQLiteHandler: def __init__(self, db_name): self.db_name = db_name self.connection = None self.cursor = None def connect(self): self.connection = sqlite3.connect(self.db_name) self.cursor = self.connection.cursor() def close(self): if self.connection: self.connection.close() self.connection = None self.cursor = None def execute_query(self, query, params=None): if params: return self.cursor.execute(query, params) else: return self.cursor.execute(query) def commit(self): self.connection.commit() def create_table(self, table_name, columns): query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})" self.execute_query(query) self.commit() def insert_data(self, table_name, values, columns=None): if columns: placeholders = ', '.join(['?'] * len(values)) query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})" else: placeholders = ', '.join(['?'] * len(values)) query = f"INSERT INTO {table_name} VALUES ({placeholders})" self.execute_query(query, values) self.commit() def select_data(self, table_name, columns=None, where_condition=None, order_by=None, limit=None): query = f"SELECT {', '.join(columns) if columns else '*'} FROM {table_name}" if where_condition: query += f" WHERE {where_condition}" if order_by: query += f" ORDER BY {order_by}" if limit: query += f" LIMIT {limit}" results = self.execute_query(query).fetchall() if results: headers = [description[0] for description in self.execute_query(query).description] return pd.DataFrame(results, columns=headers) else: return pd.DataFrame() def update_data(self, table_name, set_values, where_condition): query = f"UPDATE {table_name} SET {set_values} WHERE {where_condition}" self.execute_query(query) self.commit() def delete_data(self, table_name, where_condition): query = f"DELETE FROM {table_name} WHERE {where_condition}" self.execute_query(query) self.commit() def check_table_exists(self, table_name): query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'" result = self.execute_query(query).fetchone() return result is not None def drop_table(self, table_name): query = f"DROP TABLE IF EXISTS {table_name}" self.execute_query(query) self.commit() def add_column_if_not_exists(self, table_name, column_name, column_type): # 查询表结构 query = f"PRAGMA table_info({table_name})" self.execute_query(query) columns = [column[1] for column in self.cursor.fetchall()] # 判断列是否存在 if column_name not in columns: # 如果列不存在,则添加列 query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}" self.execute_query(query) self.commit() print( f"Column '{column_name}' added to table '{table_name}' successfully.") else: print( f"Column '{column_name}' already exists in table '{table_name}'.") class MySQLDB: def __init__(self, host, user, password, database): self.host = host self.user = user self.password = password self.database = database self.connection = None self.cursor = None def connect(self): try: self.connection = pymysql.connect( host=self.host, user=self.user, password=self.password, database=self.database, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) self.cursor = self.connection.cursor() logging.info("Connected to the database successfully.") except pymysql.Error as e: logging.error(f"Error connecting to the database: {e}") def is_connected(self): try: self.connection.ping(reconnect=True) return True except pymysql.Error: return False def execute_query(self, query): try: self.cursor.execute(query) result = self.cursor.fetchall() return result except pymysql.Error as e: logging.error(f"Error executing query: {e}") return None def execute_insert(self, query, values): try: self.cursor.execute(query, values) self.connection.commit() logging.info("Insert operation successful.") except pymysql.Error as e: logging.error(f"Error executing insert: {e}") self.connection.rollback() def execute_batch_insert(self, query, params_list): """ Batch insert data :param query: SQL insert statement :param params_list: Parameter list, each element is a parameter for a record :return: Number of affected rows """ if not self.is_connected(): print("Database is not connected, please call the connect method first") return 0 try: cursor = self.connection.cursor() logging.info(f"Executing batch insert SQL: {query}") logging.info(f"Batch insert parameters: {params_list}") cursor.executemany(query, params_list) self.connection.commit() affected_rows = cursor.rowcount cursor.close() return affected_rows except pymysql.Error as err: print(f"Batch insert failed: {err}") if self.is_connected(): self.connection.rollback() return 0 def execute_update(self, query, values): try: self.cursor.execute(query, values) self.connection.commit() logging.info("Update operation successful.") except pymysql.Error as e: logging.error(f"Error executing update: {e}") self.connection.rollback() def execute_delete(self, query, values): try: self.cursor.execute(query, values) self.connection.commit() logging.info("Delete operation successful.") except pymysql.Error as e: logging.error(f"Error executing delete: {e}") self.connection.rollback() def close(self): if self.cursor: self.cursor.close() if self.connection: self.connection.close() logging.info("Database connection closed.") def exception_logger(func): def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: # 记录异常日志 logging.error( f"An error occurred in function {func.__name__}: {str(e)}") # 可以选择重新抛出异常,或者在这里处理异常 raise e # 重新抛出异常 return wrapper def get_week_date(end_time): ''' 获取上上周五,上周周一周二周三周四周五的日期 ''' import datetime endtime = end_time endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d') up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14) up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)] create_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[4:-3]] ds_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates[-7:-2]] return create_dates, ds_dates def get_bdwd_date(date=''): ''' 计算当前日期date对应的明天,五班后,下周日,下下周日,下月最后一天,下两月最后一天,下三月最后一天,下四月最后一天的日期 参数: date (str): 输入的日期,格式为 '%Y-%m-%d',默认为空字符串,表示当前日期 返回: dict: 包含所需日期的字典,键分别为 'tomorrow', 'five_working_days_later', 'next_sunday', 'next_next_sunday', 'next_month_last_day', 'next_two_months_last_day', 'next_three_months_last_day', 'next_four_months_last_day' ''' import datetime if not date: current_date = datetime.date.today() else: current_date = datetime.datetime.strptime(date, '%Y-%m-%d').date() # 计算明天的日期 tomorrow = current_date + datetime.timedelta(days=1) # 计算五班后的日期 five_working_days_later = current_date working_days_count = 0 while working_days_count < 5: five_working_days_later += datetime.timedelta(days=1) if five_working_days_later.weekday() < 5: # 周一到周五是工作日 working_days_count += 1 # 计算下周日的日期 days_to_next_sunday = (6 - current_date.weekday()) % 7 if days_to_next_sunday == 0: days_to_next_sunday = 7 next_sunday = current_date + datetime.timedelta(days=days_to_next_sunday) # 计算下下周周日的日期 next_next_sunday = next_sunday + datetime.timedelta(days=7) # 计算下下下周周日的日期 next_next_next_sunday = next_next_sunday + datetime.timedelta(days=7) # 计算下月最后一天的日期 next_month = current_date.replace(day=28) + datetime.timedelta(days=4) next_month_last_day = next_month.replace( day=1) - datetime.timedelta(days=1) # 计算下两月最后一天的日期 next_two_months = next_month.replace(day=28) + datetime.timedelta(days=4) next_two_months_last_day = next_two_months.replace( day=1) - datetime.timedelta(days=1) # 计算下三月最后一天的日期 next_three_months = next_two_months.replace( day=28) + datetime.timedelta(days=4) next_three_months_last_day = next_three_months.replace( day=1) - datetime.timedelta(days=1) # 计算下四月最后一天的日期 next_four_months = next_three_months.replace( day=28) + datetime.timedelta(days=4) next_four_months_last_day = next_four_months.replace( day=1) - datetime.timedelta(days=1) # 计算下五月最后一天的日期 next_five_months = next_four_months.replace( day=28) + datetime.timedelta(days=4) next_five_months_last_day = next_five_months.replace( day=1) - datetime.timedelta(days=1) return { 'tomorrow': tomorrow.strftime('%Y-%m-%d'), 'five_working_days_later': five_working_days_later.strftime('%Y-%m-%d'), # 'next_sunday': next_sunday.strftime('%Y-%m-%d'), 'next_next_sunday': next_next_sunday.strftime('%Y-%m-%d'), 'next_next_next_sunday': next_next_next_sunday.strftime('%Y-%m-%d'), # 'next_month_last_day': next_month_last_day.strftime('%Y-%m-%d'), 'next_two_months_last_day': next_two_months_last_day.strftime('%Y-%m-%d'), 'next_three_months_last_day': next_three_months_last_day.strftime('%Y-%m-%d'), 'next_four_months_last_day': next_four_months_last_day.strftime('%Y-%m-%d'), 'next_five_months_last_day': next_five_months_last_day.strftime('%Y-%m-%d'), } def get_bdwd_price(date, true_price, global_config): ''' 计算当前日期date对应的明天,五班后,下周日,下下周日,下月最后一天,下两月最后一天,下三月最后一天,下四月最后一天的价格 ''' bdwd_price = {} for wd in global_config['price_columns']: if wd == 'day_price': bdwd_price[wd] = true_price[true_price['ds'] == date][wd].values[0] if wd == 'week_price': bdwd_price[wd] = true_price[true_price['ds'] == date][wd].values[0] true_price[wd] = pd.to_numeric(true_price[wd]) return class DeepSeek(): def __init__(self): pass def summary(self, text): prompt_template = '''请根据以下ARIMA预测结果分析未来的趋势: "{text}" 请用专业且结构清晰的中文撰写,重点数据用**加粗**显示 ''' chinese_prompt = PromptTemplate( template=prompt_template, input_variables=['text']) docs = [Document(page_content=text, metadata={ "source": "arima_forecast"})] apikey = os.environ.get('OPENAI_API_KEY') llm = ChatOpenAI( model="deepseek-chat", temperature=0, base_url="https://api.deepseek.com/v1", api_key=os.environ.get('OPENAI_API_KEY') ) chain = load_summarize_chain(llm, prompt=chinese_prompt) print('大语言模型分析预测结果') summary = chain.invoke({"input_documents": docs})['output_text'] print('大语言模型分析结果:') print(summary) return summary def get_model_id_name_dict(global_config): ''' 预测结果和模型表求子集得到模型名称 ''' tb = 'v_tbl_predict_models' sql = f'select model_name,id from {tb} ' modelsname = global_config['db_mysql'].execute_query(sql) model_id_name_dict = {row['id']: row['model_name'] for row in modelsname} global_config['logger'].info(f'模型id-name: {model_id_name_dict}') return model_id_name_dict def get_modelsname(df, global_config): ''' 预测结果和模型表求子集得到模型名称 ''' columns = df.columns.tolist() tb = 'v_tbl_predict_models' sql = f'select model_name,id from {tb} ' modelsname = global_config['db_mysql'].execute_query(sql) model_id_name_dict = get_model_id_name_dict(global_config=global_config) model_name_list = [row['model_name'] for row in modelsname] model_name_list = set(columns) & set(model_name_list) model_name_list = list(model_name_list) return model_name_list, model_id_name_dict def convert_df_to_pydantic(df_predict, model_id_name_dict, global_config): reverse_model_id_name_dict = { value: key for key, value in model_id_name_dict.items()} results = [] data = global_config['DEFAULT_CONFIG'].copy() data['data_date'] = df_predict['created_dt'].values[0] if isinstance(data['data_date'], np.datetime64): data['data_date'] = pd.Timestamp( data['data_date']).to_pydatetime() for c in df_predict.columns: if c not in ['ds', 'created_dt']: data['model_id'] = reverse_model_id_name_dict[c] data['predicted_price'] = Decimal( round(df_predict[c].values[0], 2)) result = PredictionResult(**data) results.append(result) return results def convert_df_to_pydantic_pp(df_predict, model_id_name_dict, global_config): reverse_model_id_name_dict = { value: key for key, value in model_id_name_dict.items()} results = [] data = global_config['DEFAULT_CONFIG'].copy() data['data_date'] = df_predict['created_dt'].values[0] if isinstance(data['data_date'], np.datetime64): data['data_date'] = pd.Timestamp( data['data_date']).to_pydatetime() for c in df_predict.columns: if c not in ['ds', 'created_dt']: data['model_id'] = reverse_model_id_name_dict[c] data['predicted_price'] = Decimal( round(df_predict[c].values[0], 2)) result = PpPredictionResult(**data) results.append(result) return results def find_best_models(date='', global_config=None): best_models = {} model_id_name_dict = get_model_id_name_dict(global_config=global_config) # 处理日期输入 if not date: date = datetime.datetime.now().strftime('%Y-%m-%d') else: try: date = datetime.datetime.strptime( date, '%Y-%m-%d').strftime('%Y-%m-%d') except ValueError: global_config['logger'].error( f"日期格式错误,期望格式为 '%Y-%m-%d',实际输入: {date}") return best_models current_date = datetime.datetime.strptime(date, '%Y-%m-%d') # 计算date对应月的一日 first_day_of_month = current_date.replace(day=1) # 计算date对应周的周一 date_monday = current_date - \ datetime.timedelta(days=current_date.weekday()) # 获取真实价格数据 try: true_price = pd.read_csv('juxitingdataset/指标数据.csv')[['ds', 'y']] except FileNotFoundError: global_config['logger'].error( f"未找到文件: {os.path.join(global_config['dataset'], '指标数据.csv')}") return best_models # 计算六月前的年月 year, month = map(int, date.split('-')[:2]) if month <= 6: year -= 1 month = 12 else: month -= 6 tb = 'v_tbl_predict_pp_prediction_results' sql = f'select * from {tb} where data_date >= \'{year}-{month}-01\'' # 数据库查询对应日期的预测值 predictresult = global_config['db_mysql'].execute_query(sql) if not predictresult: global_config['logger'].info('没有预测结果') return best_models df = pd.DataFrame(predictresult)[ ['data_date', 'model_id'] + global_config['price_columns']] global_config['logger'].info(f'预测结果数量:{df.shape}') global_config['logger'].info( f'预测结果日期范围:{df["data_date"].min()} 到 {df["data_date"].max()}') def query_predict_result(date, model_id, global_config, wd): tb = 'v_tbl_predict_pp_prediction_results' sql = f'select {wd} from {tb} where data_date = \'{date}\' and model_id = {model_id}' predictresult = global_config['db_mysql'].execute_query(sql) if not predictresult: global_config['logger'].info('没有预测结果') return None predictresult = float(predictresult[0][wd]) return predictresult def calculate_best_model(price, trend, weektrueprice=None, monthtrueprice=None): """ 计算最佳模型的辅助函数 :param price: 包含预测价格的数据框 :param trend: 价格趋势 :param weektrueprice: 周真实价格均值 :param monthtrueprice: 月真实价格均值 :return: 最佳模型的 ID 和名称 """ price = price.copy() # Explicitly create a copy of the DataFrame price[global_config['price_columns'][i] ] = price[global_config['price_columns'][i]].astype(float) price = price.dropna(subset=[global_config['price_columns'][i]]) if weektrueprice is not None: true_price_value = weektrueprice elif monthtrueprice is not None: true_price_value = monthtrueprice else: true_price_value = true_price[true_price['ds'] == date]['y'].values[0] if not price.empty: price.loc[:, 'trueprice'] = true_price_value price.loc[:, 'trend'] = np.where( price['trueprice'] - price[global_config['price_columns'][i]] > 0, 1, -1) price.loc[:, 'abs'] = (price['trueprice'] - price[global_config['price_columns'][i]]).abs() if trend is not None: price = price[price['trend'] == trend] if not price.empty: price = price[price['abs'] == price['abs'].min()] best_model_id = price.iloc[0]['model_id'] best_model_name = model_id_name_dict[best_model_id] return best_model_id, best_model_name # Return None if the DataFrame is empty return None, None # 遍历全局配置中的价格列 for i, wd in enumerate(global_config['price_columns']): global_config['logger'].info( f'*********************************************************************************************************计算预测{date}的{wd}最佳模型') best_models[wd] = {} if i == 0: # 计算当前日期的前一工作日日期 ciridate = (pd.Timestamp(date) - pd.tseries.offsets.BusinessDay(1)).strftime('%Y-%m-%d') global_config['logger'].info(f'计算预测{date}的次日{ciridate}最佳模型') global_config['logger'].info( f'{date}真实价格:{true_price[true_price["ds"] == date]["y"].values[0]}') price = df[['data_date', wd, 'model_id']] price = price[(price['data_date'] == ciridate) | (price['data_date'] == date)] trend = 1 if true_price[true_price['ds'] == date]['y'].values[0] - \ true_price[true_price['ds'] == ciridate]['y'].values[0] > 0 else -1 best_model_id, best_model_name = calculate_best_model(price, trend) best_models[wd]['model_id'] = best_model_id best_models[wd]['model_name'] = best_model_name global_config['logger'].info(f'{ciridate}预测最准确的模型:{best_model_id}') global_config['logger'].info( f'{ciridate}预测最准确的模型名称:{best_models[wd]}') predictresult = query_predict_result( date, best_model_id, global_config, wd) if predictresult: global_config['logger'].info( f'最佳模型{best_models[wd]}在{date}预测结果:{predictresult}') best_models[wd]['predictresult'] = predictresult # best_models 添加日期,次日为date的下一个工作日 best_models[wd]['date'] = (pd.Timestamp(date) + pd.tseries.offsets.BusinessDay(1)).strftime('%Y-%m-%d') elif i == 1: # 计算五个工作日之前的日期 benzhoudate = (pd.Timestamp(date) - pd.Timedelta(days=7)).strftime('%Y-%m-%d') global_config['logger'].info(f'计算预测{date}的五天前{benzhoudate}最佳模型') global_config['logger'].info( f'{date}真实价格:{true_price[true_price["ds"] == date]["y"].values[0]}') price = df[['data_date', wd, 'model_id']] price = price[(price['data_date'] == benzhoudate) | (price['data_date'] == date)] trend = 1 if true_price[true_price['ds'] == date]['y'].values[0] - \ true_price[true_price['ds'] == benzhoudate]['y'].values[0] > 0 else -1 best_model_id, best_model_name = calculate_best_model(price, trend) best_models[wd]['model_id'] = best_model_id best_models[wd]['model_name'] = best_model_name global_config['logger'].info( f'{benzhoudate}预测最准确的模型名称:{best_models[wd]}') predictresult = query_predict_result( date, best_model_id, global_config, wd) if predictresult: global_config['logger'].info( f'最佳模型{best_models[wd]}在{date}预测结果:{predictresult}') best_models[wd]['predictresult'] = predictresult else: best_models[wd]['predictresult'] = None best_models[wd]['date'] = (pd.Timestamp(date) + pd.tseries.offsets.BusinessDay(5)).strftime('%Y-%m-%d') elif i in [2, 3]: weeks_ago = 1 if i == 2 else 2 ago_monday = current_date - \ datetime.timedelta(days=current_date.weekday() + 7 * weeks_ago) ago_sunday = ago_monday + datetime.timedelta(days=6) ago_date_str = f"{ago_monday.strftime('%Y-%m-%d')} - {ago_sunday.strftime('%Y-%m-%d')}" global_config['logger'].info( f'计算预测{date}的前{weeks_ago}周{ago_date_str}最佳模型') weektrueprice = true_price[(true_price['ds'] >= date_monday.strftime( '%Y-%m-%d')) & (true_price['ds'] <= date)]['y'].mean() global_config['logger'].info( f'当周{date_monday.strftime("%Y-%m-%d")}---{date}真实价格的周均价:{weektrueprice}') price = df[['data_date', wd, 'model_id']] price = price[(price['data_date'] >= ago_monday) & (price['data_date'] <= ago_sunday)] price = price.groupby('model_id')[wd].mean().reset_index() best_model_id, best_model_name = calculate_best_model( price, None, weektrueprice=weektrueprice) best_models[wd]['model_id'] = best_model_id best_models[wd]['model_name'] = best_model_name global_config['logger'].info( f'{ago_date_str}预测最准确的模型名称:{best_models[wd]}') predictresult = query_predict_result( date, best_model_id, global_config, wd) if predictresult: global_config['logger'].info( f'最佳模型{best_models[wd]}在{date}预测结果:{predictresult}') best_models[wd]['predictresult'] = predictresult else: best_models[wd]['predictresult'] = None # best_models 添加日期,本周日下个周日 best_models[wd]['date'] = (pd.Timestamp(ago_sunday) + pd.tseries.offsets.Week(weeks_ago*2)).strftime('%Y-%m-%d') elif i in [4, 5, 6, 7]: months_ago = i - 3 current_date_ts = pd.Timestamp(date) last_month_first_day = ( current_date_ts - pd.offsets.MonthBegin(months_ago)).strftime('%Y-%m-%d') last_month_last_day = (pd.Timestamp( last_month_first_day) + pd.offsets.MonthEnd(0)).strftime('%Y-%m-%d') global_config['logger'].info( f'计算预测{date}的{months_ago}月前{last_month_first_day}-{last_month_last_day}最佳模型') monthtrueprice = true_price[(true_price['ds'] >= first_day_of_month.strftime( '%Y-%m-%d')) & (true_price['ds'] <= date)]['y'].mean() global_config['logger'].info( f'当月{first_day_of_month.strftime("%Y-%m-%d")}-{date}真实价格的月均价:{monthtrueprice}') price = df[['data_date', wd, 'model_id']] price = price[(price['data_date'] >= last_month_first_day) & ( price['data_date'] <= last_month_last_day)] price = price.groupby('model_id')[wd].mean().reset_index() best_model_id, best_model_name = calculate_best_model( price, None, monthtrueprice=monthtrueprice) best_models[wd]['model_id'] = best_model_id best_models[wd]['model_name'] = best_model_name global_config['logger'].info( f'{last_month_first_day}-{last_month_last_day}预测最准确的模型名称:{best_models[wd]}') predictresult = query_predict_result( date, best_model_id, global_config, wd) if predictresult: global_config['logger'].info( f'最佳模型{best_models[wd]}在{date}预测结果:{predictresult}') best_models[wd]['predictresult'] = predictresult else: best_models[wd]['predictresult'] = None best_models[wd]['date'] = (pd.Timestamp(date) + pd.tseries.offsets.MonthEnd(months_ago+1)).strftime('%Y-%m-%d') return best_models def plot_pp_predict_result(y_hat, global_config): """ 绘制PP期货预测结果的图表 """ import matplotlib.pyplot as plt import seaborn as sns # 获取y的真实值 # y = pd.read_csv(os.path.join( # global_config['dataset'], '指标数据.csv'))[['ds', 'y']] y = pd.read_csv('juxitingdataset/指标数据.csv')[['ds', 'y']] y['ds'] = pd.to_datetime(y['ds']) y = y[y['ds'] < y_hat['ds'].iloc[0]][-30:] # 取y的最后一行数据追加到y_hat(将真实值最后一行作为预测值起点) if not y.empty: # 获取y的最后一行并将'y'列重命名为'predictresult'以匹配y_hat结构 y_last_row = y.tail(1).rename(columns={'y': 'predictresult'}) # 追加到y_hat y_y_hat = pd.concat([y_last_row, y_hat], ignore_index=True) # 创建图表和子图布局,为表格预留空间 fig, ax = plt.subplots(figsize=(16, 9)) # 对日期列进行排序,确保日期大的在右边 y_y_hat = y_y_hat.sort_values(by='ds') y = y.sort_values(by='ds') # 绘制 y_hat 的折线图,颜色为橙色 sns.lineplot(x=y_y_hat['ds'], y=y_y_hat['predictresult'], color='orange', label='预测值', ax=ax, linestyle='--') # 绘制 y 的折线图,颜色为蓝色 sns.lineplot(x=y['ds'], y=y['y'], color='blue', label='真实值', ax=ax) # date_str = pd.Timestamp(y_hat["ds"].iloc[0]).strftime('%Y-%m-%d') ax.set_title(f'{global_config["end_time"]} PP期货八大维度 预测价格走势') ax.set_xlabel('日期') ax.set_ylabel('预测结果') ax.tick_params(axis='x', rotation=45) # 准备表格数据 y_hat = y_hat[['predictresult']].T print(y_hat) y_hat.rename(columns={'day_price': '次日', 'week_price': '本周', 'second_week_price': '次周', 'next_week_price': '隔周', 'next_month_price': '次月', 'next_february_price': '次二月', 'next_march_price': '次三月', 'next_april_price': '次四月', }, inplace=True) columns = y_hat.columns.tolist() data = y_hat.values.tolist() # 将日期转换为字符串格式 for row in data: if isinstance(row[0], pd.Timestamp): row[0] = row[0].strftime('%Y-%m-%d') # 在图表下方添加表格 table = ax.table(cellText=data, colLabels=columns, loc='bottom', bbox=[0, -0.6, 1, 0.2]) table.auto_set_font_size(False) table.set_fontsize(14) plt.tight_layout(rect=[0, 0.1, 1, 1]) # 调整布局,为表格留出空间 plt.savefig(os.path.join( global_config['dataset'], 'pp_predict_result.png')) if __name__ == '__main__': print('This is a tool, not a script.')