Beispiel #1
0
def test_add_time_features():
    tran = transform.AddTimeFeatures(
        start_field=FieldName.START,
        target_field=FieldName.TARGET,
        output_field="time_feat",
        time_features=[
            time_feature.DayOfWeek(),
            time_feature.DayOfMonth(),
            time_feature.MonthOfYear(),
        ],
        pred_length=10,
    )

    tran2 = clone(
        tran,
        {
            "time_features": [
                time_feature.DayOfWeek(),
                time_feature.DayOfMonth(),
            ]
        },
    )

    assert equals(tran, clone(tran))
    assert not equals(tran, tran2)
Beispiel #2
0
def test_gluon_predictor():
    train_length = 100
    pred_length = 10

    estimator = CanonicalRNNEstimator("5min", train_length, pred_length)

    assert equals(estimator, clone(estimator))
    assert not equals(estimator, clone(estimator, {"freq": "1h"}))
Beispiel #3
0
def test_map_transformation():
    tran = transform.VstackFeatures(
        output_field="dynamic_feat",
        input_fields=["age", "time_feat"],
        drop_inputs=True,
    )

    assert equals(tran, clone(tran))
    assert not equals(tran, clone(tran, {"drop_inputs": False}))
Beispiel #4
0
def test_chain():
    chain = transform.Chain(trans=[
        transform.AddTimeFeatures(
            start_field=FieldName.START,
            target_field=FieldName.TARGET,
            output_field="time_feat",
            time_features=[
                time_feature.DayOfWeek(),
                time_feature.DayOfMonth(),
                time_feature.MonthOfYear(),
            ],
            pred_length=10,
        ),
        transform.AddAgeFeature(
            target_field=FieldName.TARGET,
            output_field="age",
            pred_length=10,
            log_scale=True,
        ),
        transform.AddObservedValuesIndicator(target_field=FieldName.TARGET,
                                             output_field="observed_values"),
    ])

    assert equals(chain, clone(chain))
    assert not equals(chain, clone(chain, {"trans": []}))

    another_chain = transform.Chain(trans=[
        transform.AddTimeFeatures(
            start_field=FieldName.START,
            target_field=FieldName.TARGET,
            output_field="time_feat",
            time_features=[
                time_feature.DayOfWeek(),
                time_feature.DayOfMonth(),
                time_feature.MonthOfYear(),
            ],
            pred_length=10,
        ),
        transform.AddAgeFeature(
            target_field=FieldName.TARGET,
            output_field="age",
            pred_length=10,
            log_scale=False,
        ),
        transform.AddObservedValuesIndicator(target_field=FieldName.TARGET,
                                             output_field="observed_values"),
    ])
    assert not equals(chain, another_chain)
Beispiel #5
0
def test_instance_splitter():
    splitter = transform.InstanceSplitter(
        target_field=FieldName.TARGET,
        is_pad_field=FieldName.IS_PAD,
        start_field=FieldName.START,
        forecast_start_field=FieldName.FORECAST_START,
        instance_sampler=transform.ExpectedNumInstanceSampler(num_instances=4),
        past_length=100,
        future_length=10,
        time_series_fields=["dynamic_feat", "observed_values"],
    )

    splitter2 = clone(
        splitter,
        {
            "instance_sampler":
            transform.ExpectedNumInstanceSampler(num_instances=5)
        },
    )
    assert equals(splitter, clone(splitter))
    assert not equals(splitter, splitter2)
Beispiel #6
0
def test_continuous_time_splitter():
    splitter = transform.ContinuousTimeInstanceSplitter(
        past_interval_length=1,
        future_interval_length=1,
        instance_sampler=transform.ContinuousTimePointSampler(),
    )

    splitter2 = transform.ContinuousTimeInstanceSplitter(
        past_interval_length=1,
        future_interval_length=1,
        instance_sampler=transform.ContinuousTimePointSampler(min_past=1.0),
    )

    assert equals(splitter, clone(splitter))
    assert not equals(splitter, splitter2)
Beispiel #7
0
def test_continuous_time_sampler():
    sampler = transform.ContinuousTimeUniformSampler(num_instances=4)
    assert equals(sampler, clone(sampler))
    assert not equals(sampler, clone(sampler, {"num_instances": 5}))
Beispiel #8
0
def test_exp_num_sampler():
    sampler = transform.ExpectedNumInstanceSampler(num_instances=4)
    assert equals(sampler, clone(sampler))
    assert not equals(sampler, clone(sampler, {"num_instances": 5}))