添加注释
This commit is contained in:
parent
d61259d3b9
commit
11270ead85
@ -436,6 +436,11 @@ def calculate_kdj(data, n=9):
|
||||
|
||||
|
||||
def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y',dataset='dataset',delweekenday=False,add_kdj=False,is_timefurture=False):
|
||||
'''
|
||||
原油特征数据处理函数,
|
||||
接收的是两个df,一个是指标数据,一个是指标列表
|
||||
输出的是一个df,包含ds,y,指标列
|
||||
'''
|
||||
df = df_zhibiaoshuju.copy()
|
||||
if end_time == '':
|
||||
end_time = datetime.datetime.now().strftime('%Y-%m-%d')
|
||||
@ -457,6 +462,11 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
|
||||
two_months_ago = current_date - timedelta(days=40)
|
||||
|
||||
def check_column(col_name):
|
||||
'''
|
||||
判断两月不更新指标
|
||||
输入:列名
|
||||
输出:True or False
|
||||
'''
|
||||
if 'ds' in col_name or 'y' in col_name:
|
||||
return False
|
||||
df_check_column = df[['ds',col_name]]
|
||||
@ -469,7 +479,6 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
|
||||
return corresponding_date < two_months_ago
|
||||
columns_to_drop = df.columns[df.columns.map(check_column)].tolist()
|
||||
df = df.drop(columns = columns_to_drop)
|
||||
|
||||
logger.info(f'删除两月不更新特征后数据量:{df.shape}')
|
||||
|
||||
# 删除预测列空值的行
|
||||
@ -481,7 +490,7 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
|
||||
# 去掉指标列表中的columns_to_drop的行
|
||||
df_zhibiaoliebiao = df_zhibiaoliebiao[df_zhibiaoliebiao['指标名称'].isin(df.columns.tolist())]
|
||||
df_zhibiaoliebiao.to_csv(os.path.join(dataset,'特征处理后的指标名称及分类.csv'),index=False)
|
||||
# 频度分析
|
||||
# 数据频度分析
|
||||
featurePindu(dataset=dataset)
|
||||
# 向上填充
|
||||
df = df.ffill()
|
||||
@ -491,26 +500,35 @@ def datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y'
|
||||
# 删除周六日的数据
|
||||
if delweekenday:
|
||||
df = df[df['ds'].dt.weekday < 5]
|
||||
|
||||
# kdj指标
|
||||
if add_kdj:
|
||||
df = calculate_kdj(df)
|
||||
|
||||
# 衍生时间特征
|
||||
if is_timefurture:
|
||||
df = addtimecharacteristics(df=df,dataset=dataset)
|
||||
|
||||
# 特征分析
|
||||
featureAnalysis(df,dataset=dataset,y=y)
|
||||
return df
|
||||
|
||||
def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time='',y='y',dataset='dataset',delweekenday=False,add_kdj=False,is_timefurture=False):
|
||||
'''
|
||||
聚烯烃特征数据处理函数,
|
||||
接收的是两个df,一个是指标数据,一个是指标列表
|
||||
输出的是一个df,包含ds,y,指标列
|
||||
'''
|
||||
df = df_zhibiaoshuju.copy()
|
||||
if end_time == '':
|
||||
end_time = datetime.datetime.now().strftime('%Y-%m-%d')
|
||||
# date转为pddate
|
||||
df.rename(columns={datecol:'ds'},inplace=True)
|
||||
|
||||
# 指定列统一减少数值
|
||||
df[offsite_col] = df[offsite_col]-offsite
|
||||
# 预测列为avg_cols的均值
|
||||
df[y] = df[avg_cols].mean(axis=1)
|
||||
print(df[['ds',y]+avg_cols].head(20))
|
||||
# 去掉多余的列avg_cols
|
||||
df = df.drop(columns=avg_cols)
|
||||
|
||||
# 重命名预测列
|
||||
df.rename(columns={y:'y'},inplace=True)
|
||||
# 按时间顺序排列
|
||||
@ -521,10 +539,10 @@ def datachuli_juxiting(df_zhibiaoshuju,df_zhibiaoliebiao,datecol='date',end_time
|
||||
# 获取小于等于当前日期的数据
|
||||
df = df[df['ds'] <= end_time]
|
||||
logger.info(f'删除两月不更新特征前数据量:{df.shape}')
|
||||
# 去掉近最后数据对应的日期在两月以前的列,删除近2月的数据是常熟的列
|
||||
# 去掉近最后数据对应的日期在两月以前的列,删除近2月的数据是常数的列
|
||||
current_date = datetime.datetime.now()
|
||||
two_months_ago = current_date - timedelta(days=40)
|
||||
|
||||
# 检查两月不更新的特征
|
||||
def check_column(col_name):
|
||||
if 'ds' in col_name or 'y' in col_name:
|
||||
return False
|
||||
|
@ -441,8 +441,5 @@ class SQLiteHandler:
|
||||
else:
|
||||
print(f"Column '{column_name}' already exists in table '{table_name}'.")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('This is a tool, not a script.')
|
@ -26,7 +26,6 @@ def predict_main():
|
||||
)
|
||||
# 获取数据
|
||||
if is_eta:
|
||||
# eta数据
|
||||
logger.info('从eta获取数据...')
|
||||
signature = BinanceAPI(APPID, SECRET)
|
||||
etadata = EtaReader(signature=signature,
|
||||
@ -48,6 +47,7 @@ def predict_main():
|
||||
df = datachuli(df_zhibiaoshuju,df_zhibiaoliebiao,y = y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time)
|
||||
|
||||
else:
|
||||
# 读取数据
|
||||
logger.info('读取本地数据:'+os.path.join(dataset,data_set))
|
||||
df = getdata(filename=os.path.join(dataset,data_set),y=y,dataset=dataset,add_kdj=add_kdj,is_timefurture=is_timefurture,end_time=end_time) # 原始数据,未处理
|
||||
|
Loading…
Reference in New Issue
Block a user