def test_addition_combines_names_and_tags():
    selector1 = ColumnSelector(["a", "b", "c"])
    selector2 = ColumnSelector(tags=["g", "h", "i"])
    combined = selector1 + selector2

    assert combined.names == ["a", "b", "c"]
    assert combined.tags == ["g", "h", "i"]
예제 #2
0
    def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema:
        """Given a set of schemas and a column selector for the input columns,
        returns a set of schemas for the transformed columns this operator will produce
        Parameters
        -----------
        input_schema: Schema
            The schemas of the columns to apply this operator to
        col_selector: ColumnSelector
            The column selector to apply to the input schema
        Returns
        -------
        Schema
            The schemas of the columns produced by this operator
        """
        if not col_selector:
            col_selector = ColumnSelector(input_schema.column_names)

        if col_selector.tags:
            tags_col_selector = ColumnSelector(tags=col_selector.tags)
            filtered_schema = input_schema.apply(tags_col_selector)
            col_selector += ColumnSelector(filtered_schema.column_names)

            # zero tags because already filtered
            col_selector._tags = []

        col_selector = self.output_column_names(col_selector)

        for column_name in col_selector.names:
            if column_name not in input_schema.column_schemas:
                input_schema += Schema([column_name])

        output_schema = Schema()
        for column_schema in input_schema.apply(col_selector):
            output_schema += Schema([self.transformed_schema(column_schema)])
        return output_schema
def test_constructor_too_many_level():
    with pytest.raises(AttributeError) as exc_info:
        ColumnSelector(["h", "i"],
                       subgroups=ColumnSelector(names=["b"],
                                                subgroups=ColumnSelector(["a"
                                                                          ])))
    assert "Too many" in str(exc_info.value)
def test_addition_combines_names_and_subgroups():
    selector1 = ColumnSelector(["a", "b", "c", ["d", "e", "f"]])
    selector2 = ColumnSelector(["g", "h", "i", ["j", "k", "l"]])
    combined = selector1 + selector2

    assert combined._names == ["a", "b", "c", "g", "h", "i"]
    assert combined.subgroups[0]._names == ["d", "e", "f"]
    assert combined.subgroups[1]._names == ["j", "k", "l"]
    assert len(combined.subgroups) == 2
def test_constructor_creates_subgroups_from_nesting():
    selector = ColumnSelector(["a", "b", "c", ["d", "e", "f"]])

    assert selector._names == ["a", "b", "c"]
    assert selector.subgroups == [ColumnSelector(["d", "e", "f"])]

    selector = ColumnSelector(["a", "b", "c", ("d", "e", "f")])

    assert selector._names == ["a", "b", "c"]
    assert selector.subgroups == [ColumnSelector(["d", "e", "f"])]
def test_addition_enum_tags():
    selector1 = ColumnSelector(tags=["a", "b", "c"])
    combined = selector1 + Tags.CATEGORICAL

    assert combined.tags == ["a", "b", "c", Tags.CATEGORICAL]

    selector2 = ColumnSelector(["a", "b", "c", ["d", "e", "f"]])
    combined = selector2 + Tags.CATEGORICAL

    assert combined._names == ["a", "b", "c"]
    assert combined.subgroups == [ColumnSelector(["d", "e", "f"])]
    assert combined.tags == [Tags.CATEGORICAL]
def test_constructor_rejects_workflow_nodes():
    group = Node(ColumnSelector(["a"]))

    with pytest.raises(TypeError) as exception_info:
        ColumnSelector(group)

    assert "Node" in str(exception_info.value)

    with pytest.raises(ValueError) as exception_info:
        ColumnSelector(["a", "b", group])

    assert "Node" in str(exception_info.value)
예제 #8
0
def test_selecting_columns_sets_selector_and_kind():
    node = ColumnSelector(["a", "b", "c"]) >> Operator()
    output = node[["a", "b"]]
    assert output.selector.names == ["a", "b"]

    output = node["b"]
    assert output.selector.names == ["b"]
