聚烯烃周度预测调试完成
This commit is contained in:
		
							parent
							
								
									3b0011ceeb
								
							
						
					
					
						commit
						95da926b3a
					
				| @ -202,15 +202,15 @@ table_name = 'v_tbl_crude_oil_warning' | |||||||
| 
 | 
 | ||||||
| ### 开关 | ### 开关 | ||||||
| is_train = False # 是否训练 | is_train = False # 是否训练 | ||||||
| is_debug = False # 是否调试 | is_debug = True # 是否调试 | ||||||
| is_eta = False # 是否使用eta接口 | is_eta = True # 是否使用eta接口 | ||||||
| is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | is_market = False # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | ||||||
| is_timefurture = True # 是否使用时间特征 | is_timefurture = True # 是否使用时间特征 | ||||||
| is_fivemodels = False # 是否使用之前保存的最佳的5个模型 | is_fivemodels = False # 是否使用之前保存的最佳的5个模型 | ||||||
| is_edbcode = False # 特征使用edbcoding列表中的 | is_edbcode = False # 特征使用edbcoding列表中的 | ||||||
| is_edbnamelist = False # 自定义特征,对应上面的edbnamelist | is_edbnamelist = False # 自定义特征,对应上面的edbnamelist | ||||||
| is_update_eta  = True  # 预测结果上传到eta | is_update_eta  = False  # 预测结果上传到eta | ||||||
| is_update_report = True # 是否上传报告 | is_update_report = False # 是否上传报告 | ||||||
| is_update_warning_data =  False # 是否上传预警数据 | is_update_warning_data =  False # 是否上传预警数据 | ||||||
| is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 | is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 | ||||||
| is_del_tow_month = True # 是否删除两个月不更新的特征 | is_del_tow_month = True # 是否删除两个月不更新的特征 | ||||||
| @ -224,9 +224,9 @@ print("数据库连接成功",host,dbname,dbusername) | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # 数据截取日期 | # 数据截取日期 | ||||||
| start_year = 2020 # 数据开始年份 | start_year = 2015 # 数据开始年份 | ||||||
| end_time = '2025-01-27' # 数据截取日期 | end_time = '' # 数据截取日期 | ||||||
| freq = 'W'  # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 | freq = 'WW'  # 时间频率,"D": 天 "W": 周"M": 月"Q": 季度"A": 年 "H": 小时 "T": 分钟 "S": 秒 "B": 工作日 | ||||||
| delweekenday = True if freq == 'B' else False # 是否删除周末数据 | delweekenday = True if freq == 'B' else False # 是否删除周末数据 | ||||||
| is_corr = False # 特征是否参与滞后领先提升相关系数 | is_corr = False # 特征是否参与滞后领先提升相关系数 | ||||||
| add_kdj = False # 是否添加kdj指标 | add_kdj = False # 是否添加kdj指标 | ||||||
| @ -243,8 +243,8 @@ avg_cols = [ | |||||||
| ] | ] | ||||||
| offsite = 80 | offsite = 80 | ||||||
| offsite_col = ['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'] | offsite_col = ['PP:拉丝:HP550J:市场价:青岛:金能化学(日)'] | ||||||
| horizon =1 # 预测的步长 | horizon =2 # 预测的步长 | ||||||
| input_size = 7  # 输入序列长度 | input_size = 14  # 输入序列长度 | ||||||
| train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 | train_steps = 50 if is_debug else 1000 # 训练步数,用来限定epoch次数 | ||||||
| val_check_steps = 30  # 评估频率 | val_check_steps = 30  # 评估频率 | ||||||
| early_stop_patience_steps = 5 # 早停的耐心步数    | early_stop_patience_steps = 5 # 早停的耐心步数    | ||||||
| @ -263,10 +263,10 @@ weight_dict = [0.4,0.15,0.1,0.1,0.25] # 权重 | |||||||
| 
 | 
 | ||||||
