From 09e43f8ac1f02aa4506429ac8a3c8582df614f24 Mon Sep 17 00:00:00 2001 From: liurui Date: Tue, 24 Dec 2024 15:10:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=9B=8D=E5=AE=89=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_jingbo.py | 6 +- main_yuanyou.py | 21 +- models/nerulforcastmodels.py | 429 ++++++++++++++++++++++++++++++++++- 3 files changed, 435 insertions(+), 21 deletions(-) diff --git a/config_jingbo.py b/config_jingbo.py index ff08ed8..1b1f2c7 100644 --- a/config_jingbo.py +++ b/config_jingbo.py @@ -223,9 +223,9 @@ table_name = 'v_tbl_crude_oil_warning' ### 开关 -is_train = True # 是否训练 +is_train = False # 是否训练 is_debug = False # 是否调试 -is_eta = True # 是否使用eta接口 +is_eta = False # 是否使用eta接口 is_timefurture = True # 是否使用时间特征 is_fivemodels = False # 是否使用之前保存的最佳的5个模型 is_edbcode = False # 特征使用edbcoding列表中的 @@ -246,7 +246,7 @@ print("数据库连接成功",host,dbname,dbusername) # 数据截取日期 start_year = 2020 # 数据开始年份 -end_time = '2024-12-04' # 数据截取日期 +end_time = '' # 数据截取日期 freq = 'B' # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 delweekenday = True if freq == 'B' else False # 是否删除周末数据 is_corr = False # 特征是否参与滞后领先提升相关系数 diff --git a/main_yuanyou.py b/main_yuanyou.py index 3e8f425..798d49c 100644 --- a/main_yuanyou.py +++ b/main_yuanyou.py @@ -48,6 +48,7 @@ def predict_main(): 返回: None """ + global end_time signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, classifylisturl=classifylisturl, @@ -196,6 +197,7 @@ def predict_main(): modelsindex=modelsindex, data=data, is_eta=is_eta, + end_time=end_time, ) @@ -225,19 +227,20 @@ def predict_main(): # # ex_GRU(df) # 发送邮件 - m = SendMail( - username=username, - passwd=passwd, - recv=recv, - title=title, - content=content, - file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), - ssl=ssl, - ) + # m = SendMail( + # username=username, + # passwd=passwd, + # recv=recv, + # title=title, + # content=content, + # file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), + # ssl=ssl, + # ) # m.send_mail() if __name__ == '__main__': + global end_time is_on = True # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 for i_time in pd.date_range('2024-10-07', '2024-12-16', freq='B'): diff --git a/models/nerulforcastmodels.py b/models/nerulforcastmodels.py index d0ceaeb..961fef2 100644 --- a/models/nerulforcastmodels.py +++ b/models/nerulforcastmodels.py @@ -39,7 +39,7 @@ pdfmetrics.registerFont(TTFont('SimSun', 'SimSun.ttf')) @exception_logger def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patience_steps, is_debug,dataset,is_train,is_fivemodels,val_size,test_size,settings,now, - etadata,modelsindex,data,is_eta): + etadata,modelsindex,data,is_eta,end_time): ''' 模型训练与预测 :param df: 数据集 @@ -186,10 +186,10 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien filename = max(glob.glob(os.path.join(dataset,'*.joblib')), key=os.path.getctime) logger.info('读取模型:'+ filename) nf = load(filename) - # # # 测试集预测 - nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None) - # # 测试集预测结果保存 - nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False) + # 测试集预测 + # nf_test_preds = nf.cross_validation(df=df_test, val_size=val_size, test_size=test_size, n_windows=None) + # 测试集预测结果保存 + # nf_test_preds.to_csv(os.path.join(dataset,"cross_validation.csv"),index=False) df_test['ds'] = pd.to_datetime(df_test['ds'], errors='coerce') @@ -205,6 +205,8 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien # 将预测结果保存到数据库 def save_to_database(df): + # ds列转为日期字符串 + df['ds'] = df['ds'].dt.strftime('%Y-%m-%d') if not sqlitedb.check_table_exists('predict'): df.to_sql('predict',sqlitedb.connection,index=False) else: @@ -213,13 +215,14 @@ def ex_Model(df,horizon,input_size,train_steps,val_check_steps,early_stop_patien for row in df.itertuples(index=False): row_dict = row._asdict() columns=row_dict.keys() - check_query = sqlitedb.select_data('predict',where_condition = f"ds = '{row.ds} and model = {row.model}'") + check_query = sqlitedb.select_data('predict',where_condition = f"ds = '{row.ds}' and created_dt = '{end_time}'") if len(check_query) > 0: set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) - sqlitedb.update_data('predict',set_clause,where_condition = f"ds = '{row.ds}'") + sqlitedb.update_data('predict',set_clause,where_condition = f"ds = '{row.ds} and created_dt = {end_time}'") continue - sqlitedb.insert_data('predict',tuple(row_dict.values()),columns=columns) - save_to_database(df_predict,) + else: + sqlitedb.insert_data('predict',tuple(row_dict.values()),columns=columns) + save_to_database(df_predict) # 把预测值上传到eta if is_update_eta: @@ -648,6 +651,414 @@ def model_losss(sqlitedb,end_time): return model_results3 +# 原油计算预测评估指数 +@exception_logger +def model_losss_bak(sqlitedb,end_time): + global dataset + global rote + most_model = [sqlitedb.select_data('most_model',columns=['most_common_model'],order_by='ds desc',limit=1).values[0][0]] + most_model_name = most_model[0] + + # 预测数据处理 predict + df_combined = loadcsv(os.path.join(dataset,"cross_validation.csv")) + df_combined = dateConvert(df_combined) + # 删除空列 + df_combined.dropna(axis=1,inplace=True) + # 删除缺失值,预测过程不能有缺失值 + df_combined.dropna(inplace=True) + # 其他列转为数值类型 + df_combined = df_combined.astype({col: 'float32' for col in df_combined.columns if col not in ['cutoff','ds'] }) + # 使用 groupby 和 transform 结合 lambda 函数来获取每个分组中 cutoff 的最小值,并创建一个新的列来存储这个最大值 + df_combined['max_cutoff'] = df_combined.groupby('ds')['cutoff'].transform('max') + + # 然后筛选出那些 cutoff 等于 max_cutoff 的行,这样就得到了每个分组中 cutoff 最大的行,并保留了其他列 + df_combined = df_combined[df_combined['cutoff'] == df_combined['max_cutoff']] + # 删除模型生成的cutoff列 + df_combined.drop(columns=['cutoff', 'max_cutoff'], inplace=True) + # 获取模型名称 + modelnames = df_combined.columns.to_list()[1:] + if 'y' in modelnames: + modelnames.remove('y') + df_combined3 = df_combined.copy() # 备份df_combined,后面画图需要 + + + # 空的列表存储每个模型的MSE、RMSE、MAE、MAPE、SMAPE + cellText = [] + + # 遍历模型名称,计算模型评估指标 + for model in modelnames: + modelmse = mse(df_combined['y'], df_combined[model]) + modelrmse = rmse(df_combined['y'], df_combined[model]) + modelmae = mae(df_combined['y'], df_combined[model]) + # modelmape = mape(df_combined['y'], df_combined[model]) + # modelsmape = smape(df_combined['y'], df_combined[model]) + # modelr2 = r2_score(df_combined['y'], df_combined[model]) + cellText.append([model,round(modelmse, 3), round(modelrmse, 3), round(modelmae, 3)]) + + model_results3 = pd.DataFrame(cellText,columns=['模型(Model)','平均平方误差(MSE)', '均方根误差(RMSE)', '平均绝对误差(MAE)']) + # 按MSE降序排列 + model_results3 = model_results3.sort_values(by='平均平方误差(MSE)', ascending=True) + model_results3.to_csv(os.path.join(dataset,"model_evaluation.csv"),index=False) + modelnames = model_results3['模型(Model)'].tolist() + allmodelnames = modelnames.copy() + # 保存5个最佳模型的名称 + if len(modelnames) > 5: + modelnames = modelnames[0:5] + if is_fivemodels: + pass + else: + with open(os.path.join(dataset,"best_modelnames.txt"), 'w') as f: + f.write(','.join(modelnames) + '\n') + + # 预测值与真实值对比图 + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.figure(figsize=(15, 10)) + for n,model in enumerate(modelnames[:5]): + plt.subplot(3, 2, n+1) + plt.plot(df_combined3['ds'], df_combined3['y'], label='真实值') + plt.plot(df_combined3['ds'], df_combined3[model], label=model) + plt.legend() + plt.xlabel('日期') + plt.ylabel('价格') + plt.title(model+'拟合') + plt.subplots_adjust(hspace=0.5) + plt.savefig(os.path.join(dataset,'预测值与真实值对比图.png'), bbox_inches='tight') + plt.close() + + + # # 历史数据+预测数据 + # # 拼接未来时间预测 + df_predict = pd.read_csv(os.path.join(dataset,'predict.csv')) + df_predict.drop('unique_id',inplace=True,axis=1) + df_predict.dropna(axis=1,inplace=True) + + try: + df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y-%m-%d') + except ValueError : + df_predict['ds'] = pd.to_datetime(df_predict['ds'],format=r'%Y/%m/%d') + + # def first_row_to_database(df): + # # # 取第一行数据存储到数据库中 + # first_row = df.head(1) + # first_row['ds'] = first_row['ds'].dt.strftime('%Y-%m-%d 00:00:00') + # # 将预测结果保存到数据库 + # if not sqlitedb.check_table_exists('trueandpredict'): + # first_row.to_sql('trueandpredict',sqlitedb.connection,index=False) + # else: + # for col in first_row.columns: + # sqlitedb.add_column_if_not_exists('trueandpredict',col,'TEXT') + # for row in first_row.itertuples(index=False): + # row_dict = row._asdict() + # columns=row_dict.keys() + # check_query = sqlitedb.select_data('trueandpredict',where_condition = f"ds = '{row.ds}'") + # if len(check_query) > 0: + # set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) + # sqlitedb.update_data('trueandpredict',set_clause,where_condition = f"ds = '{row.ds}'") + # continue + # sqlitedb.insert_data('trueandpredict',tuple(row_dict.values()),columns=columns) + + # first_row_to_database(df_predict) + + + + df_combined3 = pd.concat([df_combined3, df_predict]).reset_index(drop=True) + + # 计算每个模型与最佳模型的绝对误差比例,根据设置的阈值rote筛选预测值显示最大最小值 + names = [] + names_df = df_combined3.copy() + for col in allmodelnames: + names_df[f'{col}-{most_model_name}-误差比例'] = abs(names_df[col] - names_df[most_model_name]) / names_df[most_model_name] + names.append(f'{col}-{most_model_name}-误差比例') + + names_df = names_df[names] + def add_rote_column(row): + columns = [] + for r in names_df.columns: + if row[r] <= rote: + columns.append(r.split('-')[0]) + return pd.Series([columns], index=['columns']) + names_df['columns'] = names_df.apply(add_rote_column, axis=1) + + def add_upper_lower_bound(row): + + # 计算上边界值 + upper_bound = row.max() + # 计算下边界值 + lower_bound = row.min() + return pd.Series([lower_bound, upper_bound], index=['min_within_quantile', 'max_within_quantile']) + + # df_combined3[['min_within_quantile','max_within_quantile']] = names_df.apply(add_upper_lower_bound, axis=1) + + # 取前五最佳模型的最大最小值作为上下边界值 + df_combined3[['min_within_quantile','max_within_quantile']]= df_combined3[modelnames].apply(add_upper_lower_bound, axis=1) + + def find_closest_values(row): + x = row.y + if x is None or np.isnan(x): + return pd.Series([None, None], index=['min_price','max_price']) + # row = row.drop('ds') + row = row.values.tolist() + row.sort() + print(row) + # x 在row中的索引 + index = row.index(x) + if index == 0: + return pd.Series([row[index+1], row[index+2]], index=['min_price','max_price']) + elif index == len(row)-1: + return pd.Series([row[index-2], row[index-1]], index=['min_price','max_price']) + else: + return pd.Series([row[index-1], row[index+1]], index=['min_price','max_price']) + + + + def find_most_common_model(): + # 最多频率的模型名称 + min_model_max_frequency_model = df_combined3['min_model'].tail(60).value_counts().idxmax() + max_model_max_frequency_model = df_combined3['max_model'].tail(60).value_counts().idxmax() + if min_model_max_frequency_model == max_model_max_frequency_model: + # 取60天第二多的模型 + max_model_max_frequency_model = df_combined3['max_model'].tail(60).value_counts().nlargest(2).index[1] + + df_predict['min_model'] = min_model_max_frequency_model + df_predict['max_model'] = max_model_max_frequency_model + df_predict['min_within_quantile'] = df_predict[min_model_max_frequency_model] + df_predict['max_within_quantile'] = df_predict[max_model_max_frequency_model] + + + # find_most_common_model() + + df_combined3['ds'] = pd.to_datetime(df_combined3['ds']) + df_combined3['ds'] = df_combined3['ds'].dt.strftime('%Y-%m-%d') + df_predict2 = df_combined3.tail(horizon) + + # 保存到数据库 + if not sqlitedb.check_table_exists('accuracy'): + columns = ','.join(df_combined3.columns.to_list()+['id','CREAT_DATE','min_price','max_price']) + sqlitedb.create_table('accuracy',columns=columns) + existing_data = sqlitedb.select_data(table_name = "accuracy") + + if not existing_data.empty: + max_id = existing_data['id'].astype(int).max() + df_predict2['id'] = range(max_id + 1, max_id + 1 + len(df_predict2)) + else: + df_predict2['id'] = range(1, 1 + len(df_predict2)) + # df_predict2['CREAT_DATE'] = now if end_time == '' else end_time + df_predict2['CREAT_DATE'] = end_time + def get_common_columns(df1, df2): + # 获取两个DataFrame的公共列名 + return list(set(df1.columns).intersection(df2.columns)) + + common_columns = get_common_columns(df_predict2, existing_data) + try: + df_predict2[common_columns].to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + except: + df_predict2.to_sql("accuracy", con=sqlitedb.connection, if_exists='append', index=False) + + # 更新accuracy表中的y值 + update_y = sqlitedb.select_data(table_name = "accuracy",where_condition='y is null') + if len(update_y) > 0: + df_combined4 = df_combined3[(df_combined3['ds'].isin(update_y['ds'])) & (df_combined3['y'].notnull())] + if len(df_combined4) > 0: + for index, row in df_combined4.iterrows(): + try: + sqlitedb.update_data('accuracy',f"y = {row['y']}",f"ds = '{row['ds']}'") + except: + logger.error(f'更新accuracy表中的y值失败,row={row}') + # 上周准确率计算 + predict_y = sqlitedb.select_data(table_name = "accuracy") + # ids = predict_y[predict_y['min_price'].isnull()]['id'].tolist() + ids = predict_y['id'].tolist() + # 准确率基准与绘图上下界逻辑一致 + # predict_y[['min_price','max_price']] = predict_y[['min_within_quantile','max_within_quantile']] + # 模型评估前五均值 + predict_y['min_price'] = predict_y[modelnames].mean(axis=1) -1 + predict_y['max_price'] = predict_y[modelnames].mean(axis=1) +1 + # 模型评估前十均值 + # predict_y['min_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) -1 + # predict_y['max_price'] = predict_y[allmodelnames[0:10]].mean(axis=1) +1 + # 模型评估前十最大最小 + # allmodelnames 和 predict_y 列 重复的 + # allmodelnames = [col for col in allmodelnames if col in predict_y.columns] + # predict_y['min_price'] = predict_y[allmodelnames[0:10]].min(axis=1) + # predict_y['max_price'] = predict_y[allmodelnames[0:10]].max(axis=1) + for id in ids: + row = predict_y[predict_y['id'] == id] + try: + sqlitedb.update_data('accuracy',f"min_price = {row['min_price'].values[0]},max_price = {row['max_price'].values[0]}",f"id = {id}") + except: + logger.error(f'更新accuracy表中的min_price,max_price值失败,row={row}') + # 拼接市场最高最低价 + xlsfilename = os.path.join(dataset,'数据项下载.xls') + df2 = pd.read_excel(xlsfilename)[5:] + df2 = df2.rename(columns = {'数据项名称':'ds','布伦特最低价':'LOW_PRICE','布伦特最高价':'HIGH_PRICE'}) + print(df2.shape) + df = pd.merge(predict_y,df2,on=['ds'],how='left') + df['ds'] = pd.to_datetime(df['ds']) + df = df.reindex() + + # 判断预测值在不在布伦特最高最低价范围内,准确率为1,否则为0 + def is_within_range(row): + for model in allmodelnames: + if row['LOW_PRICE'] <= row[col] <= row['HIGH_PRICE']: + return 1 + else: + return 0 + + # 比较真实最高最低,和预测最高最低 计算准确率 + def calculate_accuracy(row): + # 全子集情况: + if (row['max_price'] >= row['HIGH_PRICE'] and row['min_price'] <= row['LOW_PRICE']) or \ + (row['max_price'] <= row['HIGH_PRICE'] and row['min_price'] >= row['LOW_PRICE']): + return 1 + # 无交集情况: + if row['max_price'] < row['LOW_PRICE'] or \ + row['min_price'] > row['HIGH_PRICE']: + return 0 + # 有交集情况: + else: + sorted_prices = sorted([row['LOW_PRICE'], row['min_price'], row['max_price'], row['HIGH_PRICE']]) + middle_diff = sorted_prices[2] - sorted_prices[1] + price_range = row['HIGH_PRICE'] - row['LOW_PRICE'] + accuracy = middle_diff / price_range + return accuracy + + columns = ['HIGH_PRICE','LOW_PRICE','min_price','max_price'] + df[columns] = df[columns].astype(float) + df['ACCURACY'] = df.apply(calculate_accuracy, axis=1) + # df['ACCURACY'] = df.apply(is_within_range, axis=1) + # 取结束日期上一周的日期 + def get_week_date(end_time): + endtime = end_time + endtimeweek = datetime.datetime.strptime(endtime, '%Y-%m-%d') + up_week = endtimeweek - datetime.timedelta(days=endtimeweek.weekday() + 14) + up_week_dates = [up_week + datetime.timedelta(days=i) for i in range(14)][4:-2] + up_week_dates = [date.strftime('%Y-%m-%d') for date in up_week_dates] + return up_week_dates + up_week_dates = get_week_date(end_time) + + # 计算准确率并保存结果 + def _get_accuracy_rate(df,up_week_dates,endtime): + df3 = df.copy() + df3 = df3[df3['CREAT_DATE'].isin(up_week_dates)] + df3 = df3[df3['ds'].isin(up_week_dates)] + accuracy_rote = 0 + for i,group in df3.groupby('ds'): + print('权重:',weight_dict[len(group)-1]) + print('准确率:',(group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1]) + accuracy_rote += (group['ACCURACY'].sum()/len(group))*weight_dict[len(group)-1] + df3.to_csv(os.path.join(dataset,f'accuracy_{endtime}.csv'),index=False) + df4 = pd.DataFrame(columns=['开始日期','结束日期','准确率']) + df4.loc[len(df4)] = {'开始日期':up_week_dates[0],'结束日期':up_week_dates[-1],'准确率':accuracy_rote} + df4.to_sql("accuracy_rote", con=sqlitedb.connection, if_exists='append', index=False) + _get_accuracy_rate(df,up_week_dates,end_time) + + def _add_abs_error_rate(): + # 计算每个预测值与真实值之间的偏差率 + for model in allmodelnames: + df_combined3[f'{model}_abs_error_rate'] = abs(df_combined3['y'] - df_combined3[model]) / df_combined3['y'] + + # 获取每行对应的最小偏差率值 + min_abs_error_rate_values = df_combined3.apply(lambda row: row[[f'{model}_abs_error_rate' for model in allmodelnames]].min(), axis=1) + # 获取每行对应的最小偏差率值对应的列名 + min_abs_error_rate_column_name = df_combined3.apply(lambda row: row[[f'{model}_abs_error_rate' for model in allmodelnames]].idxmin(), axis=1) + # 将列名索引转换为列名 + min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0]) + # 获取最小偏差率对应的模型的预测值 + min_abs_error_rate_predictions = df_combined3.apply(lambda row: row[min_abs_error_rate_column_name[row.name]], axis=1) + # 将最小偏差率对应的模型的预测值添加到DataFrame中 + df_combined3['min_abs_error_rate_prediction'] = min_abs_error_rate_predictions + df_combined3['min_abs_error_rate_column_name'] = min_abs_error_rate_column_name + # _add_abs_error_rate() + + # 判断 df 的数值列转为float + for col in df_combined3.columns: + try: + if col != 'ds': + df_combined3[col] = df_combined3[col].astype(float) + df_combined3[col] = df_combined3[col].round(2) + except ValueError: + pass + df_combined3.to_csv(os.path.join(dataset,"testandpredict_groupby.csv"),index=False) + + + # 历史价格+预测价格 + sqlitedb.drop_table('testandpredict_groupby') + df_combined3.to_sql('testandpredict_groupby',sqlitedb.connection,index=False) + # 新增均值列 + df_combined3['mean'] = df_combined3[modelnames].mean(axis=1) + + def _plt_predict_ture(df): + lens = df.shape[0] if df.shape[0] < 180 else 90 + df = df[-lens:] # 取180个数据点画图 + # 历史价格 + plt.figure(figsize=(20, 10)) + plt.plot(df['ds'], df['y'], label='真实值') + # 均值线 + plt.plot(df['ds'], df['mean'], color='r', linestyle='--', label='前五模型预测均值') + # 颜色填充 + plt.fill_between(df['ds'], df['max_within_quantile'], df['min_within_quantile'], alpha=0.2) + markers = ['o', 's', '^', 'D', 'v', '*', 'p', 'h', 'H', '+', 'x', 'd'] + random_marker = random.choice(markers) + for model in modelnames: + # for model in ['BiTCN','RNN']: + plt.plot(df['ds'][-horizon:], df[model][-horizon:], label=model,marker=random_marker) + # plt.plot(df_combined3['ds'], df_combined3['min_abs_error_rate_prediction'], label='最小绝对误差', linestyle='--', color='orange') + # 网格 + plt.grid(True) + # 显示历史值 + for i, j in zip(df['ds'], df['y']): + plt.text(i, j, str(j), ha='center', va='bottom') + + # for model in most_model: + # plt.plot(df['ds'], df[model], label=model,marker='o') + # 当前日期画竖虚线 + plt.axvline(x=df['ds'].iloc[-horizon], color='r', linestyle='--') + plt.legend() + plt.xlabel('日期') + plt.ylabel('价格') + + plt.savefig(os.path.join(dataset,'历史价格-预测值.png'), bbox_inches='tight') + plt.close() + + def _plt_predict_table(df): + # 预测值表格 + fig, ax = plt.subplots(figsize=(20, 6)) + ax.axis('off') # 关闭坐标轴 + # 数值保留2位小数 + df = df.round(2) + df = df[-horizon:] + df['Day'] = [f'Day_{i}' for i in range(1,horizon+1)] + # Day列放到最前面 + df = df[['Day'] + list(df.columns[:-1])] + table = ax.table(cellText=df.values, colLabels=df.columns, loc='center') + #加宽表格 + table.auto_set_font_size(False) + table.set_fontsize(10) + + # 设置表格样式,列数据最小的用绿色标识 + plt.savefig(os.path.join(dataset,'预测值表格.png'), bbox_inches='tight') + plt.close() + + def _plt_model_results3(): + # 可视化评估结果 + plt.rcParams['font.sans-serif'] = ['SimHei'] + fig, ax = plt.subplots(figsize=(20, 10)) + ax.axis('off') # 关闭坐标轴 + table = ax.table(cellText=model_results3.values, colLabels=model_results3.columns, loc='center') + # 加宽表格 + table.auto_set_font_size(False) + table.set_fontsize(10) + + # 设置表格样式,列数据最小的用绿色标识 + plt.savefig(os.path.join(dataset,'模型评估.png'), bbox_inches='tight') + plt.close() + + _plt_predict_ture(df_combined3) + _plt_predict_table(df_combined3) + _plt_model_results3() + + return model_results3 + # 聚烯烃计算预测评估指数 @exception_logger