コード例 #1
0
def dsinfo(request):
    from gluonts import time_feature
    from gluonts.dataset.artificial import constant_dataset, default_synthetic

    if request.param == "constant":
        ds_info, train_ds, test_ds = constant_dataset()

        return AttrDict(
            name="constant",
            cardinality=int(ds_info.metadata.feat_static_cat[0].cardinality),
            freq=ds_info.metadata.freq,
            num_parallel_samples=2,
            prediction_length=ds_info.prediction_length,
            # FIXME: Should time features should not be needed for GP
            time_features=[time_feature.DayOfWeek(),
                           time_feature.HourOfDay()],
            train_ds=train_ds,
            test_ds=test_ds,
        )
    elif request.param == "synthetic":
        ds_info, train_ds, test_ds = default_synthetic()

        return AttrDict(
            name="synthetic",
            batch_size=32,
            cardinality=int(ds_info.metadata.feat_static_cat[0].cardinality),
            context_length=2,
            freq=ds_info.metadata.freq,
            prediction_length=ds_info.prediction_length,
            num_parallel_samples=2,
            train_ds=train_ds,
            test_ds=test_ds,
            time_features=None,
        )
コード例 #2
0
ファイル: test_models.py プロジェクト: zwbjtu123/gluon-ts
from gluonts.model.gp_forecaster import GaussianProcessEstimator
from gluonts.model.predictor import Predictor
from gluonts.model.seasonal_naive import SeasonalNaiveEstimator
from gluonts.model.seq2seq import (
    MQCNNEstimator,
    MQRNNEstimator,
    Seq2SeqEstimator,
)
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator

dataset_info, train_ds, test_ds = constant_dataset()
freq = dataset_info.metadata.time_granularity
prediction_length = dataset_info.prediction_length
cardinality = int(dataset_info.metadata.feat_static_cat[0].cardinality)
# FIXME: Should time features should not be needed for GP
time_features = [time_feature.DayOfWeek(), time_feature.HourOfDay()]
num_eval_samples = 2
epochs = 1


def seq2seq_base(seq2seq_model, hybridize: bool = True, batches_per_epoch=1):
    return (
        seq2seq_model,
        dict(
            ctx='cpu',
            epochs=epochs,
            learning_rate=1e-2,
            hybridize=hybridize,
            prediction_length=prediction_length,
            context_length=prediction_length,
            num_eval_samples=num_eval_samples,