import pandas as pd
import re
import os
import pandas as pd

import multiprocessing
import time
import joblib
import torch

# 定义函数
def loadcsv(filename):
    try:
        df = pd.read_csv(filename, encoding='utf-8')
    except UnicodeDecodeError:
        df = pd.read_csv(filename, encoding='gbk')
    return df


def datachuli(df, datecol='date'):
    # 删除空列
    df = df.dropna(axis=1, how='all')
    # 向上填充
    df.ffill
    # 向下填充
    df.bfill
    # date转为pddate
    df.rename(columns={datecol: 'ds'}, inplace=True)
    # 设置ds为pd.datetime
    df['ds'] = pd.to_datetime(df['ds'])
    # 重命名预测列
    df.rename(columns={'Brent连1合约价格': 'y'}, inplace=True)

    return df


def getdata(filename, datecol='date'):
    df = loadcsv(filename)
    df = datachuli(df, datecol)
    return df





# 预测函数
def predict(X_test, nf,result_list):
    df_predict = nf.predict(X_test).reset_index()
    result_list.append(df_predict.values.tolist())
    return df_predict


def testSetPredict(X_test, nf, columns,dataset):
    
    # 记录开始时间
    start_time = time.time()

    # 计算每个进程处理的样本数
    num_samples = len(X_test)
    num_processes = multiprocessing.cpu_count()
    samples_per_process = num_samples // num_processes

    manager = multiprocessing.Manager()
    result_list = manager.list()  # 创建共享的列表
    # 创建进程池
    with multiprocessing.Pool(num_processes) as pool:
        processes = []
        for i in range(num_processes):
            # 计算 每个进程需要处理的数据索引
            start_index = i * samples_per_process
            end_index = (i + 1) * samples_per_process if i != num_processes - 1 else num_samples
            # 按计算的索引切分数据
            X_test_split = X_test[start_index:end_index]
            # 添加任务到进程池
            for X in X_test_split:
                processes.append(pool.apply_async(predict, args=(X, nf,result_list)))
        for process in processes:
            process.get()
    # 将共享列表中的数据转换回 DataFrame
    df_combined = pd.DataFrame()
    df_combined2 = pd.DataFrame()
    for result in result_list:
        try:
            df_shared = pd.DataFrame(result, columns=['index', 'ds'] + columns)
            df_combined = pd.concat([df_combined, df_shared]).reset_index(drop=True)
        except ValueError:
            # 如果数据不匹配,就放到另一个 DataFrame 中  
            df_shared2 = pd.DataFrame(result, columns=['index', 'ds']+ columns2)
            df_combined2 = pd.concat([df_combined2, df_shared2]).reset_index(drop=True)
    # df_combined.drop(['index'], axis=1, inplace=True)
    df_combined.to_csv(os.path.join(dataset, 'df_combined.csv'), index=False)
    # df_combined2.drop(['index'], axis=1, inplace=True)
    df_combined2.to_csv('df_combined.csv', index=False)
    end_time = time.time()
    # 打印运行时间,转为时分秒
    print("运行时间:", end_time - start_time, "秒")


if __name__ == '__main__':
        # 记录开始时间
    start_time = time.time()
    
    
    # file = '指标数据处理.csv'
    file = 'brentpricepredict.csv'
    df = getdata(file)
    df.head()

    # 选择特征和标签列
    X = df.drop(['y', 'ds'], axis=1)  # 特征集,排除时间戳和标签列   Brent连1合约价格
    y = df['y']  # 标签集

    # 计算训练集的结束索引,占总数据的80%
    split_index = int(0.8 * df.shape[0])

    # 按照时间顺序划分训练集和测试集
    df_train = df[:split_index]
    df_test = df[split_index:]
    df_train['unique_id'] = 1
    df_test['unique_id'] = 1

    df_combined = pd.DataFrame()
    df_test = df_test.reindex()
    # df_test = df_test[-20:]

    # 读取模型列表,用来预测结果列名
    columns = [
        'NHITS',
        'Informer',
        'LSTM',
        'iTransformer',
        'TSMixer',
        'TSMixerx',
        'PatchTST',
        'RNN',
        'GRU',
        'TCN',
        'DeepAR',
        'BiTCN',
        'DilatedRNN',
        'MLP',
        'DLinear',
        'NLinear',
        'TFT',
        'FEDformer',
        'StemGNN',
        'MLPMultivariate',
        'TiDE',
        'DeepNPTS',
    ]
    
    # deepar 的预测结果会多 五个列,需要单独处理
    columns2 = [
        'NHITS',
        'Informer',
        'LSTM',
        'iTransformer',
        'TSMixer',
        'TSMixerx',
        'PatchTST',
        'RNN',
        'GRU',
        'TCN',
        'DeepAR',
        'DeepAR-median',
        'DeepAR-lo-90',
        'DeepAR-lo-80',
        'DeepAR-hi-80',
        'DeepAR-hi-90',
        'BiTCN',
        'DilatedRNN',
        'MLP',
        'DLinear',
        'NLinear',
        'TFT',
        'FEDformer',
        'StemGNN',
        'MLPMultivariate',
        'TiDE',
        'DeepNPT',
    ]
    

    input_size = 14
    X_test = []
    for i in range(0, len(df_test) - input_size + 1):
        X_test.append(df_test.iloc[i:i + input_size])

    nf = joblib.load('model_reg.joblib')
    
    testSetPredict(X_test, nf, columns)