def get_rebar_gradient(self): """Get the rebar gradient.""" hardELBO, nvil_gradient, logQHard = self._create_hard_elbo() if self.hparams.quadratic: gumbel_cv, _ = self._create_gumbel_control_variate_quadratic(logQHard) else: gumbel_cv, _ = self._create_gumbel_control_variate(logQHard) f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient)) eta = {} h_grads, eta_statistics = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)), eta) model_grads = U.add_grads_and_vars(f_grads, h_grads) total_grads = model_grads # Construct the variance objective variance_objective = tf.reduce_mean(tf.square(U.vectorize(model_grads, set_none_to_zero=True))) debug = { 'ELBO': hardELBO, 'etas': eta_statistics, 'variance_objective': variance_objective, } return total_grads, debug, variance_objective
def get_rebar_gradient(self): """Get the rebar gradient.""" hardELBO, nvil_gradient, logQHard = self._create_hard_elbo() if self.hparams.quadratic: gumbel_cv, _ = self._create_gumbel_control_variate_quadratic( logQHard) else: gumbel_cv, _ = self._create_gumbel_control_variate(logQHard) f_grads = self.optimizer_class.compute_gradients( tf.reduce_mean(-nvil_gradient)) eta = {} h_grads, eta_statistics = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)), eta) model_grads = U.add_grads_and_vars(f_grads, h_grads) total_grads = model_grads # Construct the variance objective variance_objective = tf.reduce_mean( tf.square(U.vectorize(model_grads, set_none_to_zero=True))) debug = { 'ELBO': hardELBO, 'etas': eta_statistics, 'variance_objective': variance_objective, } return total_grads, debug, variance_objective
def get_dynamic_rebar_gradient(self): """Get the dynamic rebar gradient (t, eta optimized).""" tiled_pre_temperature = tf.tile([self.pre_temperature_variable], [self.batch_size]) temperature = tf.exp(tiled_pre_temperature) hardELBO, nvil_gradient, logQHard = self._create_hard_elbo() if self.hparams.quadratic: gumbel_cv, extra = self._create_gumbel_control_variate_quadratic(logQHard, temperature=temperature) else: gumbel_cv, extra = self._create_gumbel_control_variate(logQHard, temperature=temperature) f_grads = self.optimizer_class.compute_gradients(tf.reduce_mean(-nvil_gradient)) eta = {} h_grads, eta_statistics = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(gumbel_cv)), eta) model_grads = U.add_grads_and_vars(f_grads, h_grads) total_grads = model_grads # Construct the variance objective g = U.vectorize(model_grads, set_none_to_zero=True) self.maintain_ema_ops.append(self.ema.apply([g])) gbar = 0 #tf.stop_gradient(self.ema.average(g)) variance_objective = tf.reduce_mean(tf.square(g - gbar)) reinf_g_t = 0 if self.hparams.quadratic: for layer in xrange(self.hparams.n_layer): gumbel_learning_signal, _ = extra[layer] df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0] reinf_g_t_i, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * logQHard[layer])), eta) reinf_g_t += U.vectorize(reinf_g_t_i, set_none_to_zero=True) reparam = tf.add_n([reparam_i for _, reparam_i in extra]) else: gumbel_learning_signal, reparam = extra df_dt = tf.gradients(gumbel_learning_signal, tiled_pre_temperature)[0] reinf_g_t, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(tf.stop_gradient(df_dt) * tf.add_n(logQHard))), eta) reinf_g_t = U.vectorize(reinf_g_t, set_none_to_zero=True) reparam_g, _ = self.multiply_by_eta_per_layer( self.optimizer_class.compute_gradients(tf.reduce_mean(reparam)), eta) reparam_g = U.vectorize(reparam_g, set_none_to_zero=True) reparam_g_t = tf.gradients(tf.reduce_mean(2*tf.stop_gradient(g - gbar)*reparam_g), self.pre_temperature_variable)[0] variance_objective_grad = tf.reduce_mean(2*(g - gbar)*reinf_g_t) + reparam_g_t debug = { 'ELBO': hardELBO, 'etas': eta_statistics, 'variance_objective': variance_objective, } return total_grads, debug, variance_objective, variance_objective_grad