def testGradientsThroughSample(self, process_name, data): tfp_hps.guitar_skip_if_matches('VariationalGaussianProcess', process_name, 'b/147770193') process = data.draw( stochastic_processes(process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) with tf.GradientTape() as tape: sample = process.sample() if process.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, process.variables) for grad, var in zip(grads, process.variables): self.assertIsNotNone( grad, 'Grad of sample was `None` for var: {}.'.format(var))
def testGradientsThroughLogProb(self, process_name, data): tfp_hps.guitar_skip_if_matches('VariationalGaussianProcess', process_name, 'b/147770193') process = data.draw( stochastic_processes(process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) # Test that log_prob produces non-None gradients. sample = process.sample() with tf.GradientTape() as tape: lp = process.log_prob(sample) grads = tape.gradient(lp, process.variables) for grad, var in zip(grads, process.variables): self.assertIsNotNone( grad, 'Grad of log_prob was `None` for var: {}.'.format(var))
def testExcessiveConcretizationInLogProb(self, process_name, data): # Check that log_prob computations avoid reading process parameters # more than once. tfp_hps.guitar_skip_if_matches( 'VariationalGaussianProcess', process_name, 'b/147770193') process = data.draw(stochastic_processes( process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) hp.note('Testing excessive var usage in {}.log_prob'.format(process_name)) sample = process.sample() try: with tfp_hps.assert_no_excessive_var_usage( 'method `log_prob` of `{}`'.format(process), max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)): process.log_prob(sample) except NotImplementedError: pass
def testExcessiveConcretizationInZeroArgPublicMethods( self, process_name, data): tfp_hps.guitar_skip_if_matches( 'VariationalGaussianProcess', process_name, 'b/147770193') # Check that standard statistics do not concretize variables/deferred # tensors more than the allowed amount. process = data.draw(stochastic_processes(process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']: hp.note('Testing excessive concretization in {}.{}'.format(process_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'method `{}` of `{}`'.format(stat, process), max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)): getattr(process, stat)() except NotImplementedError: pass
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') bijector, event_dim = self._draw_bijector(bijector_name, data) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. xs = self._draw_domain_tensor(bijector, data, event_dim) wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. # TODO(b/148459057): Except, don't verify Softfloor on Guitar because # of numerical problem. def exception(bijector): if not tfp_hps.running_under_guitar(): return False if isinstance(bijector, tfb.Softfloor): return True if is_invert(bijector): return exception(bijector.bijector) return False if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0 and not exception(bijector)): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=True) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. ys = self._draw_codomain_tensor(bijector, data, event_dim) wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=False) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads) # Verify that `_is_permutation` implies constant zero Jacobian. if bijector._is_permutation: self.assertTrue(bijector._is_constant_jacobian) self.assertAllEqual(ldj, 0.) # Verify correctness of batch shape. xs_batch_shapes = tf.nest.map_structure( lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs, bijector.inverse_event_ndims(event_ndims)) empirical_batch_shape = functools.reduce( ps.broadcast_shape, nest.flatten_up_to(bijector.forward_min_event_ndims, xs_batch_shapes)) batch_shape = bijector.experimental_batch_shape( y_event_ndims=event_ndims) if tensorshape_util.is_fully_defined(batch_shape): self.assertAllEqual(empirical_batch_shape, batch_shape) self.assertAllEqual( empirical_batch_shape, bijector.experimental_batch_shape_tensor( y_event_ndims=event_ndims)) # Check that the outputs of forward_dtype and inverse_dtype match the dtypes # of the outputs of forward and inverse. self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype)) self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return event_dim = data.draw(hps.integers(min_value=2, max_value=6)) bijector = data.draw( bijectors(bijector_name=bijector_name, event_dim=event_dim, enable_vars=True)) self.evaluate(tf.group(*[v.initializer for v in bijector.variables])) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. # TODO(axch): Would be nice to get rid of all this shape inference logic and # just rely on a notion of batch and event shape for bijectors, so we can # pass those through `domain_tensors` and `codomain_tensors` and use # `tensors_in_support`. However, `RationalQuadraticSpline` behaves weirdly # somehow and I got confused. codomain_event_shape = [event_dim] * bijector.inverse_min_event_ndims codomain_event_shape = constrain_inverse_shape(bijector, codomain_event_shape) shp = bijector.inverse_event_shape(codomain_event_shape) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.forward_min_event_ndims])), shp[shp.ndims - bijector.forward_min_event_ndims:]) xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)), name='xs') wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_forward_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_inverse_log_det_jacobian') else 4) elif is_transform_diagonal(bijector): max_permitted = (2 if hasattr(bijector.diag_bijector, '_forward_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. domain_event_shape = [event_dim] * bijector.forward_min_event_ndims domain_event_shape = constrain_forward_shape(bijector, domain_event_shape) shp = bijector.forward_event_shape(domain_event_shape) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.inverse_min_event_ndims])), shp[shp.ndims - bijector.inverse_min_event_ndims:]) ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)), name='ys') wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_inverse_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_forward_log_det_jacobian') else 4) elif is_transform_diagonal(bijector): max_permitted = (2 if hasattr(bijector.diag_bijector, '_inverse_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)