def testExcessiveConcretizationOfParams(self): loc = tfp_hps.defer_and_count_usage( tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape)) scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape)) bij_scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape)) event_shape = tfp_hps.defer_and_count_usage( tf.Variable([2, 2], name='input_event_shape', dtype=tf.int32, shape=self.shape)) batch_shape = tfp_hps.defer_and_count_usage( tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32, shape=self.shape)) dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True), bijector=tfb.Scale(scale=bij_scale, validate_args=True), event_shape=event_shape, batch_shape=batch_shape, validate_args=True) for method in ('mean', 'entropy', 'event_shape_tensor', 'batch_shape_tensor'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage( 'sample', max_permissible=self.max_permissible['sample']): dist.sample(seed=test_util.test_seed()) for method in ('log_prob', 'prob'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)(np.ones((4, 3, 5, 2, 2)) / 3.)
def testExcessiveConcretizationWithDefaultReinterpretedBatchNdims(self): loc = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((5, 2, 3)), shape=tf.TensorShape(None))) scale = tfp_hps.defer_and_count_usage( tf.Variable(np.ones([]), shape=tf.TensorShape(None))) dist = tfd.Independent( tfd.Logistic(loc=loc, scale=scale, validate_args=True), reinterpreted_batch_ndims=None, validate_args=True) for method in ('batch_shape_tensor', 'event_shape_tensor', 'mean', 'variance', 'sample'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4): getattr(dist, method)() # In addition to the four reads of `loc`, `scale` described above in # `testExcessiveConcretizationOfParams`, the methods below have two more # reads of these parameters -- from computing a default value for # `reinterpreted_batch_ndims`, which requires calling # `dist.distribution.batch_shape_tensor()`. for method in ('log_prob', 'log_cdf', 'prob', 'cdf'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=6): getattr(dist, method)(np.zeros((4, 5, 2, 3))) with tfp_hps.assert_no_excessive_var_usage('entropy', max_permissible=6): dist.entropy() # `Distribution.survival_function` and `Distribution.log_survival_function` # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in # one additional call to `Independent._parameter_control_dependencies`, # and thus two additional concretizations of the parameters. for method in ('survival_function', 'log_survival_function'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=8): getattr(dist, method)(np.zeros((4, 5, 2, 3)))
def testExcessiveConcretizationOfParamsBatchShapeOverride(self): # Test methods that are not implemented if event_shape is overriden. loc = tfp_hps.defer_and_count_usage( tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape)) scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape)) bij_scale = tfp_hps.defer_and_count_usage( tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape)) batch_shape = tfp_hps.defer_and_count_usage( tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32, shape=self.shape)) dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True), bijector=tfb.Scale(scale=bij_scale, validate_args=True), batch_shape=batch_shape, validate_args=True) for method in ('log_cdf', 'cdf', 'survival_function', 'log_survival_function'): with tfp_hps.assert_no_excessive_var_usage( method, max_permissible=self.max_permissible[method]): getattr(dist, method)(np.ones((4, 3, 2)) / 3.) with tfp_hps.assert_no_excessive_var_usage( 'quantile', max_permissible=self.max_permissible['quantile']): dist.quantile(.1)
def testExcessiveConcretizationOfParams(self): logits = tfp_hps.defer_and_count_usage( self._build_variable(np.zeros((4, 4, 5)), name='logits')) concentration = tfp_hps.defer_and_count_usage( self._build_variable(np.zeros((4, 4, 5, 3)), name='concentration')) dist = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=logits), components_distribution=tfd.Dirichlet(concentration=concentration), validate_args=True) # Many methods use mixture_distribution and components_distribution at most # once, and thus incur no extra reads/concretizations of parameters. for method in ('batch_shape_tensor', 'event_shape_tensor', 'mean'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=2): dist.sample(seed=test_util.test_seed()) for method in ('log_prob', 'prob'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2): getattr(dist, method)(np.ones((4, 4, 3)) / 3.) # TODO(b/140579567): The `variance()` and `covariance()` methods require # calling both: # - `self.components_distribution.mean()` # - `self.components_distribution.variance()` or `.covariance()` # Thus, these methods incur an additional concretization (or two if # `validate_args=True` for `self.components_distribution`). for method in ('variance', 'covariance'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=3): getattr(dist, method)()
def testExcessiveConcretizationOfParams(self): logits = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((3, 5, 2)), dtype=tf.float32, shape=tf.TensorShape([None, None, 2]), name='logits')) concentration = tfp_hps.defer_and_count_usage( tf.Variable(np.ones((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None), name='concentration')) loc = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((3, 5, 4)), dtype=tf.float32, shape=tf.TensorShape(None), name='loc')) scale = tfp_hps.defer_and_count_usage( tf.Variable(1., dtype=tf.float32, shape=tf.TensorShape(None), name='scale')) dist = tfd.Mixture(tfd.Categorical(logits=logits), components=[ tfd.Dirichlet(concentration), tfd.Independent(tfd.Normal(loc=loc, scale=scale), reinterpreted_batch_ndims=1) ], use_static_graph=self.use_static_graph, validate_args=True) for method in ('batch_shape_tensor', 'event_shape_tensor', 'entropy_lower_bound'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=2): dist.sample(seed=test_util.test_seed()) for method in ('prob', 'log_prob'): with tfp_hps.assert_no_excessive_var_usage('method', max_permissible=2): getattr(dist, method)(tf.ones((3, 5, 4)) / 4.) # TODO(b/140579567): The `stddev()` and `variance()` methods require # calling both: # - `self.components[i].mean()` # - `self.components[i].stddev()` # Thus, these methods incur an additional concretization (or two if # `validate_args=True` for `self.components[i]`). for method in ('stddev', 'variance'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=3): getattr(dist, method)()
def testConcretizationLimits(self): shape_out = tfp_hps.defer_and_count_usage(tf.Variable([1])) reshape = tfb.Reshape(shape_out, validate_args=True) x = [1] # Pun: valid input or output, and valid input or output shape for method in ['forward', 'inverse', 'forward_event_shape', 'inverse_event_shape', 'forward_event_shape_tensor', 'inverse_event_shape_tensor']: with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=7): getattr(reshape, method)(x) for method in ['forward_log_det_jacobian', 'inverse_log_det_jacobian']: with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4): getattr(reshape, method)(x, event_ndims=1)
def testExcessiveConcretizationOfParams(self): loc = tfp_hps.defer_and_count_usage( tf.Variable(np.zeros((4, 2, 2)), shape=tf.TensorShape(None))) scale = tfp_hps.defer_and_count_usage( tf.Variable(np.ones([]), shape=tf.TensorShape(None))) ndims = tf.Variable(1, trainable=False, shape=tf.TensorShape(None)) dist = tfd.Independent(tfd.Logistic(loc=loc, scale=scale, validate_args=True), reinterpreted_batch_ndims=ndims, validate_args=True) # TODO(b/140579567): All methods of `dist` may require four concretizations # of parameters `loc` and `scale`: # - `Independent._parameter_control_dependencies` calls # `Logistic.batch_shape_tensor`, which: # * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`. # * Reads `loc`, `scale` in `Logistic._batch_shape_tensor`. # - The method `dist.m` will call `dist.self.m`, which: # * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`. # * Reads `loc`, `scale` in the implementation of method `Logistic._m`. # # NOTE: If `dist.distribution` had dynamic batch shape and event shape, # there could be two more reads of the parameters of `dist.distribution` # in `dist.event_shape_tensor`, from calling # `dist.distribution.event_shape_tensor()`. for method in ('batch_shape_tensor', 'event_shape_tensor', 'mode', 'stddev', 'entropy'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4): getattr(dist, method)() with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=4): dist.sample(seed=test_util.test_seed()) for method in ('log_prob', 'log_cdf', 'prob', 'cdf'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4): getattr(dist, method)(np.zeros((3, 4, 2, 2))) # `Distribution.survival_function` and `Distribution.log_survival_function` # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in # one additional call to `Independent._parameter_control_dependencies`, # and thus two additional concretizations of the parameters. for method in ('survival_function', 'log_survival_function'): with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=6): getattr(dist, method)(np.zeros((3, 4, 2, 2)))
def testExcessiveConcretizationOfParamsWithReparameterization(self): logits = tfp_hps.defer_and_count_usage(self._build_variable( np.zeros(5), name='logits', static_rank=True)) loc = tfp_hps.defer_and_count_usage(self._build_variable( np.zeros((4, 4, 5)), name='loc', static_rank=True)) scale = tfp_hps.defer_and_count_usage(self._build_variable( 1., name='scale', static_rank=True)) dist = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=logits), components_distribution=tfd.Logistic(loc=loc, scale=scale), reparameterize=True, validate_args=True) # TODO(b/140579567): With reparameterization, there are additional reads of # the parameters of the underlying mixture and components distributions when # sampling, from calls in `_distributional_transform` to: # # - `self.mixture_distribution.logits_parameter` # - `self.components_distribution.log_prob` # - `self.components_distribution.cdf` # # NOTE: In the unlikely case that samples have a statically-known rank but # the rank of `self.components_distribution.event_shape` is not known # statically, there can be additional reads in `_distributional_transform` # from calling `self.components_distribution.is_scalar_event`. with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=4): dist.sample(seed=test_util.test_seed())
def testKernelGradient(self, kernel_name, data): event_dim = data.draw(hps.integers(min_value=2, max_value=3)) feature_ndims = data.draw(hps.integers(min_value=1, max_value=2)) feature_dim = data.draw(hps.integers(min_value=2, max_value=4)) batch_shape = data.draw(tfp_hps.shapes(max_ndims=2)) kernel, kernel_parameter_variable_names = data.draw( kernel_hps.kernels(batch_shape=batch_shape, kernel_name=kernel_name, event_dim=event_dim, feature_dim=feature_dim, feature_ndims=feature_ndims, enable_vars=True)) # Check that variable parameters get passed to the kernel.variables kernel_variables_names = [ v.name.strip('_0123456789:') for v in kernel.variables ] kernel_parameter_variable_names = [ n.strip('_0123456789:') for n in kernel_parameter_variable_names ] self.assertEqual(set(kernel_parameter_variable_names), set(kernel_variables_names)) example_ndims = data.draw(hps.integers(min_value=1, max_value=2)) input_batch_shape = data.draw( tfp_hps.broadcast_compatible_shape(kernel.batch_shape)) xs = tf.identity( data.draw( kernel_hps.kernel_input(batch_shape=input_batch_shape, example_ndims=example_ndims, feature_dim=feature_dim, feature_ndims=feature_ndims))) # Check that we pick up all relevant kernel parameters. wrt_vars = [xs] + list(kernel.variables) self.evaluate([v.initializer for v in kernel.variables]) max_permissible = 2 + EXTRA_TENSOR_CONVERSION_KERNELS.get( kernel_name, 0) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `apply` of {}'.format(kernel), max_permissible=max_permissible): tape.watch(wrt_vars) with tfp_hps.no_tf_rank_errors(): diag = kernel.apply(xs, xs, example_ndims=example_ndims) grads = tape.gradient(diag, wrt_vars) assert_no_none_grad(kernel, 'apply', wrt_vars, grads) # Check that copying the kernel works. with tfp_hps.no_tf_rank_errors(): diag2 = self.evaluate(kernel.copy().apply( xs, xs, example_ndims=example_ndims)) self.assertAllClose(diag, diag2)
def testExcessiveConcretizationInLogProb(self, process_name, data): # Check that log_prob computations avoid reading process parameters # more than once. process = data.draw(stochastic_processes( process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) hp.note('Testing excessive var usage in {}.log_prob'.format(process_name)) sample = process.sample() try: with tfp_hps.assert_no_excessive_var_usage( 'method `log_prob` of `{}`'.format(process), max_permissible=excessive_usage_count(process_name)): process.log_prob(sample) except NotImplementedError: pass
def testKernelGradient(self, kernel_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return event_dim = data.draw(hps.integers(min_value=2, max_value=6)) feature_ndims = data.draw(hps.integers(min_value=1, max_value=4)) feature_dim = data.draw(hps.integers(min_value=2, max_value=6)) kernel, kernel_parameter_variable_names = data.draw( kernel_hps.kernels( kernel_name=kernel_name, event_dim=event_dim, feature_dim=feature_dim, feature_ndims=feature_ndims, enable_vars=True)) # Check that variable parameters get passed to the kernel.variables kernel_variables_names = [ v.name.strip('_0123456789:') for v in kernel.variables] self.assertEqual( set(kernel_parameter_variable_names), set(kernel_variables_names)) example_ndims = data.draw(hps.integers(min_value=1, max_value=3)) input_batch_shape = data.draw(tfp_hps.broadcast_compatible_shape( kernel.batch_shape)) xs = tf.identity(data.draw(kernel_hps.kernel_input( batch_shape=input_batch_shape, example_ndims=example_ndims, feature_dim=feature_dim, feature_ndims=feature_ndims))) # Check that we pick up all relevant kernel parameters. wrt_vars = [xs] + list(kernel.variables) self.evaluate([v.initializer for v in kernel.variables]) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `apply` of {}'.format(kernel)): tape.watch(wrt_vars) diag = kernel.apply(xs, xs, example_ndims=example_ndims) grads = tape.gradient(diag, wrt_vars) assert_no_none_grad(kernel, 'apply', wrt_vars, grads) self.assertAllClose( diag, type(kernel)(**kernel._parameters).apply( xs, xs, example_ndims=example_ndims))
def testExcessiveConcretizationInZeroArgPublicMethods( self, process_name, data): # Check that standard statistics do not concretize variables/deferred # tensors more than the allowed amount. process = data.draw(stochastic_processes(process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']: hp.note('Testing excessive concretization in {}.{}'.format(process_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'method `{}` of `{}`'.format(stat, process), max_permissible=excessive_usage_count(process_name)): getattr(process, stat)() except NotImplementedError: pass
def testExcessiveConcretizationInLogProb(self, process_name, data): # Check that log_prob computations avoid reading process parameters # more than once. tfp_hps.guitar_skip_if_matches( 'VariationalGaussianProcess', process_name, 'b/147770193') process = data.draw(stochastic_processes( process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) hp.note('Testing excessive var usage in {}.log_prob'.format(process_name)) sample = process.sample() try: with tfp_hps.assert_no_excessive_var_usage( 'method `log_prob` of `{}`'.format(process), max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)): process.log_prob(sample) except NotImplementedError: pass
def testExcessiveConcretizationInZeroArgPublicMethods( self, process_name, data): tfp_hps.guitar_skip_if_matches( 'VariationalGaussianProcess', process_name, 'b/147770193') # Check that standard statistics do not concretize variables/deferred # tensors more than the allowed amount. process = data.draw(stochastic_processes(process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']: hp.note('Testing excessive concretization in {}.{}'.format(process_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'method `{}` of `{}`'.format(stat, process), max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1) ), kernel_hps.no_pd_errors(): getattr(process, stat)() except NotImplementedError: pass
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return tf1.set_random_seed( data.draw( hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))) dist = data.draw(distributions(dist_name=dist_name, enable_vars=True)) batch_shape = dist.batch_shape batch_shape2 = data.draw( tfp_hps.broadcast_compatible_shape(batch_shape)) dist2 = data.draw( distributions(dist_name=dist_name, batch_shape=batch_shape2, event_dim=get_event_dim(dist), enable_vars=True)) logging.info( 'distribution: %s; parameters used: %s', dist, [k for k, v in six.iteritems(dist.parameters) if v is not None]) self.evaluate([var.initializer for var in dist.variables]) # Check that the distribution passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). for k, v in six.iteritems(dist.parameters): if not tensor_util.is_ref(v): continue self.assertIs(getattr(dist, k), v) # Check that standard statistics do not read distribution parameters more # than once. for stat in data.draw( hps.sets(hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, stat) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist)): getattr(dist, stat)() except NotImplementedError: pass # Check that `sample` doesn't read distribution parameters more than once, # and that it produces non-None gradients (if the distribution is fully # reparameterized). with tf.GradientTape() as tape: # TDs do bijector assertions twice (once by distribution.sample, and once # by bijector.forward). max_permissible = (3 if isinstance( dist, tfd.TransformedDistribution) else 2) with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist), max_permissible=max_permissible): sample = dist.sample() if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's # own samples. Also, to relax conversion counts for KL (might do >2 w/ # validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) # Test that KL divergence reads distribution parameters at most once, and # that is produces non-None gradients. try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))' .format(d1, d1.variables, d2, d2.variables), max_permissible=1 ): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError( 'Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass # Test that log_prob produces non-None gradients, except for distributions # on the NO_LOG_PROB_PARAM_GRADS blacklist. if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'. format(var, dist_name)) # Test that all forms of probability evaluation avoid reading distribution # parameters more than once. for evaluative in data.draw( hps.sets(hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, evaluative) try: # No validation => 1 convert. But for TD we allow 2: # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp) max_permissible = (2 if isinstance( dist, tfd.TransformedDistribution) else 1) with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=max_permissible): getattr(dist, evaluative)(sample) except NotImplementedError: pass
def testBijector(self, bijector_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return event_dim = data.draw(hps.integers(min_value=2, max_value=6)) bijector = data.draw( bijectors(bijector_name=bijector_name, event_dim=event_dim, enable_vars=True)) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. # TODO(axch): Would be nice to get rid of all this shape inference logic and # just rely on a notion of batch and event shape for bijectors, so we can # pass those through `domain_tensors` and `codomain_tensors` and use # `tensors_in_support`. However, `RationalQuadraticSpline` behaves weirdly # somehow and I got confused. shp = bijector.inverse_event_shape([event_dim] * bijector.inverse_min_event_ndims) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.forward_min_event_ndims])), shp[shp.ndims - bijector.forward_min_event_ndims:]) xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)), name='xs') wrt_vars = [xs] + [v for v in bijector.trainable_variables if v.dtype.is_floating] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers( min_value=bijector.forward_min_event_ndims, max_value=bijector.forward_event_shape(xs.shape).ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_forward_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr(bijector.bijector, '_inverse_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian(xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. shp = bijector.forward_event_shape([event_dim] * bijector.forward_min_event_ndims) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.inverse_min_event_ndims])), shp[shp.ndims - bijector.inverse_min_event_ndims:]) ys = tf.identity( data.draw(codomain_tensors(bijector, shape=shp)), name='ys') wrt_vars = [ys] + [v for v in bijector.trainable_variables if v.dtype.is_floating] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers( min_value=bijector.inverse_min_event_ndims, max_value=bijector.inverse_event_shape(ys.shape).ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_inverse_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr(bijector.bijector, '_forward_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse_log_det_jacobian(ys + 0, event_ndims=event_ndims) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)
def testBijector(self, bijector_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return bijector, batch_shape = data.draw( bijectors(bijector_name=bijector_name, enable_vars=True)) del batch_shape event_dim = data.draw(hps.integers(min_value=2, max_value=6)) # Forward mapping. shp = bijector.inverse_event_shape([event_dim] * bijector.inverse_min_event_ndims) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.forward_min_event_ndims])), shp[shp.ndims - bijector.forward_min_event_ndims:]) xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)), name='xs') wrt_vars = [xs] + list(bijector.variables) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # FLDJ. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=bijector.forward_event_shape( xs.shape).ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_forward_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_inverse_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping. shp = bijector.forward_event_shape([event_dim] * bijector.forward_min_event_ndims) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.inverse_min_event_ndims])), shp[shp.ndims - bijector.inverse_min_event_ndims:]) ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)), name='ys') wrt_vars = [ys] + list(bijector.variables) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=bijector.inverse_event_shape( ys.shape).ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_inverse_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_forward_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse_log_det_jacobian(ys + 0, event_ndims=event_ndims) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)
def testProcess(self, process_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return seed = tfp_test_util.test_seed() process = data.draw( stochastic_processes(process_name=process_name, enable_vars=True)) self.evaluate([var.initializer for var in process.variables]) # Check that the process passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). for k, v in six.iteritems(process.parameters): if not tensor_util.is_ref(v): continue self.assertIs(getattr(process, k), v) # Check that standard statistics do not read process parameters more # than twice (once in the stat itself and up to once in any validation # assertions). for stat in ['mean', 'covariance', 'stddev', 'variance']: hp.note('Testing excessive var usage in {}.{}'.format( process_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, process), max_permissible=excessive_usage_count(process_name)): getattr(process, stat)() except NotImplementedError: pass # Check that `sample` doesn't read process parameters more than twice, # and that it produces non-None gradients (if the process is fully # reparameterized). with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(process), max_permissible=excessive_usage_count(process_name)): sample = process.sample(seed=seed) if process.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, process.variables) for grad, var in zip(grads, process.variables): var_name = var.name.rstrip('_0123456789:') if grad is None: raise AssertionError( 'Missing sample -> {} grad for process {}'.format( var_name, process_name)) # Test that log_prob produces non-None gradients. with tf.GradientTape() as tape: lp = process.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, process.variables) for grad, var in zip(grads, process.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for process {}'.format( var, process_name)) # Check that log_prob computations avoid reading process parameters # more than once. hp.note( 'Testing excessive var usage in {}.log_prob'.format(process_name)) try: with tfp_hps.assert_no_excessive_var_usage( 'evaluative `log_prob` of `{}`'.format(process), max_permissible=excessive_usage_count(process_name)): process.log_prob(sample) except NotImplementedError: pass
def testDistribution(self, dist_name, data): seed = test_util.test_seed() # Explicitly draw event_dim here to avoid relying on _params_event_ndims # later, so this test can support distributions that do not implement the # slicing protocol. event_dim = data.draw(hps.integers(min_value=2, max_value=6)) dist = data.draw(dhps.distributions( dist_name=dist_name, event_dim=event_dim, enable_vars=True)) batch_shape = dist.batch_shape batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape)) dist2 = data.draw( dhps.distributions( dist_name=dist_name, batch_shape=batch_shape2, event_dim=event_dim, enable_vars=True)) self.evaluate([var.initializer for var in dist.variables]) # Check that the distribution passes Variables through to the accessor # properties (without converting them to Tensor or anything like that). for k, v in six.iteritems(dist.parameters): if not tensor_util.is_ref(v): continue self.assertIs(getattr(dist, k), v) # Check that standard statistics do not read distribution parameters more # than twice (once in the stat itself and up to once in any validation # assertions). max_permissible = 2 + extra_tensor_conversions_allowed(dist) for stat in sorted(data.draw( hps.sets( hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3))): hp.note('Testing excessive var usage in {}.{}'.format(dist_name, stat)) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist), max_permissible=max_permissible): getattr(dist, stat)() except NotImplementedError: pass # Check that `sample` doesn't read distribution parameters more than twice, # and that it produces non-None gradients (if the distribution is fully # reparameterized). with tf.GradientTape() as tape: # TDs do bijector assertions twice (once by distribution.sample, and once # by bijector.forward). max_permissible = 2 + extra_tensor_conversions_allowed(dist) with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist), max_permissible=max_permissible): sample = dist.sample(seed=seed) if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's # own samples. Also, to relax conversion counts for KL (might do >2 w/ # validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) # Test that KL divergence reads distribution parameters at most once, and # that is produces non-None gradients. try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format( d1, d1.variables, d2, d2.variables), max_permissible=1): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError('Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass # Test that log_prob produces non-None gradients, except for distributions # on the NO_LOG_PROB_PARAM_GRADS blacklist. if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'.format( var, dist_name)) # Test that all forms of probability evaluation avoid reading distribution # parameters more than once. for evaluative in sorted(data.draw( hps.sets( hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3))): hp.note('Testing excessive var usage in {}.{}'.format( dist_name, evaluative)) try: # No validation => 1 convert. But for TD we allow 2: # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp) max_permissible = 2 + extra_tensor_conversions_allowed(dist) with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=max_permissible): getattr(dist, evaluative)(sample) except NotImplementedError: pass
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return event_dim = data.draw(hps.integers(min_value=2, max_value=6)) bijector = data.draw( bijectors(bijector_name=bijector_name, event_dim=event_dim, enable_vars=True)) self.evaluate(tf.group(*[v.initializer for v in bijector.variables])) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. # TODO(axch): Would be nice to get rid of all this shape inference logic and # just rely on a notion of batch and event shape for bijectors, so we can # pass those through `domain_tensors` and `codomain_tensors` and use # `tensors_in_support`. However, `RationalQuadraticSpline` behaves weirdly # somehow and I got confused. codomain_event_shape = [event_dim] * bijector.inverse_min_event_ndims codomain_event_shape = constrain_inverse_shape(bijector, codomain_event_shape) shp = bijector.inverse_event_shape(codomain_event_shape) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.forward_min_event_ndims])), shp[shp.ndims - bijector.forward_min_event_ndims:]) xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)), name='xs') wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_forward_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_inverse_log_det_jacobian') else 4) elif is_transform_diagonal(bijector): max_permitted = (2 if hasattr(bijector.diag_bijector, '_forward_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. domain_event_shape = [event_dim] * bijector.forward_min_event_ndims domain_event_shape = constrain_forward_shape(bijector, domain_event_shape) shp = bijector.forward_event_shape(domain_event_shape) shp = tensorshape_util.concatenate( data.draw( tfp_hps.broadcast_compatible_shape( shp[:shp.ndims - bijector.inverse_min_event_ndims])), shp[shp.ndims - bijector.inverse_min_event_ndims:]) ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)), name='ys') wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = 2 if hasattr(bijector, '_inverse_log_det_jacobian') else 4 if is_invert(bijector): max_permitted = (2 if hasattr( bijector.bijector, '_forward_log_det_jacobian') else 4) elif is_transform_diagonal(bijector): max_permitted = (2 if hasattr(bijector.diag_bijector, '_inverse_log_det_jacobian') else 4) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)
def testBijector(self, bijector_name, data): tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991') bijector, event_dim = self._draw_bijector(bijector_name, data) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. xs = self._draw_domain_tensor(bijector, data, event_dim) wrt_vars = [xs] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `forward` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads) # For scalar bijectors, verify correctness of the _is_increasing method. # TODO(b/148459057): Except, don't verify Softfloor on Guitar because # of numerical problem. def exception(bijector): if not tfp_hps.running_under_guitar(): return False if isinstance(bijector, tfb.Softfloor): return True if is_invert(bijector): return exception(bijector.bijector) return False if (bijector.forward_min_event_ndims == 0 and bijector.inverse_min_event_ndims == 0 and not exception(bijector)): dydx = grads[0] hp.note('dydx: {}'.format(dydx)) isfinite = tf.math.is_finite(dydx) incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal( dydx, 0) # pylint: disable=protected-access self.assertAllEqual( isfinite & incr_or_slope_eq0, isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0)) # FLDJ: Check differentiation through forward log det jacobian with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.forward_min_event_ndims, max_value=xs.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=True) with tfp_hps.assert_no_excessive_var_usage( 'method `forward_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.forward_log_det_jacobian( xs + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads) # Inverse mapping: Check differentiation through inverse mapping with # respect to the codomain "input" and parameter variables. Also check that # any variables are not referenced overmuch. ys = self._draw_codomain_tensor(bijector, data, event_dim) wrt_vars = [ys] + [ v for v in bijector.trainable_variables if v.dtype.is_floating ] with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `inverse` of {}'.format(bijector)): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. xs = bijector.inverse(ys + 0) grads = tape.gradient(xs, wrt_vars) assert_no_none_grad(bijector, 'inverse', wrt_vars, grads) # ILDJ: Check differentiation through inverse log det jacobian with respect # to the codomain "input" and parameter variables. Also check that any # variables are not referenced overmuch. event_ndims = data.draw( hps.integers(min_value=bijector.inverse_min_event_ndims, max_value=ys.shape.ndims)) with tf.GradientTape() as tape: max_permitted = _ldj_tensor_conversions_allowed(bijector, is_forward=False) with tfp_hps.assert_no_excessive_var_usage( 'method `inverse_log_det_jacobian` of {}'.format(bijector), max_permissible=max_permitted): tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ldj = bijector.inverse_log_det_jacobian( ys + 0, event_ndims=event_ndims) grads = tape.gradient(ldj, wrt_vars) assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads) # Verify that `_is_permutation` implies constant zero Jacobian. if bijector._is_permutation: self.assertTrue(bijector._is_constant_jacobian) self.assertAllEqual(ldj, 0.) # Verify correctness of batch shape. xs_batch_shapes = tf.nest.map_structure( lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs, bijector.inverse_event_ndims(event_ndims)) empirical_batch_shape = functools.reduce( ps.broadcast_shape, nest.flatten_up_to(bijector.forward_min_event_ndims, xs_batch_shapes)) batch_shape = bijector.experimental_batch_shape( y_event_ndims=event_ndims) if tensorshape_util.is_fully_defined(batch_shape): self.assertAllEqual(empirical_batch_shape, batch_shape) self.assertAllEqual( empirical_batch_shape, bijector.experimental_batch_shape_tensor( y_event_ndims=event_ndims)) # Check that the outputs of forward_dtype and inverse_dtype match the dtypes # of the outputs of forward and inverse. self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype)) self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
def testDistribution(self, dist_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return tf1.set_random_seed( data.draw( hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))) dist, batch_shape = data.draw( distributions(dist_name=dist_name, enable_vars=True)) batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape)) dist2, _ = data.draw( distributions( dist_name=dist_name, batch_shape=batch_shape2, event_dim=get_event_dim(dist), enable_vars=True)) del batch_shape logging.info( 'distribution: %s; parameters used: %s', dist, [k for k, v in six.iteritems(dist.parameters) if v is not None]) self.evaluate([var.initializer for var in dist.variables]) for k, v in six.iteritems(dist.parameters): if not tensor_util.is_mutable(v): continue try: self.assertIs(getattr(dist, k), v) except AssertionError as e: raise AssertionError( 'No attr found for parameter {} of distribution {}: \n{}'.format( k, dist_name, e)) for stat in data.draw( hps.sets( hps.one_of( map(hps.just, [ 'covariance', 'entropy', 'mean', 'mode', 'stddev', 'variance' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, stat) try: with tfp_hps.assert_no_excessive_var_usage( 'statistic `{}` of `{}`'.format(stat, dist)): getattr(dist, stat)() except NotImplementedError: pass with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `sample` of `{}`'.format(dist)): sample = dist.sample() if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED: grads = tape.gradient(sample, dist.variables) for grad, var in zip(grads, dist.variables): var_name = var.name.rstrip('_0123456789:') if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()): continue if grad is None: raise AssertionError( 'Missing sample -> {} grad for distribution {}'.format( var_name, dist_name)) # Turn off validations, since log_prob can choke on dist's own samples. # Also, to relax conversion counts for KL (might do >2 w/ validate_args). dist = dist.copy(validate_args=False) dist2 = dist2.copy(validate_args=False) try: for d1, d2 in (dist, dist2), (dist2, dist): with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format( d1, d1.variables, d2, d2.variables), max_permissible=1): # No validation => 1 convert per var. kl = d1.kl_divergence(d2) wrt_vars = list(d1.variables) + list(d2.variables) grads = tape.gradient(kl, wrt_vars) for grad, var in zip(grads, wrt_vars): if grad is None and dist_name not in NO_KL_PARAM_GRADS: raise AssertionError('Missing KL({} || {}) -> {} grad:\n' '{} vars: {}\n{} vars: {}'.format( d1, d2, var, d1, d1.variables, d2, d2.variables)) except NotImplementedError: pass if dist_name not in NO_LOG_PROB_PARAM_GRADS: with tf.GradientTape() as tape: lp = dist.log_prob(tf.stop_gradient(sample)) grads = tape.gradient(lp, dist.variables) for grad, var in zip(grads, dist.variables): if grad is None: raise AssertionError( 'Missing log_prob -> {} grad for distribution {}'.format( var, dist_name)) for evaluative in data.draw( hps.sets( hps.one_of( map(hps.just, [ 'log_prob', 'prob', 'log_cdf', 'cdf', 'log_survival_function', 'survival_function' ])), min_size=3, max_size=3)): logging.info('%s.%s', dist_name, evaluative) try: with tfp_hps.assert_no_excessive_var_usage( 'evaluative `{}` of `{}`'.format(evaluative, dist), max_permissible=1): # No validation => 1 convert getattr(dist, evaluative)(sample) except NotImplementedError: pass