def run_test(use_sample_shape, use_pfor): with tf.Session(graph=tf.Graph()) as sess: sample_shape = [30, 30] X = tf.Variable(tf.random_normal(sample_shape)) A = tf.Variable(tf.random_uniform(minval=1, maxval=10, shape=sample_shape)) y = tf.multiply(X, A) init = tf.global_variables_initializer() sess.run(init) if use_pfor: # Now let's try to compute the jacobian if use_sample_shape: dydx = diag_jacobian_pfor(xs=X, ys=y, sample_shape=sample_shape, use_pfor=True) else: dydx = diag_jacobian_pfor(xs=X, ys=y, use_pfor=True) else: from tensorflow_probability.python.math import diag_jacobian if use_sample_shape: dydx = diag_jacobian(xs=X, ys=y, sample_shape=sample_shape) else: dydx = diag_jacobian(xs=X, ys=y) sess.run([dydx, A])
def _apply_noisy_update(self, mom, grad, var): # Compute and apply the gradient update following # preconditioned Langevin dynamics stddev = tf.where(tf.squeeze(self._counter > self._burnin), tf.cast(tf.rsqrt(self._learning_rate), grad.dtype), tf.zeros([], grad.dtype)) # Keep an exponentially weighted moving average of squared gradients. # Not thread safe decay_tensor = tf.cast(self._decay_tensor, grad.dtype) new_mom = decay_tensor * mom + (1. - decay_tensor) * tf.square(grad) preconditioner = tf.rsqrt(new_mom + tf.cast(self._diagonal_bias, grad.dtype)) # Compute gradients of the preconsitionaer _, preconditioner_grads = diag_jacobian( xs=var, ys=preconditioner, parallel_iterations=self._parallel_iterations) mean = 0.5 * ( preconditioner * grad * tf.cast(self._data_size, grad.dtype) - preconditioner_grads[0]) stddev *= tf.sqrt(preconditioner) result_shape = tf.broadcast_dynamic_shape(tf.shape(mean), tf.shape(stddev)) with tf.control_dependencies([tf.assign(mom, new_mom)]): return tf.random_normal(shape=result_shape, mean=mean, stddev=stddev, dtype=grad.dtype)
def _apply_noisy_update(self, mom, grad, var): # Compute and apply the gradient update following # preconditioned Langevin dynamics stddev = tf.where( tf.squeeze(self._counter > self._burnin), tf.cast(tf.rsqrt(self._learning_rate), grad.dtype), tf.zeros([], grad.dtype)) # Keep an exponentially weighted moving average of squared gradients. # Not thread safe decay_tensor = tf.cast(self._decay_tensor, grad.dtype) new_mom = decay_tensor * mom + (1. - decay_tensor) * tf.square(grad) preconditioner = tf.rsqrt( new_mom + tf.cast(self._diagonal_bias, grad.dtype)) # Compute gradients of the preconsitionaer _, preconditioner_grads = diag_jacobian( xs=var, ys=preconditioner, parallel_iterations=self._parallel_iterations) mean = 0.5 * (preconditioner * grad * tf.cast(self._data_size, grad.dtype) - preconditioner_grads[0]) stddev *= tf.sqrt(preconditioner) result_shape = tf.broadcast_dynamic_shape(tf.shape(mean), tf.shape(stddev)) with tf.control_dependencies([tf.assign(mom, new_mom)]): return tf.random_normal(shape=result_shape, mean=mean, stddev=stddev, dtype=grad.dtype)
def testPreconditionerComputedCorrectly(self): """Test that SGLD step is computed correctly for a 3D Gaussian energy.""" with self.cached_session(graph=tf.Graph()) as sess: tf.compat.v1.set_random_seed(42) dtype = np.float32 # Target function is the energy function of normal distribution true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]]) # Target distribution is defined through the Cholesky decomposition chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) var_1 = tf.compat.v1.get_variable('var_1', initializer=[1., 1.]) var_2 = tf.compat.v1.get_variable('var_2', initializer=[1.]) var = [var_1, var_2] # Set up the learning rate and the optimizer learning_rate = .5 optimizer_kernel = tfp.optimizer.StochasticGradientLangevinDynamics( learning_rate=learning_rate, burnin=1) # Target function def target_fn(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) - true_mean return -target.log_prob(z) grads = tf.gradients(ys=target_fn(*var), xs=var) # Update value of `var` with one iteration of the SGLD (without the # normal perturbation, since `burnin > 0`) step = optimizer_kernel.apply_gradients(zip(grads, var)) # True theoretical value of `var` after one iteration decay_tensor = tf.cast(optimizer_kernel._decay_tensor, var[0].dtype) diagonal_bias = tf.cast(optimizer_kernel._diagonal_bias, var[0].dtype) learning_rate = tf.cast(optimizer_kernel._learning_rate, var[0].dtype) velocity = [(decay_tensor * tf.ones_like(v) + (1 - decay_tensor) * tf.square(g)) for v, g in zip(var, grads)] preconditioner = [tf.math.rsqrt(vel + diagonal_bias) for vel in velocity] # Compute second order gradients _, grad_grads = diag_jacobian( xs=var, ys=grads) # Compute gradient of the preconsitioner (compute the gradient manually) preconditioner_grads = [-(g * g_g * (1. - decay_tensor) * p**3.) for g, g_g, p in zip(grads, grad_grads, preconditioner)] # True theoretical value of `var` after one iteration var_true = [v - learning_rate * 0.5 * (p * g - p_g) for v, p, g, p_g in zip(var, preconditioner, grads, preconditioner_grads)] sess.run(tf.compat.v1.initialize_all_variables()) var_true_ = sess.run(var_true) sess.run(step) var_ = sess.run(var) # new `var` after one SGLD step self.assertAllClose(var_true_, var_, atol=0.001, rtol=0.001)
def _maybe_call_volatility_fn_and_grads(volatility_fn, state, volatility_fn_results=None, grads_volatility_fn=None, sample_shape=None, parallel_iterations=10): """Helper which computes `volatility_fn` results and grads, if needed.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] needs_volatility_fn_gradients = grads_volatility_fn is None # Convert `volatility_fn_results` to a list if volatility_fn_results is None: volatility_fn_results = volatility_fn(*state_parts) volatility_fn_results = (list(volatility_fn_results) if mcmc_util.is_list_like(volatility_fn_results) else [volatility_fn_results]) if len(volatility_fn_results) == 1: volatility_fn_results *= len(state_parts) if len(state_parts) != len(volatility_fn_results): raise ValueError('`volatility_fn` should return a tensor or a list ' 'of the same length as `current_state`.') # The shape of 'volatility_parts' needs to have the number of chains as a # leading dimension. For determinism we broadcast 'volatility_parts' to the # shape of `state_parts` since each dimension of `state_parts` could have a # different volatility value. volatility_fn_results = _maybe_broadcast_volatility(volatility_fn_results, state_parts) if grads_volatility_fn is None: [ _, grads_volatility_fn, ] = diag_jacobian( xs=state_parts, ys=volatility_fn_results, sample_shape=sample_shape, parallel_iterations=parallel_iterations, fn=volatility_fn) # Compute gradient of `volatility_parts**2` if needs_volatility_fn_gradients: grads_volatility_fn = [ 2. * g * volatility if g is not None else tf.zeros_like( fn_arg, dtype=fn_arg.dtype.base_dtype) for g, volatility, fn_arg in zip( grads_volatility_fn, volatility_fn_results, state_parts) ] return volatility_fn_results, grads_volatility_fn
def augmented_ode_fn(time, state_log_det_jac): """Computes both time derivative and trace of the jacobian.""" state, _ = state_log_det_jac ode_fn_with_time = lambda x: ode_fn(time, x) batch_shape = [prefer_static.size0(state)] state_time_derivative, diag_jac = tfp_math.diag_jacobian( xs=state, fn=ode_fn_with_time, sample_shape=batch_shape) # tfp_math.diag_jacobian returns lists if isinstance(state_time_derivative, list): state_time_derivative = state_time_derivative[0] if isinstance(diag_jac, list): diag_jac = diag_jac[0] trace_value = diag_jac return state_time_derivative, trace_value
def augmented_ode_fn(time, state_log_det_jac): """Computes both time derivative and trace of the jacobian.""" state, _ = state_log_det_jac ode_fn_with_time = lambda x: ode_fn(time, x) batch_shape = [prefer_static.size0(state)] if dv_dt_reg is not None: watched_vars = [time, state] elif (kinetic_reg is not None) or (jacobian_reg is not None): watched_vars = [state] else: watched_vars = [] with tf.GradientTape(watch_accessed_variables=False, persistent=True) as g: #g.watch([time, state]) g.watch(watched_vars) state_time_derivative, diag_jac = tfp_math.diag_jacobian( xs=state, fn=ode_fn_with_time, sample_shape=batch_shape) # tfp_math.diag_jacobian returns lists if isinstance(state_time_derivative, list): state_time_derivative = state_time_derivative[0] if isinstance(diag_jac, list): diag_jac = diag_jac[0] trace_value = diag_jac # Calculate regularization terms if (dv_dt_reg is not None) or (jacobian_reg is not None): delv_delx = g.batch_jacobian(state_time_derivative, state) if dv_dt_reg is not None: delv_delt = g.gradient(state_time_derivative, time) vnabla_v = tf.linalg.matvec(delv_delx, state_time_derivative) dv_dt = delv_delt + vnabla_v #print('dv/dt :', dv_dt) trace_value = trace_value - dv_dt_reg * dv_dt**2 if kinetic_reg is not None: #print('v :', state_time_derivative.shape) trace_value = trace_value - kinetic_reg * state_time_derivative**2 if jacobian_reg is not None: jacobian_norm2 = tf.math.reduce_sum(delv_delx**2, axis=-1) #print('|J|^2 :', jacobian_norm2.shape) trace_value = trace_value - jacobian_reg * jacobian_norm2 return state_time_derivative, trace_value
def _apply_noisy_update(self, mom, grad, var, indices=None): # Compute and apply the gradient update following # preconditioned Langevin dynamics stddev = tf1.where( tf.squeeze(self.iterations > tf.cast(self._burnin, tf.int64)), tf.cast(tf.math.rsqrt(self._learning_rate), grad.dtype), tf.zeros([], grad.dtype)) # Keep an exponentially weighted moving average of squared gradients. # Not thread safe decay_tensor = tf.cast(self._decay_tensor, grad.dtype) new_mom = decay_tensor * mom + (1. - decay_tensor) * tf.square(grad) preconditioner = tf.math.rsqrt( new_mom + tf.cast(self._diagonal_bias, grad.dtype)) # Compute gradients of the preconditioner. # Note: Since the preconditioner depends indirectly on `var` through `grad`, # in Eager mode, `diag_jacobian` would need access to the loss function. # This is the only blocker to supporting Eager mode for the SGLD optimizer. _, preconditioner_grads = diag_jacobian( xs=var, ys=preconditioner, parallel_iterations=self._parallel_iterations) mean = 0.5 * ( preconditioner * grad * tf.cast(self._data_size, grad.dtype) - preconditioner_grads[0]) stddev *= tf.sqrt(preconditioner) result_shape = tf.broadcast_dynamic_shape(tf.shape(input=mean), tf.shape(input=stddev)) update_ops = [] if indices is None: update_ops.append(mom.assign(new_mom)) else: update_ops.append( self._resource_scatter_update(mom, indices, new_mom)) with tf.control_dependencies(update_ops): return tf.random.normal(shape=result_shape, mean=mean, stddev=stddev, dtype=grad.dtype)
def testPreconditionerComputedCorrectly(self): """Test that SGLD step is computed correctly for a 3D Gaussian energy.""" with self.test_session(graph=tf.Graph()) as sess: tf.set_random_seed(42) dtype = np.float32 # Target function is the energy function of normal distribution true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]]) # Target distribution is defined through the Cholesky decomposition chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) var_1 = tf.get_variable( 'var_1', initializer=[1., 1.]) var_2 = tf.get_variable( 'var_2', initializer=[1.]) var = [var_1, var_2] # Set up the learning rate and the optimizer learning_rate = .5 optimizer_kernel = tfp.optimizer.StochasticGradientLangevinDynamics( learning_rate=learning_rate, burnin=1) # Target function def target_fn(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) - true_mean return -target.log_prob(z) grads = tf.gradients(target_fn(*var), var) # Update value of `var` with one iteration of the SGLD (without the # normal perturbation, since `burnin > 0`) step = optimizer_kernel.apply_gradients(zip(grads, var)) # True theoretical value of `var` after one iteration decay_tensor = tf.cast(optimizer_kernel._decay_tensor, var[0].dtype) diagonal_bias = tf.cast(optimizer_kernel._diagonal_bias, var[0].dtype) learning_rate = tf.cast(optimizer_kernel._learning_rate, var[0].dtype) velocity = [(decay_tensor * tf.ones_like(v) + (1 - decay_tensor) * tf.square(g)) for v, g in zip(var, grads)] preconditioner = [tf.rsqrt(vel + diagonal_bias) for vel in velocity] # Compute second order gradients _, grad_grads = diag_jacobian( xs=var, ys=grads) # Compute gradient of the preconsitioner (compute the gradient manually) preconditioner_grads = [-(g * g_g * (1. - decay_tensor) * p**3.) for g, g_g, p in zip(grads, grad_grads, preconditioner)] # True theoretical value of `var` after one iteration var_true = [v - learning_rate * 0.5 * (p * g - p_g) for v, p, g, p_g in zip(var, preconditioner, grads, preconditioner_grads)] sess.run(tf.initialize_all_variables()) var_true_ = sess.run(var_true) sess.run(step) var_ = sess.run(var) # new `var` after one SGLD step self.assertAllClose(var_true_, var_, atol=0.001, rtol=0.001)