def test_add_time_features(): tran = transform.AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field="time_feat", time_features=[ time_feature.DayOfWeek(), time_feature.DayOfMonth(), time_feature.MonthOfYear(), ], pred_length=10, ) tran2 = clone( tran, { "time_features": [ time_feature.DayOfWeek(), time_feature.DayOfMonth(), ] }, ) assert equals(tran, clone(tran)) assert not equals(tran, tran2)
def test_gluon_predictor(): train_length = 100 pred_length = 10 estimator = CanonicalRNNEstimator("5min", train_length, pred_length) assert equals(estimator, clone(estimator)) assert not equals(estimator, clone(estimator, {"freq": "1h"}))
def test_map_transformation(): tran = transform.VstackFeatures( output_field="dynamic_feat", input_fields=["age", "time_feat"], drop_inputs=True, ) assert equals(tran, clone(tran)) assert not equals(tran, clone(tran, {"drop_inputs": False}))
def test_chain(): chain = transform.Chain(trans=[ transform.AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field="time_feat", time_features=[ time_feature.DayOfWeek(), time_feature.DayOfMonth(), time_feature.MonthOfYear(), ], pred_length=10, ), transform.AddAgeFeature( target_field=FieldName.TARGET, output_field="age", pred_length=10, log_scale=True, ), transform.AddObservedValuesIndicator(target_field=FieldName.TARGET, output_field="observed_values"), ]) assert equals(chain, clone(chain)) assert not equals(chain, clone(chain, {"trans": []})) another_chain = transform.Chain(trans=[ transform.AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field="time_feat", time_features=[ time_feature.DayOfWeek(), time_feature.DayOfMonth(), time_feature.MonthOfYear(), ], pred_length=10, ), transform.AddAgeFeature( target_field=FieldName.TARGET, output_field="age", pred_length=10, log_scale=False, ), transform.AddObservedValuesIndicator(target_field=FieldName.TARGET, output_field="observed_values"), ]) assert not equals(chain, another_chain)
def test_instance_splitter(): splitter = transform.InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=transform.ExpectedNumInstanceSampler(num_instances=4), past_length=100, future_length=10, time_series_fields=["dynamic_feat", "observed_values"], ) splitter2 = clone( splitter, { "instance_sampler": transform.ExpectedNumInstanceSampler(num_instances=5) }, ) assert equals(splitter, clone(splitter)) assert not equals(splitter, splitter2)
def test_continuous_time_splitter(): splitter = transform.ContinuousTimeInstanceSplitter( past_interval_length=1, future_interval_length=1, instance_sampler=transform.ContinuousTimePointSampler(), ) splitter2 = transform.ContinuousTimeInstanceSplitter( past_interval_length=1, future_interval_length=1, instance_sampler=transform.ContinuousTimePointSampler(min_past=1.0), ) assert equals(splitter, clone(splitter)) assert not equals(splitter, splitter2)
def test_continuous_time_sampler(): sampler = transform.ContinuousTimeUniformSampler(num_instances=4) assert equals(sampler, clone(sampler)) assert not equals(sampler, clone(sampler, {"num_instances": 5}))
def test_exp_num_sampler(): sampler = transform.ExpectedNumInstanceSampler(num_instances=4) assert equals(sampler, clone(sampler)) assert not equals(sampler, clone(sampler, {"num_instances": 5}))