Пример #1
0
    def __init__(self,
                 config: de_config_pb2.DynamicEmbeddingConfig,
                 var_name: typing.Text,
                 service_address: typing.Text = "",
                 timeout_ms: int = -1):
        """Constructor for DynamicEmbeddingLookup.

    Args:
      config: A DynamicEmbeddingConfig proto that configures the embedding.
      var_name: A unique name for the given embedding.
      service_address: The address of a knowledge bank service. If empty, the
        value passed from --kbs_address (defined in
        .../carls/dynamic_embedding_manager.cc) flag will be used instead.
      timeout_ms: Timeout millseconds for the connection. If negative, never
        timout.

    Raises:
      ValueError: if var_name is `None` or empty.
    """
        super(DynamicEmbeddingLookup, self).__init__()
        if not var_name:
            raise ValueError("Must specify a non-empty var_name.")

        self.embedding_dimension = config.embedding_dimension
        context.add_to_collection(var_name, config)
        self.resource = gen_carls_ops.dynamic_embedding_manager_resource(
            config.SerializeToString(), var_name, service_address, timeout_ms)
 def testClearAllCollection(self):
     config = de_config_pb2.DynamicEmbeddingConfig(embedding_dimension=5)
     context.add_to_collection('first', config)
     context.add_to_collection('second', config)
     collections = context.get_all_collection()
     self.assertLen(collections, 2)
     context.clear_all_collection()
     collections = context.get_all_collection()
     self.assertLen(collections, 0)
 def test_get_all_collection(self):
     config = de_config_pb2.DynamicEmbeddingConfig(embedding_dimension=5)
     context.add_to_collection('first', config)
     context.add_to_collection('second', config)
     collections = context.get_all_collection()
     self.assertLen(collections, 2)
     for key, value in collections:
         self.assertIn(key, {'first', 'second'})
         self.assertProtoEquals(value, config)
