From 11270ead859023111c1816f1d7b0073c7ca20424 Mon Sep 17 00:00:00 2001 From: liurui Date: Thu, 14 Nov 2024 10:21:25 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/dataread.py | 34 ++++++++++++++++++++++++++-------- lib/tools.py | 3 --- main.py => main_yuanyou.py | 2 +- 3 files changed, 27 insertions(+), 12 deletions(-) rename main.py => main_yuanyou.py (99%) diff --git a/lib/dataread.py b/lib/dataread.py index 7d180b5..3c7678b 100644 --- a/lib/dataread.py +++ b/lib/dataread.py @@ -436,6 +436,11 @@ def calculate_kdj(data, n=9): def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y',dataset='dataset',delweekenday=False,add_kdj=False,is_timefurture=False): + ''' + 原油特征数据处理函数, + 接收的是两个df,一个是指标数据,一个是指标列表 + 输出的是一个df,包含ds,y,指标列 + ''' df = df_zhibiaoshuju.copy() if end_time == '': end_time = datetime.datetime.now().strftime('%Y-%m-%d') @@ -457,6 +462,11 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y' two_months_ago = current_date - timedelta(days=40) def check_column(col_name): + ''' + 判断两月不更新指标 + 输入:列名 + 输出:True or False + ''' if 'ds' in col_name or 'y' in col_name: return False df_check_column = df[['ds',col_name]] @@ -469,7 +479,6 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y' return corresponding_date < two_months_ago columns_to_drop = df.columns[df.columns.map(check_column)].tolist() df = df.drop(columns = columns_to_drop) - logger.info(f'删除两月不更新特征后数据量:{df.shape}') # 删除预测列空值的行 @@ -481,7 +490,7 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y' # 去掉指标列表中的columns_to_drop的行 df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(df.columns.tolist())] df_zhibiaoliebiao.to_csv(os.path.join(dataset,'特征处理后的指标名称及分类.csv'),index=False) - # 频度分析 + # 数据频度分析 featurePindu(dataset=dataset) # 向上填充 df = df.ffill() @@ -491,26 +500,35 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y' # 删除周六日的数据 if delweekenday: df = df[df['ds'].dt.weekday < 5] - + # kdj指标 if add_kdj: df = calculate_kdj(df) - + # 衍生时间特征 if is_timefurture: df = addtimecharacteristics(df=df,dataset=dataset) - + # 特征分析 featureAnalysis(df,dataset=dataset,y=y) return df def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y',dataset='dataset',delweekenday=False,add_kdj=False,is_timefurture=False): + ''' + 聚烯烃特征数据处理函数, + 接收的是两个df,一个是指标数据,一个是指标列表 + 输出的是一个df,包含ds,y,指标列 + ''' df = df_zhibiaoshuju.copy() if end_time == '': end_time = datetime.datetime.now().strftime('%Y-%m-%d') # date转为pddate df.rename(columns={datecol:'ds'},inplace=True) + # 指定列统一减少数值 df[offsite_col] = df[offsite_col]-offsite + # 预测列为avg_cols的均值 df[y] = df[avg_cols].mean(axis=1) - print(df[['ds',y]+avg_cols].head(20)) + # 去掉多余的列avg_cols + df = df.drop(columns=avg_cols) + # 重命名预测列 df.rename(columns={y:'y'},inplace=True) # 按时间顺序排列 @@ -521,10 +539,10 @@ def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time # 获取小于等于当前日期的数据 df = df[df['ds'] <= end_time] logger.info(f'删除两月不更新特征前数据量:{df.shape}') - # 去掉近最后数据对应的日期在两月以前的列,删除近2月的数据是常熟的列 + # 去掉近最后数据对应的日期在两月以前的列,删除近2月的数据是常数的列 current_date = datetime.datetime.now() two_months_ago = current_date - timedelta(days=40) - + # 检查两月不更新的特征 def check_column(col_name): if 'ds' in col_name or 'y' in col_name: return False diff --git a/lib/tools.py b/lib/tools.py index 61024d2..4f34b20 100644 --- a/lib/tools.py +++ b/lib/tools.py @@ -441,8 +441,5 @@ class SQLiteHandler: else: print(f"Column '{column_name}' already exists in table '{table_name}'.") - - - if __name__ == '__main__': print('This is a tool, not a script.') \ No newline at end of file diff --git a/main.py b/main_yuanyou.py similarity index 99% rename from main.py rename to main_yuanyou.py index 802697d..a51635d 100644 --- a/main.py +++ b/main_yuanyou.py @@ -26,7 +26,6 @@ def predict_main(): ) # 获取数据 if is_eta: - # eta数据 logger.info('从eta获取数据...') signature = BinanceAPI(APPID, SECRET) etadata = EtaReader(signature=signature, @@ -48,6 +47,7 @@ def predict_main(): df = datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) else: + # 读取数据 logger.info('读取本地数据:'+os.path.join(dataset,data_set)) df = getdata(filename=os.path.join(dataset,data_set),y=y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) # 原始数据,未处理