| ### 文件 | ### 文件 | ||||||
| data_set = 'PP指标数据.xlsx'  # 数据集文件   | data_set = 'PP指标数据.xlsx'  # 数据集文件   | ||||||
| dataset = 'juxitingzhududataset' # 数据集文件夹 | dataset = 'juxitingzhoududataset' # 数据集文件夹 | ||||||
| 
 | 
 | ||||||
| # 数据库名称 | # 数据库名称 | ||||||
| db_name = os.path.join(dataset,'jbsh_juxiting.db') | db_name = os.path.join(dataset,'jbsh_juxiting_zhoudu.db') | ||||||
| sqlitedb = SQLiteHandler(db_name)  | sqlitedb = SQLiteHandler(db_name)  | ||||||
| sqlitedb.connect() | sqlitedb.connect() | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1063,6 +1063,22 @@ def getdata_juxiting(filename, datecol='date', y='y', dataset='', add_kdj=False, | |||||||
| 
 | 
 | ||||||
|     return df, df_zhibiaoliebiao |     return df, df_zhibiaoliebiao | ||||||
| 
 | 
 | ||||||
|  | def getdata_zhoudu_juxiting(filename, datecol='date', y='y', dataset='', add_kdj=False, is_timefurture=False, end_time=''): | ||||||
|  |     config.logger.info('getdata接收:'+filename+' '+datecol+' '+end_time) | ||||||
|  |     # 判断后缀名 csv或excel | ||||||
|  |     if filename.endswith('.csv'): | ||||||
|  |         df = loadcsv(filename) | ||||||
|  |     else: | ||||||
|  |         # 读取excel 指标数据 | ||||||
|  |         df_zhibiaoshuju = pd.read_excel(filename, sheet_name='指标数据') | ||||||
|  |         df_zhibiaoliebiao = pd.read_excel(filename, sheet_name='指标列表') | ||||||
|  | 
 | ||||||
|  |     # 日期字符串转为datatime | ||||||
|  |     df = zhoududatachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol, y=y, dataset=dataset, | ||||||
|  |                             add_kdj=add_kdj, is_timefurture=is_timefurture, end_time=end_time) | ||||||
|  | 
 | ||||||
|  |     return df, df_zhibiaoliebiao | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def sanitize_filename(filename): | def sanitize_filename(filename): | ||||||
|     # 使用正则表达式替换不合规的字符 |     # 使用正则表达式替换不合规的字符 | ||||||
|  | |||||||
| @ -1,12 +1,80 @@ | |||||||
| # 读取配置 | # 读取配置 | ||||||
| from lib.dataread import * |  | ||||||
| from lib.tools import SendMail,exception_logger |  | ||||||
| from models.nerulforcastmodels import ex_Model_Juxiting,model_losss,model_losss_juxiting,brent_export_pdf,tansuanli_export_pdf,pp_export_pdf,model_losss_juxiting |  | ||||||
| 
 | 
 | ||||||
| import glob | from lib.dataread import * | ||||||
|  | from config_juxiting_zhoudu import * | ||||||
|  | from lib.tools import SendMail, exception_logger | ||||||
|  | from models.nerulforcastmodels import ex_Model, model_losss_juxiting, tansuanli_export_pdf, pp_export_pdf | ||||||
|  | import datetime | ||||||
| import torch | import torch | ||||||
| torch.set_float32_matmul_precision("high") | torch.set_float32_matmul_precision("high") | ||||||
| 
 | 
 | ||||||
