Example #1
0
def _embeddings_from_arguments(column,
                               args,
                               weight_collections,
                               trainable,
                               output_rank=2):
  """Returns embeddings for a column based on the computed arguments.

  Args:
   column: the column name.
   args: the _DeepEmbeddingLookupArguments for this column.
   weight_collections: collections to store weights in.
   trainable: whether these embeddings should be trainable.
   output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
     be combined to produce the desired rank.

  Returns:
   the embeddings.

  Raises:
   ValueError: if not possible to create.
  """
  # pylint: disable=protected-access
  input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
  weight_tensor = None
  if args.weight_tensor is not None:
    weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
  # pylint: enable=protected-access

  if args.hashed:
    embeddings = contrib_variables.model_variable(
        name='weights',
        shape=[args.vocab_size],
        dtype=dtypes.float32,
        initializer=args.initializer,
        trainable=trainable,
        collections=weight_collections)

    return embedding_ops.hashed_embedding_lookup_sparse(
        embeddings, input_tensor, args.dimension,
        combiner=args.combiner, name='lookup')

  if args.shared_embedding_name is not None:
    shared_embedding_collection_name = (
        'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
    graph = ops.get_default_graph()
    shared_embedding_collection = (
        graph.get_collection_ref(shared_embedding_collection_name))
    shape = [args.vocab_size, args.dimension]
    if shared_embedding_collection:
      if len(shared_embedding_collection) > 1:
        raise ValueError('Collection %s can only contain one '
                         '(partitioned) variable.'
                         % shared_embedding_collection_name)
      else:
        embeddings = shared_embedding_collection[0]
        if embeddings.get_shape() != shape:
          raise ValueError('The embedding variable with name {} already '
                           'exists, but its shape does not match required '
                           'embedding shape  here. Please make sure to use '
                           'different shared_embedding_name for different '
                           'shared embeddings.'.format(
                               args.shared_embedding_name))
    else:
      embeddings = contrib_variables.model_variable(
          name=args.shared_embedding_name,
          shape=shape,
          dtype=dtypes.float32,
          initializer=args.initializer,
          trainable=trainable,
          collections=weight_collections)
      graph.add_to_collection(shared_embedding_collection_name, embeddings)
  else:
    embeddings = contrib_variables.model_variable(
        name='weights',
        shape=[args.vocab_size, args.dimension],
        dtype=dtypes.float32,
        initializer=args.initializer,
        trainable=trainable,
        collections=weight_collections)

  if isinstance(embeddings, variables.Variable):
    embeddings = [embeddings]
  else:
    embeddings = embeddings._get_variable_list()  # pylint: disable=protected-access
  # pylint: disable=protected-access
  _maybe_restore_from_checkpoint(
      column._checkpoint_path(), embeddings)
  return embedding_ops.safe_embedding_lookup_sparse(
      embeddings,
      input_tensor,
      sparse_weights=weight_tensor,
      combiner=args.combiner,
      name=column.name + 'weights',
      max_norm=args.max_norm)