def dynamic_embedding_lookup(keys: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             skip_gradient_update: bool = False,
                             timeout_ms: int = -1) -> tf.Tensor:
    """Returns the embeddings of from given keys.

  Args:
    keys: A string `Tensor` of shape [batch_size] or [batch_size,
      max_sequence_length] where an empty string would be mapped to an all zero
      embedding.
    config: A DynamicEmbeddingConfig proto that configs the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    skip_gradient_update: A boolean indicating if gradient update is needed.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    ValueError: If name is not specified.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")

    # If skip_gradient_update is true, reate a dummy variable so that the
    # gradients can be passed in.
    if skip_gradient_update:
        grad_placeholder = tf.constant(0.0)
    else:
        grad_placeholder = tf.Variable(0.0)

    context.add_to_collection(var_name, config)
    resource = gen_de_op.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_de_op.dynamic_embedding_lookup(keys, grad_placeholder, resource)
Example #2
0
def restore_knowledge_bank(config: de_config_pb2.DynamicEmbeddingConfig,
                           var_name: Text,
                           saved_path: Text,
                           service_address: Text = '',
                           timeout_ms: int = -1) -> None:
    """Restores knowledge bank data (`config`, `name`) from given `saved_path`.

  Args:
    config: A DynamicEmbeddingConfig proto that configs the embedding.
    var_name: A unique name for the given embedding.
    saved_path: A string representing the saved embedding data.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.
  """
    resource = gen_de_op.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    gen_io_ops.restore_knowledge_bank(saved_path, resource)
Example #3
0
def top_k(inputs: tf.Tensor,
          k: int,
          de_config: de_config_pb2.DynamicEmbeddingConfig,
          var_name: typing.Text,
          service_address: typing.Text = "",
          timeout_ms: int = -1):
    """Computes logits for the top k closest embeddings to the inputs.

  Args:
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    k: An `int` denoting the number of returned keys.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    keys: A string `Tensor` of shape `[batch_size, k]` representing the top k
        keys relative to the input.
    logits: A float `Tensor` of shape `[batch_size, k]` representing the logits
        for the returned keys.

  Raises:
    ValueError: if k is not greater than zero.

  Note: The (keys, logits) pair returned here should not be used for training as
  they only represent biased sampling. Instead, use sampled_softmax_loss()
  for training.
  """
    if not var_name:
        raise ValueError("Must specify a valid var_name.")
    if k <= 0:
        raise ValueError("k must be greater than zero, got %d" % k)

    context.add_to_collection(var_name, de_config)
    resource = de_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)
    return gen_topk_op.topk_lookup(inputs, k, resource)
Example #4
0
def save_knowledge_bank(output_directory: Text,
                        service_address: Text = '',
                        timeout_ms: int = -1,
                        append_timestamp: bool = True,
                        var_names=None):
    """Saves knowledge bank data to given output directory.

  Each knowldge bank data will be saved in a subdir:
  `%output_directory%/%var_name%`.

  Args:
    output_directory: A string representing the output directory path.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.
    append_timestamp: A boolean variable indicating if a timestamped dir should
      be added when saving the data.
    var_names: A list of strings represent list of variable names with dynamic
      embedding data to be saved. If not specified, save all data.

  Returns:
    Path to the saved file.
  """
    if not output_directory:
        raise ValueError('Empty output_directory.')

    saved_paths = []
    for name, config in context.get_all_collection():
        if var_names and (name not in var_names):
            continue
        resource = gen_de_op.dynamic_embedding_manager_resource(
            config.SerializeToString(), name, service_address, timeout_ms)

        saved_path = gen_io_ops.save_knowledge_bank(
            output_directory,
            append_timestamp=append_timestamp,
            handle=resource)
        saved_paths.append(saved_path)

    return saved_paths
def dynamic_embedding_update(keys: tf.Tensor,
                             values: tf.Tensor,
                             config: de_config_pb2.DynamicEmbeddingConfig,
                             var_name: typing.Text,
                             service_address: typing.Text = "",
                             timeout_ms: int = -1):
    """Updates the embeddings of given keys with given values.

  Args:
    keys: A string `Tensor` of shape [batch] or [batch_size,
      max_sequence_length].
    values: A `Tensor` of shape [batch_size, embedding_dimension] or
      [batch_size, max_sequence_length, embedding_dimension].
    config: A DynamicEmbeddingConfig proto that configs the embedding.
    var_name: A unique name for the given embedding.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A `Tensor` of shape with one of below:
    - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
    - [batch_size, max_sequence_length, config.embedding_dimension] if the
      input is 2D.
  Raises:
    TypeError: If var_name is not specified.
  """
    if not var_name:
        raise TypeError("Must specify a valid var_name.")

    context.add_to_collection(var_name, config)
    resource = gen_de_op.dynamic_embedding_manager_resource(
        config.SerializeToString(), var_name, service_address, timeout_ms)

    return gen_de_op.dynamic_embedding_update(keys, values, resource)
Example #6
0
def compute_sampled_logits(positive_keys,
                           inputs,
                           num_samples: int,
                           de_config: de_config_pb2.DynamicEmbeddingConfig,
                           var_name: typing.Text,
                           service_address: typing.Text = "",
                           timeout_ms: int = -1):
    """Computes sampled logits from given positive labels.

  Args:
    positive_keys: A string `Tensor` of shape `[batch_size, None]` representing
      input positive keys.
    inputs: A float `Tensor` of shape `[batch_size, dim]` representing the
      forward activations of the input network.
    num_samples: An int denoting the returned positive and negative samples.
    de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding.
    var_name: A unique name for the operation.
    service_address: The address of a dynamic embedding service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    logits: A float `Tensor` of shape `[batch_size, num_samples]` representing
        the logits for sampled labels.
    labels: A float `Tensor` of shape `[batch_size, num_samples]` with values
        in {0, 1} indicating if the sample is positive or negative.
    keys: A string `Tensor` of shape `[batch_size, num_samples]` representing
        the keys for each sample.
    mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask
        of each batch. For example, if all keys in positive_keys[i] are empty,
        mask[i] = 0; otherwise mask[i] = 1.
    weights: A float `Tensor` representing the embeddings of the sampled keys.

  Raises:
    ValueError: If var_name is not specified.
    TypeError: If de_config is an instance of DynamicEmbeddingConfig.
  """
    if not var_name:
        raise ValueError("Must specify a valid name, got %s" % var_name)
    if num_samples < 1:
        raise ValueError("Invalid num_samples: %d" % num_samples)

    context.add_to_collection(var_name, de_config)
    resource = de_ops.dynamic_embedding_manager_resource(
        de_config.SerializeToString(), var_name, service_address, timeout_ms)

    # Create a dummy variable so that the gradients can be passed in.
    grad_placeholder = tf.Variable(0.0)

    keys, labels, expected_counts, mask, weights = (
        gen_sampled_logits_ops.sampled_logits_lookup(positive_keys, inputs,
                                                     num_samples,
                                                     grad_placeholder,
                                                     resource))

    # Compute sampled logits.
    # Shape of weights: [d1, d2, dn-1, num_samples, embed_dim]
    # Shape of inputs: [d1, d2, dn-1, embed_dim]
    # Shape of output logits: [d1, d2, dn-1, num_samples]

    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim]
    tiled_inputs = tf.expand_dims(inputs, axis=-2)
    # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim]
    multiples = [1] * (inputs.ndim + 1)
    multiples[-2] = num_samples
    tiled_inputs = tf.tile(tiled_inputs, multiples)
    # [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples]
    logits = tf.reduce_sum(weights * tiled_inputs, -1)
    # Sampled logits.
    logits -= tf.math.log(expected_counts)

    return logits, labels, keys, mask, weights