|  | global_config.update({ | ||||||
|  |     # 核心参数 | ||||||
|  |     'logger': logger, | ||||||
|  |     'dataset': dataset, | ||||||
|  |     'y': y, | ||||||
|  |     'offsite_col': offsite_col, | ||||||
|  |     'avg_cols': avg_cols, | ||||||
|  |     'offsite': offsite, | ||||||
|  |     'edbcodenamedict': edbcodenamedict, | ||||||
|  |     'is_debug': is_debug, | ||||||
|  |     'is_train': is_train, | ||||||
|  |     'is_fivemodels': is_fivemodels, | ||||||
|  |     'settings': settings, | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     # 模型参数 | ||||||
|  |     'data_set': data_set, | ||||||
|  |     'input_size': input_size, | ||||||
|  |     'horizon': horizon, | ||||||
|  |     'train_steps': train_steps, | ||||||
|  |     'val_check_steps': val_check_steps, | ||||||
|  |     'val_size': val_size, | ||||||
|  |     'test_size': test_size, | ||||||
|  |     'modelsindex': modelsindex, | ||||||
|  |     'rote': rote, | ||||||
|  | 
 | ||||||
|  |     # 特征工程开关 | ||||||
|  |     'is_del_corr': is_del_corr, | ||||||
|  |     'is_del_tow_month': is_del_tow_month, | ||||||
|  |     'is_eta': is_eta, | ||||||
|  |     'is_update_eta': is_update_eta, | ||||||
|  |     'is_fivemodels': is_fivemodels, | ||||||
|  |     'early_stop_patience_steps': early_stop_patience_steps, | ||||||
|  | 
 | ||||||
|  |     # 时间参数 | ||||||
|  |     'start_year': start_year, | ||||||
|  |     'end_time': end_time or datetime.datetime.now().strftime("%Y-%m-%d"), | ||||||
|  |     'freq': freq,  # 保持列表结构 | ||||||
|  | 
 | ||||||
|  |     # 接口配置 | ||||||
|  |     'login_pushreport_url': login_pushreport_url, | ||||||
|  |     'login_data': login_data, | ||||||
|  |     'upload_url': upload_url, | ||||||
|  |     'upload_warning_url': upload_warning_url, | ||||||
|  |     'warning_data': warning_data, | ||||||
|  | 
 | ||||||
|  |     # 查询接口 | ||||||
|  |     'query_data_list_item_nos_url': query_data_list_item_nos_url, | ||||||
|  |     'query_data_list_item_nos_data': query_data_list_item_nos_data, | ||||||
|  | 
 | ||||||
|  |     # eta 配置 | ||||||
|  |     'APPID': APPID, | ||||||
|  |     'SECRET': SECRET, | ||||||
|  |     'etadata': data, | ||||||
|  |     'edbcodelist': edbcodelist, | ||||||
|  |     'ClassifyId': ClassifyId, | ||||||
|  |     'edbcodedataurl': edbcodedataurl, | ||||||
|  |     'classifyidlisturl': classifyidlisturl, | ||||||
|  |     'edbdatapushurl': edbdatapushurl, | ||||||
|  |     'edbdeleteurl': edbdeleteurl, | ||||||
|  |     'edbbusinessurl': edbbusinessurl, | ||||||
|  |     'ClassifyId':   ClassifyId, | ||||||
|  |     'classifylisturl': classifylisturl, | ||||||
|  | 
 | ||||||
|  |     # 数据库配置 | ||||||
|  |     'sqlitedb': sqlitedb, | ||||||
|  | }) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def predict_main(): | def predict_main(): | ||||||
| @ -48,31 +116,23 @@ def predict_main(): | |||||||
|     返回: |     返回: | ||||||
|         None |         None | ||||||
|     """ |     """ | ||||||
|     global end_time |     end_time = global_config['end_time'] | ||||||
|     signature = BinanceAPI(APPID, SECRET) |  | ||||||
|     etadata = EtaReader(signature=signature, |  | ||||||
|                         classifylisturl=classifylisturl, |  | ||||||
|                         classifyidlisturl=classifyidlisturl, |  | ||||||
|                         edbcodedataurl=edbcodedataurl, |  | ||||||
|                         edbcodelist=edbcodelist, |  | ||||||
|                         edbdatapushurl=edbdatapushurl, |  | ||||||
|                         edbdeleteurl=edbdeleteurl, |  | ||||||
|                         edbbusinessurl=edbbusinessurl |  | ||||||
|                         ) |  | ||||||
|     # 获取数据 |     # 获取数据 | ||||||
|     if is_eta: |     if is_eta: | ||||||
|         logger.info('从eta获取数据...') |         logger.info('从eta获取数据...') | ||||||
|         signature = BinanceAPI(APPID, SECRET) |         signature = BinanceAPI(APPID, SECRET) | ||||||
|         etadata = EtaReader(signature=signature, |         etadata = EtaReader(signature=signature, | ||||||
|                             classifylisturl=classifylisturl, |                             classifylisturl=global_config['classifylisturl'], | ||||||
|                             classifyidlisturl=classifyidlisturl, |                             classifyidlisturl=global_config['classifyidlisturl'], | ||||||
|                             edbcodedataurl=edbcodedataurl, |                             edbcodedataurl=global_config['edbcodedataurl'], | ||||||
|                             edbcodelist=edbcodelist, |                             edbcodelist=global_config['edbcodelist'], | ||||||
|                             edbdatapushurl=edbdatapushurl, |                             edbdatapushurl=global_config['edbdatapushurl'], | ||||||
|                             edbdeleteurl=edbdeleteurl, |                             edbdeleteurl=global_config['edbdeleteurl'], | ||||||
|                             edbbusinessurl=edbbusinessurl, |                             edbbusinessurl=global_config['edbbusinessurl'], | ||||||
|  |                             classifyId=global_config['ClassifyId'], | ||||||
|                             ) |                             ) | ||||||
|         df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data(data_set=data_set, dataset=dataset)  # 原始数据,未处理 |         df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_pp_data( | ||||||
|  |             data_set=data_set, dataset=dataset)  # 原始数据,未处理 | ||||||
| 
 | 
 | ||||||