예제 #9
0
def test_workflow_node_select():
    df = dispatch._make_df({
        "a": [1, 4, 9, 16, 25],
        "b": [0, 1, 2, 3, 4],
        "c": [25, 16, 9, 4, 1]
    })
    dataset = Dataset(df)

    input_features = WorkflowNode(ColumnSelector(["a", "b", "c"]))
    # pylint: disable=unnecessary-lambda
    sqrt_features = input_features[["a", "c"]] >> (lambda col: np.sqrt(col))
    plus_one_features = input_features["b"] >> (lambda col: col + 1)
    features = sqrt_features + plus_one_features

    workflow = Workflow(features)
    workflow.fit(dataset)

    df_out = workflow.transform(dataset).to_ddf().compute(
        scheduler="synchronous")

    expected = dispatch._make_df()
    expected["a"] = np.sqrt(df["a"])
    expected["c"] = np.sqrt(df["c"])
    expected["b"] = df["b"] + 1

    assert_eq(expected, df_out)
예제 #10
0
def test_workflow_move_saved(tmpdir):
    raw = """US>SC>519 US>CA>807 US>MI>505 US>CA>510 CA>NB US>CA>534""".split()
    data = nvt.dispatch._make_df({"geo": raw})

    geo_location = ColumnSelector(["geo"])
    state = (geo_location >> ops.LambdaOp(lambda col: col.str.slice(0, 5)) >>
             ops.Rename(postfix="_state"))
    country = (geo_location >> ops.LambdaOp(lambda col: col.str.slice(0, 2)) >>
               ops.Rename(postfix="_country"))
    geo_features = state + country + geo_location >> ops.Categorify()

    # create the workflow and transform the input
    workflow = Workflow(geo_features)
    expected = workflow.fit_transform(Dataset(data)).to_ddf().compute()

    # save the workflow (including categorical mapping parquet files)
    # and then verify we can load the saved workflow after moving the directory
    out_path = os.path.join(tmpdir, "output", "workflow")
    workflow.save(out_path)

    moved_path = os.path.join(tmpdir, "output", "workflow2")
    shutil.move(out_path, moved_path)
    workflow2 = Workflow.load(moved_path)

    # also check that when transforming our input we get the same results after loading
    transformed = workflow2.transform(Dataset(data)).to_ddf().compute()
    assert_eq(expected, transformed)
예제 #11
0
    def output_columns(self):
        if self.output_schema is None:
            raise RuntimeError(
                "The output columns aren't computed until the workflow "
                "is fit to a dataset or input schema."
            )

        return ColumnSelector(self.output_schema.column_names)
예제 #12
0
def test_applying_selector_to_schema_selects_by_name_or_tags():
    schema1 = ColumnSchema("col1")
    schema2 = ColumnSchema("col2", tags=["b", "c", "d"])

    schema = Schema([schema1, schema2])
    selector = ColumnSelector(["col1"], tags=["a", "b"])
    result = schema.apply(selector)

    assert result.column_names == schema.column_names
예제 #13
0
def test_spec_set(tmpdir, client):
    gdf_test = nvt.dispatch._make_df({
        "ad_id": [1, 2, 2, 6, 6, 8, 3, 3],
        "source_id": [2, 4, 4, 7, 5, 2, 5, 2],
        "platform": [1, 2, np.nan, 2, 1, 3, 3, 1],
        "cont": [1, 2, np.nan, 2, 1, 3, 3, 1],
        "clicked": [1, 0, 1, 0, 0, 1, 1, 0],
    })

    cats = ColumnSelector(["ad_id", "source_id", "platform"])
    cat_features = cats >> ops.Categorify
    cont_features = ColumnSelector(["cont"
                                    ]) >> ops.FillMissing >> ops.Normalize
    te_features = cats >> ops.TargetEncoding(
        "clicked", kfold=5, fold_seed=42, p_smooth=20)

    p = Workflow(cat_features + cont_features + te_features, client=client)
    p.fit_transform(nvt.Dataset(gdf_test)).to_ddf().compute()
예제 #14
0
def test_applying_selector_to_schema_selects_by_name():
    schema = Schema(["a", "b", "c", "d", "e"])
    selector = ColumnSelector(["a", "b"])
    result = schema.apply(selector)

    assert result == Schema(["a", "b"])

    selector = None
    result = schema.apply(selector)

    assert result == schema
