예제 #1
0
def test_to_interval_size_format(
    transform, target, expected, convert_to_np, is_train
):
    if convert_to_np:
        target = np.array(target)

    data_set = ListDataset(
        [{"target": target, "start": "2010-01-01"}], freq="1m"
    )

    if transform.drop_empty:
        try:
            next(transform(data_set, is_train=is_train))
        except StopIteration:
            return

    transformed = next(transform(data_set, is_train=is_train))
    assert np.allclose(transformed["target"], expected)
예제 #2
0
def test_count_trailing_zeros(target, expected, convert_to_np, is_train):
    if convert_to_np:
        target = np.array(target)

    data_set = ListDataset(
        [{"target": target, "start": "2010-01-01"}], freq="1m"
    )
    transform = CountTrailingZeros(new_field="time_remaining")

    transformed = next(transform(data_set, is_train=is_train))

    if len(target) == 0:
        assert "time_remaining" not in transformed
        return

    assert "time_remaining" in transformed
    assert transformed["time_remaining"] == expected