Пример #1
0
def input_from_feature_columns(columns_to_tensors,
                               feature_columns,
                               weight_collections,
                               scope,
                               trainable=True,
                               output_rank=2,
                               default_name='input_from_feature_columns'):
    columns_to_tensors = columns_to_tensors.copy()
    feature_column_ops.check_feature_columns(feature_columns)
    with variable_scope.variable_scope(scope,
                                       default_name=default_name,
                                       values=columns_to_tensors.values()):
        output_tensors_dict = {}
        transformer = feature_column_ops._Transformer(columns_to_tensors)
        if weight_collections:
            weight_collections = list(set(list(weight_collections) +
                                          [ops.GraphKeys.GLOBAL_VARIABLES]))
        for column in sorted(set(feature_columns), key=lambda x: x.key):
            with variable_scope.variable_scope(None,
                                               default_name=column.name,
                                               values=columns_to_tensors.values()):
                transformed_tensor = transformer.transform(column)
                key = column.key
                try:
                    # pylint: disable=protected-access
                    arguments = column._deep_embedding_lookup_arguments(
                            transformed_tensor)
                    output_tensors_dict[key] = \
                            fc._embeddings_from_arguments(    # pylint: disable=protected-access
                                    column,
                                    arguments,
                                    weight_collections,
                                    trainable,
                                    output_rank=output_rank)
                except NotImplementedError as ee:
                    try:
                        # pylint: disable=protected-access
                        output_tensors_dict[key] = \
                                column._to_dnn_input_layer(
                                        transformed_tensor,
                                        weight_collections,
                                        trainable,
                                        output_rank=output_rank)
                    except ValueError as e:
                        raise ValueError('Error creating input layer for column: {}.\n'
                                                         '{}, {}'.format(column.name, e, ee))
        return output_tensors_dict