def test_scattered_embedding_lookup_sparse(self):
        with self.test_session():
            embedding_weights = self._random_weights(num_shards=3)
            sparse_tensor = sparse_tensor_lib.SparseTensor(
                values=["foo", "bar", "foo", "bar"],
                indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
                dense_shape=[5, 2])

            embedding_lookup_result = (
                embedding_ops.scattered_embedding_lookup_sparse(
                    embedding_weights,
                    sparse_tensor,
                    dimension=5,
                    combiner="mean").eval())

            self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
            # Same non-zero embedding for the empty rows filled with a default value.
            self.assertAllEqual(embedding_lookup_result[2],
                                embedding_lookup_result[4])
            embedding_norm = np.sum(embedding_lookup_result[2]**2)
            self.assertGreater(embedding_norm, 0)

            self.assertAllEqual(
                embedding_lookup_result[1], 0.5 *
                (embedding_lookup_result[0] + embedding_lookup_result[3]))
  def test_scattered_embedding_lookup_sparse(self):
    with self.cached_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_tensor = sparse_tensor_lib.SparseTensor(
          values=["foo", "bar", "foo", "bar"],
          indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
          dense_shape=[5, 2])

      embedding_lookup_result = (
          embedding_ops.scattered_embedding_lookup_sparse(
              embedding_weights, sparse_tensor, dimension=5,
              combiner="mean").eval())

      self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
      # Same non-zero embedding for the empty rows filled with a default value.
      self.assertAllEqual(embedding_lookup_result[2],
                          embedding_lookup_result[4])
      embedding_norm = np.sum(embedding_lookup_result[2]**2)
      self.assertGreater(embedding_norm, 0)

      self.assertAllEqual(embedding_lookup_result[1], 0.5 * (
          embedding_lookup_result[0] + embedding_lookup_result[3]))