|         if is_market: |         if is_market: | ||||||
|             logger.info('从市场信息平台获取数据...') |             logger.info('从市场信息平台获取数据...') | ||||||
| @ -83,26 +143,26 @@ def predict_main(): | |||||||
|                     df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) |                     df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) | ||||||
|                 else: |                 else: | ||||||
|                     logger.info('从市场信息平台获取数据') |                     logger.info('从市场信息平台获取数据') | ||||||
|                     df_zhibiaoshuju = get_market_data(end_time,df_zhibiaoshuju) |                     df_zhibiaoshuju = get_market_data( | ||||||
|                      |                         end_time, df_zhibiaoshuju) | ||||||
|             except : | 
 | ||||||
|  |             except: | ||||||
|                 logger.info('最高最低价拼接失败') |                 logger.info('最高最低价拼接失败') | ||||||
|          | 
 | ||||||
|         # 保存到xlsx文件的sheet表 |         # 保存到xlsx文件的sheet表 | ||||||
|         with pd.ExcelWriter(os.path.join(dataset,data_set)) as file: |         with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: | ||||||
|             df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) |             df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) | ||||||
|             df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) |             df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) | ||||||
|          | 
 | ||||||
|          |  | ||||||
|         # 数据处理 |         # 数据处理 | ||||||
|         df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, |         df = datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, y=global_config['y'], dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, | ||||||
|                         end_time=end_time) |                                 end_time=end_time) | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|         # 读取数据 |         # 读取数据 | ||||||
|         logger.info('读取本地数据:' + os.path.join(dataset, data_set)) |         logger.info('读取本地数据:' + os.path.join(dataset, data_set)) | ||||||
|         df,df_zhibiaoliebiao = getdata_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, |         df, df_zhibiaoliebiao = getdata_zhoudu_juxiting(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, | ||||||
|                      is_timefurture=is_timefurture, end_time=end_time)  # 原始数据,未处理 |                                                  is_timefurture=is_timefurture, end_time=end_time)  # 原始数据,未处理 | ||||||
| 
 | 
 | ||||||
