def dsinfo(request): from gluonts import time_feature from gluonts.dataset.artificial import constant_dataset, default_synthetic if request.param == "constant": ds_info, train_ds, test_ds = constant_dataset() return AttrDict( name="constant", cardinality=int(ds_info.metadata.feat_static_cat[0].cardinality), freq=ds_info.metadata.freq, num_parallel_samples=2, prediction_length=ds_info.prediction_length, # FIXME: Should time features should not be needed for GP time_features=[time_feature.DayOfWeek(), time_feature.HourOfDay()], train_ds=train_ds, test_ds=test_ds, ) elif request.param == "synthetic": ds_info, train_ds, test_ds = default_synthetic() return AttrDict( name="synthetic", batch_size=32, cardinality=int(ds_info.metadata.feat_static_cat[0].cardinality), context_length=2, freq=ds_info.metadata.freq, prediction_length=ds_info.prediction_length, num_parallel_samples=2, train_ds=train_ds, test_ds=test_ds, time_features=None, )
from gluonts.model.gp_forecaster import GaussianProcessEstimator from gluonts.model.predictor import Predictor from gluonts.model.seasonal_naive import SeasonalNaiveEstimator from gluonts.model.seq2seq import ( MQCNNEstimator, MQRNNEstimator, Seq2SeqEstimator, ) from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator dataset_info, train_ds, test_ds = constant_dataset() freq = dataset_info.metadata.time_granularity prediction_length = dataset_info.prediction_length cardinality = int(dataset_info.metadata.feat_static_cat[0].cardinality) # FIXME: Should time features should not be needed for GP time_features = [time_feature.DayOfWeek(), time_feature.HourOfDay()] num_eval_samples = 2 epochs = 1 def seq2seq_base(seq2seq_model, hybridize: bool = True, batches_per_epoch=1): return ( seq2seq_model, dict( ctx='cpu', epochs=epochs, learning_rate=1e-2, hybridize=hybridize, prediction_length=prediction_length, context_length=prediction_length, num_eval_samples=num_eval_samples,