Exemple #1
0
def test_mqcnn_covariate_smoke_test(
    use_feat_dynamic_real,
    add_time_feature,
    add_age_feature,
    enable_decoder_dynamic_feature,
    hybridize,
):
    hps = {
        "seed": 42,
        "freq": "D",
        "prediction_length": 3,
        "quantiles": [0.5, 0.1],
        "epochs": 3,
        "num_batches_per_epoch": 3,
        "use_feat_dynamic_real": use_feat_dynamic_real,
        "add_time_feature": add_time_feature,
        "add_age_feature": add_age_feature,
        "enable_decoder_dynamic_feature": enable_decoder_dynamic_feature,
        "hybridize": hybridize,
    }

    dataset_train, dataset_test = make_dummy_datasets_with_features(
        cardinality=[3, 10],
        num_feat_dynamic_real=2,
        freq=hps["freq"],
        prediction_length=hps["prediction_length"],
    )

    estimator = MQCNNEstimator.from_hyperparameters(**hps)

    predictor = estimator.train(dataset_train, num_workers=0)
    forecasts = list(predictor.predict(dataset_test))
    assert len(forecasts) == len(dataset_test)
Exemple #2
0
def test_mqcnn_covariate_smoke_test(
    use_past_feat_dynamic_real,
    use_feat_dynamic_real,
    add_time_feature,
    add_age_feature,
    enable_encoder_dynamic_feature,
    enable_decoder_dynamic_feature,
    hybridize,
    quantiles,
    distr_output,
    is_iqf,
):
    hps = {
        "seed": 42,
        "freq": "Y",
        "context_length": 5,
        "prediction_length": 3,
        "quantiles": quantiles,
        "distr_output": distr_output,
        "epochs": 3,
        "num_batches_per_epoch": 3,
        "use_past_feat_dynamic_real": use_past_feat_dynamic_real,
        "use_feat_dynamic_real": use_feat_dynamic_real,
        "add_time_feature": add_time_feature,
        "add_age_feature": add_age_feature,
        "enable_encoder_dynamic_feature": enable_encoder_dynamic_feature,
        "enable_decoder_dynamic_feature": enable_decoder_dynamic_feature,
        "hybridize": hybridize,
        "is_iqf": is_iqf,
    }

    dataset_train, dataset_test = make_dummy_datasets_with_features(
        cardinality=[3, 10],
        num_feat_dynamic_real=2,
        num_past_feat_dynamic_real=4,
        freq=hps["freq"],
        prediction_length=hps["prediction_length"],
    )

    estimator = MQCNNEstimator.from_hyperparameters(**hps)

    predictor = estimator.train(dataset_train, num_workers=None)
    forecasts = list(predictor.predict(dataset_test))
    assert len(forecasts) == len(dataset_test)