|     # 更改预测列名称 |     # 更改预测列名称 | ||||||
|     df.rename(columns={y: 'y'}, inplace=True) |     df.rename(columns={y: 'y'}, inplace=True) | ||||||
| @ -124,47 +184,65 @@ def predict_main(): | |||||||
|     else: |     else: | ||||||
|         for row in first_row.itertuples(index=False): |         for row in first_row.itertuples(index=False): | ||||||
|             row_dict = row._asdict() |             row_dict = row._asdict() | ||||||
|             row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') |             config.logger.info(f'要保存的真实值:{row_dict}') | ||||||
|             check_query = sqlitedb.select_data('trueandpredict', where_condition=f"ds = '{row.ds}'") |             # 判断ds是否为字符串类型,如果不是则转换为字符串类型 | ||||||
|  |             if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): | ||||||
|  |                 row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') | ||||||
|  |             elif not isinstance(row_dict['ds'], str): | ||||||
|  |                 try: | ||||||
|  |                     row_dict['ds'] = pd.to_datetime( | ||||||
|  |                         row_dict['ds']).strftime('%Y-%m-%d') | ||||||
|  |                 except: | ||||||
|  |                     logger.warning(f"无法解析的时间格式: {row_dict['ds']}") | ||||||
|  |             # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') | ||||||
|  |             # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') | ||||||
|  |             check_query = sqlitedb.select_data( | ||||||
|  |                 'trueandpredict', where_condition=f"ds = '{row.ds}'") | ||||||
|             if len(check_query) > 0: |             if len(check_query) > 0: | ||||||
|                 set_clause = ", ".join([f"{key} = '{value}'" for key, value in row_dict.items()]) |                 set_clause = ", ".join( | ||||||
|                 sqlitedb.update_data('trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") |                     [f"{key} = '{value}'" for key, value in row_dict.items()]) | ||||||
|  |                 sqlitedb.update_data( | ||||||
|  |                     'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") | ||||||
|                 continue |                 continue | ||||||
|             sqlitedb.insert_data('trueandpredict', tuple(row_dict.values()), columns=row_dict.keys()) |             sqlitedb.insert_data('trueandpredict', tuple( | ||||||
|  |                 row_dict.values()), columns=row_dict.keys()) | ||||||
| 
 | 
 | ||||||
|     # 更新accuracy表的y值 |     # 更新accuracy表的y值 | ||||||
|     if not sqlitedb.check_table_exists('accuracy'): |     if not sqlitedb.check_table_exists('accuracy'): | ||||||
|         pass |         pass | ||||||
|     else: |     else: | ||||||
|         update_y = sqlitedb.select_data('accuracy',where_condition="y is null") |         update_y = sqlitedb.select_data( | ||||||
|  |             'accuracy', where_condition="y is null") | ||||||
|         if len(update_y) > 0: |         if len(update_y) > 0: | ||||||
|             logger.info('更新accuracy表的y值') |             logger.info('更新accuracy表的y值') | ||||||
|             # 找到update_y 中ds且df中的y的行 |             # 找到update_y 中ds且df中的y的行 | ||||||
|             update_y = update_y[update_y['ds']<=end_time] |             update_y = update_y[update_y['ds'] <= end_time] | ||||||
|             logger.info(f'要更新y的信息:{update_y}') |             logger.info(f'要更新y的信息:{update_y}') | ||||||
|             # try: |             # try: | ||||||
|             for row in update_y.itertuples(index=False): |             for row in update_y.itertuples(index=False): | ||||||
|                 try: |                 try: | ||||||
|                     row_dict = row._asdict()  	 |                     row_dict = row._asdict() | ||||||
|                     yy = df[df['ds']==row_dict['ds']]['y'].values[0] |                     yy = df[df['ds'] == row_dict['ds']]['y'].values[0] | ||||||
|                     LOW = df[df['ds']==row_dict['ds']]['Brentzdj'].values[0] |                     LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] | ||||||
|                     HIGH = df[df['ds']==row_dict['ds']]['Brentzgj'].values[0] |                     HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] | ||||||
|                     sqlitedb.update_data('accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") |                     sqlitedb.update_data( | ||||||
|  |                         'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") | ||||||
|                 except: |                 except: | ||||||
|                     logger.info(f'更新accuracy表的y值失败:{row_dict}') |                     logger.info(f'更新accuracy表的y值失败:{row_dict}') | ||||||
|             # except Exception as e: |             # except Exception as e: | ||||||
|             #     logger.info(f'更新accuracy表的y值失败:{e}') |             #     logger.info(f'更新accuracy表的y值失败:{e}') | ||||||
| 
 | 
 | ||||||
|     import datetime |  | ||||||
|     # 判断当前日期是不是周一 |     # 判断当前日期是不是周一 | ||||||
|     is_weekday = datetime.datetime.now().weekday() == 0 |     is_weekday = datetime.datetime.now().weekday() == 0 | ||||||
|     if is_weekday: |     if is_weekday: | ||||||
|         logger.info('今天是周一,更新预测模型') |         logger.info('今天是周一,更新预测模型') | ||||||
|         # 计算最近60天预测残差最低的模型名称 |         # 计算最近60天预测残差最低的模型名称 | ||||||
|         model_results = sqlitedb.select_data('trueandpredict', order_by="ds DESC", limit="60") |         model_results = sqlitedb.select_data( | ||||||
|  |             'trueandpredict', order_by="ds DESC", limit="60") | ||||||
|         # 删除空值率为90%以上的列 |         # 删除空值率为90%以上的列 | ||||||
|         if len(model_results) > 10: |         if len(model_results) > 10: | ||||||
|             model_results = model_results.dropna(thresh=len(model_results)*0.1,axis=1) |             model_results = model_results.dropna( | ||||||
|  |                 thresh=len(model_results)*0.1, axis=1) | ||||||
|         # 删除空行 |         # 删除空行 | ||||||
|         model_results = model_results.dropna() |         model_results = model_results.dropna() | ||||||
|         modelnames = model_results.columns.to_list()[2:-1] |         modelnames = model_results.columns.to_list()[2:-1] | ||||||
| @ -172,47 +250,59 @@ def predict_main(): | |||||||
|             model_results[col] = model_results[col].astype(np.float32) |             model_results[col] = model_results[col].astype(np.float32) | ||||||
|         # 计算每个预测值与真实值之间的偏差率 |         # 计算每个预测值与真实值之间的偏差率 | ||||||
|         for model in modelnames: |         for model in modelnames: | ||||||
|             model_results[f'{model}_abs_error_rate'] = abs(model_results['y'] - model_results[model]) / model_results['y'] |             model_results[f'{model}_abs_error_rate'] = abs( | ||||||
|  |                 model_results['y'] - model_results[model]) / model_results['y'] | ||||||
|         # 获取每行对应的最小偏差率值 |         # 获取每行对应的最小偏差率值 | ||||||
|         min_abs_error_rate_values = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) |         min_abs_error_rate_values = model_results.apply( | ||||||
|  |             lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) | ||||||
|         # 获取每行对应的最小偏差率值对应的列名 |         # 获取每行对应的最小偏差率值对应的列名 | ||||||
|         min_abs_error_rate_column_name = model_results.apply(lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) |         min_abs_error_rate_column_name = model_results.apply( | ||||||
|  |             lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) | ||||||
|         # 将列名索引转换为列名 |         # 将列名索引转换为列名 | ||||||
|         min_abs_error_rate_column_name = min_abs_error_rate_column_name.map(lambda x: x.split('_')[0]) |         min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( | ||||||
|  |             lambda x: x.split('_')[0]) | ||||||
|         # 取出现次数最多的模型名称 |         # 取出现次数最多的模型名称 | ||||||
|         most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() |         most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() | ||||||
|         logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") |         logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") | ||||||
|         # 保存结果到数据库 |         # 保存结果到数据库 | ||||||
|         if not sqlitedb.check_table_exists('most_model'): |         if not sqlitedb.check_table_exists('most_model'): | ||||||
|             sqlitedb.create_table('most_model', columns="ds datetime, most_common_model TEXT") |             sqlitedb.create_table( | ||||||
|         sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) |                 'most_model', columns="ds datetime, most_common_model TEXT") | ||||||
|  |         sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( | ||||||
|  |             '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         if is_weekday: |         if is_weekday: | ||||||
|         # if True: |             # if True: | ||||||
|             logger.info('今天是周一,发送特征预警') |             logger.info('今天是周一,发送特征预警') | ||||||
|             # 上传预警信息到数据库 |             # 上传预警信息到数据库 | ||||||
|             warning_data_df = df_zhibiaoliebiao.copy() |             warning_data_df = df_zhibiaoliebiao.copy() | ||||||
|             warning_data_df = warning_data_df[warning_data_df['停更周期']> 3 ][['指标名称', '指标id', '频度','更新周期','指标来源','最后更新时间','停更周期']] |             warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ | ||||||
|  |                 '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] | ||||||
|             # 重命名列名 |             # 重命名列名 | ||||||
|             warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) |             warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', | ||||||
|  |                                                      '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) | ||||||
|             from sqlalchemy import create_engine |             from sqlalchemy import create_engine | ||||||
|             import urllib |             import urllib | ||||||
|             global password |             global password | ||||||
|             if '@' in password: |             if '@' in password: | ||||||
|                 password = urllib.parse.quote_plus(password) |                 password = urllib.parse.quote_plus(password) | ||||||
| 
 | 
 | ||||||
