def _embeddings_from_arguments(column, args, weight_collections, trainable, output_rank=2): """Returns embeddings for a column based on the computed arguments. Args: column: the column name. args: the _DeepEmbeddingLookupArguments for this column. weight_collections: collections to store weights in. trainable: whether these embeddings should be trainable. output_rank: the desired rank of the returned `Tensor`. Inner dimensions will be combined to produce the desired rank. Returns: the embeddings. Raises: ValueError: if not possible to create. """ # pylint: disable=protected-access input_tensor = layers._inner_flatten(args.input_tensor, output_rank) weight_tensor = None if args.weight_tensor is not None: weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank) # pylint: enable=protected-access if args.hashed: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) return embedding_ops.hashed_embedding_lookup_sparse( embeddings, input_tensor, args.dimension, combiner=args.combiner, name='lookup') if args.shared_embedding_name is not None: shared_embedding_collection_name = ( 'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper()) graph = ops.get_default_graph() shared_embedding_collection = ( graph.get_collection_ref(shared_embedding_collection_name)) shape = [args.vocab_size, args.dimension] if shared_embedding_collection: if len(shared_embedding_collection) > 1: raise ValueError('Collection %s can only contain one ' '(partitioned) variable.' % shared_embedding_collection_name) else: embeddings = shared_embedding_collection[0] if embeddings.get_shape() != shape: raise ValueError('The embedding variable with name {} already ' 'exists, but its shape does not match required ' 'embedding shape here. Please make sure to use ' 'different shared_embedding_name for different ' 'shared embeddings.'.format( args.shared_embedding_name)) else: embeddings = contrib_variables.model_variable( name=args.shared_embedding_name, shape=shape, dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) graph.add_to_collection(shared_embedding_collection_name, embeddings) else: embeddings = contrib_variables.model_variable( name='weights', shape=[args.vocab_size, args.dimension], dtype=dtypes.float32, initializer=args.initializer, trainable=trainable, collections=weight_collections) if isinstance(embeddings, variables.Variable): embeddings = [embeddings] else: embeddings = embeddings._get_variable_list() # pylint: disable=protected-access # pylint: disable=protected-access _maybe_restore_from_checkpoint( column._checkpoint_path(), embeddings) return embedding_ops.safe_embedding_lookup_sparse( embeddings, input_tensor, sparse_weights=weight_tensor, combiner=args.combiner, name=column.name + 'weights', max_norm=args.max_norm)