def dynamic_gaussian_memory_lookup(
        inputs: tf.Tensor,
        mode: typing.Union[int, tf.Tensor],
        config: de_config_pb2.DynamicEmbeddingConfig,
        var_name: typing.Text,
        service_address: typing.Text = "",
        timeout_ms: int = -1):
    """Applies dynamic Gaussian memory to given inputs.

  A Gaussian memory assumes the input pattern can be represented by a number of
  Gaussian clusters. This function returns the closest Gaussian mean, variance
  and the distance between each data and the closest Guassian center.

  This function can be used in conjunction with a DynamicNormalization layer in
  a DNN. The distance between the input and the Gaussian cluster can be used for
  model uncertainty inferece.

  Note that the memory data is only based on the last dimension of the input.
  Hence if the input shape is [d1, d2, ..., dn], it is assumed to contain
  d1*d2*...*dn-1 data points.

  Args:
    inputs: A float `Tensor` of shape [d1, d2, ..., dn] with n > 0.
    mode: An int or a `Tensor` whose value must be one of
      {LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW}.
    config: An instance of DynamicEmbeddingConfig.
    var_name: A unique name for the given op.
    service_address: The address of a knowledge bank service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    - A `Tensor` with the same shape of input representing the mean values.
    - A `Tensor` with the same shape of input representing the variance values.
    - A `Tensor` with the shape [d1, d2, ..., dn-1] representing the distance to
      the cluster center.
    - An int `Tensor` with the shape [d1, d2, ..., dn-1] representing the
    cluster ids.
  Raises:
    TypeError: if dm_config is not an instance of DynamicMemoryConfig.
    ValueError: If layer_name is not specified or mode is not valid.
  """
    if isinstance(mode, int) and mode not in {
            LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW
    }:
        raise ValueError("Invalid mode: %r" % mode)
    else:  # mode is a Tensor
        mode = tf.cast(mode, tf.int32)
    if not var_name:
        raise ValueError("Must specify a valid layer_name.")

    context.add_to_collection(var_name, config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_carls_ops.dynamic_gaussian_memory_lookup(inputs, mode, resource)
Пример #5
0
def dynamic_embedding_lookup(keys: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             skip_gradient_update: bool = False,
                             timeout_ms: int = -1) -> tf.Tensor:
    """Returns the embeddings of from given keys.

  Args:
    keys: A string `Tensor` of shape [batch_size] or [batch_size,
      max_sequence_length] where an empty string would be mapped to an all zero
      embedding.
    config: A DynamicEmbeddingConfig proto that configures the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a knowledge bank service. If empty, the
      value passed from --kbs_address flag will be used instead.
    skip_gradient_update: A boolean indicating if gradient update is needed.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    ValueError: If name is not specified.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")

    # If skip_gradient_update is true, reate a dummy variable so that the
    # gradients can be passed in.
    if skip_gradient_update:
        grad_placeholder = tf.constant(0.0)
    else:
        grad_placeholder = tf.Variable(0.0)

    context.add_to_collection(var_name, config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_carls_ops.dynamic_embedding_lookup(keys, grad_placeholder,
                                                  resource,
                                                  config.embedding_dimension)
Пример #6
0
def top_k(inputs: tf.Tensor,
          k: int,
          de_config: de_config_pb2.DynamicEmbeddingConfig,
          var_name: typing.Text,
          service_address: typing.Text = "",
          timeout_ms: int = -1):
    """Computes logits for the top k closest embeddings to the inputs.

  Args:
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    k: An `int` denoting the number of returned keys.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    keys: A string `Tensor` of shape `[batch_size, k]` representing the top k
        keys relative to the input.
    logits: A float `Tensor` of shape `[batch_size, k]` representing the logits
        for the returned keys.

  Raises:
    ValueError: if k is not greater than zero.

  Note: The (keys, logits) pair returned here should not be used for training as
  they only represent biased sampling. Instead, use sampled_softmax_loss()
  for training.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")
    if k <= 0:
        raise ValueError("k must be greater than zero, got %d" % k)

    context.add_to_collection(var_name, de_config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)
    return gen_carls_ops.topk_lookup(inputs, k, resource)
Пример #7
0
def dynamic_embedding_update(keys: tf.Tensor,
                             values: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             timeout_ms: int = -1):
    """Updates the embeddings of given keys with given values.

  Args:
    keys: A string `Tensor` of shape [batch] or [batch_size,
      max_sequence_length].
    values: A `Tensor` of shape [batch_size, embedding_dimension] or
      [batch_size, max_sequence_length, embedding_dimension].
    config: A DynamicEmbeddingConfig proto that configures the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    TypeError: If var_name is not specified.
  """
    if not var_name:
        raise TypeError("Must specify a valid var_name.")

    context.add_to_collection(var_name, config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_carls_ops.dynamic_embedding_update(keys, values, resource,
                                                  config.embedding_dimension)
Пример #8
0
def compute_sampled_logits(positive_keys,
                           inputs,
                           num_samples: int,
                           de_config: de_config_pb2.DynamicEmbeddingConfig,
                           var_name: typing.Text,
                           service_address: typing.Text = "",
                           timeout_ms: int = -1):
    """Computes sampled logits from given positive labels.

  Args:
    positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
      input positive keys.
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    num_samples: An int denoting the returned positive and negative samples.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    logits: A float `Tensor` of shape `[batch_size, num_samples]` representing
        the logits for sampled labels.
    labels: A float `Tensor` of shape `[batch_size, num_samples]` with values
        in {0, 1} indicating if the sample is positive or negative.
    keys: A string `Tensor` of shape `[batch_size, num_samples]` representing
        the keys for each sample.
    mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask
        of each batch. For example, if all keys in positive_keys[i] are empty,
        mask[i] = 0; otherwise mask[i] = 1.
    weights: A float `Tensor` representing the embeddings of the sampled keys.

  Raises:
    ValueError: If var_name is not specified.
    TypeError: If de_config is an instance of DynamicEmbeddingConfig.
  """
    if not var_name:
        raise ValueError("Must specify a valid name, got %s" % var_name)
    if num_samples < 1:
        raise ValueError("Invalid num_samples: %d" % num_samples)

    context.add_to_collection(var_name, de_config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)

    # Create a dummy variable so that the gradients can be passed in.
    grad_placeholder = tf.Variable(0.0)

    keys, labels, expected_counts, mask, weights = (
        gen_carls_ops.sampled_logits_lookup(positive_keys, inputs, num_samples,
                                            grad_placeholder, resource))

    # Compute sampled logits.
    # Shape of weights: [d1, d2, dn-1, num_samples, embed_dim]
    # Shape of inputs: [d1, d2, dn-1, embed_dim]
    # Shape of output logits: [d1, d2, dn-1, num_samples]

    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim]
    tiled_inputs = tf.expand_dims(inputs, axis=-2)
    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim]
    multiples = [1] * (inputs.ndim + 1)
    multiples[-2] = num_samples
    tiled_inputs = tf.tile(tiled_inputs, multiples)
    # [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples]
    logits = tf.reduce_sum(weights * tiled_inputs, -1)
    # Sampled logits.
    logits -= tf.math.log(expected_counts)

    return logits, labels, keys, mask, weights
    def test_add_to_collection(self):
        config = de_config_pb2.DynamicEmbeddingConfig(embedding_dimension=5)
        context.add_to_collection('first', config)
        context.add_to_collection('second', config)
        context.add_to_collection('first', config)  # ok to add twice.
        self.assertLen(context._knowledge_bank_collections, 2)

        # Empty name.
        with self.assertRaises(ValueError):
            context.add_to_collection('', config)
        # Wrong config type.
        with self.assertRaises(TypeError):
            context.add_to_collection('first', 'config')
        # Checks adding a different config with the same name is not allowed.
        config.embedding_dimension = 10
        with self.assertRaises(ValueError):
            context.add_to_collection('first', config)