def testCompositeTensor(self, kernel_name, data): kernel, _ = data.draw( kernel_hps.kernels(kernel_name=kernel_name, event_dim=2, feature_dim=2, feature_ndims=1, enable_vars=True)) self.assertIsInstance(kernel, tf.__internal__.CompositeTensor) xs = tf.identity( data.draw( kernel_hps.kernel_input(batch_shape=[], example_ndims=1, feature_dim=2, feature_ndims=1))) with tfp_hps.no_tf_rank_errors(): diag = kernel.apply(xs, xs, example_ndims=1) # Test flatten/unflatten. flat = tf.nest.flatten(kernel, expand_composites=True) unflat = tf.nest.pack_sequence_as(kernel, flat, expand_composites=True) # Test tf.function. @tf.function def diag_fn(k): return k.apply(xs, xs, example_ndims=1) self.evaluate([v.initializer for v in kernel.variables]) with tfp_hps.no_tf_rank_errors(): self.assertAllClose(diag, diag_fn(kernel)) self.assertAllClose(diag, diag_fn(unflat))
def testKernelGradient(self, kernel_name, data): event_dim = data.draw(hps.integers(min_value=2, max_value=3)) feature_ndims = data.draw(hps.integers(min_value=1, max_value=2)) feature_dim = data.draw(hps.integers(min_value=2, max_value=4)) batch_shape = data.draw(tfp_hps.shapes(max_ndims=2)) kernel, kernel_parameter_variable_names = data.draw( kernel_hps.kernels(batch_shape=batch_shape, kernel_name=kernel_name, event_dim=event_dim, feature_dim=feature_dim, feature_ndims=feature_ndims, enable_vars=True)) # Check that variable parameters get passed to the kernel.variables kernel_variables_names = [ v.name.strip('_0123456789:') for v in kernel.variables ] kernel_parameter_variable_names = [ n.strip('_0123456789:') for n in kernel_parameter_variable_names ] self.assertEqual(set(kernel_parameter_variable_names), set(kernel_variables_names)) example_ndims = data.draw(hps.integers(min_value=1, max_value=2)) input_batch_shape = data.draw( tfp_hps.broadcast_compatible_shape(kernel.batch_shape)) xs = tf.identity( data.draw( kernel_hps.kernel_input(batch_shape=input_batch_shape, example_ndims=example_ndims, feature_dim=feature_dim, feature_ndims=feature_ndims))) # Check that we pick up all relevant kernel parameters. wrt_vars = [xs] + list(kernel.variables) self.evaluate([v.initializer for v in kernel.variables]) max_permissible = 2 + EXTRA_TENSOR_CONVERSION_KERNELS.get( kernel_name, 0) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `apply` of {}'.format(kernel), max_permissible=max_permissible): tape.watch(wrt_vars) with tfp_hps.no_tf_rank_errors(): diag = kernel.apply(xs, xs, example_ndims=example_ndims) grads = tape.gradient(diag, wrt_vars) assert_no_none_grad(kernel, 'apply', wrt_vars, grads) # Check that copying the kernel works. with tfp_hps.no_tf_rank_errors(): diag2 = self.evaluate(kernel.copy().apply( xs, xs, example_ndims=example_ndims)) self.assertAllClose(diag, diag2)
def testKernelGradient(self, kernel_name, data): if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return event_dim = data.draw(hps.integers(min_value=2, max_value=6)) feature_ndims = data.draw(hps.integers(min_value=1, max_value=4)) feature_dim = data.draw(hps.integers(min_value=2, max_value=6)) kernel, kernel_parameter_variable_names = data.draw( kernel_hps.kernels( kernel_name=kernel_name, event_dim=event_dim, feature_dim=feature_dim, feature_ndims=feature_ndims, enable_vars=True)) # Check that variable parameters get passed to the kernel.variables kernel_variables_names = [ v.name.strip('_0123456789:') for v in kernel.variables] self.assertEqual( set(kernel_parameter_variable_names), set(kernel_variables_names)) example_ndims = data.draw(hps.integers(min_value=1, max_value=3)) input_batch_shape = data.draw(tfp_hps.broadcast_compatible_shape( kernel.batch_shape)) xs = tf.identity(data.draw(kernel_hps.kernel_input( batch_shape=input_batch_shape, example_ndims=example_ndims, feature_dim=feature_dim, feature_ndims=feature_ndims))) # Check that we pick up all relevant kernel parameters. wrt_vars = [xs] + list(kernel.variables) self.evaluate([v.initializer for v in kernel.variables]) with tf.GradientTape() as tape: with tfp_hps.assert_no_excessive_var_usage( 'method `apply` of {}'.format(kernel)): tape.watch(wrt_vars) diag = kernel.apply(xs, xs, example_ndims=example_ndims) grads = tape.gradient(diag, wrt_vars) assert_no_none_grad(kernel, 'apply', wrt_vars, grads) self.assertAllClose( diag, type(kernel)(**kernel._parameters).apply( xs, xs, example_ndims=example_ndims))
def testKernels(self, kernel_name, data): event_dim = data.draw(hps.integers(min_value=2, max_value=3)) feature_ndims = data.draw(hps.integers(min_value=1, max_value=2)) feature_dim = data.draw(hps.integers(min_value=2, max_value=4)) kernel, _ = data.draw( kernel_hps.kernels(kernel_name=kernel_name, event_dim=event_dim, feature_dim=feature_dim, feature_ndims=feature_ndims, enable_vars=False)) # Check that all kernels still register as non-iterable despite # defining __getitem__. (Because __getitem__ magically makes an object # iterable for some reason.) with self.assertRaisesRegex(TypeError, 'not iterable'): iter(kernel) # Test slicing self._test_slicing(data, kernel_name, kernel, feature_dim, feature_ndims)