Ejemplo n.º 1
0
def embedding_lookup_train(variational_params,
                           ids,
                           name=None,
                           clip_alpha=None,
                           eps=common.EPSILON):
    R"""Embedding trained with variational dropout.

  In a standard embedding lookup, `ids` are looked-up in a list of embedding
  tensors. In an embedding trained with variational dropout, we lookup the
  parameters of the fully-factorized Gaussian posterior over the embedding
  tensor for each index in `ids` and draw a sample from this distribution
  that is returned.

  The `ids` argument is analogous to those in the standard tf.embedding_lookup.

  Args:
    variational_params: 2-tuple of Tensors, where the first tensor is the \theta
      values and the second contains the log of the \sigma^2 values.
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
      in params.
    name: String. Name of the operator.
    clip_alpha: Int or None. If integer, we clip the log \alpha values
      to [-clip_alpha, clip_alpha]. If None, don't clip the values.
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.

  Returns:
    The output Tensor result of the embedding lookup.

  Raises:
    RuntimeError: If the input variational_params is not a 2-tuple of Tensors
      that have the same shape.
  """
    theta, log_sigma2 = _verify_variational_params(variational_params)

    # Before we do anything, lookup the mean and log variances of the embedding
    # vectors we are going to output and do all our operations in this lower
    # dimensional space
    embedding_theta = layer_utils.gather(theta, ids)
    embedding_log_sigma2 = layer_utils.gather(log_sigma2, ids)

    if clip_alpha:
        # Compute the log_alphas and then compute the
        # log_sigma2 again so that we can clip on the
        # log alpha magnitudes
        embedding_log_alpha = common.compute_log_alpha(embedding_log_sigma2,
                                                       embedding_theta, eps,
                                                       clip_alpha)
        embedding_log_sigma2 = common.compute_log_sigma2(
            embedding_log_alpha, embedding_theta, eps)

    # Calculate the standard deviation from the log variance
    embedding_std = tf.sqrt(tf.exp(embedding_log_sigma2) + eps)

    # Output samples from the distribution over the embedding vectors
    output_shape = tf.shape(embedding_std)
    embedding = embedding_theta + embedding_std * tf.random_normal(
        output_shape)
    return tf.identity(embedding, name=name)
Ejemplo n.º 2
0
def embedding_lookup_eval(variational_params,
                          ids,
                          name=None,
                          threshold=3.0,
                          eps=common.EPSILON):
    R"""Evaluation mode embedding trained with variational dropout.

  In a standard embedding lookup, `ids` are looked-up in a list of embedding
  tensors. In an embedding trained with variational dropout, we lookup the
  parameters of the fully-factorized Gaussian posterior over the embedding
  tensor for each index in `ids` and draw a sample from this distribution
  that is returned. At evaluation time, we use the mean of the posterior
  over each embedding tensor instead of sampling.

  The `ids` and `partition_strategy` arguments are analogous to those in the
  standard tf.embedding_lookup.

  Args:
    variational_params: 2-tuple of Tensors, where the first tensor is the \theta
      values and the second contains the log of the \sigma^2 values.
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
      in params.
    name: String. Name of the operator.
    threshold: Weights with a log \alpha_{ij} value greater than this will be
      set to zero.
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.

  Returns:
    The output Tensor result of the embedding lookup.

  Raises:
    RuntimeError: If the input variational_params is not a 2-tuple of Tensors
      that have the same shape.
  """
    theta, log_sigma2 = _verify_variational_params(variational_params)

    # Rather than mask the whole embedding every iteration, we can do a second
    # embedding lookup on the log \sigma2 values, compute the log \alpha values
    # for each output embedding vector, and then mask the much lower dimensional
    # output embedding vectors
    embedding_theta = layer_utils.gather(theta, ids)
    embedding_log_sigma2 = layer_utils.gather(log_sigma2, ids)

    # Compute the weight mask by thresholding on the log-space alpha values
    embedding_log_alpha = common.compute_log_alpha(embedding_log_sigma2,
                                                   embedding_theta,
                                                   eps,
                                                   value_limit=None)
    embedding_mask = tf.cast(tf.less(embedding_log_alpha, threshold),
                             tf.float32)

    # Return the masked embedding vectors
    return tf.identity(embedding_theta * embedding_mask, name=name)
Ejemplo n.º 3
0
def embedding_lookup_train(
    weight_parameters,
    ids,
    name=None,
    beta=common.BETA,
    gamma=common.GAMMA,
    zeta=common.ZETA,
    eps=common.EPSILON):
  """Training computation for a l0-regularized embedding lookup.

  Args:
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
      in params.
    name: String. Name of the operator.
    beta: The beta parameter, which controls the "temperature" of
      the distribution. Defaults to 2/3 from the above paper.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.
    eps: Small constant value to use in log and sqrt operations to avoid NaNs.

  Returns:
    Output Tensor of the embedding lookup.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  theta, log_alpha = _verify_weight_parameters(weight_parameters)

  # Before we do anything, lookup the theta values and log_alpha
  # values so that we can do our sampling and weight scaling in
  # the lower dimensional output batch
  embedding_theta = layer_utils.gather(theta, ids)
  embedding_log_alpha = layer_utils.gather(log_alpha, ids)

  # Sample the z values for the output batch from the hard-concrete
  embedding_noise = common.hard_concrete_sample(
      embedding_log_alpha,
      beta,
      gamma,
      zeta,
      eps)
  return tf.identity(embedding_theta * embedding_noise, name=name)
Ejemplo n.º 4
0
def embedding_lookup_eval(
    weight_parameters,
    ids,
    name=None,
    gamma=common.GAMMA,
    zeta=common.ZETA):
  """Evaluation computation for a l0-regularized embedding lookup.

  Args:
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
      in params.
    name: String. Name of the operator.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.

  Returns:
    Output Tensor of the embedding lookup.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  theta, log_alpha = _verify_weight_parameters(weight_parameters)

  # Before we do anything, lookup the theta values and log_alpha
  # values so that we can do our sampling and weight scaling in
  # the lower dimensional output batch
  embedding_theta = layer_utils.gather(theta, ids)
  embedding_log_alpha = layer_utils.gather(log_alpha, ids)

  # Calculate the mean of the learned hard-concrete distribution
  # and scale the output embedding vectors
  embedding_noise = common.hard_concrete_mean(
      embedding_log_alpha,
      gamma,
      zeta)
  return tf.identity(embedding_theta * embedding_noise, name=name)