Пример #1
0
def add_to_collection(name: Text,
                      config: de_config_pb2.DynamicEmbeddingConfig):
    """Adds given (name, config) pair to global collectionss.

  Args:
    name: A string denoting the variable name.
    config: An instance of DynamicEmbeddingConfig.

  Raises:
    TypeError: Invalid input.
    ValueError: Name is empty, or a different config is added for an existing
    variable.
  """
    if not name:
        raise ValueError("Empty name.")
    if not isinstance(config, de_config_pb2.DynamicEmbeddingConfig):
        raise TypeError("Config is not an instance of DynamicEmbeddingConfig.")
    if name in _knowledge_bank_collections.keys():
        existing_config = _knowledge_bank_collections[name]
        if config.SerializeToString() != existing_config.SerializeToString():
            raise ValueError(
                "Adding a new config for the same var name is not allowed, existing:"
                " %r, new: %r." % (existing_config, config))

    with _lock:
        _knowledge_bank_collections[
            name] = de_config_pb2.DynamicEmbeddingConfig()
        _knowledge_bank_collections[name].CopyFrom(config)
Пример #2
0
    def __init__(self,
                 config: de_config_pb2.DynamicEmbeddingConfig,
                 var_name: typing.Text,
                 service_address: typing.Text = "",
                 timeout_ms: int = -1):
        """Constructor for DynamicEmbeddingLookup.

    Args:
      config: A DynamicEmbeddingConfig proto that configures the embedding.
      var_name: A unique name for the given embedding.
      service_address: The address of a knowledge bank service. If empty, the
        value passed from --kbs_address (defined in
        .../carls/dynamic_embedding_manager.cc) flag will be used instead.
      timeout_ms: Timeout millseconds for the connection. If negative, never
        timout.

    Raises:
      ValueError: if var_name is `None` or empty.
    """
        super(DynamicEmbeddingLookup, self).__init__()
        if not var_name:
            raise ValueError("Must specify a non-empty var_name.")

        self.embedding_dimension = config.embedding_dimension
        context.add_to_collection(var_name, config)
        self.resource = gen_carls_ops.dynamic_embedding_manager_resource(
            config.SerializeToString(), var_name, service_address, timeout_ms)
def dynamic_embedding_update(keys: tf.Tensor,
                             values: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             timeout_ms: int = -1):
  """Updates the embeddings of given keys with given values.

  Args:
    keys: A string `Tensor` of shape [batch] or [batch_size,
      max_sequence_length].
    values: A `Tensor` of shape [batch_size, embedding_dimension] or
      [batch_size, max_sequence_length, embedding_dimension].
    config: A DynamicEmbeddingConfig proto that configs the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    TypeError: If var_name is not specified.
  """
  if not var_name:
    raise TypeError("Must specify a valid var_name.")

  return gen_dynamic_embedding_ops.dynamic_embedding_update(
      keys, values, config.SerializeToString(), var_name, service_address,
      timeout_ms)
