From 5d5b2ae251c2c6a74c1fe425f1204bd36186abb0 Mon Sep 17 00:00:00 2001 From: workpc Date: Wed, 23 Jul 2025 15:30:24 +0800 Subject: [PATCH] =?UTF-8?q?=E8=81=9A=E7=83=AF=E7=83=83=E6=97=A5=E5=BA=A6?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B6=A8=E8=B7=8C=E5=81=9C=E4=BB=B7=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_juxiting.py | 1 + lib/dataread.py | 26 +++++++++++++++++++++++--- main_juxiting.py | 21 +++++++++++---------- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/config_juxiting.py b/config_juxiting.py index 7603e39..928c722 100644 --- a/config_juxiting.py +++ b/config_juxiting.py @@ -439,6 +439,7 @@ is_update_predict_value = True # 是否上传预测值到市场信息平台 is_del_corr = 0.6 # 是否删除相关性高的特征,取值为 0-1 ,0 为不删除,0.6 表示删除相关性小于0.6的特征 is_del_tow_month = True # 是否删除两个月不更新的特征 is_bdwd = False # 是否使用八大维度 +is_add_zt_price = True # 是否添加涨跌停价格 # 连接到数据库 diff --git a/lib/dataread.py b/lib/dataread.py index 41b254e..36cf617 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -67,6 +67,7 @@ global_config = { 'is_update_eta_data': None, # ETA数据更新开关 'early_stop_patience_steps': None, # 早停步数 'is_update_report': None, # 是否更新报告开关 + 'is_add_zt_price': None, # 是否添加涨跌停价格 # 时间参数 'start_year': None, # 起始年份 @@ -516,8 +517,7 @@ def featureAnalysis(df, dataset, y): import matplotlib.pyplot as plt # 选择特征和标签列 X = df.drop(['ds', 'y'], axis=1) # 特征集,排除时间戳和标签列 - yy = df['y'] # 标签集 - + yy = df['y'] # 标签集 # 标签集自相关函数分析 from statsmodels.graphics.tsaplots import plot_acf @@ -1037,6 +1037,9 @@ def datachuli_juxiting(df_zhibiaoshuju, df_zhibiaoliebiao, datecol='date', end_t if is_timefurture: df = addtimecharacteristics(df=df, dataset=dataset) + if config.is_add_zt_price: + df = addztprice(df=df) + if config.freq == 'WW': # 自定义周数据 # 按weekofmothe分组取均值得到新的数据 @@ -1391,6 +1394,10 @@ class Config: @property def columnsrename(self): return global_config['columnsrename'] + # 涨跌停价格 + @property + def is_add_zt_price(self): return global_config['is_add_zt_price'] + config = Config() @@ -2440,11 +2447,13 @@ def get_high_low_data(df): df = pd.merge(df, df1, how='left', on='date') return df + def get_shujuxiang_data(df): # 读取excel 从第五行开始 df1 = pd.read_excel(os.path.join(config.dataset, '数据项下载.xls'), header=5, names=[ 'numid', 'date', 'MAIN_CONFT_SETTLE_PRICE']) - df1['MAIN_CONFT_SETTLE_PRICE'] = df1['MAIN_CONFT_SETTLE_PRICE'].str.replace(',', '').astype(float) + df1['MAIN_CONFT_SETTLE_PRICE'] = df1['MAIN_CONFT_SETTLE_PRICE'].str.replace( + ',', '').astype(float) # 合并数据 df = pd.merge(df, df1, how='left', on='date') return df @@ -2523,6 +2532,17 @@ def addtimecharacteristics(df, dataset): return df +def addztprice(df): + try: + df['ztprice'] = df['y'] * 1.06 + df['dtprice'] = df['y'] * 0.94 + df['ztprice'] = df['ztprice'].shift(1) + df['dtprice'] = df['dtprice'].shift(1) + except: + config.logger.info('添加涨跌停价格失败') + return df + + # 从数据库获取百川数据,接收一个百川id列表,返回df格式的数据 def get_baichuan_data(baichuanidnamedict): baichuanidlist = [str(k) for k in baichuanidnamedict.keys()] diff --git a/main_juxiting.py b/main_juxiting.py index a202f29..7b2ae63 100644 --- a/main_juxiting.py +++ b/main_juxiting.py @@ -45,6 +45,7 @@ global_config.update({ 'is_fivemodels': is_fivemodels, 'is_update_predict_value': is_update_predict_value, 'early_stop_patience_steps': early_stop_patience_steps, + 'is_add_zt_price': is_add_zt_price, # 时间参数 'start_year': start_year, @@ -95,6 +96,7 @@ global_config.update({ 'is_bdwd': is_bdwd, 'db_mysql': db_mysql, 'DEFAULT_CONFIG': DEFAULT_CONFIG, + }) @@ -545,17 +547,16 @@ def predict_main(): if __name__ == '__main__': # global end_time # 遍历2024-11-25 到 2024-12-3 之间的工作日日期 - for i_time in pd.date_range('2025-6-2', '2025-7-23', freq='B'): - try: - global_config['end_time'] = i_time.strftime('%Y-%m-%d') - global_config['db_mysql'].connect() - predict_main() - except Exception as e: - logger.info(f'预测失败:{e}') - continue - - # predict_main() + # for i_time in pd.date_range('2025-6-2', '2025-7-23', freq='B'): + # try: + # global_config['end_time'] = i_time.strftime('%Y-%m-%d') + # global_config['db_mysql'].connect() + # predict_main() + # except Exception as e: + # logger.info(f'预测失败:{e}') + # continue + predict_main() # push_market_value() # sql_inset_predict(global_config)