|             engine = create_engine(f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') |             engine = create_engine( | ||||||
|             warning_data_df['WARNING_DATE'] =  datetime.date.today().strftime("%Y-%m-%d %H:%M:%S") |                 f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') | ||||||
|             warning_data_df['TENANT_CODE'] =  'T0004' |             warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( | ||||||
|  |                 "%Y-%m-%d %H:%M:%S") | ||||||
|  |             warning_data_df['TENANT_CODE'] = 'T0004' | ||||||
|             # 插入数据之前查询表数据然后新增id列 |             # 插入数据之前查询表数据然后新增id列 | ||||||
|             existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) |             existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) | ||||||
|             if not existing_data.empty: |             if not existing_data.empty: | ||||||
|                 max_id = existing_data['ID'].astype(int).max() |                 max_id = existing_data['ID'].astype(int).max() | ||||||
|                 warning_data_df['ID'] = range(max_id + 1, max_id + 1 + len(warning_data_df)) |                 warning_data_df['ID'] = range( | ||||||
|  |                     max_id + 1, max_id + 1 + len(warning_data_df)) | ||||||
|             else: |             else: | ||||||
|                 warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) |                 warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) | ||||||
|             warning_data_df.to_sql(table_name,  con=engine, if_exists='append', index=False) |             warning_data_df.to_sql( | ||||||
|  |                 table_name,  con=engine, if_exists='append', index=False) | ||||||
|             if is_update_warning_data: |             if is_update_warning_data: | ||||||
|                 upload_warning_info(len(warning_data_df)) |                 upload_warning_info(len(warning_data_df)) | ||||||
|     except: |     except: | ||||||
| @ -226,76 +316,75 @@ def predict_main(): | |||||||
|     row, col = df.shape |     row, col = df.shape | ||||||
| 
 | 
 | ||||||
