def add_to_collection(name: Text, config: de_config_pb2.DynamicEmbeddingConfig): """Adds given (name, config) pair to global collectionss. Args: name: A string denoting the variable name. config: An instance of DynamicEmbeddingConfig. Raises: TypeError: Invalid input. ValueError: Name is empty, or a different config is added for an existing variable. """ if not name: raise ValueError("Empty name.") if not isinstance(config, de_config_pb2.DynamicEmbeddingConfig): raise TypeError("Config is not an instance of DynamicEmbeddingConfig.") if name in _knowledge_bank_collections.keys(): existing_config = _knowledge_bank_collections[name] if config.SerializeToString() != existing_config.SerializeToString(): raise ValueError( "Adding a new config for the same var name is not allowed, existing:" " %r, new: %r." % (existing_config, config)) with _lock: _knowledge_bank_collections[ name] = de_config_pb2.DynamicEmbeddingConfig() _knowledge_bank_collections[name].CopyFrom(config)
def __init__(self, config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", timeout_ms: int = -1): """Constructor for DynamicEmbeddingLookup. Args: config: A DynamicEmbeddingConfig proto that configures the embedding. var_name: A unique name for the given embedding. service_address: The address of a knowledge bank service. If empty, the value passed from --kbs_address (defined in .../carls/dynamic_embedding_manager.cc) flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Raises: ValueError: if var_name is `None` or empty. """ super(DynamicEmbeddingLookup, self).__init__() if not var_name: raise ValueError("Must specify a non-empty var_name.") self.embedding_dimension = config.embedding_dimension context.add_to_collection(var_name, config) self.resource = gen_carls_ops.dynamic_embedding_manager_resource( config.SerializeToString(), var_name, service_address, timeout_ms)
def dynamic_embedding_update(keys: tf.Tensor, values: tf.Tensor, config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", timeout_ms: int = -1): """Updates the embeddings of given keys with given values. Args: keys: A string `Tensor` of shape [batch] or [batch_size, max_sequence_length]. values: A `Tensor` of shape [batch_size, embedding_dimension] or [batch_size, max_sequence_length, embedding_dimension]. config: A DynamicEmbeddingConfig proto that configs the embedding. var_name: A unique name for the given embedding. service_address: The address of a dynamic embedding service. If empty, the value passed from --kbs_address flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Returns: A `Tensor` of shape with one of below: - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or - [batch_size, max_sequence_length, config.embedding_dimension] if the input is 2D. Raises: TypeError: If var_name is not specified. """ if not var_name: raise TypeError("Must specify a valid var_name.") return gen_dynamic_embedding_ops.dynamic_embedding_update( keys, values, config.SerializeToString(), var_name, service_address, timeout_ms)
def dynamic_gaussian_memory_lookup( inputs: tf.Tensor, mode: typing.Union[int, tf.Tensor], config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", timeout_ms: int = -1): """Applies dynamic Gaussian memory to given inputs. A Gaussian memory assumes the input pattern can be represented by a number of Gaussian clusters. This function returns the closest Gaussian mean, variance and the distance between each data and the closest Guassian center. This function can be used in conjunction with a DynamicNormalization layer in a DNN. The distance between the input and the Gaussian cluster can be used for model uncertainty inferece. Note that the memory data is only based on the last dimension of the input. Hence if the input shape is [d1, d2, ..., dn], it is assumed to contain d1*d2*...*dn-1 data points. Args: inputs: A float `Tensor` of shape [d1, d2, ..., dn] with n > 0. mode: An int or a `Tensor` whose value must be one of {LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW}. config: An instance of DynamicEmbeddingConfig. var_name: A unique name for the given op. service_address: The address of a knowledge bank service. If empty, the value passed from --kbs_address flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Returns: - A `Tensor` with the same shape of input representing the mean values. - A `Tensor` with the same shape of input representing the variance values. - A `Tensor` with the shape [d1, d2, ..., dn-1] representing the distance to the cluster center. - An int `Tensor` with the shape [d1, d2, ..., dn-1] representing the cluster ids. Raises: TypeError: if dm_config is not an instance of DynamicMemoryConfig. ValueError: If layer_name is not specified or mode is not valid. """ if isinstance(mode, int) and mode not in { LOOKUP_WITHOUT_UPDATE, LOOKUP_WITH_UPDATE, LOOKUP_WITH_GROW }: raise ValueError("Invalid mode: %r" % mode) else: # mode is a Tensor mode = tf.cast(mode, tf.int32) if not var_name: raise ValueError("Must specify a valid layer_name.") context.add_to_collection(var_name, config) resource = gen_carls_ops.dynamic_embedding_manager_resource( config.SerializeToString(), var_name, service_address, timeout_ms) return gen_carls_ops.dynamic_gaussian_memory_lookup(inputs, mode, resource)
def dynamic_embedding_lookup(keys: tf.Tensor, config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", skip_gradient_update: bool = False, timeout_ms: int = -1) -> tf.Tensor: """Returns the embeddings of from given keys. Args: keys: A string `Tensor` of shape [batch_size] or [batch_size, max_sequence_length] where an empty string would be mapped to an all zero embedding. config: A DynamicEmbeddingConfig proto that configures the embedding. var_name: A unique name for the given embedding. service_address: The address of a knowledge bank service. If empty, the value passed from --kbs_address flag will be used instead. skip_gradient_update: A boolean indicating if gradient update is needed. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Returns: A `Tensor` of shape with one of below: - [batch_size, config.embedding_dimension] if the input Tensor is 1D, or - [batch_size, max_sequence_length, config.embedding_dimension] if the input is 2D. Raises: ValueError: If name is not specified. """ if not var_name: raise ValueError("Must specify a valid var_name.") # If skip_gradient_update is true, reate a dummy variable so that the # gradients can be passed in. if skip_gradient_update: grad_placeholder = tf.constant(0.0) else: grad_placeholder = tf.Variable(0.0) context.add_to_collection(var_name, config) resource = gen_carls_ops.dynamic_embedding_manager_resource( config.SerializeToString(), var_name, service_address, timeout_ms) return gen_carls_ops.dynamic_embedding_lookup(keys, grad_placeholder, resource, config.embedding_dimension)
def restore_knowledge_bank(config: de_config_pb2.DynamicEmbeddingConfig, var_name: Text, saved_path: Text, service_address: Text = '', timeout_ms: int = -1) -> None: """Restores knowledge bank data (`config`, `name`) from given `saved_path`. Args: config: A DynamicEmbeddingConfig proto that configs the embedding. var_name: A unique name for the given embedding. saved_path: A string representing the saved embedding data. service_address: The address of a dynamic embedding service. If empty, the value passed from --kbs_address flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. """ resource = gen_carls_ops.dynamic_embedding_manager_resource( config.SerializeToString(), var_name, service_address, timeout_ms) gen_carls_ops.restore_knowledge_bank(saved_path, resource)
def top_k(inputs: tf.Tensor, k: int, de_config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", timeout_ms: int = -1): """Computes logits for the top k closest embeddings to the inputs. Args: inputs: A float `Tensor` of shape `[batch_size, dim]` representing the forward activations of the input network. k: An `int` denoting the number of returned keys. de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. var_name: A unique name for the operation. service_address: The address of a dynamic embedding service. If empty, the value passed from --kbs_address flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Returns: keys: A string `Tensor` of shape `[batch_size, k]` representing the top k keys relative to the input. logits: A float `Tensor` of shape `[batch_size, k]` representing the logits for the returned keys. Raises: ValueError: if k is not greater than zero. Note: The (keys, logits) pair returned here should not be used for training as they only represent biased sampling. Instead, use sampled_softmax_loss() for training. """ if not var_name: raise ValueError("Must specify a valid var_name.") if k <= 0: raise ValueError("k must be greater than zero, got %d" % k) context.add_to_collection(var_name, de_config) resource = gen_carls_ops.dynamic_embedding_manager_resource( de_config.SerializeToString(), var_name, service_address, timeout_ms) return gen_carls_ops.topk_lookup(inputs, k, resource)
def compute_sampled_logits(positive_keys, inputs, num_samples: int, de_config: de_config_pb2.DynamicEmbeddingConfig, var_name: typing.Text, service_address: typing.Text = "", timeout_ms: int = -1): """Computes sampled logits from given positive labels. Args: positive_keys: A string `Tensor` of shape `[batch_size, None]` representing input positive keys. inputs: A float `Tensor` of shape `[batch_size, dim]` representing the forward activations of the input network. num_samples: An int denoting the returned positive and negative samples. de_config: A DynamicEmbeddingConfig for configuring the dynamic embedding. var_name: A unique name for the operation. service_address: The address of a dynamic embedding service. If empty, the value passed from --kbs_address flag will be used instead. timeout_ms: Timeout millseconds for the connection. If negative, never timout. Returns: logits: A float `Tensor` of shape `[batch_size, num_samples]` representing the logits for sampled labels. labels: A float `Tensor` of shape `[batch_size, num_samples]` with values in {0, 1} indicating if the sample is positive or negative. keys: A string `Tensor` of shape `[batch_size, num_samples]` representing the keys for each sample. mask: A float `Tensor` of shape `[batch_size]` representing the 0/1 mask of each batch. For example, if all keys in positive_keys[i] are empty, mask[i] = 0; otherwise mask[i] = 1. weights: A float `Tensor` representing the embeddings of the sampled keys. Raises: ValueError: If var_name is not specified. TypeError: If de_config is an instance of DynamicEmbeddingConfig. """ if not var_name: raise ValueError("Must specify a valid name, got %s" % var_name) if num_samples < 1: raise ValueError("Invalid num_samples: %d" % num_samples) context.add_to_collection(var_name, de_config) resource = gen_carls_ops.dynamic_embedding_manager_resource( de_config.SerializeToString(), var_name, service_address, timeout_ms) # Create a dummy variable so that the gradients can be passed in. grad_placeholder = tf.Variable(0.0) keys, labels, expected_counts, mask, weights = ( gen_carls_ops.sampled_logits_lookup(positive_keys, inputs, num_samples, grad_placeholder, resource)) # Compute sampled logits. # Shape of weights: [d1, d2, dn-1, num_samples, embed_dim] # Shape of inputs: [d1, d2, dn-1, embed_dim] # Shape of output logits: [d1, d2, dn-1, num_samples] # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, 1, embed_dim] tiled_inputs = tf.expand_dims(inputs, axis=-2) # [d1, d2, dn-1, embed_dim] -> [d1, d2, dn-1, num_samples, embed_dim] multiples = [1] * (inputs.ndim + 1) multiples[-2] = num_samples tiled_inputs = tf.tile(tiled_inputs, multiples) # [d1, d2, dn-1, num_samples, embed_dim] -> [d1, d2, dn-1, num_samples] logits = tf.reduce_sum(weights * tiled_inputs, -1) # Sampled logits. logits -= tf.math.log(expected_counts) return logits, labels, keys, mask, weights