예제 #15
0
def test_workflow_node_dependencies():
    # Full WorkflowNode case
    node1 = ["a", "b"] >> Operator()
    output_node = ["timestamp"] >> DifferenceLag(partition_cols=[node1],
                                                 shift=[1, -1])
    assert list(output_node.dependencies) == [node1]

    # ColumnSelector case
    output_node = ["timestamp"] >> DifferenceLag(partition_cols=["userid"],
                                                 shift=[1, -1])
    assert output_node.dependencies[0].selector == ColumnSelector(["userid"])
예제 #16
0
def test_applying_inverse_selector_to_schema_selects_relevant_columns():
    schema = Schema(["a", "b", "c", "d", "e"])
    selector = ColumnSelector(["a", "b"])
    result = schema.apply_inverse(selector)

    assert result == Schema(["c", "d", "e"])

    selector = None
    result = schema.apply_inverse(selector)

    assert result == schema
예제 #17
0
    def compute_schemas(self, root_schema):
        # If parent is an addition node, we may need to propagate grouping
        # unless we're a node that already has a selector
        if not self.selector:
            if (
                len(self.parents) == 1
                and isinstance(self.parents[0].op, ConcatColumns)
                and self.parents[0].selector
                and (self.parents[0].selector.names)
            ):
                self.selector = self.parents[0].selector

        if isinstance(self.op, ConcatColumns):  # +
            # For addition nodes, some of the operands are parents and
            # others are dependencies so grab schemas from both
            self.selector = _combine_selectors(self.grouped_parents_with_dependencies)
            self.input_schema = _combine_schemas(self.parents_with_dependencies)

        elif isinstance(self.op, SubtractionOp):  # -
            left_operand = _combine_schemas(self.parents)

            if self.dependencies:
                right_operand = _combine_schemas(self.dependencies)
                self.input_schema = left_operand - right_operand
            else:
                self.input_schema = left_operand.apply_inverse(self.op.selector)

            self.selector = ColumnSelector(self.input_schema.column_names)

        elif isinstance(self.op, SubsetColumns):  # []
            left_operand = _combine_schemas(self.parents)
            right_operand = _combine_schemas(self.dependencies)
            self.input_schema = left_operand - right_operand

        # If we have a selector, apply it to upstream schemas from nodes/dataset
        elif isinstance(self.op, SelectionOp):  # ^
            upstream_schema = root_schema + _combine_schemas(self.parents_with_dependencies)
            self.input_schema = upstream_schema.apply(self.selector)

        # If none of the above apply, then we don't have a selector
        # and we're not an add or sub node, so our input is just the
        # parents output
        else:
            self.input_schema = _combine_schemas(self.parents)

        # Then we delegate to the op (if there is one) to compute this node's
        # output schema. If there's no op, then outputs are just the inputs
        if self.op:
            self.output_schema = self.op.compute_output_schema(self.input_schema, self.selector)
        else:
            self.output_schema = self.input_schema
def test_constructor_works_with_single_subgroups_and_lists():
    selector1 = ColumnSelector([], subgroups=ColumnSelector("a"))
    assert isinstance(selector1.subgroups, list)
    assert selector1.subgroups[0] == ColumnSelector("a")

    selector2 = ColumnSelector([], subgroups=ColumnSelector(["a", "b", "c"]))
    assert isinstance(selector2.subgroups, list)
    assert selector2.subgroups[0] == ColumnSelector(["a", "b", "c"])
def test_rshift_operator_onto_selector_creates_selection_node():
    selector = ColumnSelector(["a", "b", "c"])
    operator = BaseOperator()

    output_node = selector >> operator

    assert isinstance(output_node, Node)
    assert isinstance(output_node.op, BaseOperator)
    assert output_node._selector is None
    assert len(output_node.parents) == 1

    assert isinstance(output_node.parents[0], Node)
    assert isinstance(output_node.parents[0].op, SelectionOp)
    assert output_node.parents[0]._selector == selector
    assert len(output_node.parents[0].parents) == 0
예제 #20
0
    def input_columns(self):
        if self.input_schema is None:
            raise RuntimeError(
                "The input columns aren't computed until the workflow "
                "is fit to a dataset or input schema."
            )

        if (
            self.selector
            and not self.selector.tags
            and all(not selector.tags for selector in self.selector.subgroups)
        ):
            # To maintain column groupings
            return self.selector
        else:
            return ColumnSelector(self.input_schema.column_names)
