예제 #1
0
 def __init__(self,
              decay,
              local=True,
              differentiable=False,
              name='snt_moving_average'):
     super(MovingAverage, self).__init__(name=name)
     self._differentiable = differentiable
     self._moving_average = snt.MovingAverage(decay=decay,
                                              local=local,
                                              name=name)
예제 #2
0
def moving_averages_baseline(dist,
                             dist_samples,
                             function,
                             decay=0.9,
                             grad_loss_fn=None):
    loss = tf.reduce_mean(function(dist_samples))
    moving_avg = tf.stop_gradient(snt.MovingAverage(decay=decay)(loss))
    control_variate = moving_avg
    expected_control_variate = moving_avg
    surrogate_cv, jacobians = grad_loss_fn(lambda x: moving_avg, dist_samples,
                                           dist)
    # Note: this has no effect on the gradient in the pathwise case.
    return control_variate, expected_control_variate, surrogate_cv, jacobians
예제 #3
0
def compute_control_variate_coeff(dist,
                                  dist_var,
                                  model_loss_fn,
                                  grad_loss_fn,
                                  control_variate_fn,
                                  num_samples,
                                  moving_averages=False,
                                  eps=1e-3):
    r"""Computes the control variate coefficients for the given variable.

  The coefficient is given by:
    \sum_k cov(df/d var_k, dcv/d var_k) / (\sum var(dcv/d var_k) + eps)

  Where var_k is the k'th element of the variable dist_var.
  The covariance and variance calculations are done from samples obtained
  from the distribution `dist`.

  Args:
    dist: a tfp.distributions.Distribution instance.
    dist_var: the variable for which we are interested in computing the
      coefficient.
      The distribution samples should depend on these variables.
    model_loss_fn: A function with signature: lambda samples: f(samples).
      The model loss function.
    grad_loss_fn: The gradient estimator function.
      Needs to return both a surrogate loss and a dictionary of jacobians.
    control_variate_fn: The surrogate control variate function. Its gradient
      will be used as a control variate.
    num_samples: Int. The number of samples to use for the cov/var calculation.
    moving_averages: Bool. Whether or not to use moving averages for the
      calculation.
    eps: Float. Used to stabilize division.

  Returns:
    a tf.Tensor of rank 0. The coefficient for the input variable.
  """
    # Resample to avoid biased gradients.

    cv_dist_samples = dist.sample(num_samples)
    cv_jacobians = control_variate_fn(dist,
                                      cv_dist_samples,
                                      model_loss_fn,
                                      grad_loss_fn=grad_loss_fn)[-1]
    loss_jacobians = grad_loss_fn(model_loss_fn, cv_dist_samples, dist)[-1]

    cv_jacobians = cv_jacobians[dist_var]
    loss_jacobians = loss_jacobians[dist_var]
    # Num samples x num_variables
    utils.assert_rank(loss_jacobians, 2)
    # Num samples x num_variables
    utils.assert_rank(cv_jacobians, 2)

    mean_f = tf.reduce_mean(loss_jacobians, axis=0)
    mean_cv, var_cv = tf.nn.moments(cv_jacobians, axes=[0])

    cov = tf.reduce_mean((loss_jacobians - mean_f) * (cv_jacobians - mean_cv),
                         axis=0)

    utils.assert_rank(var_cv, 1)
    utils.assert_rank(cov, 1)

    # Compute the coefficients which minimize variance.
    # Since we want to minimize the variances across parameter dimensions,
    # the optimal # coefficients are given by the sum of covariances per
    # dimensions over the sum of variances per dimension.
    cv_coeff = tf.reduce_sum(cov) / (tf.reduce_sum(var_cv) + eps)
    cv_coeff = tf.stop_gradient(cv_coeff)
    utils.assert_rank(cv_coeff, 0)
    if moving_averages:
        cv_coeff = tf.stop_gradient(snt.MovingAverage(decay=0.9)(cv_coeff))

    return cv_coeff