|     now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') |     now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') | ||||||
|     ex_Model_Juxiting(df, |     ex_Model(df, | ||||||
|              horizon=horizon, |              horizon=global_config['horizon'], | ||||||
|              input_size=input_size, |              input_size=global_config['input_size'], | ||||||
|              train_steps=train_steps, |              train_steps=global_config['train_steps'], | ||||||
|              val_check_steps=val_check_steps, |              val_check_steps=global_config['val_check_steps'], | ||||||
|              early_stop_patience_steps=early_stop_patience_steps, |              early_stop_patience_steps=global_config['early_stop_patience_steps'], | ||||||
|              is_debug=is_debug, |              is_debug=global_config['is_debug'], | ||||||
|              dataset=dataset, |              dataset=global_config['dataset'], | ||||||
|              is_train=is_train, |              is_train=global_config['is_train'], | ||||||
|              is_fivemodels=is_fivemodels, |              is_fivemodels=global_config['is_fivemodels'], | ||||||
|              val_size=val_size, |              val_size=global_config['val_size'], | ||||||
|              test_size=test_size, |              test_size=global_config['test_size'], | ||||||
|              settings=settings, |              settings=global_config['settings'], | ||||||
|              now=now, |              now=now, | ||||||
|              etadata=etadata, |              etadata=global_config['etadata'], | ||||||
|              modelsindex=modelsindex, |              modelsindex=global_config['modelsindex'], | ||||||
|              data=data, |              data=data, | ||||||
|              is_eta=is_eta, |              is_eta=global_config['is_eta'], | ||||||
|              end_time=end_time, |              end_time=global_config['end_time'], | ||||||
|              ) |              ) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     logger.info('模型训练完成') |     logger.info('模型训练完成') | ||||||
|   | 
 | ||||||