예제 #21
0
    def __getitem__(self, columns):
        """Selects certain columns from this Node, and returns a new Columngroup with only
        those columns

        Parameters
        -----------
        columns: str or list of str
            Columns to select

        Returns
        -------
        Node
        """
        col_selector = ColumnSelector(columns)
        child = type(self)(col_selector)
        child.op = SubsetColumns(label=str(list(columns)))
        child.add_parent(self)
        return child
예제 #22
0
def test_nested_workflow_node():
    df = dispatch._make_df({
        "geo": ["US>CA", "US>NY", "CA>BC", "CA>ON"],
        "user": ["User_A", "User_A", "User_A", "User_B"],
    })
    dataset = Dataset(df)

    geo_selector = ColumnSelector(["geo"])
    country = (geo_selector >> LambdaOp(lambda col: col.str.slice(0, 2)) >>
               Rename(postfix="_country"))
    # country1 = geo_selector >> (lambda col: col.str.slice(0, 2)) >> Rename(postfix="_country1")
    # country2 = geo_selector >> (lambda col: col.str.slice(0, 2)) >> Rename(postfix="_country2")
    user = "******"
    # user2 = "user2"

    # make sure we can do a 'combo' categorify (cross based) of country+user
    # as well as categorifying the country and user columns on their own
    cats = country + user + [country + user] >> Categorify(encode_type="combo")

    workflow = Workflow(cats)
    workflow.fit_schema(dataset.infer_schema())

    df_out = workflow.fit_transform(dataset).to_ddf().compute(
        scheduler="synchronous")

    geo_country = df_out["geo_country"]
    assert geo_country[0] == geo_country[1]  # rows 0,1 are both 'US'
    assert geo_country[2] == geo_country[3]  # rows 2,3 are both 'CA'

    user = df_out["user"]
    assert user[0] == user[1] == user[2]
    assert user[3] != user[2]

    geo_country_user = df_out["geo_country_user"]
    assert geo_country_user[0] == geo_country_user[1]  # US / userA
    assert geo_country_user[2] != geo_country_user[
        0]  # same user but in canada

    # make sure we get an exception if we nest too deeply (can't handle arbitrarily deep
    # nested column groups - and the exceptions we would get in operators like Categorify
    # are super confusing for users)
    with pytest.raises(ValueError):
        cats = [[country + "user"] + country + "user"
                ] >> Categorify(encode_type="combo")
예제 #23
0
    def __init__(self, selector=None):
        self.parents = []
        self.children = []
        self.dependencies = []

        self.op = None
        self.input_schema = None
        self.output_schema = None

        if isinstance(selector, list):
            selector = ColumnSelector(selector)

        if selector and not isinstance(selector, ColumnSelector):
            raise TypeError("The selector argument must be a list or a ColumnSelector")

        if selector:
            self.op = SelectionOp(selector)

        self._selector = selector
예제 #24
0
def test_join_external_workflow(tmpdir, df, dataset, engine):

    # Define "external" table
    how = "left"
    drop_duplicates = True
    cache = "device"
    shift = 100
    df_ext = df[["id"]].copy().sort_values("id")
    df_ext["new_col"] = df_ext["id"] + shift
    df_ext["new_col_2"] = "keep"
    df_ext["new_col_3"] = "ignore"
    df_ext_check = df_ext.copy()

    # Define Op
    on = "id"
    columns_left = list(df.columns)
    columns_ext = ["id", "new_col", "new_col_2"]
    df_ext_check = df_ext_check[columns_ext]
    if drop_duplicates:
        df_ext_check.drop_duplicates(ignore_index=True, inplace=True)
    joined = ColumnSelector(columns_left) >> nvt.ops.JoinExternal(
        df_ext,
        on,
        how=how,
        columns_ext=columns_ext,
        cache=cache,
        drop_duplicates_ext=drop_duplicates,
    )

    # Define Workflow
    gdf = df.reset_index()
    dataset = nvt.Dataset(gdf)
    processor = nvt.Workflow(joined)
    processor.fit(dataset)
    new_gdf = processor.transform(dataset).to_ddf().compute().reset_index()

    # Validate
    check_gdf = gdf.merge(df_ext_check, how=how, on=on)
    assert len(check_gdf) == len(new_gdf)
    assert (new_gdf["id"] + shift).all() == new_gdf["new_col"].all()
    assert gdf["id"].all() == new_gdf["id"].all()
    assert "new_col_2" in new_gdf.columns
    assert "new_col_3" not in new_gdf.columns
