def __init__(self, decay, local=True, differentiable=False, name='snt_moving_average'): super(MovingAverage, self).__init__(name=name) self._differentiable = differentiable self._moving_average = snt.MovingAverage(decay=decay, local=local, name=name)
def moving_averages_baseline(dist, dist_samples, function, decay=0.9, grad_loss_fn=None): loss = tf.reduce_mean(function(dist_samples)) moving_avg = tf.stop_gradient(snt.MovingAverage(decay=decay)(loss)) control_variate = moving_avg expected_control_variate = moving_avg surrogate_cv, jacobians = grad_loss_fn(lambda x: moving_avg, dist_samples, dist) # Note: this has no effect on the gradient in the pathwise case. return control_variate, expected_control_variate, surrogate_cv, jacobians
def compute_control_variate_coeff(dist, dist_var, model_loss_fn, grad_loss_fn, control_variate_fn, num_samples, moving_averages=False, eps=1e-3): r"""Computes the control variate coefficients for the given variable. The coefficient is given by: \sum_k cov(df/d var_k, dcv/d var_k) / (\sum var(dcv/d var_k) + eps) Where var_k is the k'th element of the variable dist_var. The covariance and variance calculations are done from samples obtained from the distribution `dist`. Args: dist: a tfp.distributions.Distribution instance. dist_var: the variable for which we are interested in computing the coefficient. The distribution samples should depend on these variables. model_loss_fn: A function with signature: lambda samples: f(samples). The model loss function. grad_loss_fn: The gradient estimator function. Needs to return both a surrogate loss and a dictionary of jacobians. control_variate_fn: The surrogate control variate function. Its gradient will be used as a control variate. num_samples: Int. The number of samples to use for the cov/var calculation. moving_averages: Bool. Whether or not to use moving averages for the calculation. eps: Float. Used to stabilize division. Returns: a tf.Tensor of rank 0. The coefficient for the input variable. """ # Resample to avoid biased gradients. cv_dist_samples = dist.sample(num_samples) cv_jacobians = control_variate_fn(dist, cv_dist_samples, model_loss_fn, grad_loss_fn=grad_loss_fn)[-1] loss_jacobians = grad_loss_fn(model_loss_fn, cv_dist_samples, dist)[-1] cv_jacobians = cv_jacobians[dist_var] loss_jacobians = loss_jacobians[dist_var] # Num samples x num_variables utils.assert_rank(loss_jacobians, 2) # Num samples x num_variables utils.assert_rank(cv_jacobians, 2) mean_f = tf.reduce_mean(loss_jacobians, axis=0) mean_cv, var_cv = tf.nn.moments(cv_jacobians, axes=[0]) cov = tf.reduce_mean((loss_jacobians - mean_f) * (cv_jacobians - mean_cv), axis=0) utils.assert_rank(var_cv, 1) utils.assert_rank(cov, 1) # Compute the coefficients which minimize variance. # Since we want to minimize the variances across parameter dimensions, # the optimal # coefficients are given by the sum of covariances per # dimensions over the sum of variances per dimension. cv_coeff = tf.reduce_sum(cov) / (tf.reduce_sum(var_cv) + eps) cv_coeff = tf.stop_gradient(cv_coeff) utils.assert_rank(cv_coeff, 0) if moving_averages: cv_coeff = tf.stop_gradient(snt.MovingAverage(decay=0.9)(cv_coeff)) return cv_coeff