def dynamic_gaussian_memory_lookup(
        inputs: tf.Tensor,
        mode: typing.Union[int, tf.Tensor],
        config: de_config_pb2.DynamicEmbeddingConfig,
        var_name: typing.Text,
        service_address: typing.Text = "",
        timeout_ms: int = -1):
    """Applies dynamic Gaussian memory to given inputs.

  A Gaussian memory assumes the input pattern can be represented by a number of
  Gaussian clusters. This function returns the closest Gaussian mean, variance
  and the distance between each data and the closest Guassian center.

  This function can be used in conjunction with a DynamicNormalization layer in
  a DNN. The distance between the input and the Gaussian cluster can be used for
  model uncertainty inferece.

  Note that the memory data is only based on the last dimension of the input.
  Hence if the input shape is [d1, d2, ..., dn], it is assumed to contain
  d1*d2*...*dn-1 data points.

  Args:
    inputs: A float `Tensor` of shape [d1, d2, ..., dn] with n > 0.
    mode: An int or a `Tensor` whose value must be one of
      {LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW}.
    config: An instance of DynamicEmbeddingConfig.
    var_name: A unique name for the given op.
    service_address: The address of a knowledge bank service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    - A `Tensor` with the same shape of input representing the mean values.
    - A `Tensor` with the same shape of input representing the variance values.
    - A `Tensor` with the shape [d1, d2, ..., dn-1] representing the distance to
      the cluster center.
    - An int `Tensor` with the shape [d1, d2, ..., dn-1] representing the
    cluster ids.
  Raises:
    TypeError: if dm_config is not an instance of DynamicMemoryConfig.
    ValueError: If layer_name is not specified or mode is not valid.
  """
    if isinstance(mode, int) and mode not in {
            LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW
    }:
        raise ValueError("Invalid mode: %r" % mode)
    else:  # mode is a Tensor
        mode = tf.cast(mode, tf.int32)
    if not var_name:
        raise ValueError("Must specify a valid layer_name.")

    context.add_to_collection(var_name, config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_carls_ops.dynamic_gaussian_memory_lookup(inputs, mode, resource)
Пример #5
0
def dynamic_embedding_lookup(keys: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             skip_gradient_update: bool = False,
                             timeout_ms: int = -1) -> tf.Tensor:
    """Returns the embeddings of from given keys.

  Args:
    keys: A string `Tensor` of shape [batch_size] or [batch_size,
      max_sequence_length] where an empty string would be mapped to an all zero
      embedding.
    config: A DynamicEmbeddingConfig proto that configures the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a knowledge bank service. If empty, the
      value passed from --kbs_address flag will be used instead.
    skip_gradient_update: A boolean indicating if gradient update is needed.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    ValueError: If name is not specified.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")

    # If skip_gradient_update is true, reate a dummy variable so that the
    # gradients can be passed in.
    if skip_gradient_update:
        grad_placeholder = tf.constant(0.0)
    else:
        grad_placeholder = tf.Variable(0.0)

    context.add_to_collection(var_name, config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_carls_ops.dynamic_embedding_lookup(keys, grad_placeholder,
                                                  resource,
                                                  config.embedding_dimension)
Пример #6
0
def restore_knowledge_bank(config: de_config_pb2.DynamicEmbeddingConfig,
                           var_name: Text,
                           saved_path: Text,
                           service_address: Text = '',
                           timeout_ms: int = -1) -> None:
    """Restores knowledge bank data (`config`, `name`) from given `saved_path`.

  Args:
    config: A DynamicEmbeddingConfig proto that configs the embedding.
    var_name: A unique name for the given embedding.
    saved_path: A string representing the saved embedding data.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.
  """
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    gen_carls_ops.restore_knowledge_bank(saved_path, resource)
Пример #7
0
def top_k(inputs: tf.Tensor,
          k: int,
          de_config: de_config_pb2.DynamicEmbeddingConfig,
          var_name: typing.Text,
          service_address: typing.Text = "",
          timeout_ms: int = -1):
    """Computes logits for the top k closest embeddings to the inputs.

  Args:
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    k: An `int` denoting the number of returned keys.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    keys: A string `Tensor` of shape `[batch_size, k]` representing the top k
        keys relative to the input.
    logits: A float `Tensor` of shape `[batch_size, k]` representing the logits
        for the returned keys.

  Raises:
    ValueError: if k is not greater than zero.

  Note: The (keys, logits) pair returned here should not be used for training as
  they only represent biased sampling. Instead, use sampled_softmax_loss()
  for training.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")
    if k <= 0:
        raise ValueError("k must be greater than zero, got %d" % k)

    context.add_to_collection(var_name, de_config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)
    return gen_carls_ops.topk_lookup(inputs, k, resource)
Пример #8
0
def compute_sampled_logits(positive_keys,
                           inputs,
                           num_samples: int,
                           de_config: de_config_pb2.DynamicEmbeddingConfig,
                           var_name: typing.Text,
                           service_address: typing.Text = "",
                           timeout_ms: int = -1):
    """Computes sampled logits from given positive labels.

  Args:
    positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
      input positive keys.
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    num_samples: An int denoting the returned positive and negative samples.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    logits: A float `Tensor` of shape `[batch_size, num_samples]` representing
        the logits for sampled labels.
    labels: A float `Tensor` of shape `[batch_size, num_samples]` with values
        in {0, 1} indicating if the sample is positive or negative.
    keys: A string `Tensor` of shape `[batch_size, num_samples]` representing
        the keys for each sample.
    mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask
        of each batch. For example, if all keys in positive_keys[i] are empty,
        mask[i] = 0; otherwise mask[i] = 1.
    weights: A float `Tensor` representing the embeddings of the sampled keys.

  Raises:
    ValueError: If var_name is not specified.
    TypeError: If de_config is an instance of DynamicEmbeddingConfig.
  """
    if not var_name:
        raise ValueError("Must specify a valid name, got %s" % var_name)
    if num_samples < 1:
        raise ValueError("Invalid num_samples: %d" % num_samples)

    context.add_to_collection(var_name, de_config)
    resource = gen_carls_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)

    # Create a dummy variable so that the gradients can be passed in.
    grad_placeholder = tf.Variable(0.0)

    keys, labels, expected_counts, mask, weights = (
        gen_carls_ops.sampled_logits_lookup(positive_keys, inputs, num_samples,
                                            grad_placeholder, resource))

    # Compute sampled logits.
    # Shape of weights: [d1, d2, dn-1, num_samples, embed_dim]
    # Shape of inputs: [d1, d2, dn-1, embed_dim]
    # Shape of output logits: [d1, d2, dn-1, num_samples]

    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim]
    tiled_inputs = tf.expand_dims(inputs, axis=-2)
    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim]
    multiples = [1] * (inputs.ndim + 1)
    multiples[-2] = num_samples
    tiled_inputs = tf.tile(tiled_inputs, multiples)
    # [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples]
    logits = tf.reduce_sum(weights * tiled_inputs, -1)
    # Sampled logits.
    logits -= tf.math.log(expected_counts)

    return logits, labels, keys, mask, weights