def test_normalize_minmax(tmpdir, df, dataset, gpu_memory_frac, engine, op_columns): cat_names = ["name-cat", "name-string" ] if engine == "parquet" else ["name-string"] cont_names = ["x", "y"] label_name = ["label"] config = nvt.workflow.get_new_config() config["PP"]["continuous"] = [ops.MinMax()] processor = nvtabular.Workflow(cat_names=cat_names, cont_names=cont_names, label_name=label_name, config=config) processor.update_stats(dataset) op = ops.NormalizeMinMax() columns_ctx = {} columns_ctx["continuous"] = {} columns_ctx["continuous"]["base"] = cont_names new_gdf = op.apply_op(df, columns_ctx, "continuous", stats_context=processor.stats) df["x"] = (df["x"] - processor.stats["mins"]["x"]) / ( processor.stats["maxs"]["x"] - processor.stats["mins"]["x"]) assert new_gdf["x"].equals(df["x"])
def test_normalize_minmax(tmpdir, df, dataset, gpu_memory_frac, engine, op_columns): cont_features = op_columns >> ops.NormalizeMinMax() processor = nvtabular.Workflow(cont_features) processor.fit(dataset) new_gdf = processor.transform(dataset).to_ddf().compute() new_gdf.index = df.index # Make sure index is aligned for checks for col in op_columns: col_min = df[col].min() assert col_min == pytest.approx(processor.column_group.op.mins[col], 1e-2) col_max = df[col].max() assert col_max == pytest.approx(processor.column_group.op.maxs[col], 1e-2) df[col] = (df[col] - processor.column_group.op.mins[col]) / ( processor.column_group.op.maxs[col] - processor.column_group.op.mins[col] ) assert np.all((df[col] - new_gdf[col]).abs().values <= 1e-2)
def test_chaining_1(): df = cudf.DataFrame({ "cont01": np.random.randint(1, 100, 100), "cont02": np.random.random(100) * 100, "cat01": np.random.randint(0, 10, 100), "label": np.random.randint(0, 3, 100), }) df["cont01"][:10] = None cont1 = "cont01" >> ops.FillMissing() conts = cont1 + "cont02" >> ops.NormalizeMinMax() workflow = Workflow(conts + "cat01" + "label") result = workflow.fit_transform(Dataset(df)).to_ddf().compute() assert result["cont01"].max() <= 1.0 assert result["cont02"].max() <= 1.0