Example #1
0
def test_bucketized(tmpdir, df, dataset, gpu_memory_frac, engine, use_dict):
    cont_names = ["x", "y"]
    boundaries = [[-1, 0, 1], [-4, 100]]

    if use_dict:
        bucketize_op = ops.Bucketize(
            {name: boundary
             for name, boundary in zip(cont_names, boundaries)})
    else:
        bucketize_op = ops.Bucketize(boundaries, cont_names)

    columns_ctx = {}
    columns_ctx["continuous"] = {}
    columns_ctx["continuous"]["base"] = list(cont_names)
    for gdf in dataset.to_iter():
        new_gdf = bucketize_op.apply_op(gdf, columns_ctx, "continuous")
        for col, bs in zip(cont_names, boundaries):
            assert np.all(new_gdf[col].values >= 0)
            assert np.all(new_gdf[col].values <= len(bs))
Example #2
0
def test_bucketized(tmpdir, df, dataset, gpu_memory_frac, engine):
    cont_names = ["x", "y"]
    boundaries = [[-1, 0, 1], [-4, 100]]

    bucketize_op = ops.Bucketize(dict(zip(cont_names, boundaries)))

    bucket_features = cont_names >> bucketize_op
    processor = nvtabular.Workflow(bucket_features)
    processor.fit(dataset)
    new_gdf = processor.transform(dataset).to_ddf().compute()

    for col, bs in zip(cont_names, boundaries):
        assert np.all(new_gdf[col].values >= 0)
        assert np.all(new_gdf[col].values <= len(bs))
    workflow1 = Workflow(cat_features)
    workflow1.fit_schema(schema)

    assert workflow1.output_schema.column_names == [
        "TE_x_cost_renamed", "TE_y_cost_renamed"
    ]


# initial column selector works with tags
# filter within the workflow by tags
# test tags correct at output
@pytest.mark.parametrize(
    "op",
    [
        ops.Bucketize([1]),
        ops.Rename(postfix="_trim"),
        ops.Categorify(),
        ops.Categorify(encode_type="combo"),
        ops.Clip(0),
        ops.DifferenceLag("col1"),
        ops.FillMissing(),
        ops.Groupby("col1"),
        ops.HashBucket(1),
        ops.HashedCross(1),
        ops.JoinGroupby("col1"),
        ops.ListSlice(0),
        ops.LogOp(),
        ops.Normalize(),
        ops.TargetEncoding("col1"),
    ],