PriceForecast/lib/duojinchengpredict.py
2024-11-01 16:38:21 +08:00

191 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import re
import os
import pandas as pd
import multiprocessing
import time
import joblib
import torch
# 定义函数
def loadcsv(filename):
try:
df = pd.read_csv(filename, encoding='utf-8')
except UnicodeDecodeError:
df = pd.read_csv(filename, encoding='gbk')
return df
def datachuli(df, datecol='date'):
# 删除空列
df = df.dropna(axis=1, how='all')
# 向上填充
df.ffill
# 向下填充
df.bfill
# date转为pddate
df.rename(columns={datecol: 'ds'}, inplace=True)
# 设置ds为pd.datetime
df['ds'] = pd.to_datetime(df['ds'])
# 重命名预测列
df.rename(columns={'Brent连1合约价格': 'y'}, inplace=True)
return df
def getdata(filename, datecol='date'):
df = loadcsv(filename)
df = datachuli(df, datecol)
return df
# 预测函数
def predict(X_test, nf,result_list):
df_predict = nf.predict(X_test).reset_index()
result_list.append(df_predict.values.tolist())
return df_predict
def testSetPredict(X_test, nf, columns,dataset):
# 记录开始时间
start_time = time.time()
# 计算每个进程处理的样本数
num_samples = len(X_test)
num_processes = multiprocessing.cpu_count()
samples_per_process = num_samples // num_processes
manager = multiprocessing.Manager()
result_list = manager.list() # 创建共享的列表
# 创建进程池
with multiprocessing.Pool(num_processes) as pool:
processes = []
for i in range(num_processes):
# 计算 每个进程需要处理的数据索引
start_index = i * samples_per_process
end_index = (i + 1) * samples_per_process if i != num_processes - 1 else num_samples
# 按计算的索引切分数据
X_test_split = X_test[start_index:end_index]
# 添加任务到进程池
for X in X_test_split:
processes.append(pool.apply_async(predict, args=(X, nf,result_list)))
for process in processes:
process.get()
# 将共享列表中的数据转换回 DataFrame
df_combined = pd.DataFrame()
df_combined2 = pd.DataFrame()
for result in result_list:
try:
df_shared = pd.DataFrame(result, columns=['index', 'ds'] + columns)
df_combined = pd.concat([df_combined, df_shared]).reset_index(drop=True)
except ValueError:
# 如果数据不匹配,就放到另一个 DataFrame 中
df_shared2 = pd.DataFrame(result, columns=['index', 'ds']+ columns2)
df_combined2 = pd.concat([df_combined2, df_shared2]).reset_index(drop=True)
# df_combined.drop(['index'], axis=1, inplace=True)
df_combined.to_csv(os.path.join(dataset, 'df_combined.csv'), index=False)
# df_combined2.drop(['index'], axis=1, inplace=True)
df_combined2.to_csv('df_combined.csv', index=False)
end_time = time.time()
# 打印运行时间,转为时分秒
print("运行时间:", end_time - start_time, "")
if __name__ == '__main__':
# 记录开始时间
start_time = time.time()
# file = '指标数据处理.csv'
file = 'brentpricepredict.csv'
df = getdata(file)
df.head()
# 选择特征和标签列
X = df.drop(['y', 'ds'], axis=1) # 特征集,排除时间戳和标签列 Brent连1合约价格
y = df['y'] # 标签集
# 计算训练集的结束索引占总数据的80%
split_index = int(0.8 * df.shape[0])
# 按照时间顺序划分训练集和测试集
df_train = df[:split_index]
df_test = df[split_index:]
df_train['unique_id'] = 1
df_test['unique_id'] = 1
df_combined = pd.DataFrame()
df_test = df_test.reindex()
# df_test = df_test[-20:]
# 读取模型列表,用来预测结果列名
columns = [
'NHITS',
'Informer',
'LSTM',
'iTransformer',
'TSMixer',
'TSMixerx',
'PatchTST',
'RNN',
'GRU',
'TCN',
'DeepAR',
'BiTCN',
'DilatedRNN',
'MLP',
'DLinear',
'NLinear',
'TFT',
'FEDformer',
'StemGNN',
'MLPMultivariate',
'TiDE',
'DeepNPTS',
]
# deepar 的预测结果会多 五个列,需要单独处理
columns2 = [
'NHITS',
'Informer',
'LSTM',
'iTransformer',
'TSMixer',
'TSMixerx',
'PatchTST',
'RNN',
'GRU',
'TCN',
'DeepAR',
'DeepAR-median',
'DeepAR-lo-90',
'DeepAR-lo-80',
'DeepAR-hi-80',
'DeepAR-hi-90',
'BiTCN',
'DilatedRNN',
'MLP',
'DLinear',
'NLinear',
'TFT',
'FEDformer',
'StemGNN',
'MLPMultivariate',
'TiDE',
'DeepNPT',
]
input_size = 14
X_test = []
for i in range(0, len(df_test) - input_size + 1):
X_test.append(df_test.iloc[i:i + input_size])
nf = joblib.load('model_reg.joblib')
testSetPredict(X_test, nf, columns)