def _embeddings_from_arguments(column,
                               args,
                               weight_collections,
                               trainable,
                               output_rank=2):
  """Returns embeddings for a column based on the computed arguments.

  Args:
   column: the column name.
   args: the _DeepEmbeddingLookupArguments for this column.
   weight_collections: collections to store weights in.
   trainable: whether these embeddings should be trainable.
   output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
     be combined to produce the desired rank.

  Returns:
   the embeddings.

  Raises:
   ValueError: if not possible to create.
  """
  # pylint: disable=protected-access
  input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
  weight_tensor = None
  if args.weight_tensor is not None:
    weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
  # pylint: enable=protected-access

  # This option is only enabled for scattered_embedding_column.
  if args.hash_key:
    embeddings = contrib_variables.model_variable(
        name='weights',
        shape=[args.vocab_size],
        dtype=dtypes.float32,
        initializer=args.initializer,
        trainable=trainable,
        collections=weight_collections)

    return embedding_ops.scattered_embedding_lookup_sparse(
        embeddings, input_tensor, args.dimension,
        hash_key=args.hash_key,
        combiner=args.combiner, name='lookup')

  if args.shared_embedding_name is not None:
    shared_embedding_collection_name = (
        'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
    graph = ops.get_default_graph()
    shared_embedding_collection = (
        graph.get_collection_ref(shared_embedding_collection_name))
    shape = [args.vocab_size, args.dimension]
    if shared_embedding_collection:
      if len(shared_embedding_collection) > 1:
        raise ValueError('Collection %s can only contain one '
                         '(partitioned) variable.'
                         % shared_embedding_collection_name)
      else:
        embeddings = shared_embedding_collection[0]
        if embeddings.get_shape() != shape:
          raise ValueError('The embedding variable with name {} already '
                           'exists, but its shape does not match required '
                           'embedding shape  here. Please make sure to use '
                           'different shared_embedding_name for different '
                           'shared embeddings.'.format(
                               args.shared_embedding_name))
    else:
      embeddings = contrib_variables.model_variable(
          name=args.shared_embedding_name,
          shape=shape,
          dtype=dtypes.float32,
          initializer=args.initializer,
          trainable=trainable,
          collections=weight_collections)
      graph.add_to_collection(shared_embedding_collection_name, embeddings)
  else:
    embeddings = contrib_variables.model_variable(
        name='weights',
        shape=[args.vocab_size, args.dimension],
        dtype=dtypes.float32,
        initializer=args.initializer,
        trainable=trainable,
        collections=weight_collections)

  if isinstance(embeddings, variables.Variable):
    embeddings = [embeddings]
  else:
    embeddings = embeddings._get_variable_list()  # pylint: disable=protected-access
  # pylint: disable=protected-access
  _maybe_restore_from_checkpoint(
      column._checkpoint_path(), embeddings)
  return embedding_ops.safe_embedding_lookup_sparse(
      embeddings,
      input_tensor,
      sparse_weights=weight_tensor,
      combiner=args.combiner,
      name=column.name + 'weights',
      max_norm=args.max_norm)
Example #4
0
def _embeddings_from_arguments(column,
                               args,
                               weight_collections,
                               trainable,
                               output_rank=2):
    """Returns embeddings for a column based on the computed arguments.

  Args:
   column: the column name.
   args: the _DeepEmbeddingLookupArguments for this column.
   weight_collections: collections to store weights in.
   trainable: whether these embeddings should be trainable.
   output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
     be combined to produce the desired rank.

  Returns:
   the embeddings.

  Raises:
   ValueError: if not possible to create.
  """
    # pylint: disable=protected-access
    input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
    weight_tensor = None
    if args.weight_tensor is not None:
        weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
    # pylint: enable=protected-access

    # This option is only enabled for scattered_embedding_column.
    if args.hash_key:
        embeddings = contrib_variables.model_variable(
            name='weights',
            shape=[args.vocab_size],
            dtype=dtypes.float32,
            initializer=args.initializer,
            trainable=(trainable and args.trainable),
            collections=weight_collections)

        return embedding_ops.scattered_embedding_lookup_sparse(
            embeddings,
            input_tensor,
            args.dimension,
            hash_key=args.hash_key,
            combiner=args.combiner,
            name='lookup')

    if args.shared_embedding_name is not None:
        shared_embedding_collection_name = ('SHARED_EMBEDDING_COLLECTION_' +
                                            args.shared_embedding_name.upper())
        graph = ops.get_default_graph()
        shared_embedding_collection = (
            graph.get_collection_ref(shared_embedding_collection_name))
        shape = [args.vocab_size, args.dimension]
        if shared_embedding_collection:
            if len(shared_embedding_collection) > 1:
                raise ValueError('Collection %s can only contain one '
                                 '(partitioned) variable.' %
                                 shared_embedding_collection_name)
            else:
                embeddings = shared_embedding_collection[0]
                if embeddings.get_shape() != shape:
                    raise ValueError(
                        'The embedding variable with name {} already '
                        'exists, but its shape does not match required '
                        'embedding shape  here. Please make sure to use '
                        'different shared_embedding_name for different '
                        'shared embeddings.'.format(
                            args.shared_embedding_name))
        else:
            embeddings = contrib_variables.model_variable(
                name=args.shared_embedding_name,
                shape=shape,
                dtype=dtypes.float32,
                initializer=args.initializer,
                trainable=(trainable and args.trainable),
                collections=weight_collections)
            graph.add_to_collection(shared_embedding_collection_name,
                                    embeddings)
    else:
        embeddings = contrib_variables.model_variable(
            name='weights',
            shape=[args.vocab_size, args.dimension],
            dtype=dtypes.float32,
            initializer=args.initializer,
            trainable=(trainable and args.trainable),
            collections=weight_collections)

    if isinstance(embeddings, variables.Variable):
        embeddings = [embeddings]
    else:
        embeddings = embeddings._get_variable_list()  # pylint: disable=protected-access
    # pylint: disable=protected-access
    _maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)
    return embedding_ops.safe_embedding_lookup_sparse(
        embeddings,
        input_tensor,
        sparse_weights=weight_tensor,
        combiner=args.combiner,
        name=column.name + 'weights',
        max_norm=args.max_norm)