def _maybe_reshape_input_tensor(tensor, column_name, output_rank): """Reshape the input tensor by the following rule. 1. If `output_rank > input_rank + 1`, raise a `ValueError`. 2. If `output_rank == input_rank + 1`, expand the tensor by one dimension. 3. If `output_rank == input_rank`, do nothing. 4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor. Args: tensor: A Tensor or SparseTensor to be reshaped. column_name: A string name of the feature column for the tensor. output_rank: the desired rank of the tensor. Returns: A reshaped Tensor or SparseTensor. Raises: ValueError: if `output_rank > input_rank + 1` for the input tensor. """ input_rank = tensor.get_shape().ndims if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor): # Try to get the rank of a sparse tensor by its dense_shape's shape. input_rank = tensor.dense_shape.get_shape().as_list()[0] if input_rank is None: raise ValueError( 'Error while processing column {}. Rank of input Tensor ' 'can not be None.'.format(column_name)) if output_rank > input_rank + 1: raise ValueError( 'Error while processing column {}. Rank of input Tensor ' '({}) should be the same as output_rank ({}). For ' 'example, sequence data should typically be 3 ' 'dimensional (rank 3) while non-sequence data is ' 'typically 2 dimensional (rank 2).'.format(column_name, input_rank, output_rank)) elif output_rank == input_rank + 1: # Expand the tensor's shape by 1 dimension. if isinstance(tensor, sparse_tensor_py.SparseTensor): output_shape = array_ops.concat([tensor.dense_shape, [1]], 0) return sparse_ops.sparse_reshape(tensor, output_shape) else: reshaped = array_ops.expand_dims(tensor, -1) # Try to calculate the new shape. static_shape = tensor.get_shape() if static_shape is not None and static_shape.dims is not None: reshaped.set_shape(static_shape.as_list() + [1]) return reshaped elif output_rank < input_rank: return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access else: return tensor
def _maybe_reshape_input_tensor(tensor, column_name, output_rank): """Reshape the input tensor by the following rule. 1. If `output_rank > input_rank + 1`, raise a `ValueError`. 2. If `output_rank == input_rank + 1`, expand the tensor by one dimension. 3. If `output_rank == input_rank`, do nothing. 4. If `output_rank < input_rank`, flatten the inner dimensions of the tensor. Args: tensor: A Tensor or SparseTensor to be reshaped. column_name: A string name of the feature column for the tensor. output_rank: the desired rank of the tensor. Returns: A reshaped Tensor or SparseTensor. Raises: ValueError: if `output_rank > input_rank + 1` for the input tensor. """ input_rank = tensor.get_shape().ndims if input_rank is None and isinstance(tensor, sparse_tensor_py.SparseTensor): # Try to get the rank of a sparse tensor by its dense_shape's shape. input_rank = tensor.dense_shape.get_shape().as_list()[0] if input_rank is None: raise ValueError('Error while processing column {}. Rank of input Tensor ' 'can not be None.'.format(column_name)) if output_rank > input_rank + 1: raise ValueError('Error while processing column {}. Rank of input Tensor ' '({}) should be the same as output_rank ({}). For ' 'example, sequence data should typically be 3 ' 'dimensional (rank 3) while non-sequence data is ' 'typically 2 dimensional (rank 2).'.format( column_name, input_rank, output_rank)) elif output_rank == input_rank + 1: # Expand the tensor's shape by 1 dimension. if isinstance(tensor, sparse_tensor_py.SparseTensor): output_shape = array_ops.concat([tensor.dense_shape, [1]], 0) return sparse_ops.sparse_reshape(tensor, output_shape) else: reshaped = array_ops.expand_dims(tensor, -1) # Try to calculate the new shape. static_shape = tensor.get_shape() if static_shape is not None and static_shape.dims is not None: reshaped.set_shape(static_shape.as_list() + [1]) return reshaped elif output_rank < input_rank: return layers._inner_flatten(tensor, output_rank) # pylint: disable=protected-access else: return tensor
def _embeddings_from_arguments(column, args, weight_collections, trainable, output_rank=2): """Returns embeddings for a column based on the computed arguments. Args: column: the column name. args: the _DeepEmbeddingLookupArguments for this column. weight_collections: collections to store weights in. trainable: whether these embeddings should be trainable. output_rank: the desired rank of the returned `Tensor`. Inner dimensions will be combined to produce the desired rank. Returns: the embeddings. Raises: ValueError: if not possible to create. """ # pylint: disable=protected-access input_tensor = layers._inner_flatten(args.input_tensor, output_rank) weight_tensor = None if args.weight_tensor is not None: weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank) # pylint: enable=protected-access # This option is only enabled for scattered_embedding_column. if args.hash_key: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) return embedding_ops.scattered_embedding_lookup_sparse( embeddings, input_tensor, args.dimension, hash_key=args.hash_key, combiner=args.combiner, name='lookup') if args.shared_embedding_name is not None: shared_embedding_collection_name = ( 'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper()) graph = ops.get_default_graph() shared_embedding_collection = ( graph.get_collection_ref(shared_embedding_collection_name)) shape = [args.vocab_size, args.dimension] if shared_embedding_collection: if len(shared_embedding_collection) > 1: raise ValueError('Collection %s can only contain one ' '(partitioned) variable.' % shared_embedding_collection_name) else: embeddings = shared_embedding_collection[0] if embeddings.get_shape() != shape: raise ValueError('The embedding variable with name {} already ' 'exists, but its shape does not match required ' 'embedding shape here. Please make sure to use ' 'different shared_embedding_name for different ' 'shared embeddings.'.format( args.shared_embedding_name)) else: embeddings = contrib_variables.model_variable( name=args.shared_embedding_name, shape=shape, dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) graph.add_to_collection(shared_embedding_collection_name, embeddings) else: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size, args.dimension], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) if isinstance(embeddings, variables.Variable): embeddings = [embeddings] else: embeddings = embeddings._get_variable_list() # pylint: disable=protected-access # pylint: disable=protected-access _maybe_restore_from_checkpoint( column._checkpoint_path(), embeddings) return embedding_ops.safe_embedding_lookup_sparse( embeddings, input_tensor, sparse_weights=weight_tensor, combiner=args.combiner, name=column.name + 'weights', max_norm=args.max_norm)
def _embeddings_from_arguments(column, args, weight_collections, trainable, output_rank=2): """Returns embeddings for a column based on the computed arguments. Args: column: the column name. args: the _DeepEmbeddingLookupArguments for this column. weight_collections: collections to store weights in. trainable: whether these embeddings should be trainable. output_rank: the desired rank of the returned `Tensor`. Inner dimensions will be combined to produce the desired rank. Returns: the embeddings. Raises: ValueError: if not possible to create. """ # pylint: disable=protected-access input_tensor = layers._inner_flatten(args.input_tensor, output_rank) weight_tensor = None if args.weight_tensor is not None: weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank) # pylint: enable=protected-access if args.hashed: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) return embedding_ops.hashed_embedding_lookup_sparse( embeddings, input_tensor, args.dimension, combiner=args.combiner, name='lookup') if args.shared_embedding_name is not None: shared_embedding_collection_name = ( 'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper()) graph = ops.get_default_graph() shared_embedding_collection = ( graph.get_collection_ref(shared_embedding_collection_name)) shape = [args.vocab_size, args.dimension] if shared_embedding_collection: if len(shared_embedding_collection) > 1: raise ValueError('Collection %s can only contain one ' '(partitioned) variable.' % shared_embedding_collection_name) else: embeddings = shared_embedding_collection[0] if embeddings.get_shape() != shape: raise ValueError('The embedding variable with name {} already ' 'exists, but its shape does not match required ' 'embedding shape here. Please make sure to use ' 'different shared_embedding_name for different ' 'shared embeddings.'.format( args.shared_embedding_name)) else: embeddings = contrib_variables.model_variable( name=args.shared_embedding_name, shape=shape, dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) graph.add_to_collection(shared_embedding_collection_name, embeddings) else: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size, args.dimension], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) if isinstance(embeddings, variables.Variable): embeddings = [embeddings] else: embeddings = embeddings._get_variable_list() # pylint: disable=protected-access # pylint: disable=protected-access _maybe_restore_from_checkpoint( column._checkpoint_path(), embeddings) return embedding_ops.safe_embedding_lookup_sparse( embeddings, input_tensor, sparse_weights=weight_tensor, combiner=args.combiner, name=column.name + 'weights', max_norm=args.max_norm)