예제 #25
0
def _nodify(nodable):
    # TODO: Update to use abstract nodes
    if isinstance(nodable, str):
        return Node(ColumnSelector([nodable]))

    if isinstance(nodable, ColumnSelector):
        return Node(nodable)
    elif isinstance(nodable, Node):
        return nodable
    elif isinstance(nodable, list):
        nodes = [_nodify(node) for node in nodable]
        non_selection_nodes = [node for node in nodes if not node.selector]
        selection_nodes = [node.selector for node in nodes if node.selector]
        selection_nodes = [Node(_combine_selectors(selection_nodes))] if selection_nodes else []
        return non_selection_nodes + selection_nodes
    else:
        raise TypeError(
            "Unsupported type: Cannot convert object " f"of type {type(nodable)} to Node."
        )
예제 #26
0
def test_transform_geolocation():
    raw = """US>SC>519 US>CA>807 US>MI>505 US>CA>510 CA>NB US>CA>534""".split()
    data = nvt.dispatch._make_df({"geo_location": raw})

    geo_location = ColumnSelector(["geo_location"])
    state = (geo_location >> ops.LambdaOp(lambda col: col.str.slice(0, 5)) >>
             ops.Rename(postfix="_state"))
    country = (geo_location >> ops.LambdaOp(lambda col: col.str.slice(0, 2)) >>
               ops.Rename(postfix="_country"))
    geo_features = state + country + geo_location >> ops.HashBucket(
        num_buckets=100)

    # for this workflow we don't have any statoperators, so we can get away without fitting
    workflow = Workflow(geo_features)
    transformed = workflow.transform(Dataset(data)).to_ddf().compute()

    expected = nvt.dispatch._make_df()
    expected["geo_location_state"] = data["geo_location"].str.slice(
        0, 5).hash_values() % 100
    expected["geo_location_country"] = data["geo_location"].str.slice(
        0, 2).hash_values() % 100
    expected["geo_location"] = data["geo_location"].hash_values() % 100
    assert_eq(expected, transformed)
예제 #27
0
def _combine_selectors(elements):
    combined = ColumnSelector()
    for elem in elements:
        if isinstance(elem, Node):
            if elem.selector:
                selector = elem.op.output_column_names(elem.selector)
            elif elem.output_schema:
                selector = ColumnSelector(elem.output_schema.column_names)
            elif elem.input_schema:
                selector = ColumnSelector(elem.input_schema.column_names)
                selector = elem.op.output_column_names(selector)
            else:
                selector = ColumnSelector()

            combined += selector
        elif isinstance(elem, ColumnSelector):
            combined += elem
        elif isinstance(elem, str):
            combined += ColumnSelector(elem)
        elif isinstance(elem, list):
            combined += ColumnSelector(subgroups=_combine_selectors(elem))
    return combined
예제 #28
0
def test_workflow_node_converts_lists_to_selectors():
    node = WorkflowNode([])
    assert node.selector == ColumnSelector([])

    node.selector = ["a", "b", "c"]
    assert node.selector == ColumnSelector(["a", "b", "c"])