|     logger.info('训练数据绘图ing') |     logger.info('训练数据绘图ing') | ||||||
|     model_results3 = model_losss_juxiting(sqlitedb) |     model_results3 = model_losss_juxiting(sqlitedb, end_time=global_config['end_time'],is_fivemodels=global_config['is_fivemodels']) | ||||||
|     logger.info('训练数据绘图end') |     logger.info('训练数据绘图end') | ||||||
|      | 
 | ||||||
|     # 模型报告 |     # # 模型报告 | ||||||
|     logger.info('制作报告ing') |     logger.info('制作报告ing') | ||||||
|     title = f'{settings}--{end_time}-预测报告' # 报告标题 |     title = f'{settings}--{end_time}-预测报告'  # 报告标题 | ||||||
|     reportname = f'PP大模型预测报告--{end_time}.pdf' # 报告文件名 |     reportname = f'Brent原油大模型周度预测--{end_time}.pdf'  # 报告文件名 | ||||||
|     reportname = reportname.replace(':', '-') # 替换冒号 |     reportname = reportname.replace(':', '-')  # 替换冒号 | ||||||
|     pp_export_pdf(dataset=dataset,num_models = 5 if is_fivemodels else 22,time=end_time, |     pp_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, | ||||||
|                 reportname=reportname,sqlitedb=sqlitedb), |                   reportname=reportname, sqlitedb=sqlitedb), | ||||||
| 
 | 
 | ||||||
|     logger.info('制作报告end') |     logger.info('制作报告end') | ||||||
|     logger.info('模型训练完成') |     logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
|     # # LSTM 单变量模型 |     # # LSTM 单变量模型 | ||||||
|     # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) |     # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) | ||||||
|      | 
 | ||||||
|     # # lstm 多变量模型 |     # # lstm 多变量模型 | ||||||
|     # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) |     # ex_Lstm_M(df,n_days=input_size,out_days=horizon,is_debug=is_debug,datasetpath=dataset) | ||||||
|      | 
 | ||||||
|     # # GRU 模型 |     # # GRU 模型 | ||||||
|     # # ex_GRU(df) |     # # ex_GRU(df) | ||||||
| 
 | 
 | ||||||
|     # 发送邮件 |     # 发送邮件 | ||||||
|     m = SendMail( |     # m = SendMail( | ||||||
|         username=username, |     #     username=username, | ||||||
|         passwd=passwd, |     #     passwd=passwd, | ||||||
|         recv=recv, |     #     recv=recv, | ||||||
|         title=title, |     #     title=title, | ||||||
|         content=content, |     #     content=content, | ||||||
|         file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), |     #     file=max(glob.glob(os.path.join(dataset,'*.pdf')), key=os.path.getctime), | ||||||
|         ssl=ssl, |     #     ssl=ssl, | ||||||
|     ) |     # ) | ||||||
|     # m.send_mail()    |     # m.send_mail() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # global end_time |     # global end_time | ||||||
|     # is_on = True |     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 | ||||||
|     # # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 |     # for i_time in pd.date_range('2022-1-1', '2025-3-26', freq='M'): | ||||||
|     # for i_time in pd.date_range('2025-1-20', '2025-2-6', freq='B'): |  | ||||||
|     #     end_time = i_time.strftime('%Y-%m-%d') |  | ||||||
|     #     try: |     #     try: | ||||||
|  |     #         global_config['end_time'] = i_time.strftime('%Y-%m-%d') | ||||||
|     #         predict_main() |     #         predict_main() | ||||||
|     #     except: |     #     except Exception as e: | ||||||
|     #         pass |     #         logger.info(f'预测失败:{e}') | ||||||
|  |     #         continue | ||||||
| 
 | 
 | ||||||
|     predict_main() |     predict_main() | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user