コード例 #1
0
    def testCompositeTensor(self, kernel_name, data):
        kernel, _ = data.draw(
            kernel_hps.kernels(kernel_name=kernel_name,
                               event_dim=2,
                               feature_dim=2,
                               feature_ndims=1,
                               enable_vars=True))
        self.assertIsInstance(kernel, tf.__internal__.CompositeTensor)

        xs = tf.identity(
            data.draw(
                kernel_hps.kernel_input(batch_shape=[],
                                        example_ndims=1,
                                        feature_dim=2,
                                        feature_ndims=1)))
        with tfp_hps.no_tf_rank_errors():
            diag = kernel.apply(xs, xs, example_ndims=1)

        # Test flatten/unflatten.
        flat = tf.nest.flatten(kernel, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(kernel, flat, expand_composites=True)

        # Test tf.function.
        @tf.function
        def diag_fn(k):
            return k.apply(xs, xs, example_ndims=1)

        self.evaluate([v.initializer for v in kernel.variables])
        with tfp_hps.no_tf_rank_errors():
            self.assertAllClose(diag, diag_fn(kernel))
            self.assertAllClose(diag, diag_fn(unflat))
コード例 #2
0
    def testKernelGradient(self, kernel_name, data):
        event_dim = data.draw(hps.integers(min_value=2, max_value=3))
        feature_ndims = data.draw(hps.integers(min_value=1, max_value=2))
        feature_dim = data.draw(hps.integers(min_value=2, max_value=4))
        batch_shape = data.draw(tfp_hps.shapes(max_ndims=2))

        kernel, kernel_parameter_variable_names = data.draw(
            kernel_hps.kernels(batch_shape=batch_shape,
                               kernel_name=kernel_name,
                               event_dim=event_dim,
                               feature_dim=feature_dim,
                               feature_ndims=feature_ndims,
                               enable_vars=True))

        # Check that variable parameters get passed to the kernel.variables
        kernel_variables_names = [
            v.name.strip('_0123456789:') for v in kernel.variables
        ]
        kernel_parameter_variable_names = [
            n.strip('_0123456789:') for n in kernel_parameter_variable_names
        ]
        self.assertEqual(set(kernel_parameter_variable_names),
                         set(kernel_variables_names))

        example_ndims = data.draw(hps.integers(min_value=1, max_value=2))
        input_batch_shape = data.draw(
            tfp_hps.broadcast_compatible_shape(kernel.batch_shape))
        xs = tf.identity(
            data.draw(
                kernel_hps.kernel_input(batch_shape=input_batch_shape,
                                        example_ndims=example_ndims,
                                        feature_dim=feature_dim,
                                        feature_ndims=feature_ndims)))

        # Check that we pick up all relevant kernel parameters.
        wrt_vars = [xs] + list(kernel.variables)
        self.evaluate([v.initializer for v in kernel.variables])

        max_permissible = 2 + EXTRA_TENSOR_CONVERSION_KERNELS.get(
            kernel_name, 0)

        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `apply` of {}'.format(kernel),
                    max_permissible=max_permissible):
                tape.watch(wrt_vars)
                with tfp_hps.no_tf_rank_errors():
                    diag = kernel.apply(xs, xs, example_ndims=example_ndims)
        grads = tape.gradient(diag, wrt_vars)
        assert_no_none_grad(kernel, 'apply', wrt_vars, grads)

        # Check that copying the kernel works.
        with tfp_hps.no_tf_rank_errors():
            diag2 = self.evaluate(kernel.copy().apply(
                xs, xs, example_ndims=example_ndims))
        self.assertAllClose(diag, diag2)
コード例 #3
0
  def testKernelGradient(self, kernel_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    event_dim = data.draw(hps.integers(min_value=2, max_value=6))
    feature_ndims = data.draw(hps.integers(min_value=1, max_value=4))
    feature_dim = data.draw(hps.integers(min_value=2, max_value=6))

    kernel, kernel_parameter_variable_names = data.draw(
        kernel_hps.kernels(
            kernel_name=kernel_name,
            event_dim=event_dim,
            feature_dim=feature_dim,
            feature_ndims=feature_ndims,
            enable_vars=True))

    # Check that variable parameters get passed to the kernel.variables
    kernel_variables_names = [
        v.name.strip('_0123456789:') for v in kernel.variables]
    self.assertEqual(
        set(kernel_parameter_variable_names),
        set(kernel_variables_names))

    example_ndims = data.draw(hps.integers(min_value=1, max_value=3))
    input_batch_shape = data.draw(tfp_hps.broadcast_compatible_shape(
        kernel.batch_shape))
    xs = tf.identity(data.draw(kernel_hps.kernel_input(
        batch_shape=input_batch_shape,
        example_ndims=example_ndims,
        feature_dim=feature_dim,
        feature_ndims=feature_ndims)))

    # Check that we pick up all relevant kernel parameters.
    wrt_vars = [xs] + list(kernel.variables)
    self.evaluate([v.initializer for v in kernel.variables])

    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `apply` of {}'.format(kernel)):
        tape.watch(wrt_vars)
        diag = kernel.apply(xs, xs, example_ndims=example_ndims)
    grads = tape.gradient(diag, wrt_vars)
    assert_no_none_grad(kernel, 'apply', wrt_vars, grads)

    self.assertAllClose(
        diag,
        type(kernel)(**kernel._parameters).apply(
            xs, xs, example_ndims=example_ndims))
コード例 #4
0
    def testKernels(self, kernel_name, data):
        event_dim = data.draw(hps.integers(min_value=2, max_value=3))
        feature_ndims = data.draw(hps.integers(min_value=1, max_value=2))
        feature_dim = data.draw(hps.integers(min_value=2, max_value=4))

        kernel, _ = data.draw(
            kernel_hps.kernels(kernel_name=kernel_name,
                               event_dim=event_dim,
                               feature_dim=feature_dim,
                               feature_ndims=feature_ndims,
                               enable_vars=False))

        # Check that all kernels still register as non-iterable despite
        # defining __getitem__.  (Because __getitem__ magically makes an object
        # iterable for some reason.)
        with self.assertRaisesRegex(TypeError, 'not iterable'):
            iter(kernel)

        # Test slicing
        self._test_slicing(data, kernel_name, kernel, feature_dim,
                           feature_ndims)