def make_feature_column_workflow(feature_columns,
                                 label_name,
                                 category_dir=None):
    """
    Maps a list of TensorFlow `feature_column`s to an NVTabular `Workflow` which
    imitates their preprocessing functionality. Returns both the finalized
    `Workflow` as well as a list of `feature_column`s that can be used to
    instantiate a `layers.ScalarDenseFeatures` layer to map from `Workflow`
    outputs to dense network inputs. Useful for replacing feature column
    online preprocessing with NVTabular GPU-accelerated online preprocessing
    for faster training.

    Parameters
    ----------
    feature_columns: list(tf.feature_column)
        List of TensorFlow feature columns to emulate preprocessing functions
        of. Doesn't support sequence columns.
    label_name: str
        Name of label column in dataset
    category_dir: str or None
        Directory in which to save categories from vocabulary list and
        vocabulary file columns. If left as None, will create directory
        `/tmp/categories` and save there

    Returns
    -------
    workflow: nvtabular.Workflow
        An NVTabular `Workflow` which performs the preprocessing steps
        defined in `feature_columns`
    new_feature_columns: list(feature_columns)
        List of TensorFlow feature columns that correspond to the output
        from `workflow`. Only contains numeric and identity categorical columns.
    """

    # TODO: should we support a dict input for feature columns
    # for multi-tower support?

    def _get_parents(column):
        """
        quick utility function for getting all the input tensors
        that will feed into a column
        """
        # column has no parents, so we've reached a terminal node
        if isinstance(column, str) or isinstance(column.parents[0], str):
            return [column]

        # else climb family tree
        parents = []
        for parent in column.parents:
            parents.extend(
                [i for i in _get_parents(parent) if i not in parents])
        return parents

    # could be more effiient with sets but this is deterministic which
    # might be helpful? Still not sure about this so being safe
    base_columns = []
    for column in feature_columns:
        parents = _get_parents(column)
        base_columns.extend(
            [col for col in parents if col not in base_columns])

    cat_names, cont_names = [], []
    for column in base_columns:
        if isinstance(column, str):
            # cross column input
            # TODO: this means we only accept categorical inputs to
            # cross? How do we generalize this? Probably speaks to
            # the inefficiencies of feature columns as a schema
            # representation
            cat_names.extend(column)
        elif isinstance(column, fc.CategoricalColumn):
            cat_names.extend(column.key)
        else:
            cont_names.extend(column.key)

    _CATEGORIFY_COLUMNS = (fc.VocabularyListCategoricalColumn,
                           fc.VocabularyFileCategoricalColumn)
    categorifies, hashes, crosses, buckets, replaced_buckets = {}, {}, {}, {}, {}

    numeric_columns = []
    new_feature_columns = []
    for column in feature_columns:
        # TODO: check for shared embedding or weighted embedding columns?
        # Do they just inherit from EmbeddingColumn?
        if not isinstance(column, (fc.EmbeddingColumn, fc.IndicatorColumn)):
            if isinstance(column, (fc.BucketizedColumn)):
                # bucketized column being fed directly to model means it's
                # implicitly wrapped into an indicator column
                cat_column = column
                embedding_dim = None
            else:
                # can this be anything else? I don't think so
                assert isinstance(column, fc.NumericColumn)

                # check to see if we've seen a bucketized column
                # that gets fed by this feature. If we have, note
                # that it shouldn't be replaced
                if column.key in replaced_buckets:
                    buckets[column.key] = replaced_buckets.pop(column.key)

                numeric_columns.append(column)
                continue
        else:
            cat_column = column.categorical_column

            # use this to keep track of what should be embedding
            # and what should be indicator, makes the bucketized
            # checking easier
            if isinstance(column, fc.EmbeddingColumn):
                embedding_dim = column.dimension
            else:
                embedding_dim = None

        if isinstance(cat_column, fc.BucketizedColumn):
            key = cat_column.source_column.key

            # check if the source numeric column is being fed
            # directly to the model. Keep track of both the
            # boundaries and embedding dim so that we can wrap
            # with either indicator or embedding later
            if key in [col.key for col in numeric_columns]:
                buckets[key] = (column.boundaries, embedding_dim)
            else:
                replaced_buckets[key] = (column.boundaries, embedding_dim)

            # put off dealing with these until the end so that
            # we know whether we need to replace numeric
            # columns or create a separate feature column
            # for them
            continue

        elif isinstance(cat_column, _CATEGORIFY_COLUMNS):
            if cat_column.num_oov_buckets > 1:
                warnings.warn(
                    "More than 1 oov bucket not supported for Categorify")

            if isinstance(cat_column, _CATEGORIFY_COLUMNS[1]):
                # TODO: how do we handle the case where it's too big to load?
                with open(cat_column.vocab_file, "r") as f:
                    vocab = f.read().split("\n")
            else:
                vocab = cat_column.vocabulary_list
            categorifies[cat_column.key] = list(vocab)
            key = cat_column.key

        elif isinstance(cat_column, fc.HashedCategoricalColumn):
            hashes[cat_column.key] = cat_column.hash_bucket_size
            key = cat_column.key

        elif isinstance(cat_column, fc.CrossedColumn):
            keys = []
            for key in cat_column.keys:
                if isinstance(key, fc.BucketizedColumn):
                    keys.append(key.source_column.key + "_bucketized")
                elif isinstance(key, str):
                    keys.append(key)
                else:
                    keys.append(key.key)
            crosses[tuple(keys)] = (cat_column.hash_bucket_size, embedding_dim)

            # put off making the new columns here too so that we
            # make sure we have the key right after we check
            # for buckets later
            continue

        elif isinstance(cat_column, fc.IdentityCategoricalColumn):
            new_feature_columns.append(column)
            continue

        else:
            raise ValueError("Unknown column {}".format(cat_column))

        new_feature_columns.append(
            _make_categorical_embedding(key, cat_column.num_buckets,
                                        embedding_dim))

    features = ColumnSelector(label_name)

    if len(buckets) > 0:
        new_buckets = {}
        for key, (boundaries, embedding_dim) in buckets.items():
            new_feature_columns.append(
                _make_categorical_embedding(key + "_bucketized",
                                            len(boundaries) + 1,
                                            embedding_dim))
            new_buckets[key] = boundaries

        features_buckets = (new_buckets.keys() >> Bucketize(new_buckets) >>
                            Rename(postfix="_bucketized"))
        features += features_buckets

    if len(replaced_buckets) > 0:
        new_replaced_buckets = {}
        for key, (boundaries, embedding_dim) in replaced_buckets.items():
            new_feature_columns.append(
                _make_categorical_embedding(key,
                                            len(boundaries) + 1,
                                            embedding_dim))
            new_replaced_buckets[key] = boundaries
        features_replaced_buckets = new_replaced_buckets.keys() >> Bucketize(
            new_replaced_buckets)
        features += features_replaced_buckets

    if len(categorifies) > 0:
        vocabs = {
            column: pd.Series(vocab)
            for column, vocab in categorifies.items()
        }
        features += ColumnSelector(list(
            categorifies.keys())) >> Categorify(vocabs=vocabs)

    if len(hashes) > 0:
        features += ColumnSelector(list(hashes.keys())) >> HashBucket(hashes)

    if len(crosses) > 0:
        # need to check if any bucketized columns are coming from
        # the bucketized version or the raw version
        new_crosses = {}
        for keys, (hash_bucket_size, embedding_dim) in crosses.items():
            # if we're bucketizing the input we have to do more work here -
            if any(key.endswith("_bucketized") for key in keys):
                cross_columns = []
                for key in keys:
                    if key.endswith("_bucketized"):
                        bucketized_cols = []
                        bucketized_cols.append(key)
                        key = key.replace("_bucketized", "")
                        if key in buckets:
                            # find if there are different columns
                            diff_col = list(
                                set(features_buckets.columns)
                                ^ set(bucketized_cols))
                            if diff_col:
                                features_buckets.columns.remove(diff_col[0])
                            cross_columns.append(features_buckets)
                        elif key in replaced_buckets:
                            diff_col = list(
                                set(features_replaced_buckets.columns)
                                ^ set(bucketized_cols))
                            if diff_col:
                                features_replaced_buckets.columns.remove(
                                    diff_col[0])
                            cross_columns.append(features_replaced_buckets)
                        else:
                            raise RuntimeError(f"Unknown bucket column {key}")
                    else:
                        cross_columns.append(nvt.WorkflowNode(key))

                features += sum(
                    cross_columns[1:],
                    cross_columns[0]) >> HashedCross(hash_bucket_size)

            else:
                new_crosses[tuple(keys)] = hash_bucket_size
            key = "_X_".join(keys)
            new_feature_columns.append(
                _make_categorical_embedding(key, hash_bucket_size,
                                            embedding_dim))

        if new_crosses:
            features += new_crosses.keys() >> HashedCross(new_crosses)

    if numeric_columns:
        features += [col.key for col in numeric_columns]

    workflow = nvt.Workflow(features)

    return workflow, numeric_columns + new_feature_columns
def test_grouped_names_returns_nested_list():
    selector = ColumnSelector(["a", "b", "c"],
                              [ColumnSelector(["d", "e", "f"])])

    assert selector.grouped_names == ["a", "b", "c", ("d", "e", "f")]