八大维度自有指标
This commit is contained in:
		
							parent
							
								
									05bfeebcb0
								
							
						
					
					
						commit
						8074647329
					
				| @ -89,8 +89,7 @@ data = { | |||||||
| ClassifyId = 1214 | ClassifyId = 1214 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | # 变量定义--线上环境 | ||||||
| ################################################################################################################  变量定义--线上环境 |  | ||||||
| # server_host = '10.200.32.39' | # server_host = '10.200.32.39' | ||||||
| # login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" | # login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" | ||||||
| # upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" | # upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" | ||||||
| @ -111,7 +110,6 @@ ClassifyId = 1214 | |||||||
| # } | # } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # upload_data = { | # upload_data = { | ||||||
| #     "funcModule":'研究报告信息', | #     "funcModule":'研究报告信息', | ||||||
| #     "funcOperation":'上传原油价格预测报告', | #     "funcOperation":'上传原油价格预测报告', | ||||||
| @ -151,7 +149,6 @@ ClassifyId = 1214 | |||||||
| # } | # } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # push_data_value_list_data = { | # push_data_value_list_data = { | ||||||
| #     "funcModule": "数据表信息列表", | #     "funcModule": "数据表信息列表", | ||||||
| #     "funcOperation": "新增", | #     "funcOperation": "新增", | ||||||
| @ -186,9 +183,6 @@ ClassifyId = 1214 | |||||||
| # } | # } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # # 生产环境数据库 | # # 生产环境数据库 | ||||||
| # host = 'rm-2zehj3r1n60ttz9x5.mysql.rds.aliyuncs.com' | # host = 'rm-2zehj3r1n60ttz9x5.mysql.rds.aliyuncs.com' | ||||||
| # port = 3306 | # port = 3306 | ||||||
| @ -198,8 +192,6 @@ ClassifyId = 1214 | |||||||
| # table_name = 'v_tbl_crude_oil_warning' | # table_name = 'v_tbl_crude_oil_warning' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # # 变量定义--测试环境 | # # 变量定义--测试环境 | ||||||
| server_host = '192.168.100.53:8080'  # 内网 | server_host = '192.168.100.53:8080'  # 内网 | ||||||
| # server_host = '183.242.74.28'  # 外网 | # server_host = '183.242.74.28'  # 外网 | ||||||
| @ -307,7 +299,7 @@ table_name = 'v_tbl_crude_oil_warning' | |||||||
| # 开关 | # 开关 | ||||||
| is_train = True  # 是否训练 | is_train = True  # 是否训练 | ||||||
| is_debug = False  # 是否调试 | is_debug = False  # 是否调试 | ||||||
| is_eta = False  # 是否使用eta接口 | is_eta = True  # 是否使用eta接口 | ||||||
| is_market = True  # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | is_market = True  # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | ||||||
| is_timefurture = True  # 是否使用时间特征 | is_timefurture = True  # 是否使用时间特征 | ||||||
| is_fivemodels = False  # 是否使用之前保存的最佳的5个模型 | is_fivemodels = False  # 是否使用之前保存的最佳的5个模型 | ||||||
| @ -415,4 +407,3 @@ logger.addHandler(file_handler) | |||||||
| logger.addHandler(console_handler) | logger.addHandler(console_handler) | ||||||
| 
 | 
 | ||||||
| # logger.info('当前配置:'+settings) | # logger.info('当前配置:'+settings) | ||||||
| 
 |  | ||||||
|  | |||||||
| @ -89,8 +89,7 @@ data = { | |||||||
| ClassifyId = 1214 | ClassifyId = 1214 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | # 变量定义--线上环境 | ||||||
| ################################################################################################################  变量定义--线上环境 |  | ||||||
| # server_host = '10.200.32.39' | # server_host = '10.200.32.39' | ||||||
| # login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" | # login_pushreport_url = "http://10.200.32.39/jingbo-api/api/server/login" | ||||||
| # upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" | # upload_url = "http://10.200.32.39/jingbo-api/api/analysis/reportInfo/researchUploadReportSave" | ||||||
| @ -111,7 +110,6 @@ ClassifyId = 1214 | |||||||
| # } | # } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # upload_data = { | # upload_data = { | ||||||
| #     "funcModule":'研究报告信息', | #     "funcModule":'研究报告信息', | ||||||
| #     "funcOperation":'上传原油价格预测报告', | #     "funcOperation":'上传原油价格预测报告', | ||||||
| @ -151,7 +149,6 @@ ClassifyId = 1214 | |||||||
| # } | # } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # push_data_value_list_data = { | # push_data_value_list_data = { | ||||||
| #     "funcModule": "数据表信息列表", | #     "funcModule": "数据表信息列表", | ||||||
| #     "funcOperation": "新增", | #     "funcOperation": "新增", | ||||||
| @ -195,9 +192,6 @@ ClassifyId = 1214 | |||||||
| # table_name = 'v_tbl_crude_oil_warning' | # table_name = 'v_tbl_crude_oil_warning' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # # 变量定义--测试环境 | # # 变量定义--测试环境 | ||||||
| server_host = '192.168.100.53:8080'  # 内网 | server_host = '192.168.100.53:8080'  # 内网 | ||||||
| # server_host = '183.242.74.28'  # 外网 | # server_host = '183.242.74.28'  # 外网 | ||||||
| @ -304,7 +298,7 @@ table_name = 'v_tbl_crude_oil_warning' | |||||||
| # 开关 | # 开关 | ||||||
| is_train = True  # 是否训练 | is_train = True  # 是否训练 | ||||||
| is_debug = False  # 是否调试 | is_debug = False  # 是否调试 | ||||||
| is_eta = True  # 是否使用eta接口 | is_eta = False  # 是否使用eta接口 | ||||||
| is_market = True  # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | is_market = True  # 是否通过市场信息平台获取特征 ,在is_eta 为true 的情况下生效 | ||||||
| is_timefurture = True  # 是否使用时间特征 | is_timefurture = True  # 是否使用时间特征 | ||||||
| is_fivemodels = False  # 是否使用之前保存的最佳的5个模型 | is_fivemodels = False  # 是否使用之前保存的最佳的5个模型 | ||||||
|  | |||||||
| @ -265,7 +265,7 @@ def upload_report_data(token, upload_data): | |||||||
|     config.logger.info(f"token:{token}") |     config.logger.info(f"token:{token}") | ||||||
| 
 | 
 | ||||||
|     # 打印日志,显示要上传的报告数据 |     # 打印日志,显示要上传的报告数据 | ||||||
|     config.logger.info(f"upload_data:{upload_data}") |     # config.logger.info(f"upload_data:{upload_data}") | ||||||
| 
 | 
 | ||||||
|     # 发送POST请求,上传报告数据 |     # 发送POST请求,上传报告数据 | ||||||
|     upload_res = requests.post( |     upload_res = requests.post( | ||||||
| @ -275,7 +275,7 @@ def upload_report_data(token, upload_data): | |||||||
|     upload_res = json.loads(upload_res.text) |     upload_res = json.loads(upload_res.text) | ||||||
| 
 | 
 | ||||||
|     # 打印日志,显示响应内容 |     # 打印日志,显示响应内容 | ||||||
|     config.logger.info(upload_res) |     # config.logger.info(upload_res) | ||||||
| 
 | 
 | ||||||
|     # 如果上传成功,返回响应对象 |     # 如果上传成功,返回响应对象 | ||||||
|     if upload_res: |     if upload_res: | ||||||
| @ -808,7 +808,8 @@ def datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_time='', y | |||||||
|             # 判断对应的 'ds' 是否大于 start_date |             # 判断对应的 'ds' 是否大于 start_date | ||||||
|             if df.loc[first_valid_index, 'ds'] > start_date: |             if df.loc[first_valid_index, 'ds'] > start_date: | ||||||
|                 df.drop(columns=[col], inplace=True) |                 df.drop(columns=[col], inplace=True) | ||||||
|                 config.logger.info(f'删除开始时间没有数据的列:{col},第一条数据日期为:{df.loc[first_valid_index, "ds"]}') |                 config.logger.info( | ||||||
|  |                     f'删除开始时间没有数据的列:{col},第一条数据日期为:{df.loc[first_valid_index, "ds"]}') | ||||||
| 
 | 
 | ||||||
|     config.logger.info(f'删除开始时间没有数据的列后数据量:{df.shape}') |     config.logger.info(f'删除开始时间没有数据的列后数据量:{df.shape}') | ||||||
| 
 | 
 | ||||||
| @ -1200,6 +1201,7 @@ class Config: | |||||||
|     @property |     @property | ||||||
|     def warning_data(self): return global_config['warning_data'] |     def warning_data(self): return global_config['warning_data'] | ||||||
|     # 查询接口 |     # 查询接口 | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def query_data_list_item_nos_url( |     def query_data_list_item_nos_url( | ||||||
|         self): return global_config['query_data_list_item_nos_url'] |         self): return global_config['query_data_list_item_nos_url'] | ||||||
| @ -1219,8 +1221,8 @@ class Config: | |||||||
|     @property |     @property | ||||||
|     def bdwd_items(self): return global_config['bdwd_items'] |     def bdwd_items(self): return global_config['bdwd_items'] | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     # 字段映射 |     # 字段映射 | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def offsite_col(self): return global_config['offsite_col'] |     def offsite_col(self): return global_config['offsite_col'] | ||||||
|     @property |     @property | ||||||
| @ -1247,7 +1249,6 @@ class Config: | |||||||
|     def is_bdwd(self): return global_config['is_bdwd'] |     def is_bdwd(self): return global_config['is_bdwd'] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| config = Config() | config = Config() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -2303,7 +2304,6 @@ def get_baichuan_data(baichuanidnamedict): | |||||||
|     return df1 |     return df1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| def get_bdwd_predict_data(): | def get_bdwd_predict_data(): | ||||||
|     # 获取认证头部信息 |     # 获取认证头部信息 | ||||||
|     token = get_head_auth_report() |     token = get_head_auth_report() | ||||||
| @ -2315,21 +2315,20 @@ def get_bdwd_predict_data(): | |||||||
|     config.logger.info("获取八大维度数据...") |     config.logger.info("获取八大维度数据...") | ||||||
| 
 | 
 | ||||||
|     # 打印日志,显示上传的URL |     # 打印日志,显示上传的URL | ||||||
|     config.logger.info(f"query_data_list_item_nos_url:{config.query_data_list_item_nos_url}") |     config.logger.info( | ||||||
|  |         f"query_data_list_item_nos_url:{config.query_data_list_item_nos_url}") | ||||||
| 
 | 
 | ||||||
|     # 打印日志,显示认证头部信息 |     # 打印日志,显示认证头部信息 | ||||||
|     config.logger.info(f"token:{token}") |     config.logger.info(f"token:{token}") | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     # 打印日志,显示要查询的数据项 |     # 打印日志,显示要查询的数据项 | ||||||
|     config.logger.info(f"query_data_list_item_nos_data:{query_data_list_item_nos_data}") |     config.logger.info( | ||||||
|  |         f"query_data_list_item_nos_data:{query_data_list_item_nos_data}") | ||||||
| 
 | 
 | ||||||
|     # 发送POST请求,上传预警数据 |     # 发送POST请求,上传预警数据 | ||||||
|     respose = requests.post( |     respose = requests.post( | ||||||
|         url=config.upload_warning_url, headers=headers, json=query_data_list_item_nos_data, timeout=(3, 15)) |         url=config.upload_warning_url, headers=headers, json=query_data_list_item_nos_data, timeout=(3, 15)) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     # 如果上传成功,返回响应对象 |     # 如果上传成功,返回响应对象 | ||||||
|     if respose: |     if respose: | ||||||
|         # 处理返回结果为df |         # 处理返回结果为df | ||||||
| @ -2381,11 +2380,13 @@ def get_bdwd_predict_data(): | |||||||
|     df2['date'] = df2['date'].dt.strftime('%Y-%m-%d') |     df2['date'] = df2['date'].dt.strftime('%Y-%m-%d') | ||||||
|     # df = pd.merge(df, df2, how='left', on='date') |     # df = pd.merge(df, df2, how='left', on='date') | ||||||
|     # 更改列名: |     # 更改列名: | ||||||
|     df2.rename(columns={'yyycbdwdbz':'本周','yyycbdwdcey':'次二月','yyycbdwdcr':'次日','yyycbdwdcsiy':'次四月','yyycbdwdcsy':'次三月','yyycbdwdcy':'次月','yyycbdwdcz':'次周','yyycbdwdgz':'隔周',}, inplace=True) |     df2.rename(columns={'yyycbdwdbz': '本周', 'yyycbdwdcey': '次二月', 'yyycbdwdcr': '次日', 'yyycbdwdcsiy': '次四月', | ||||||
|  |                'yyycbdwdcsy': '次三月', 'yyycbdwdcy': '次月', 'yyycbdwdcz': '次周', 'yyycbdwdgz': '隔周', }, inplace=True) | ||||||
|     # df2.rename(columns={'原油大数据预测|FORECAST|PRICE|W':'本周','原油大数据预测|FORECAST|PRICE|M_2':'次二月','原油大数据预测|FORECAST|PRICE|T':'次日','原油大数据预测|FORECAST|PRICE|M_4':'次四月','原油大数据预测|FORECAST|PRICE|M_3':'次三月','原油大数据预测|FORECAST|PRICE|M_1':'次月','原油大数据预测|FORECAST|PRICE|W_1':'次周','原油大数据预测|FORECAST|PRICE|W_2':'隔周',}, inplace=True) |     # df2.rename(columns={'原油大数据预测|FORECAST|PRICE|W':'本周','原油大数据预测|FORECAST|PRICE|M_2':'次二月','原油大数据预测|FORECAST|PRICE|T':'次日','原油大数据预测|FORECAST|PRICE|M_4':'次四月','原油大数据预测|FORECAST|PRICE|M_3':'次三月','原油大数据预测|FORECAST|PRICE|M_1':'次月','原油大数据预测|FORECAST|PRICE|W_1':'次周','原油大数据预测|FORECAST|PRICE|W_2':'隔周',}, inplace=True) | ||||||
|     # 更改显示顺序 |     # 更改显示顺序 | ||||||
|     # 过滤掉不存在的列 |     # 过滤掉不存在的列 | ||||||
|     desired_columns = ['date','次日','本周','次周','隔周','次月','次二月','次三月','次四月'] |     desired_columns = ['date', '次日', '本周', | ||||||
|  |                        '次周', '隔周', '次月', '次二月', '次三月', '次四月'] | ||||||
|     existing_columns = [col for col in desired_columns if col in df2.columns] |     existing_columns = [col for col in desired_columns if col in df2.columns] | ||||||
| 
 | 
 | ||||||
|     # 更改显示顺序 |     # 更改显示顺序 | ||||||
|  | |||||||
| @ -45,7 +45,7 @@ global_config.update({ | |||||||
| 
 | 
 | ||||||
|     # 时间参数 |     # 时间参数 | ||||||
|     'start_year': start_year, |     'start_year': start_year, | ||||||
|     'end_time': end_time , |     'end_time': end_time, | ||||||
|     'freq': freq,  # 保持列表结构 |     'freq': freq,  # 保持列表结构 | ||||||
| 
 | 
 | ||||||
|     # 接口配置 |     # 接口配置 | ||||||
| @ -59,7 +59,7 @@ global_config.update({ | |||||||
|     'query_data_list_item_nos_url': query_data_list_item_nos_url, |     'query_data_list_item_nos_url': query_data_list_item_nos_url, | ||||||
|     'query_data_list_item_nos_data': query_data_list_item_nos_data, |     'query_data_list_item_nos_data': query_data_list_item_nos_data, | ||||||
| 
 | 
 | ||||||
|      # 上传数据项 |     # 上传数据项 | ||||||
|     'push_data_value_list_url': push_data_value_list_url, |     'push_data_value_list_url': push_data_value_list_url, | ||||||
|     'push_data_value_list_data': push_data_value_list_data, |     'push_data_value_list_data': push_data_value_list_data, | ||||||
| 
 | 
 | ||||||
| @ -82,9 +82,6 @@ global_config.update({ | |||||||
| }) | }) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def push_market_value(): | def push_market_value(): | ||||||
|     config.logger.info('发送预测结果到市场信息平台') |     config.logger.info('发送预测结果到市场信息平台') | ||||||
|     # 读取预测数据和模型评估数据 |     # 读取预测数据和模型评估数据 | ||||||
| @ -121,13 +118,13 @@ def push_market_value(): | |||||||
|     predictdata = [ |     predictdata = [ | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['ciri'], |             "dataItemNo": global_config['bdwd_items']['ciri'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": first_mean |             "dataValue": first_mean | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['benzhou'], |             "dataItemNo": global_config['bdwd_items']['benzhou'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": last_mean |             "dataValue": last_mean | ||||||
|         } |         } | ||||||
| @ -142,8 +139,6 @@ def push_market_value(): | |||||||
|         config.logger.error(f"推送数据失败: {e}") |         config.logger.error(f"推送数据失败: {e}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def predict_main(): | def predict_main(): | ||||||
|     """ |     """ | ||||||
|     主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。 |     主预测函数,用于从 ETA 获取数据、处理数据、训练模型并进行预测。 | ||||||
| @ -385,25 +380,25 @@ def predict_main(): | |||||||
| 
 | 
 | ||||||
|     now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') |     now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') | ||||||
|     ex_Model_Juxiting(df, |     ex_Model_Juxiting(df, | ||||||
|              horizon=global_config['horizon'], |                       horizon=global_config['horizon'], | ||||||
|              input_size=global_config['input_size'], |                       input_size=global_config['input_size'], | ||||||
|              train_steps=global_config['train_steps'], |                       train_steps=global_config['train_steps'], | ||||||
|              val_check_steps=global_config['val_check_steps'], |                       val_check_steps=global_config['val_check_steps'], | ||||||
|              early_stop_patience_steps=global_config['early_stop_patience_steps'], |                       early_stop_patience_steps=global_config['early_stop_patience_steps'], | ||||||
|              is_debug=global_config['is_debug'], |                       is_debug=global_config['is_debug'], | ||||||
|              dataset=global_config['dataset'], |                       dataset=global_config['dataset'], | ||||||
|              is_train=global_config['is_train'], |                       is_train=global_config['is_train'], | ||||||
|              is_fivemodels=global_config['is_fivemodels'], |                       is_fivemodels=global_config['is_fivemodels'], | ||||||
|              val_size=global_config['val_size'], |                       val_size=global_config['val_size'], | ||||||
|              test_size=global_config['test_size'], |                       test_size=global_config['test_size'], | ||||||
|              settings=global_config['settings'], |                       settings=global_config['settings'], | ||||||
|              now=now, |                       now=now, | ||||||
|              etadata=etadata, |                       etadata=etadata, | ||||||
|              modelsindex=global_config['modelsindex'], |                       modelsindex=global_config['modelsindex'], | ||||||
|              data=data, |                       data=data, | ||||||
|              is_eta=global_config['is_eta'], |                       is_eta=global_config['is_eta'], | ||||||
|              end_time=global_config['end_time'], |                       end_time=global_config['end_time'], | ||||||
|              ) |                       ) | ||||||
| 
 | 
 | ||||||
|     logger.info('模型训练完成') |     logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
| @ -450,15 +445,14 @@ def predict_main(): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # global end_time |     # global end_time | ||||||
|     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 |     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 | ||||||
|     for i_time in pd.date_range('2025-4-7', '2025-4-8', freq='B'): |     # for i_time in pd.date_range('2025-4-14', '2025-4-15', freq='B'): | ||||||
|         try: |     #     try: | ||||||
|             global_config['end_time'] = i_time.strftime('%Y-%m-%d') |     #         global_config['end_time'] = i_time.strftime('%Y-%m-%d') | ||||||
|             predict_main() |     #         predict_main() | ||||||
|         except Exception as e: |     #     except Exception as e: | ||||||
|             logger.info(f'预测失败:{e}') |     #         logger.info(f'预测失败:{e}') | ||||||
|             continue |     #         continue | ||||||
| 
 | 
 | ||||||
|     # predict_main() |     predict_main() | ||||||
| 
 | 
 | ||||||
|     # push_market_value() |     # push_market_value() | ||||||
| 
 |  | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ | |||||||
| from lib.dataread import * | from lib.dataread import * | ||||||
| from config_jingbo_yuedu import * | from config_jingbo_yuedu import * | ||||||
| from lib.tools import SendMail, exception_logger | from lib.tools import SendMail, exception_logger | ||||||
| from models.nerulforcastmodels import ex_Model, model_losss,brent_export_pdf | from models.nerulforcastmodels import ex_Model, model_losss, brent_export_pdf | ||||||
| import datetime | import datetime | ||||||
| import torch | import torch | ||||||
| torch.set_float32_matmul_precision("high") | torch.set_float32_matmul_precision("high") | ||||||
| @ -121,25 +121,25 @@ def push_market_value(): | |||||||
|     predictdata = [ |     predictdata = [ | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['ciyue'], |             "dataItemNo": global_config['bdwd_items']['ciyue'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": ciyue_mean |             "dataValue": ciyue_mean | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['cieryue'], |             "dataItemNo": global_config['bdwd_items']['cieryue'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": cieryue_mean |             "dataValue": cieryue_mean | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['cisanyue'], |             "dataItemNo": global_config['bdwd_items']['cisanyue'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": cisanyue_mean |             "dataValue": cisanyue_mean | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['cisiyue'], |             "dataItemNo": global_config['bdwd_items']['cisiyue'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": cisieryue_mean |             "dataValue": cisieryue_mean | ||||||
|         } |         } | ||||||
| @ -415,7 +415,7 @@ def predict_main(): | |||||||
|              end_time=global_config['end_time'], |              end_time=global_config['end_time'], | ||||||
|              ) |              ) | ||||||
| 
 | 
 | ||||||
|     logger.info('模型训练完成') |     # logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
|     logger.info('训练数据绘图ing') |     logger.info('训练数据绘图ing') | ||||||
|     model_results3 = model_losss(sqlitedb, end_time=end_time) |     model_results3 = model_losss(sqlitedb, end_time=end_time) | ||||||
| @ -423,21 +423,19 @@ def predict_main(): | |||||||
| 
 | 
 | ||||||
|     push_market_value() |     push_market_value() | ||||||
| 
 | 
 | ||||||
|     # 模型报告 |     # # 模型报告 | ||||||
|     logger.info('制作报告ing') |     # logger.info('制作报告ing') | ||||||
|     title = f'{settings}--{end_time}-预测报告'  # 报告标题 |     # title = f'{settings}--{end_time}-预测报告'  # 报告标题 | ||||||
|     reportname = f'Brent原油大模型月度预测--{end_time}.pdf'  # 报告文件名 |     # reportname = f'Brent原油大模型月度预测--{end_time}.pdf'  # 报告文件名 | ||||||
|     reportname = reportname.replace(':', '-')  # 替换冒号 |     # reportname = reportname.replace(':', '-')  # 替换冒号 | ||||||
|     brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, |     # brent_export_pdf(dataset=dataset, num_models=5 if is_fivemodels else 22, time=end_time, | ||||||
|                         reportname=reportname, |     #                     reportname=reportname, | ||||||
|                         inputsize = global_config['horizon'], |     #                     inputsize = global_config['horizon'], | ||||||
|                         sqlitedb=sqlitedb |     #                     sqlitedb=sqlitedb | ||||||
|                       ), |     #                   ), | ||||||
| 
 |  | ||||||
|     logger.info('制作报告end') |  | ||||||
|     logger.info('模型训练完成') |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
|  |     # logger.info('制作报告end') | ||||||
|  |     # logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
|     # # LSTM 单变量模型 |     # # LSTM 单变量模型 | ||||||
|     # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) |     # ex_Lstm(df,input_seq_len=input_size,output_seq_len=horizon,is_debug=is_debug,dataset=dataset) | ||||||
| @ -464,7 +462,7 @@ def predict_main(): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # global end_time |     # global end_time | ||||||
|     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 |     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 | ||||||
|     for i_time in pd.date_range('2025-3-29', '2025-4-8', freq='B'): |     for i_time in pd.date_range('2025-3-13', '2025-3-31', freq='B'): | ||||||
|         try: |         try: | ||||||
|             global_config['end_time'] = i_time.strftime('%Y-%m-%d') |             global_config['end_time'] = i_time.strftime('%Y-%m-%d') | ||||||
|             predict_main() |             predict_main() | ||||||
| @ -473,4 +471,3 @@ if __name__ == '__main__': | |||||||
|             continue |             continue | ||||||
| 
 | 
 | ||||||
|     # predict_main() |     # predict_main() | ||||||
|      |  | ||||||
|  | |||||||
| @ -116,13 +116,13 @@ def push_market_value(): | |||||||
|     predictdata = [ |     predictdata = [ | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['cizhou'], |             "dataItemNo": global_config['bdwd_items']['cizhou'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": first_mean |             "dataValue": first_mean | ||||||
|         }, |         }, | ||||||
|         { |         { | ||||||
|             "dataItemNo": global_config['bdwd_items']['gezhou'], |             "dataItemNo": global_config['bdwd_items']['gezhou'], | ||||||
|             "dataDate": global_config['end_time'].replace('-',''), |             "dataDate": global_config['end_time'].replace('-', ''), | ||||||
|             "dataStatus": "add", |             "dataStatus": "add", | ||||||
|             "dataValue": last_mean |             "dataValue": last_mean | ||||||
|         } |         } | ||||||
| @ -176,234 +176,234 @@ def predict_main(): | |||||||
|     返回: |     返回: | ||||||
|         None |         None | ||||||
|     """ |     """ | ||||||
|     # end_time = global_config['end_time'] |     end_time = global_config['end_time'] | ||||||
| 
 | 
 | ||||||
|     # signature = BinanceAPI(APPID, SECRET) |     signature = BinanceAPI(APPID, SECRET) | ||||||
|     # etadata = EtaReader(signature=signature, |     etadata = EtaReader(signature=signature, | ||||||
|     #                     classifylisturl=global_config['classifylisturl'], |                         classifylisturl=global_config['classifylisturl'], | ||||||
|     #                     classifyidlisturl=global_config['classifyidlisturl'], |                         classifyidlisturl=global_config['classifyidlisturl'], | ||||||
|     #                     edbcodedataurl=global_config['edbcodedataurl'], |                         edbcodedataurl=global_config['edbcodedataurl'], | ||||||
|     #                     edbcodelist=global_config['edbcodelist'], |                         edbcodelist=global_config['edbcodelist'], | ||||||
|     #                     edbdatapushurl=global_config['edbdatapushurl'], |                         edbdatapushurl=global_config['edbdatapushurl'], | ||||||
|     #                     edbdeleteurl=global_config['edbdeleteurl'], |                         edbdeleteurl=global_config['edbdeleteurl'], | ||||||
|     #                     edbbusinessurl=global_config['edbbusinessurl'], |                         edbbusinessurl=global_config['edbbusinessurl'], | ||||||
|     #                     classifyId=global_config['ClassifyId'], |                         classifyId=global_config['ClassifyId'], | ||||||
|     #                     ) |                         ) | ||||||
|     # # 获取数据 |     # 获取数据 | ||||||
|     # if is_eta: |     if is_eta: | ||||||
|     #     logger.info('从eta获取数据...') |         logger.info('从eta获取数据...') | ||||||
| 
 | 
 | ||||||
|     #     df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( |         df_zhibiaoshuju, df_zhibiaoliebiao = etadata.get_eta_api_yuanyou_data( | ||||||
|     #         data_set=data_set, dataset=dataset)  # 原始数据,未处理 |             data_set=data_set, dataset=dataset)  # 原始数据,未处理 | ||||||
| 
 | 
 | ||||||
|     #     if is_market: |         if is_market: | ||||||
|     #         logger.info('从市场信息平台获取数据...') |             logger.info('从市场信息平台获取数据...') | ||||||
|     #         try: |             try: | ||||||
|     #             # 如果是测试环境,最高价最低价取excel文档 |                 # 如果是测试环境,最高价最低价取excel文档 | ||||||
|     #             if server_host == '192.168.100.53': |                 if server_host == '192.168.100.53': | ||||||
|     #                 logger.info('从excel文档获取最高价最低价') |                     logger.info('从excel文档获取最高价最低价') | ||||||
|     #                 df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) |                     df_zhibiaoshuju = get_high_low_data(df_zhibiaoshuju) | ||||||
|     #             else: |                 else: | ||||||
|     #                 logger.info('从市场信息平台获取数据') |                     logger.info('从市场信息平台获取数据') | ||||||
|     #                 df_zhibiaoshuju = get_market_data( |                     df_zhibiaoshuju = get_market_data( | ||||||
|     #                     end_time, df_zhibiaoshuju) |                         end_time, df_zhibiaoshuju) | ||||||
| 
 | 
 | ||||||
|     #         except: |             except: | ||||||
|     #             logger.info('最高最低价拼接失败') |                 logger.info('最高最低价拼接失败') | ||||||
| 
 | 
 | ||||||
|     #     # 保存到xlsx文件的sheet表 |         # 保存到xlsx文件的sheet表 | ||||||
|     #     with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: |         with pd.ExcelWriter(os.path.join(dataset, data_set)) as file: | ||||||
|     #         df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) |             df_zhibiaoshuju.to_excel(file, sheet_name='指标数据', index=False) | ||||||
|     #         df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) |             df_zhibiaoliebiao.to_excel(file, sheet_name='指标列表', index=False) | ||||||
| 
 | 
 | ||||||
|     #     # 数据处理 |         # 数据处理 | ||||||
|     #     df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, |         df = datachuli(df_zhibiaoshuju, df_zhibiaoliebiao, y=y, dataset=dataset, add_kdj=add_kdj, is_timefurture=is_timefurture, | ||||||
|     #                    end_time=end_time) |                        end_time=end_time) | ||||||
| 
 | 
 | ||||||
|     # else: |     else: | ||||||
|     #     # 读取数据 |         # 读取数据 | ||||||
|     #     logger.info('读取本地数据:' + os.path.join(dataset, data_set)) |         logger.info('读取本地数据:' + os.path.join(dataset, data_set)) | ||||||
|     #     df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, |         df, df_zhibiaoliebiao = getdata(filename=os.path.join(dataset, data_set), y=y, dataset=dataset, add_kdj=add_kdj, | ||||||
|     #                                     is_timefurture=is_timefurture, end_time=end_time)  # 原始数据,未处理 |                                         is_timefurture=is_timefurture, end_time=end_time)  # 原始数据,未处理 | ||||||
| 
 | 
 | ||||||
|     # # 更改预测列名称 |     # 更改预测列名称 | ||||||
|     # df.rename(columns={y: 'y'}, inplace=True) |     df.rename(columns={y: 'y'}, inplace=True) | ||||||
| 
 | 
 | ||||||
|     # if is_edbnamelist: |     if is_edbnamelist: | ||||||
|     #     df = df[edbnamelist] |         df = df[edbnamelist] | ||||||
|     # df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) |     df.to_csv(os.path.join(dataset, '指标数据.csv'), index=False) | ||||||
|     # # 保存最新日期的y值到数据库 |     # 保存最新日期的y值到数据库 | ||||||
|     # # 取第一行数据存储到数据库中 |     # 取第一行数据存储到数据库中 | ||||||
|     # first_row = df[['ds', 'y']].tail(1) |     first_row = df[['ds', 'y']].tail(1) | ||||||
|     # # 判断y的类型是否为float |     # 判断y的类型是否为float | ||||||
|     # if not isinstance(first_row['y'].values[0], float): |     if not isinstance(first_row['y'].values[0], float): | ||||||
|     #     logger.info(f'{end_time}预测目标数据为空,跳过') |         logger.info(f'{end_time}预测目标数据为空,跳过') | ||||||
|     #     return None |         return None | ||||||
| 
 | 
 | ||||||
|     # # 将最新真实值保存到数据库 |     # 将最新真实值保存到数据库 | ||||||
|     # if not sqlitedb.check_table_exists('trueandpredict'): |     if not sqlitedb.check_table_exists('trueandpredict'): | ||||||
|     #     first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) |         first_row.to_sql('trueandpredict', sqlitedb.connection, index=False) | ||||||
|     # else: |     else: | ||||||
|     #     for row in first_row.itertuples(index=False): |         for row in first_row.itertuples(index=False): | ||||||
|     #         row_dict = row._asdict() |             row_dict = row._asdict() | ||||||
|     #         config.logger.info(f'要保存的真实值:{row_dict}') |             config.logger.info(f'要保存的真实值:{row_dict}') | ||||||
|     #         # 判断ds是否为字符串类型,如果不是则转换为字符串类型 |             # 判断ds是否为字符串类型,如果不是则转换为字符串类型 | ||||||
|     #         if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): |             if isinstance(row_dict['ds'], (pd.Timestamp, datetime.datetime)): | ||||||
|     #             row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') |                 row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') | ||||||
|     #         elif not isinstance(row_dict['ds'], str): |             elif not isinstance(row_dict['ds'], str): | ||||||
|     #             try: |                 try: | ||||||
|     #                 row_dict['ds'] = pd.to_datetime( |                     row_dict['ds'] = pd.to_datetime( | ||||||
|     #                     row_dict['ds']).strftime('%Y-%m-%d') |                         row_dict['ds']).strftime('%Y-%m-%d') | ||||||
|     #             except: |                 except: | ||||||
|     #                 logger.warning(f"无法解析的时间格式: {row_dict['ds']}") |                     logger.warning(f"无法解析的时间格式: {row_dict['ds']}") | ||||||
|     #         # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') |             # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d') | ||||||
|     #         # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') |             # row_dict['ds'] = row_dict['ds'].strftime('%Y-%m-%d %H:%M:%S') | ||||||
|     #         check_query = sqlitedb.select_data( |             check_query = sqlitedb.select_data( | ||||||
|     #             'trueandpredict', where_condition=f"ds = '{row.ds}'") |                 'trueandpredict', where_condition=f"ds = '{row.ds}'") | ||||||
|     #         if len(check_query) > 0: |             if len(check_query) > 0: | ||||||
|     #             set_clause = ", ".join( |                 set_clause = ", ".join( | ||||||
|     #                 [f"{key} = '{value}'" for key, value in row_dict.items()]) |                     [f"{key} = '{value}'" for key, value in row_dict.items()]) | ||||||
|     #             sqlitedb.update_data( |                 sqlitedb.update_data( | ||||||
|     #                 'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") |                     'trueandpredict', set_clause, where_condition=f"ds = '{row.ds}'") | ||||||
|     #             continue |                 continue | ||||||
|     #         sqlitedb.insert_data('trueandpredict', tuple( |             sqlitedb.insert_data('trueandpredict', tuple( | ||||||
|     #             row_dict.values()), columns=row_dict.keys()) |                 row_dict.values()), columns=row_dict.keys()) | ||||||
| 
 | 
 | ||||||
|     # # 更新accuracy表的y值 |     # 更新accuracy表的y值 | ||||||
|     # if not sqlitedb.check_table_exists('accuracy'): |     if not sqlitedb.check_table_exists('accuracy'): | ||||||
|     #     pass |         pass | ||||||
|     # else: |     else: | ||||||
|     #     update_y = sqlitedb.select_data( |         update_y = sqlitedb.select_data( | ||||||
|     #         'accuracy', where_condition="y is null") |             'accuracy', where_condition="y is null") | ||||||
|     #     if len(update_y) > 0: |         if len(update_y) > 0: | ||||||
|     #         logger.info('更新accuracy表的y值') |             logger.info('更新accuracy表的y值') | ||||||
|     #         # 找到update_y 中ds且df中的y的行 |             # 找到update_y 中ds且df中的y的行 | ||||||
|     #         update_y = update_y[update_y['ds'] <= end_time] |             update_y = update_y[update_y['ds'] <= end_time] | ||||||
|     #         logger.info(f'要更新y的信息:{update_y}') |             logger.info(f'要更新y的信息:{update_y}') | ||||||
|     #         # try: |             # try: | ||||||
|     #         for row in update_y.itertuples(index=False): |             for row in update_y.itertuples(index=False): | ||||||
|     #             try: |                 try: | ||||||
|     #                 row_dict = row._asdict() |                     row_dict = row._asdict() | ||||||
|     #                 yy = df[df['ds'] == row_dict['ds']]['y'].values[0] |                     yy = df[df['ds'] == row_dict['ds']]['y'].values[0] | ||||||
|     #                 LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] |                     LOW = df[df['ds'] == row_dict['ds']]['Brentzdj'].values[0] | ||||||
|     #                 HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] |                     HIGH = df[df['ds'] == row_dict['ds']]['Brentzgj'].values[0] | ||||||
|     #                 sqlitedb.update_data( |                     sqlitedb.update_data( | ||||||
|     #                     'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") |                         'accuracy', f"y = {yy},LOW_PRICE = {LOW},HIGH_PRICE = {HIGH}", where_condition=f"ds = '{row_dict['ds']}'") | ||||||
|     #             except: |                 except: | ||||||
|     #                 logger.info(f'更新accuracy表的y值失败:{row_dict}') |                     logger.info(f'更新accuracy表的y值失败:{row_dict}') | ||||||
|     #         # except Exception as e: |             # except Exception as e: | ||||||
|     #         #     logger.info(f'更新accuracy表的y值失败:{e}') |             #     logger.info(f'更新accuracy表的y值失败:{e}') | ||||||
| 
 | 
 | ||||||
|     # # 判断当前日期是不是周一 |     # 判断当前日期是不是周一 | ||||||
|     # is_weekday = datetime.datetime.now().weekday() == 0 |     is_weekday = datetime.datetime.now().weekday() == 0 | ||||||
|     # if is_weekday: |     if is_weekday: | ||||||
|     #     logger.info('今天是周一,更新预测模型') |         logger.info('今天是周一,更新预测模型') | ||||||
|     #     # 计算最近60天预测残差最低的模型名称 |         # 计算最近60天预测残差最低的模型名称 | ||||||
|     #     model_results = sqlitedb.select_data( |         model_results = sqlitedb.select_data( | ||||||
|     #         'trueandpredict', order_by="ds DESC", limit="60") |             'trueandpredict', order_by="ds DESC", limit="60") | ||||||
|     #     # 删除空值率为90%以上的列 |         # 删除空值率为90%以上的列 | ||||||
|     #     if len(model_results) > 10: |         if len(model_results) > 10: | ||||||
|     #         model_results = model_results.dropna( |             model_results = model_results.dropna( | ||||||
|     #             thresh=len(model_results)*0.1, axis=1) |                 thresh=len(model_results)*0.1, axis=1) | ||||||
|     #     # 删除空行 |         # 删除空行 | ||||||
|     #     model_results = model_results.dropna() |         model_results = model_results.dropna() | ||||||
|     #     modelnames = model_results.columns.to_list()[2:-2] |         modelnames = model_results.columns.to_list()[2:-2] | ||||||
|     #     for col in model_results[modelnames].select_dtypes(include=['object']).columns: |         for col in model_results[modelnames].select_dtypes(include=['object']).columns: | ||||||
|     #         model_results[col] = model_results[col].astype(np.float32) |             model_results[col] = model_results[col].astype(np.float32) | ||||||
|     #     # 计算每个预测值与真实值之间的偏差率 |         # 计算每个预测值与真实值之间的偏差率 | ||||||
|     #     for model in modelnames: |         for model in modelnames: | ||||||
|     #         model_results[f'{model}_abs_error_rate'] = abs( |             model_results[f'{model}_abs_error_rate'] = abs( | ||||||
|     #             model_results['y'] - model_results[model]) / model_results['y'] |                 model_results['y'] - model_results[model]) / model_results['y'] | ||||||
|     #     # 获取每行对应的最小偏差率值 |         # 获取每行对应的最小偏差率值 | ||||||
|     #     min_abs_error_rate_values = model_results.apply( |         min_abs_error_rate_values = model_results.apply( | ||||||
|     #         lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) |             lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].min(), axis=1) | ||||||
|     #     # 获取每行对应的最小偏差率值对应的列名 |         # 获取每行对应的最小偏差率值对应的列名 | ||||||
|     #     min_abs_error_rate_column_name = model_results.apply( |         min_abs_error_rate_column_name = model_results.apply( | ||||||
|     #         lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) |             lambda row: row[[f'{model}_abs_error_rate' for model in modelnames]].idxmin(), axis=1) | ||||||
|     #     # 将列名索引转换为列名 |         # 将列名索引转换为列名 | ||||||
|     #     min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( |         min_abs_error_rate_column_name = min_abs_error_rate_column_name.map( | ||||||
|     #         lambda x: x.split('_')[0]) |             lambda x: x.split('_')[0]) | ||||||
|     #     # 取出现次数最多的模型名称 |         # 取出现次数最多的模型名称 | ||||||
|     #     most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() |         most_common_model = min_abs_error_rate_column_name.value_counts().idxmax() | ||||||
|     #     logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") |         logger.info(f"最近60天预测残差最低的模型名称:{most_common_model}") | ||||||
|     #     # 保存结果到数据库 |         # 保存结果到数据库 | ||||||
|     #     if not sqlitedb.check_table_exists('most_model'): |         if not sqlitedb.check_table_exists('most_model'): | ||||||
|     #         sqlitedb.create_table( |             sqlitedb.create_table( | ||||||
|     #             'most_model', columns="ds datetime, most_common_model TEXT") |                 'most_model', columns="ds datetime, most_common_model TEXT") | ||||||
|     #     sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( |         sqlitedb.insert_data('most_model', (datetime.datetime.now().strftime( | ||||||
|     #         '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) |             '%Y-%m-%d %H:%M:%S'), most_common_model,), columns=('ds', 'most_common_model',)) | ||||||
| 
 | 
 | ||||||
|     # try: |     try: | ||||||
|     #     if is_weekday: |         if is_weekday: | ||||||
|     #         # if True: |             # if True: | ||||||
|     #         logger.info('今天是周一,发送特征预警') |             logger.info('今天是周一,发送特征预警') | ||||||
|     #         # 上传预警信息到数据库 |             # 上传预警信息到数据库 | ||||||
|     #         warning_data_df = df_zhibiaoliebiao.copy() |             warning_data_df = df_zhibiaoliebiao.copy() | ||||||
|     #         warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ |             warning_data_df = warning_data_df[warning_data_df['停更周期'] > 3][[ | ||||||
|     #             '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] |                 '指标名称', '指标id', '频度', '更新周期', '指标来源', '最后更新时间', '停更周期']] | ||||||
|     #         # 重命名列名 |             # 重命名列名 | ||||||
|     #         warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', |             warning_data_df = warning_data_df.rename(columns={'指标名称': 'INDICATOR_NAME', '指标id': 'INDICATOR_ID', '频度': 'FREQUENCY', | ||||||
|     #                                                  '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) |                                                      '更新周期': 'UPDATE_FREQUENCY', '指标来源': 'DATA_SOURCE', '最后更新时间': 'LAST_UPDATE_DATE', '停更周期': 'UPDATE_SUSPENSION_CYCLE'}) | ||||||
|     #         from sqlalchemy import create_engine |             from sqlalchemy import create_engine | ||||||
|     #         import urllib |             import urllib | ||||||
|     #         global password |             global password | ||||||
|     #         if '@' in password: |             if '@' in password: | ||||||
|     #             password = urllib.parse.quote_plus(password) |                 password = urllib.parse.quote_plus(password) | ||||||
| 
 | 
 | ||||||
|     #         engine = create_engine( |             engine = create_engine( | ||||||
|     #             f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') |                 f'mysql+pymysql://{dbusername}:{password}@{host}:{port}/{dbname}') | ||||||
|     #         warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( |             warning_data_df['WARNING_DATE'] = datetime.date.today().strftime( | ||||||
|     #             "%Y-%m-%d %H:%M:%S") |                 "%Y-%m-%d %H:%M:%S") | ||||||
|     #         warning_data_df['TENANT_CODE'] = 'T0004' |             warning_data_df['TENANT_CODE'] = 'T0004' | ||||||
|     #         # 插入数据之前查询表数据然后新增id列 |             # 插入数据之前查询表数据然后新增id列 | ||||||
|     #         existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) |             existing_data = pd.read_sql(f"SELECT * FROM {table_name}", engine) | ||||||
|     #         if not existing_data.empty: |             if not existing_data.empty: | ||||||
|     #             max_id = existing_data['ID'].astype(int).max() |                 max_id = existing_data['ID'].astype(int).max() | ||||||
|     #             warning_data_df['ID'] = range( |                 warning_data_df['ID'] = range( | ||||||
|     #                 max_id + 1, max_id + 1 + len(warning_data_df)) |                     max_id + 1, max_id + 1 + len(warning_data_df)) | ||||||
|     #         else: |             else: | ||||||
|     #             warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) |                 warning_data_df['ID'] = range(1, 1 + len(warning_data_df)) | ||||||
|     #         warning_data_df.to_sql( |             warning_data_df.to_sql( | ||||||
|     #             table_name,  con=engine, if_exists='append', index=False) |                 table_name,  con=engine, if_exists='append', index=False) | ||||||
|     #         if is_update_warning_data: |             if is_update_warning_data: | ||||||
|     #             upload_warning_info(len(warning_data_df)) |                 upload_warning_info(len(warning_data_df)) | ||||||
|     # except: |     except: | ||||||
|     #     logger.info('上传预警信息到数据库失败') |         logger.info('上传预警信息到数据库失败') | ||||||
| 
 | 
 | ||||||
|     # if is_corr: |     if is_corr: | ||||||
|     #     df = corr_feature(df=df) |         df = corr_feature(df=df) | ||||||
| 
 | 
 | ||||||
|     # df1 = df.copy()  # 备份一下,后面特征筛选完之后加入ds y 列用 |     df1 = df.copy()  # 备份一下,后面特征筛选完之后加入ds y 列用 | ||||||
|     # logger.info(f"开始训练模型...") |     logger.info(f"开始训练模型...") | ||||||
|     # row, col = df.shape |     row, col = df.shape | ||||||
| 
 | 
 | ||||||
|     # now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') |     now = datetime.datetime.now().strftime('%Y%m%d%H%M%S') | ||||||
|     # ex_Model(df, |     ex_Model(df, | ||||||
|     #          horizon=global_config['horizon'], |              horizon=global_config['horizon'], | ||||||
|     #          input_size=global_config['input_size'], |              input_size=global_config['input_size'], | ||||||
|     #          train_steps=global_config['train_steps'], |              train_steps=global_config['train_steps'], | ||||||
|     #          val_check_steps=global_config['val_check_steps'], |              val_check_steps=global_config['val_check_steps'], | ||||||
|     #          early_stop_patience_steps=global_config['early_stop_patience_steps'], |              early_stop_patience_steps=global_config['early_stop_patience_steps'], | ||||||
|     #          is_debug=global_config['is_debug'], |              is_debug=global_config['is_debug'], | ||||||
|     #          dataset=global_config['dataset'], |              dataset=global_config['dataset'], | ||||||
|     #          is_train=global_config['is_train'], |              is_train=global_config['is_train'], | ||||||
|     #          is_fivemodels=global_config['is_fivemodels'], |              is_fivemodels=global_config['is_fivemodels'], | ||||||
|     #          val_size=global_config['val_size'], |              val_size=global_config['val_size'], | ||||||
|     #          test_size=global_config['test_size'], |              test_size=global_config['test_size'], | ||||||
|     #          settings=global_config['settings'], |              settings=global_config['settings'], | ||||||
|     #          now=now, |              now=now, | ||||||
|     #          etadata=etadata, |              etadata=etadata, | ||||||
|     #          modelsindex=global_config['modelsindex'], |              modelsindex=global_config['modelsindex'], | ||||||
|     #          data=data, |              data=data, | ||||||
|     #          is_eta=global_config['is_eta'], |              is_eta=global_config['is_eta'], | ||||||
|     #          end_time=global_config['end_time'], |              end_time=global_config['end_time'], | ||||||
|     #          ) |              ) | ||||||
| 
 | 
 | ||||||
|     # logger.info('模型训练完成') |     logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
|     # logger.info('训练数据绘图ing') |     logger.info('训练数据绘图ing') | ||||||
|     # model_results3 = model_losss(sqlitedb, end_time=end_time) |     model_results3 = model_losss(sqlitedb, end_time=end_time) | ||||||
|     # logger.info('训练数据绘图end') |     logger.info('训练数据绘图end') | ||||||
| 
 | 
 | ||||||
|     # # 模型报告 |     # # 模型报告 | ||||||
|     logger.info('制作报告ing') |     logger.info('制作报告ing') | ||||||
| @ -411,17 +411,17 @@ def predict_main(): | |||||||
|     reportname = f'Brent原油大模型周度预测--{end_time}.pdf'  # 报告文件名 |     reportname = f'Brent原油大模型周度预测--{end_time}.pdf'  # 报告文件名 | ||||||
|     reportname = reportname.replace(':', '-')  # 替换冒号 |     reportname = reportname.replace(':', '-')  # 替换冒号 | ||||||
|     brent_export_pdf(dataset=dataset, |     brent_export_pdf(dataset=dataset, | ||||||
|                         num_models=5 if is_fivemodels else 22, |                      num_models=5 if is_fivemodels else 22, | ||||||
|                         time=end_time, |                      time=end_time, | ||||||
|                         reportname=reportname, |                      reportname=reportname, | ||||||
|                         inputsize = global_config['horizon'], |                      inputsize=global_config['horizon'], | ||||||
|                         sqlitedb=sqlitedb |                      sqlitedb=sqlitedb | ||||||
|                       ), |                      ), | ||||||
| 
 | 
 | ||||||
|     logger.info('制作报告end') |     logger.info('制作报告end') | ||||||
|     logger.info('模型训练完成') |     logger.info('模型训练完成') | ||||||
| 
 | 
 | ||||||
|     # push_market_value() |     push_market_value() | ||||||
| 
 | 
 | ||||||
|     # 发送邮件 |     # 发送邮件 | ||||||
|     # m = SendMail( |     # m = SendMail( | ||||||
| @ -439,12 +439,12 @@ def predict_main(): | |||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     # global end_time |     # global end_time | ||||||
|     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 |     # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 | ||||||
|     # for i_time in pd.date_range('2025-1-1', '2025-3-29', freq='B'): |     for i_time in pd.date_range('2025-2-1', '2025-3-31', freq='B'): | ||||||
|     #     try: |         try: | ||||||
|     #         global_config['end_time'] = i_time.strftime('%Y-%m-%d') |             global_config['end_time'] = i_time.strftime('%Y-%m-%d') | ||||||
|     #         predict_main() |             predict_main() | ||||||
|     #     except Exception as e: |         except Exception as e: | ||||||
|     #         logger.info(f'预测失败:{e}') |             logger.info(f'预测失败:{e}') | ||||||
|     #         continue |             continue | ||||||
| 
 | 
 | ||||||
|     predict_main() |     # predict_main() | ||||||
|  | |||||||
| @ -173,10 +173,14 @@ if __name__ == '__main__': | |||||||
|     # cal_time_series(df, 7) # 模型调用 |     # cal_time_series(df, 7) # 模型调用 | ||||||
|     # 数据测试2(从excel中读取): |     # 数据测试2(从excel中读取): | ||||||
|     path = r'D:\code\PriceForecast-svn\yuanyoudataset\指标数据.csv' |     path = r'D:\code\PriceForecast-svn\yuanyoudataset\指标数据.csv' | ||||||
|     y = 'Brent活跃合约' |     # y = 'Brent活跃合约' | ||||||
|  |     y = 'y' | ||||||
|     df = pd.read_csv(path) |     df = pd.read_csv(path) | ||||||
|     df.rename(columns={f'{y}': 'deal_data'}, inplace=True) |     print(df.columns) | ||||||
|     df = df[['ds', f'{y}']] |     # df.rename(columns={f'{y}': 'deal_data'}, inplace=True) | ||||||
|  |     df.rename(columns={'y': 'deal_data'}, inplace=True) | ||||||
|  |     # df = df[['ds', f'{y}']] | ||||||
|  |     df = df[['ds', 'deal_data']] | ||||||
|     print(df.tail()) |     print(df.tail()) | ||||||
|     df.set_index(['ds'], inplace=True)  # 设置索引 |     df.set_index(['ds'], inplace=True)  # 设置索引 | ||||||
|     cal_time_series(df, 7) # 模型调用 |     cal_time_series(df, 7) # 模型调用 | ||||||
| @ -1243,7 +1243,11 @@ def model_losss(sqlitedb, end_time): | |||||||
|         df4.to_sql("accuracy_rote",  con=sqlitedb.connection, |         df4.to_sql("accuracy_rote",  con=sqlitedb.connection, | ||||||
|                    if_exists='append', index=False) |                    if_exists='append', index=False) | ||||||
|     create_dates, ds_dates = get_week_date(end_time) |     create_dates, ds_dates = get_week_date(end_time) | ||||||
|     _get_accuracy_rate(df, create_dates, ds_dates) |     try: | ||||||
|  |         _get_accuracy_rate(df, create_dates, ds_dates) | ||||||
|  |     except Exception as e: | ||||||
|  |         config.logger.info(f'准确率计算错误{e}') | ||||||
|  |      | ||||||
| 
 | 
 | ||||||
|     def _add_abs_error_rate(): |     def _add_abs_error_rate(): | ||||||
|         # 计算每个预测值与真实值之间的偏差率 |         # 计算每个预测值与真实值之间的偏差率 | ||||||
| @ -2502,16 +2506,16 @@ def brent_export_pdf(num_indicators=475, num_models=21, num_dayindicator=202, in | |||||||
|         config.dataset, reportname), pagesize=letter) |         config.dataset, reportname), pagesize=letter) | ||||||
|     doc.build(content) |     doc.build(content) | ||||||
|     # pdf 上传到数字化信息平台 |     # pdf 上传到数字化信息平台 | ||||||
|     # try: |     try: | ||||||
|     #     if config.is_update_report: |         if config.is_update_report: | ||||||
|     #         with open(os.path.join(config.dataset, reportname), 'rb') as f: |             with open(os.path.join(config.dataset, reportname), 'rb') as f: | ||||||
|     #             base64_data = base64.b64encode(f.read()).decode('utf-8') |                 base64_data = base64.b64encode(f.read()).decode('utf-8') | ||||||
|     #             config.upload_data["data"]["fileBase64"] = base64_data |                 config.upload_data["data"]["fileBase64"] = base64_data | ||||||
|     #         config.upload_data["data"]["fileName"] = reportname |             config.upload_data["data"]["fileName"] = reportname | ||||||
|     #         token = get_head_auth_report() |             token = get_head_auth_report() | ||||||
|     #         upload_report_data(token, config.upload_data) |             upload_report_data(token, config.upload_data) | ||||||
|     # except TimeoutError as e: |     except TimeoutError as e: | ||||||
|     #     print(f"请求超时: {e}") |         print(f"请求超时: {e}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @exception_logger | @exception_logger | ||||||
|  | |||||||
							
								
								
									
										237
									
								
								test/pushdata.py
									
									
									
									
									
								
							
							
						
						
									
										237
									
								
								test/pushdata.py
									
									
									
									
									
								
							| @ -3,25 +3,163 @@ from config_jingbo import * | |||||||
| # from config_tansuanli import * | # from config_tansuanli import * | ||||||
| from lib.tools import * | from lib.tools import * | ||||||
| from lib.dataread import * | from lib.dataread import * | ||||||
| from models.nerulforcastmodels import ex_Model,model_losss,brent_export_pdf,tansuanli_export_pdf | from models.nerulforcastmodels import ex_Model, model_losss, brent_export_pdf, tansuanli_export_pdf | ||||||
| from models.lstmmodels import ex_Lstm_M,ex_Lstm | from models.lstmmodels import ex_Lstm_M, ex_Lstm | ||||||
| from models.grumodels import ex_GRU | from models.grumodels import ex_GRU | ||||||
| import glob | import glob | ||||||
| import torch | import torch | ||||||
| torch.set_float32_matmul_precision("high") | torch.set_float32_matmul_precision("high") | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | names = [ | ||||||
|  |     '价格预测NHITS模型-次周', | ||||||
|  |     '价格预测Informer模型-次周', | ||||||
|  |     '价格预测LSTM模型-次周', | ||||||
|  |     '价格预测iTransformer模型-次周', | ||||||
|  |     '价格预测TSMixer模型-次周', | ||||||
|  |     '价格预测TSMixerx模型-次周', | ||||||
|  |     '价格预测PatchTST模型-次周', | ||||||
|  |     '价格预测RNN模型-次周', | ||||||
|  |     '价格预测GRU模型-次周', | ||||||
|  |     '价格预测TCN模型-次周', | ||||||
|  |     '价格预测BiTCN模型-次周', | ||||||
|  |     '价格预测DilatedRNN模型-次周', | ||||||
|  |     '价格预测MLP模型-次周', | ||||||
|  |     '价格预测DLinear模型-次周', | ||||||
|  |     '价格预测NLinear模型-次周', | ||||||
|  |     '价格预测TFT模型-次周', | ||||||
|  |     '价格预测FEDformer模型-次周', | ||||||
|  |     '价格预测StemGNN模型-次周', | ||||||
|  |     '价格预测MLPMultivariate模型-次周', | ||||||
|  |     '价格预测TiDE模型-次周', | ||||||
|  |     '价格预测DeepNPTS模型-次周', | ||||||
|  |     '价格预测NBEATS模型-次周', | ||||||
|  |     '价格预测NHITS模型-隔周', | ||||||
|  |     '价格预测Informer模型-隔周', | ||||||
|  |     '价格预测LSTM模型-隔周', | ||||||
|  |     '价格预测iTransformer模型-隔周', | ||||||
|  |     '价格预测TSMixer模型-隔周', | ||||||
|  |     '价格预测TSMixerx模型-隔周', | ||||||
|  |     '价格预测PatchTST模型-隔周', | ||||||
|  |     '价格预测RNN模型-隔周', | ||||||
|  |     '价格预测GRU模型-隔周', | ||||||
|  |     '价格预测TCN模型-隔周', | ||||||
|  |     '价格预测BiTCN模型-隔周', | ||||||
|  |     '价格预测DilatedRNN模型-隔周', | ||||||
|  |     '价格预测MLP模型-隔周', | ||||||
|  |     '价格预测DLinear模型-隔周', | ||||||
|  |     '价格预测NLinear模型-隔周', | ||||||
|  |     '价格预测TFT模型-隔周', | ||||||
|  |     '价格预测FEDformer模型-隔周', | ||||||
|  |     '价格预测StemGNN模型-隔周', | ||||||
|  |     '价格预测MLPMultivariate模型-隔周', | ||||||
|  |     '价格预测TiDE模型-隔周', | ||||||
|  |     '价格预测DeepNPTS模型-隔周', | ||||||
|  |     '价格预测NBEATS模型-隔周', | ||||||
|  |     '价格预测NHITS模型-次月', | ||||||
|  |     '价格预测Informer模型-次月', | ||||||
|  |     '价格预测LSTM模型-次月', | ||||||
|  |     '价格预测iTransformer模型-次月', | ||||||
|  |     '价格预测TSMixer模型-次月', | ||||||
|  |     '价格预测TSMixerx模型-次月', | ||||||
|  |     '价格预测PatchTST模型-次月', | ||||||
|  |     '价格预测RNN模型-次月', | ||||||
|  |     '价格预测GRU模型-次月', | ||||||
|  |     '价格预测TCN模型-次月', | ||||||
|  |     '价格预测BiTCN模型-次月', | ||||||
|  |     '价格预测DilatedRNN模型-次月', | ||||||
|  |     '价格预测MLP模型-次月', | ||||||
|  |     '价格预测DLinear模型-次月', | ||||||
|  |     '价格预测NLinear模型-次月', | ||||||
|  |     '价格预测TFT模型-次月', | ||||||
|  |     '价格预测FEDformer模型-次月', | ||||||
|  |     '价格预测StemGNN模型-次月', | ||||||
|  |     '价格预测MLPMultivariate模型-次月', | ||||||
|  |     '价格预测TiDE模型-次月', | ||||||
|  |     '价格预测DeepNPTS模型-次月', | ||||||
|  |     '价格预测NBEATS模型-次月', | ||||||
|  |     '价格预测NHITS模型-次二月', | ||||||
|  |     '价格预测Informer模型-次二月', | ||||||
|  |     '价格预测LSTM模型-次二月', | ||||||
|  |     '价格预测iTransformer模型-次二月', | ||||||
|  |     '价格预测TSMixer模型-次二月', | ||||||
|  |     '价格预测TSMixerx模型-次二月', | ||||||
|  |     '价格预测PatchTST模型-次二月', | ||||||
|  |     '价格预测RNN模型-次二月', | ||||||
|  |     '价格预测GRU模型-次二月', | ||||||
|  |     '价格预测TCN模型-次二月', | ||||||
|  |     '价格预测BiTCN模型-次二月', | ||||||
|  |     '价格预测DilatedRNN模型-次二月', | ||||||
|  |     '价格预测MLP模型-次二月', | ||||||
|  |     '价格预测DLinear模型-次二月', | ||||||
|  |     '价格预测NLinear模型-次二月', | ||||||
|  |     '价格预测TFT模型-次二月', | ||||||
|  |     '价格预测FEDformer模型-次二月', | ||||||
|  |     '价格预测StemGNN模型-次二月', | ||||||
|  |     '价格预测MLPMultivariate模型-次二月', | ||||||
|  |     '价格预测TiDE模型-次二月', | ||||||
|  |     '价格预测DeepNPTS模型-次二月', | ||||||
|  |     '价格预测NBEATS模型-次二月', | ||||||
|  |     '价格预测NHITS模型-次三月', | ||||||
|  |     '价格预测Informer模型-次三月', | ||||||
|  |     '价格预测LSTM模型-次三月', | ||||||
|  |     '价格预测iTransformer模型-次三月', | ||||||
|  |     '价格预测TSMixer模型-次三月', | ||||||
|  |     '价格预测TSMixerx模型-次三月', | ||||||
|  |     '价格预测PatchTST模型-次三月', | ||||||
|  |     '价格预测RNN模型-次三月', | ||||||
|  |     '价格预测GRU模型-次三月', | ||||||
|  |     '价格预测TCN模型-次三月', | ||||||
|  |     '价格预测BiTCN模型-次三月', | ||||||
|  |     '价格预测DilatedRNN模型-次三月', | ||||||
|  |     '价格预测MLP模型-次三月', | ||||||
|  |     '价格预测DLinear模型-次三月', | ||||||
|  |     '价格预测NLinear模型-次三月', | ||||||
|  |     '价格预测TFT模型-次三月', | ||||||
|  |     '价格预测FEDformer模型-次三月', | ||||||
|  |     '价格预测StemGNN模型-次三月', | ||||||
|  |     '价格预测MLPMultivariate模型-次三月', | ||||||
|  |     '价格预测TiDE模型-次三月', | ||||||
|  |     '价格预测DeepNPTS模型-次三月', | ||||||
|  |     '价格预测NBEATS模型-次三月', | ||||||
|  |     '价格预测NHITS模型-次四月', | ||||||
|  |     '价格预测Informer模型-次四月', | ||||||
|  |     '价格预测LSTM模型-次四月', | ||||||
|  |     '价格预测iTransformer模型-次四月', | ||||||
|  |     '价格预测TSMixer模型-次四月', | ||||||
|  |     '价格预测TSMixerx模型-次四月', | ||||||
|  |     '价格预测PatchTST模型-次四月', | ||||||
|  |     '价格预测RNN模型-次四月', | ||||||
|  |     '价格预测GRU模型-次四月', | ||||||
|  |     '价格预测TCN模型-次四月', | ||||||
|  |     '价格预测BiTCN模型-次四月', | ||||||
|  |     '价格预测DilatedRNN模型-次四月', | ||||||
|  |     '价格预测MLP模型-次四月', | ||||||
|  |     '价格预测DLinear模型-次四月', | ||||||
|  |     '价格预测NLinear模型-次四月', | ||||||
|  |     '价格预测TFT模型-次四月', | ||||||
|  |     '价格预测FEDformer模型-次四月', | ||||||
|  |     '价格预测StemGNN模型-次四月', | ||||||
|  |     '价格预测MLPMultivariate模型-次四月', | ||||||
|  |     '价格预测TiDE模型-次四月', | ||||||
|  |     '价格预测DeepNPTS模型-次四月', | ||||||
|  |     '价格预测NBEATS模型-次四月', | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
| 
 | 
 | ||||||
|     signature = BinanceAPI(APPID, SECRET) |     signature = BinanceAPI(APPID, SECRET) | ||||||
|     etadata = EtaReader(signature=signature, |     etadata = EtaReader(signature=signature, | ||||||
|                             classifylisturl = classifylisturl, |                         classifylisturl=classifylisturl, | ||||||
|                             classifyidlisturl=classifyidlisturl, |                         classifyidlisturl=classifyidlisturl, | ||||||
|                             edbcodedataurl=edbcodedataurl, |                         edbcodedataurl=edbcodedataurl, | ||||||
|                             edbcodelist=edbcodelist, |                         edbcodelist=edbcodelist, | ||||||
|                             edbdatapushurl = edbdatapushurl, |                         edbdatapushurl=edbdatapushurl, | ||||||
|                             edbdeleteurl = edbdeleteurl, |                         edbdeleteurl=edbdeleteurl, | ||||||
|                             edbbusinessurl = edbbusinessurl |                         edbbusinessurl=edbbusinessurl, | ||||||
|                             ) |                         classifyId=ClassifyId, | ||||||
|  |                         ) | ||||||
| 
 | 
 | ||||||
|     models = [ |     models = [ | ||||||
|         'NHITS', |         'NHITS', | ||||||
| @ -48,28 +186,28 @@ if __name__ == '__main__': | |||||||
| 
 | 
 | ||||||
|     # eta自由数据指标编码 |     # eta自由数据指标编码 | ||||||
|     modelsindex = { |     modelsindex = { | ||||||
|             'NHITS': 'SELF0000001', |         'NHITS': 'SELF0000001', | ||||||
|             'Informer':'SELF0000057', |         'Informer': 'SELF0000057', | ||||||
|             'LSTM':'SELF0000058', |         'LSTM': 'SELF0000058', | ||||||
|             'iTransformer':'SELF0000059', |         'iTransformer': 'SELF0000059', | ||||||
|             'TSMixer':'SELF0000060', |         'TSMixer': 'SELF0000060', | ||||||
|             'TSMixerx':'SELF0000061', |         'TSMixerx': 'SELF0000061', | ||||||
|             'PatchTST':'SELF0000062', |         'PatchTST': 'SELF0000062', | ||||||
|             'RNN':'SELF0000063', |         'RNN': 'SELF0000063', | ||||||
|             'GRU':'SELF0000064', |         'GRU': 'SELF0000064', | ||||||
|             'TCN':'SELF0000065', |         'TCN': 'SELF0000065', | ||||||
|             'BiTCN':'SELF0000066', |         'BiTCN': 'SELF0000066', | ||||||
|             'DilatedRNN':'SELF0000067', |         'DilatedRNN': 'SELF0000067', | ||||||
|             'MLP':'SELF0000068', |         'MLP': 'SELF0000068', | ||||||
|             'DLinear':'SELF0000069', |         'DLinear': 'SELF0000069', | ||||||
|             'NLinear':'SELF0000070', |         'NLinear': 'SELF0000070', | ||||||
|             'TFT':'SELF0000071', |         'TFT': 'SELF0000071', | ||||||
|             'FEDformer':'SELF0000072', |         'FEDformer': 'SELF0000072', | ||||||
|             'StemGNN':'SELF0000073', |         'StemGNN': 'SELF0000073', | ||||||
|             'MLPMultivariate':'SELF0000074', |         'MLPMultivariate': 'SELF0000074', | ||||||
|             'TiDE':'SELF0000075', |         'TiDE': 'SELF0000075', | ||||||
|             'DeepNPT':'SELF0000076' |         'DeepNPT': 'SELF0000076' | ||||||
|         } |     } | ||||||
| 
 | 
 | ||||||
|     # df_predict = pd.read_csv('dataset/predict.csv',encoding='gbk') |     # df_predict = pd.read_csv('dataset/predict.csv',encoding='gbk') | ||||||
|     # # df_predict.rename(columns={'ds':'Date'},inplace=True) |     # # df_predict.rename(columns={'ds':'Date'},inplace=True) | ||||||
| @ -84,21 +222,32 @@ if __name__ == '__main__': | |||||||
|     #     # print(data['DataList']) |     #     # print(data['DataList']) | ||||||
|     #     etadata.push_data(data) |     #     etadata.push_data(data) | ||||||
| 
 | 
 | ||||||
|  |     # 新增eta自有指标 | ||||||
|  |     # list = [{'Date': '2025-04-21', 'Value': 100}] | ||||||
|  |     # for name in names: | ||||||
|  |     #     data['DataList'] = list | ||||||
|  |     #     data['IndexName'] = name | ||||||
|  |     #     data['Remark'] = name | ||||||
|  |     #     # print(data['DataList']) | ||||||
|  |     #     etadata.push_data(data) | ||||||
|  |     #     time.sleep(1) | ||||||
|  | 
 | ||||||
|     # 删除指标 |     # 删除指标 | ||||||
|     # IndexCodeList = ['SELF0000055'] |     # SELF0000098 | ||||||
|     # for i in range(1,57): |     # IndexCodeList = ['SELF0000098'] | ||||||
|     #     if i < 10 : i = f'0{i}' |     # # for i in range(1,57): | ||||||
|     #     IndexCodeList.append(f'SELF00000{i}') |     # #     if i < 10 : i = f'0{i}' | ||||||
|  |     # #     IndexCodeList.append(f'SELF00000{i}') | ||||||
|     # print(IndexCodeList) |     # print(IndexCodeList) | ||||||
|     # etadata.del_zhibiao(IndexCodeList) |     # etadata.del_zhibiao(IndexCodeList) | ||||||
| 
 | 
 | ||||||
|     # 删除特定日期的值 |     # 删除特定日期的值 | ||||||
|     indexcodelist = modelsindex.values() |     # indexcodelist = modelsindex.values() | ||||||
|     for indexcode in indexcodelist: |     # for indexcode in indexcodelist: | ||||||
|         data = { |     #     data = { | ||||||
|             "IndexCode": indexcode, #指标编码 |     #         "IndexCode": indexcode,  # 指标编码 | ||||||
|             "StartDate": "2020-04-20", #指标需要删除的开始日期(>=),如果开始日期和结束日期相等,那么就是删除该日期 |     #         "StartDate": "2020-04-20",  # 指标需要删除的开始日期(>=),如果开始日期和结束日期相等,那么就是删除该日期 | ||||||
|             "EndDate": "2024-05-28" #指标需要删除的结束日期(<=),如果开始日期和结束日期相等,那么就是删除该日期 |     #         "EndDate": "2024-05-28"  # 指标需要删除的结束日期(<=),如果开始日期和结束日期相等,那么就是删除该日期 | ||||||
|         } |     #     } | ||||||
| 
 | 
 | ||||||
|         # etadata.del_business(data) |     # etadata.del_business(data) | ||||||
|  | |||||||
| @ -9,14 +9,15 @@ import time | |||||||
| def run_predictions(target_date): | def run_predictions(target_date): | ||||||
|     """执行三个预测脚本""" |     """执行三个预测脚本""" | ||||||
|     scripts = [ |     scripts = [ | ||||||
|         "main_yuanyou.py", |         # "main_yuanyou.py", | ||||||
|         "main_yuanyou_zhoudu.py", |         "main_yuanyou_zhoudu.py", | ||||||
|         "main_yuanyou_yuedu.py" |         "main_yuanyou_yuedu.py" | ||||||
|     ] |     ] | ||||||
| 
 | 
 | ||||||
|     # 依次执行每个脚本 |     # 依次执行每个脚本 | ||||||
|     for script in scripts: |     for script in scripts: | ||||||
|         command = [r"C:\Users\Hello\.conda\envs\predict\python", script] |         # command = [r"C:\Users\Hello\.conda\envs\predict\python", script] | ||||||
|  |         command = [r"C:\Users\EDY\.conda\envs\predict\python", script] | ||||||
|         subprocess.run(command, check=True) |         subprocess.run(command, check=True) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -26,10 +27,10 @@ def is_weekday(date): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # start_date = datetime.date(2025, 3, 13) |     start_date = datetime.date(2025, 2, 1) | ||||||
|     # 开始时间取当前时间 |     # 开始时间取当前时间 | ||||||
|     start_date = datetime.date.today() |     # start_date = datetime.date.today() | ||||||
|     # end_date = datetime.date(2100, 12, 31) |     end_date = datetime.date(2025, 3, 31) | ||||||
| 
 | 
 | ||||||
|     current_date = start_date |     current_date = start_date | ||||||
|     # while current_date <= end_date: |     # while current_date <= end_date: | ||||||
| @ -46,5 +47,7 @@ if __name__ == "__main__": | |||||||
| 
 | 
 | ||||||
|     #     current_date += datetime.timedelta(days=1) |     #     current_date += datetime.timedelta(days=1) | ||||||
| 
 | 
 | ||||||
|     print(f"开始执行 {current_date} 的预测任务") |     while current_date <= end_date: | ||||||
|     run_predictions(current_date) |         print(f"开始执行 {current_date} 的预测任务") | ||||||
|  |         run_predictions(current_date) | ||||||
|  |         current_date += datetime.timedelta(days=1) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user