def _create_post_split_transform(): return Chain([ CountTrailingZeros( new_field="time_remaining", target_field="past_target", as_array=True, ), ToIntervalSizeFormat(target_field="past_target", discard_first=True), RenameFields({"future_target": "sparse_future"}), AsNumpyArray(field="past_target", expected_ndim=2), SwapAxes(input_fields=["past_target"], axes=(0, 1)), AddAxisLength(target_field="past_target", axis=0), ])
def test_count_trailing_zeros(target, expected, convert_to_np, is_train): if convert_to_np: target = np.array(target) data_set = ListDataset( [{"target": target, "start": "2010-01-01"}], freq="1m" ) transform = CountTrailingZeros(new_field="time_remaining") transformed = next(transform(data_set, is_train=is_train)) if len(target) == 0: assert "time_remaining" not in transformed return assert "time_remaining" in transformed assert transformed["time_remaining"] == expected