PriceForecast/八个维度demo copy.py

63 lines
2.5 KiB
Python
Raw Normal View History

2024-11-01 16:38:21 +08:00
import logging
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS
from neuralforecast.utils import AirPassengersPanel
from mlforecast.utils import PredictionIntervals
from neuralforecast.losses.pytorch import DistributionLoss, MAE
os.environ['NIXTLA_ID_AS_COL'] = '1'
AirPassengersPanel_train = AirPassengersPanel[AirPassengersPanel['ds'] < AirPassengersPanel['ds'].values[-12]].reset_index(drop=True)
AirPassengersPanel_test = AirPassengersPanel[AirPassengersPanel['ds'] >= AirPassengersPanel['ds'].values[-12]].reset_index(drop=True)
AirPassengersPanel_test['y'] = np.nan
AirPassengersPanel_test['y_[lag12]'] = np.nan
horizon = 12
input_size = 24
prediction_intervals = PredictionIntervals()
models = [NHITS(h=horizon, input_size=input_size, max_steps=100, loss=MAE(), scaler_type="robust"),
NHITS(h=horizon, input_size=input_size, max_steps=100, loss=DistributionLoss("Normal", level=[90]), scaler_type="robust")]
nf = NeuralForecast(models=models, freq='ME')
nf.fit(AirPassengersPanel_train, prediction_intervals=prediction_intervals)
preds = nf.predict(futr_df=AirPassengersPanel_test, level=[90])
fig, (ax1, ax2) = plt.subplots(2, 1, figsize = (20, 7))
plot_df = pd.concat([AirPassengersPanel_train, preds])
plot_df = plot_df[plot_df['unique_id']=='Airline1'].drop(['unique_id','trend','y_[lag12]'], axis=1).iloc[-50:]
ax1.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
ax1.plot(plot_df['ds'], plot_df['NHITS'], c='blue', label='median')
ax1.fill_between(x=plot_df['ds'][-12:],
y1=plot_df['NHITS-lo-90'][-12:].values,
y2=plot_df['NHITS-hi-90'][-12:].values,
alpha=0.4, label='level 90')
ax1.set_title('AirPassengers Forecast - Uncertainty quantification using Conformal Prediction', fontsize=18)
ax1.set_ylabel('Monthly Passengers', fontsize=15)
ax1.set_xticklabels([])
ax1.legend(prop={'size': 10})
ax1.grid()
ax2.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
ax2.plot(plot_df['ds'], plot_df['NHITS1'], c='blue', label='median')
ax2.fill_between(x=plot_df['ds'][-12:],
y1=plot_df['NHITS1-lo-90'][-12:].values,
y2=plot_df['NHITS1-hi-90'][-12:].values,
alpha=0.4, label='level 90')
ax2.set_title('AirPassengers Forecast - Uncertainty quantification using Normal distribution', fontsize=18)
ax2.set_ylabel('Monthly Passengers', fontsize=15)
ax2.set_xlabel('Timestamp [t]', fontsize=15)
ax2.legend(prop={'size': 10})
ax2.grid()