コード例 #1
0
  def testExcessiveConcretizationOfParams(self):
    loc = tfp_hps.defer_and_count_usage(
        tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape))
    scale = tfp_hps.defer_and_count_usage(
        tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape))
    bij_scale = tfp_hps.defer_and_count_usage(
        tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape))
    event_shape = tfp_hps.defer_and_count_usage(
        tf.Variable([2, 2], name='input_event_shape', dtype=tf.int32,
                    shape=self.shape))
    batch_shape = tfp_hps.defer_and_count_usage(
        tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32,
                    shape=self.shape))

    dist = tfd.TransformedDistribution(
        distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True),
        bijector=tfb.Scale(scale=bij_scale, validate_args=True),
        event_shape=event_shape,
        batch_shape=batch_shape,
        validate_args=True)

    for method in ('mean', 'entropy', 'event_shape_tensor',
                   'batch_shape_tensor'):
      with tfp_hps.assert_no_excessive_var_usage(
          method, max_permissible=self.max_permissible[method]):
        getattr(dist, method)()

    with tfp_hps.assert_no_excessive_var_usage(
        'sample', max_permissible=self.max_permissible['sample']):
      dist.sample(seed=test_util.test_seed())

    for method in ('log_prob', 'prob'):
      with tfp_hps.assert_no_excessive_var_usage(
          method, max_permissible=self.max_permissible[method]):
        getattr(dist, method)(np.ones((4, 3, 5, 2, 2)) / 3.)
コード例 #2
0
  def testExcessiveConcretizationWithDefaultReinterpretedBatchNdims(self):
    loc = tfp_hps.defer_and_count_usage(
        tf.Variable(np.zeros((5, 2, 3)), shape=tf.TensorShape(None)))
    scale = tfp_hps.defer_and_count_usage(
        tf.Variable(np.ones([]), shape=tf.TensorShape(None)))
    dist = tfd.Independent(
        tfd.Logistic(loc=loc, scale=scale, validate_args=True),
        reinterpreted_batch_ndims=None, validate_args=True)

    for method in ('batch_shape_tensor', 'event_shape_tensor',
                   'mean', 'variance', 'sample'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4):
        getattr(dist, method)()

    # In addition to the four reads of `loc`, `scale` described above in
    # `testExcessiveConcretizationOfParams`, the methods below have two more
    # reads of these parameters -- from computing a default value for
    # `reinterpreted_batch_ndims`, which requires calling
    # `dist.distribution.batch_shape_tensor()`.

    for method in ('log_prob', 'log_cdf', 'prob', 'cdf'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=6):
        getattr(dist, method)(np.zeros((4, 5, 2, 3)))

    with tfp_hps.assert_no_excessive_var_usage('entropy', max_permissible=6):
      dist.entropy()

    # `Distribution.survival_function` and `Distribution.log_survival_function`
    # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in
    # one additional call to `Independent._parameter_control_dependencies`,
    # and thus two additional concretizations of the parameters.

    for method in ('survival_function', 'log_survival_function'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=8):
        getattr(dist, method)(np.zeros((4, 5, 2, 3)))
コード例 #3
0
    def testExcessiveConcretizationOfParamsBatchShapeOverride(self):
        # Test methods that are not implemented if event_shape is overriden.
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape))
        bij_scale = tfp_hps.defer_and_count_usage(
            tf.Variable(2.,
                        name='bij_scale',
                        dtype=tf.float32,
                        shape=self.shape))
        batch_shape = tfp_hps.defer_and_count_usage(
            tf.Variable([4, 3, 5],
                        name='input_batch_shape',
                        dtype=tf.int32,
                        shape=self.shape))
        dist = tfd.TransformedDistribution(
            distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True),
            bijector=tfb.Scale(scale=bij_scale, validate_args=True),
            batch_shape=batch_shape,
            validate_args=True)

        for method in ('log_cdf', 'cdf', 'survival_function',
                       'log_survival_function'):
            with tfp_hps.assert_no_excessive_var_usage(
                    method, max_permissible=self.max_permissible[method]):
                getattr(dist, method)(np.ones((4, 3, 2)) / 3.)

        with tfp_hps.assert_no_excessive_var_usage(
                'quantile', max_permissible=self.max_permissible['quantile']):
            dist.quantile(.1)
コード例 #4
0
  def testExcessiveConcretizationOfParams(self):
    logits = tfp_hps.defer_and_count_usage(
        self._build_variable(np.zeros((4, 4, 5)), name='logits'))
    concentration = tfp_hps.defer_and_count_usage(
        self._build_variable(np.zeros((4, 4, 5, 3)), name='concentration'))
    dist = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=logits),
        components_distribution=tfd.Dirichlet(concentration=concentration),
        validate_args=True)

    # Many methods use mixture_distribution and components_distribution at most
    # once, and thus incur no extra reads/concretizations of parameters.

    for method in ('batch_shape_tensor', 'event_shape_tensor',
                   'mean'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2):
        getattr(dist, method)()

    with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=2):
      dist.sample(seed=test_util.test_seed())

    for method in ('log_prob', 'prob'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=2):
        getattr(dist, method)(np.ones((4, 4, 3)) / 3.)

    # TODO(b/140579567): The `variance()` and `covariance()` methods require
    # calling both:
    #  - `self.components_distribution.mean()`
    #  - `self.components_distribution.variance()` or `.covariance()`
    # Thus, these methods incur an additional concretization (or two if
    # `validate_args=True` for `self.components_distribution`).

    for method in ('variance', 'covariance'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=3):
        getattr(dist, method)()
コード例 #5
0
    def testExcessiveConcretizationOfParams(self):
        logits = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((3, 5, 2)),
                        dtype=tf.float32,
                        shape=tf.TensorShape([None, None, 2]),
                        name='logits'))
        concentration = tfp_hps.defer_and_count_usage(
            tf.Variable(np.ones((3, 5, 4)),
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='concentration'))
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((3, 5, 4)),
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='loc'))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(1.,
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='scale'))

        dist = tfd.Mixture(tfd.Categorical(logits=logits),
                           components=[
                               tfd.Dirichlet(concentration),
                               tfd.Independent(tfd.Normal(loc=loc,
                                                          scale=scale),
                                               reinterpreted_batch_ndims=1)
                           ],
                           use_static_graph=self.use_static_graph,
                           validate_args=True)

        for method in ('batch_shape_tensor', 'event_shape_tensor',
                       'entropy_lower_bound'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=2):
                getattr(dist, method)()

        with tfp_hps.assert_no_excessive_var_usage('sample',
                                                   max_permissible=2):
            dist.sample(seed=test_util.test_seed())

        for method in ('prob', 'log_prob'):
            with tfp_hps.assert_no_excessive_var_usage('method',
                                                       max_permissible=2):
                getattr(dist, method)(tf.ones((3, 5, 4)) / 4.)

        # TODO(b/140579567): The `stddev()` and `variance()` methods require
        # calling both:
        #  - `self.components[i].mean()`
        #  - `self.components[i].stddev()`
        # Thus, these methods incur an additional concretization (or two if
        # `validate_args=True` for `self.components[i]`).

        for method in ('stddev', 'variance'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=3):
                getattr(dist, method)()
コード例 #6
0
ファイル: reshape_test.py プロジェクト: xzxzmmnn/probability
 def testConcretizationLimits(self):
   shape_out = tfp_hps.defer_and_count_usage(tf.Variable([1]))
   reshape = tfb.Reshape(shape_out, validate_args=True)
   x = [1]  # Pun: valid input or output, and valid input or output shape
   for method in ['forward', 'inverse', 'forward_event_shape',
                  'inverse_event_shape', 'forward_event_shape_tensor',
                  'inverse_event_shape_tensor']:
     with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=7):
       getattr(reshape, method)(x)
   for method in ['forward_log_det_jacobian', 'inverse_log_det_jacobian']:
     with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4):
       getattr(reshape, method)(x, event_ndims=1)
コード例 #7
0
    def testExcessiveConcretizationOfParams(self):
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((4, 2, 2)), shape=tf.TensorShape(None)))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(np.ones([]), shape=tf.TensorShape(None)))
        ndims = tf.Variable(1, trainable=False, shape=tf.TensorShape(None))
        dist = tfd.Independent(tfd.Logistic(loc=loc,
                                            scale=scale,
                                            validate_args=True),
                               reinterpreted_batch_ndims=ndims,
                               validate_args=True)

        # TODO(b/140579567): All methods of `dist` may require four concretizations
        # of parameters `loc` and `scale`:
        #  - `Independent._parameter_control_dependencies` calls
        #    `Logistic.batch_shape_tensor`, which:
        #    * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`.
        #    * Reads `loc`, `scale` in `Logistic._batch_shape_tensor`.
        #  - The method `dist.m` will call `dist.self.m`, which:
        #    * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`.
        #    * Reads `loc`, `scale` in the implementation of method  `Logistic._m`.
        #
        # NOTE: If `dist.distribution` had dynamic batch shape and event shape,
        # there could be two more reads of the parameters of `dist.distribution`
        # in `dist.event_shape_tensor`, from calling
        # `dist.distribution.event_shape_tensor()`.

        for method in ('batch_shape_tensor', 'event_shape_tensor', 'mode',
                       'stddev', 'entropy'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=4):
                getattr(dist, method)()

        with tfp_hps.assert_no_excessive_var_usage('sample',
                                                   max_permissible=4):
            dist.sample(seed=test_util.test_seed())

        for method in ('log_prob', 'log_cdf', 'prob', 'cdf'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=4):
                getattr(dist, method)(np.zeros((3, 4, 2, 2)))

        # `Distribution.survival_function` and `Distribution.log_survival_function`
        # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in
        # one additional call to `Independent._parameter_control_dependencies`,
        # and thus two additional concretizations of the parameters.

        for method in ('survival_function', 'log_survival_function'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=6):
                getattr(dist, method)(np.zeros((3, 4, 2, 2)))
コード例 #8
0
  def testExcessiveConcretizationOfParamsWithReparameterization(self):
    logits = tfp_hps.defer_and_count_usage(self._build_variable(
        np.zeros(5), name='logits', static_rank=True))
    loc = tfp_hps.defer_and_count_usage(self._build_variable(
        np.zeros((4, 4, 5)), name='loc', static_rank=True))
    scale = tfp_hps.defer_and_count_usage(self._build_variable(
        1., name='scale', static_rank=True))
    dist = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=logits),
        components_distribution=tfd.Logistic(loc=loc, scale=scale),
        reparameterize=True, validate_args=True)

    # TODO(b/140579567): With reparameterization, there are additional reads of
    # the parameters of the underlying mixture and components distributions when
    # sampling, from calls in `_distributional_transform` to:
    #
    #  - `self.mixture_distribution.logits_parameter`
    #  - `self.components_distribution.log_prob`
    #  - `self.components_distribution.cdf`
    #
    # NOTE: In the unlikely case that samples have a statically-known rank but
    # the rank of `self.components_distribution.event_shape` is not known
    # statically, there can be additional reads in `_distributional_transform`
    # from calling `self.components_distribution.is_scalar_event`.

    with tfp_hps.assert_no_excessive_var_usage('sample', max_permissible=4):
      dist.sample(seed=test_util.test_seed())
コード例 #9
0
    def testKernelGradient(self, kernel_name, data):
        event_dim = data.draw(hps.integers(min_value=2, max_value=3))
        feature_ndims = data.draw(hps.integers(min_value=1, max_value=2))
        feature_dim = data.draw(hps.integers(min_value=2, max_value=4))
        batch_shape = data.draw(tfp_hps.shapes(max_ndims=2))

        kernel, kernel_parameter_variable_names = data.draw(
            kernel_hps.kernels(batch_shape=batch_shape,
                               kernel_name=kernel_name,
                               event_dim=event_dim,
                               feature_dim=feature_dim,
                               feature_ndims=feature_ndims,
                               enable_vars=True))

        # Check that variable parameters get passed to the kernel.variables
        kernel_variables_names = [
            v.name.strip('_0123456789:') for v in kernel.variables
        ]
        kernel_parameter_variable_names = [
            n.strip('_0123456789:') for n in kernel_parameter_variable_names
        ]
        self.assertEqual(set(kernel_parameter_variable_names),
                         set(kernel_variables_names))

        example_ndims = data.draw(hps.integers(min_value=1, max_value=2))
        input_batch_shape = data.draw(
            tfp_hps.broadcast_compatible_shape(kernel.batch_shape))
        xs = tf.identity(
            data.draw(
                kernel_hps.kernel_input(batch_shape=input_batch_shape,
                                        example_ndims=example_ndims,
                                        feature_dim=feature_dim,
                                        feature_ndims=feature_ndims)))

        # Check that we pick up all relevant kernel parameters.
        wrt_vars = [xs] + list(kernel.variables)
        self.evaluate([v.initializer for v in kernel.variables])

        max_permissible = 2 + EXTRA_TENSOR_CONVERSION_KERNELS.get(
            kernel_name, 0)

        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `apply` of {}'.format(kernel),
                    max_permissible=max_permissible):
                tape.watch(wrt_vars)
                with tfp_hps.no_tf_rank_errors():
                    diag = kernel.apply(xs, xs, example_ndims=example_ndims)
        grads = tape.gradient(diag, wrt_vars)
        assert_no_none_grad(kernel, 'apply', wrt_vars, grads)

        # Check that copying the kernel works.
        with tfp_hps.no_tf_rank_errors():
            diag2 = self.evaluate(kernel.copy().apply(
                xs, xs, example_ndims=example_ndims))
        self.assertAllClose(diag, diag2)
コード例 #10
0
  def testExcessiveConcretizationInLogProb(self, process_name, data):
    # Check that log_prob computations avoid reading process parameters
    # more than once.
    process = data.draw(stochastic_processes(
        process_name=process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    hp.note('Testing excessive var usage in {}.log_prob'.format(process_name))
    sample = process.sample()
    try:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `log_prob` of `{}`'.format(process),
          max_permissible=excessive_usage_count(process_name)):
        process.log_prob(sample)
    except NotImplementedError:
      pass
コード例 #11
0
  def testKernelGradient(self, kernel_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    event_dim = data.draw(hps.integers(min_value=2, max_value=6))
    feature_ndims = data.draw(hps.integers(min_value=1, max_value=4))
    feature_dim = data.draw(hps.integers(min_value=2, max_value=6))

    kernel, kernel_parameter_variable_names = data.draw(
        kernel_hps.kernels(
            kernel_name=kernel_name,
            event_dim=event_dim,
            feature_dim=feature_dim,
            feature_ndims=feature_ndims,
            enable_vars=True))

    # Check that variable parameters get passed to the kernel.variables
    kernel_variables_names = [
        v.name.strip('_0123456789:') for v in kernel.variables]
    self.assertEqual(
        set(kernel_parameter_variable_names),
        set(kernel_variables_names))

    example_ndims = data.draw(hps.integers(min_value=1, max_value=3))
    input_batch_shape = data.draw(tfp_hps.broadcast_compatible_shape(
        kernel.batch_shape))
    xs = tf.identity(data.draw(kernel_hps.kernel_input(
        batch_shape=input_batch_shape,
        example_ndims=example_ndims,
        feature_dim=feature_dim,
        feature_ndims=feature_ndims)))

    # Check that we pick up all relevant kernel parameters.
    wrt_vars = [xs] + list(kernel.variables)
    self.evaluate([v.initializer for v in kernel.variables])

    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `apply` of {}'.format(kernel)):
        tape.watch(wrt_vars)
        diag = kernel.apply(xs, xs, example_ndims=example_ndims)
    grads = tape.gradient(diag, wrt_vars)
    assert_no_none_grad(kernel, 'apply', wrt_vars, grads)

    self.assertAllClose(
        diag,
        type(kernel)(**kernel._parameters).apply(
            xs, xs, example_ndims=example_ndims))
コード例 #12
0
  def testExcessiveConcretizationInZeroArgPublicMethods(
      self, process_name, data):
    # Check that standard statistics do not concretize variables/deferred
    # tensors more than the allowed amount.
    process = data.draw(stochastic_processes(process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']:
      hp.note('Testing excessive concretization in {}.{}'.format(process_name,
                                                                 stat))
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'method `{}` of `{}`'.format(stat, process),
            max_permissible=excessive_usage_count(process_name)):
          getattr(process, stat)()

      except NotImplementedError:
        pass
コード例 #13
0
  def testExcessiveConcretizationInLogProb(self, process_name, data):
    # Check that log_prob computations avoid reading process parameters
    # more than once.
    tfp_hps.guitar_skip_if_matches(
        'VariationalGaussianProcess', process_name, 'b/147770193')
    process = data.draw(stochastic_processes(
        process_name=process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    hp.note('Testing excessive var usage in {}.log_prob'.format(process_name))
    sample = process.sample()
    try:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `log_prob` of `{}`'.format(process),
          max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)):
        process.log_prob(sample)
    except NotImplementedError:
      pass
コード例 #14
0
  def testExcessiveConcretizationInZeroArgPublicMethods(
      self, process_name, data):
    tfp_hps.guitar_skip_if_matches(
        'VariationalGaussianProcess', process_name, 'b/147770193')
    # Check that standard statistics do not concretize variables/deferred
    # tensors more than the allowed amount.
    process = data.draw(stochastic_processes(process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']:
      hp.note('Testing excessive concretization in {}.{}'.format(process_name,
                                                                 stat))
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'method `{}` of `{}`'.format(stat, process),
            max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)
            ), kernel_hps.no_pd_errors():
          getattr(process, stat)()

      except NotImplementedError:
        pass
コード例 #15
0
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        tf1.set_random_seed(
            data.draw(
                hpnp.arrays(dtype=np.int64,
                            shape=[]).filter(lambda x: x != 0)))
        dist = data.draw(distributions(dist_name=dist_name, enable_vars=True))
        batch_shape = dist.batch_shape
        batch_shape2 = data.draw(
            tfp_hps.broadcast_compatible_shape(batch_shape))
        dist2 = data.draw(
            distributions(dist_name=dist_name,
                          batch_shape=batch_shape2,
                          event_dim=get_event_dim(dist),
                          enable_vars=True))
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.evaluate([var.initializer for var in dist.variables])

        # Check that the distribution passes Variables through to the accessor
        # properties (without converting them to Tensor or anything like that).
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_ref(v):
                continue
            self.assertIs(getattr(dist, k), v)

        # Check that standard statistics do not read distribution parameters more
        # than once.
        for stat in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'covariance', 'entropy', 'mean', 'mode', 'stddev',
                        'variance'
                    ])),
                         min_size=3,
                         max_size=3)):
            logging.info('%s.%s', dist_name, stat)
            try:
                with tfp_hps.assert_no_excessive_var_usage(
                        'statistic `{}` of `{}`'.format(stat, dist)):
                    getattr(dist, stat)()

            except NotImplementedError:
                pass

        # Check that `sample` doesn't read distribution parameters more than once,
        # and that it produces non-None gradients (if the distribution is fully
        # reparameterized).
        with tf.GradientTape() as tape:
            # TDs do bijector assertions twice (once by distribution.sample, and once
            # by bijector.forward).
            max_permissible = (3 if isinstance(
                dist, tfd.TransformedDistribution) else 2)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `sample` of `{}`'.format(dist),
                    max_permissible=max_permissible):
                sample = dist.sample()
        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, dist.variables)
            for grad, var in zip(grads, dist.variables):
                var_name = var.name.rstrip('_0123456789:')
                if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
                    continue
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var_name, dist_name))

        # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's
        # own samples.  Also, to relax conversion counts for KL (might do >2 w/
        # validate_args).
        dist = dist.copy(validate_args=False)
        dist2 = dist2.copy(validate_args=False)

        # Test that KL divergence reads distribution parameters at most once, and
        # that is produces non-None gradients.
        try:
            for d1, d2 in (dist, dist2), (dist2, dist):
                with tf.GradientTape() as tape:
                    with tfp_hps.assert_no_excessive_var_usage(
                            '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'
                            .format(d1, d1.variables, d2, d2.variables),
                            max_permissible=1
                    ):  # No validation => 1 convert per var.
                        kl = d1.kl_divergence(d2)
                wrt_vars = list(d1.variables) + list(d2.variables)
                grads = tape.gradient(kl, wrt_vars)
                for grad, var in zip(grads, wrt_vars):
                    if grad is None and dist_name not in NO_KL_PARAM_GRADS:
                        raise AssertionError(
                            'Missing KL({} || {}) -> {} grad:\n'
                            '{} vars: {}\n{} vars: {}'.format(
                                d1, d2, var, d1, d1.variables, d2,
                                d2.variables))
        except NotImplementedError:
            pass

        # Test that log_prob produces non-None gradients, except for distributions
        # on the NO_LOG_PROB_PARAM_GRADS blacklist.
        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(sample))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))

        # Test that all forms of probability evaluation avoid reading distribution
        # parameters more than once.
        for evaluative in data.draw(
                hps.sets(hps.one_of(
                    map(hps.just, [
                        'log_prob', 'prob', 'log_cdf', 'cdf',
                        'log_survival_function', 'survival_function'
                    ])),
                         min_size=3,
                         max_size=3)):
            logging.info('%s.%s', dist_name, evaluative)
            try:
                # No validation => 1 convert. But for TD we allow 2:
                # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp)
                max_permissible = (2 if isinstance(
                    dist, tfd.TransformedDistribution) else 1)
                with tfp_hps.assert_no_excessive_var_usage(
                        'evaluative `{}` of `{}`'.format(evaluative, dist),
                        max_permissible=max_permissible):
                    getattr(dist, evaluative)(sample)
            except NotImplementedError:
                pass
コード例 #16
0
  def testBijector(self, bijector_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    event_dim = data.draw(hps.integers(min_value=2, max_value=6))
    bijector = data.draw(
        bijectors(bijector_name=bijector_name, event_dim=event_dim,
                  enable_vars=True))

    # Forward mapping: Check differentiation through forward mapping with
    # respect to the input and parameter variables.  Also check that any
    # variables are not referenced overmuch.
    # TODO(axch): Would be nice to get rid of all this shape inference logic and
    # just rely on a notion of batch and event shape for bijectors, so we can
    # pass those through `domain_tensors` and `codomain_tensors` and use
    # `tensors_in_support`.  However, `RationalQuadraticSpline` behaves weirdly
    # somehow and I got confused.
    shp = bijector.inverse_event_shape([event_dim] *
                                       bijector.inverse_min_event_ndims)
    shp = tensorshape_util.concatenate(
        data.draw(
            tfp_hps.broadcast_compatible_shape(
                shp[:shp.ndims - bijector.forward_min_event_ndims])),
        shp[shp.ndims - bijector.forward_min_event_ndims:])
    xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)), name='xs')
    wrt_vars = [xs] + [v for v in bijector.trainable_variables
                       if v.dtype.is_floating]
    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `forward` of {}'.format(bijector)):
        tape.watch(wrt_vars)
        # TODO(b/73073515): Fix graph mode gradients with bijector caching.
        ys = bijector.forward(xs + 0)
    grads = tape.gradient(ys, wrt_vars)
    assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

    # FLDJ: Check differentiation through forward log det jacobian with
    # respect to the input and parameter variables.  Also check that any
    # variables are not referenced overmuch.
    event_ndims = data.draw(
        hps.integers(
            min_value=bijector.forward_min_event_ndims,
            max_value=bijector.forward_event_shape(xs.shape).ndims))
    with tf.GradientTape() as tape:
      max_permitted = 2 if hasattr(bijector, '_forward_log_det_jacobian') else 4
      if is_invert(bijector):
        max_permitted = (2 if hasattr(bijector.bijector,
                                      '_inverse_log_det_jacobian') else 4)
      with tfp_hps.assert_no_excessive_var_usage(
          'method `forward_log_det_jacobian` of {}'.format(bijector),
          max_permissible=max_permitted):
        tape.watch(wrt_vars)
        # TODO(b/73073515): Fix graph mode gradients with bijector caching.
        ldj = bijector.forward_log_det_jacobian(xs + 0, event_ndims=event_ndims)
    grads = tape.gradient(ldj, wrt_vars)
    assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars, grads)

    # Inverse mapping: Check differentiation through inverse mapping with
    # respect to the codomain "input" and parameter variables.  Also check that
    # any variables are not referenced overmuch.
    shp = bijector.forward_event_shape([event_dim] *
                                       bijector.forward_min_event_ndims)
    shp = tensorshape_util.concatenate(
        data.draw(
            tfp_hps.broadcast_compatible_shape(
                shp[:shp.ndims - bijector.inverse_min_event_ndims])),
        shp[shp.ndims - bijector.inverse_min_event_ndims:])
    ys = tf.identity(
        data.draw(codomain_tensors(bijector, shape=shp)), name='ys')
    wrt_vars = [ys] + [v for v in bijector.trainable_variables
                       if v.dtype.is_floating]
    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `inverse` of {}'.format(bijector)):
        tape.watch(wrt_vars)
        # TODO(b/73073515): Fix graph mode gradients with bijector caching.
        xs = bijector.inverse(ys + 0)
    grads = tape.gradient(xs, wrt_vars)
    assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

    # ILDJ: Check differentiation through inverse log det jacobian with respect
    # to the codomain "input" and parameter variables.  Also check that any
    # variables are not referenced overmuch.
    event_ndims = data.draw(
        hps.integers(
            min_value=bijector.inverse_min_event_ndims,
            max_value=bijector.inverse_event_shape(ys.shape).ndims))
    with tf.GradientTape() as tape:
      max_permitted = 2 if hasattr(bijector, '_inverse_log_det_jacobian') else 4
      if is_invert(bijector):
        max_permitted = (2 if hasattr(bijector.bijector,
                                      '_forward_log_det_jacobian') else 4)
      with tfp_hps.assert_no_excessive_var_usage(
          'method `inverse_log_det_jacobian` of {}'.format(bijector),
          max_permissible=max_permitted):
        tape.watch(wrt_vars)
        # TODO(b/73073515): Fix graph mode gradients with bijector caching.
        xs = bijector.inverse_log_det_jacobian(ys + 0, event_ndims=event_ndims)
    grads = tape.gradient(xs, wrt_vars)
    assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)
コード例 #17
0
    def testBijector(self, bijector_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        bijector, batch_shape = data.draw(
            bijectors(bijector_name=bijector_name, enable_vars=True))
        del batch_shape

        event_dim = data.draw(hps.integers(min_value=2, max_value=6))

        # Forward mapping.
        shp = bijector.inverse_event_shape([event_dim] *
                                           bijector.inverse_min_event_ndims)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.forward_min_event_ndims])),
            shp[shp.ndims - bijector.forward_min_event_ndims:])
        xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)),
                         name='xs')
        wrt_vars = [xs] + list(bijector.variables)
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # FLDJ.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=bijector.forward_event_shape(
                             xs.shape).ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_forward_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_inverse_log_det_jacobian') else 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping.
        shp = bijector.forward_event_shape([event_dim] *
                                           bijector.forward_min_event_ndims)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.inverse_min_event_ndims])),
            shp[shp.ndims - bijector.inverse_min_event_ndims:])
        ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)),
                         name='ys')
        wrt_vars = [ys] + list(bijector.variables)
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=bijector.inverse_event_shape(
                             ys.shape).ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_inverse_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_forward_log_det_jacobian') else 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse_log_det_jacobian(ys + 0,
                                                       event_ndims=event_ndims)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)
コード例 #18
0
    def testProcess(self, process_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        seed = tfp_test_util.test_seed()
        process = data.draw(
            stochastic_processes(process_name=process_name, enable_vars=True))
        self.evaluate([var.initializer for var in process.variables])

        # Check that the process passes Variables through to the accessor
        # properties (without converting them to Tensor or anything like that).
        for k, v in six.iteritems(process.parameters):
            if not tensor_util.is_ref(v):
                continue
            self.assertIs(getattr(process, k), v)

        # Check that standard statistics do not read process parameters more
        # than twice (once in the stat itself and up to once in any validation
        # assertions).
        for stat in ['mean', 'covariance', 'stddev', 'variance']:
            hp.note('Testing excessive var usage in {}.{}'.format(
                process_name, stat))
            try:
                with tfp_hps.assert_no_excessive_var_usage(
                        'statistic `{}` of `{}`'.format(stat, process),
                        max_permissible=excessive_usage_count(process_name)):
                    getattr(process, stat)()

            except NotImplementedError:
                pass

        # Check that `sample` doesn't read process parameters more than twice,
        # and that it produces non-None gradients (if the process is fully
        # reparameterized).
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `sample` of `{}`'.format(process),
                    max_permissible=excessive_usage_count(process_name)):
                sample = process.sample(seed=seed)
        if process.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, process.variables)
            for grad, var in zip(grads, process.variables):
                var_name = var.name.rstrip('_0123456789:')
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for process {}'.format(
                            var_name, process_name))

        # Test that log_prob produces non-None gradients.
        with tf.GradientTape() as tape:
            lp = process.log_prob(tf.stop_gradient(sample))
        grads = tape.gradient(lp, process.variables)
        for grad, var in zip(grads, process.variables):
            if grad is None:
                raise AssertionError(
                    'Missing log_prob -> {} grad for process {}'.format(
                        var, process_name))

        # Check that log_prob computations avoid reading process parameters
        # more than once.
        hp.note(
            'Testing excessive var usage in {}.log_prob'.format(process_name))
        try:
            with tfp_hps.assert_no_excessive_var_usage(
                    'evaluative `log_prob` of `{}`'.format(process),
                    max_permissible=excessive_usage_count(process_name)):
                process.log_prob(sample)
        except NotImplementedError:
            pass
コード例 #19
0
  def testDistribution(self, dist_name, data):
    seed = test_util.test_seed()
    # Explicitly draw event_dim here to avoid relying on _params_event_ndims
    # later, so this test can support distributions that do not implement the
    # slicing protocol.
    event_dim = data.draw(hps.integers(min_value=2, max_value=6))
    dist = data.draw(dhps.distributions(
        dist_name=dist_name, event_dim=event_dim, enable_vars=True))
    batch_shape = dist.batch_shape
    batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape))
    dist2 = data.draw(
        dhps.distributions(
            dist_name=dist_name,
            batch_shape=batch_shape2,
            event_dim=event_dim,
            enable_vars=True))
    self.evaluate([var.initializer for var in dist.variables])

    # Check that the distribution passes Variables through to the accessor
    # properties (without converting them to Tensor or anything like that).
    for k, v in six.iteritems(dist.parameters):
      if not tensor_util.is_ref(v):
        continue
      self.assertIs(getattr(dist, k), v)

    # Check that standard statistics do not read distribution parameters more
    # than twice (once in the stat itself and up to once in any validation
    # assertions).
    max_permissible = 2 + extra_tensor_conversions_allowed(dist)
    for stat in sorted(data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ])),
            min_size=3,
            max_size=3))):
      hp.note('Testing excessive var usage in {}.{}'.format(dist_name, stat))
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'statistic `{}` of `{}`'.format(stat, dist),
            max_permissible=max_permissible):
          getattr(dist, stat)()

      except NotImplementedError:
        pass

    # Check that `sample` doesn't read distribution parameters more than twice,
    # and that it produces non-None gradients (if the distribution is fully
    # reparameterized).
    with tf.GradientTape() as tape:
      # TDs do bijector assertions twice (once by distribution.sample, and once
      # by bijector.forward).
      max_permissible = 2 + extra_tensor_conversions_allowed(dist)
      with tfp_hps.assert_no_excessive_var_usage(
          'method `sample` of `{}`'.format(dist),
          max_permissible=max_permissible):
        sample = dist.sample(seed=seed)
    if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
      grads = tape.gradient(sample, dist.variables)
      for grad, var in zip(grads, dist.variables):
        var_name = var.name.rstrip('_0123456789:')
        if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
          continue
        if grad is None:
          raise AssertionError(
              'Missing sample -> {} grad for distribution {}'.format(
                  var_name, dist_name))

    # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's
    # own samples.  Also, to relax conversion counts for KL (might do >2 w/
    # validate_args).
    dist = dist.copy(validate_args=False)
    dist2 = dist2.copy(validate_args=False)

    # Test that KL divergence reads distribution parameters at most once, and
    # that is produces non-None gradients.
    try:
      for d1, d2 in (dist, dist2), (dist2, dist):
        with tf.GradientTape() as tape:
          with tfp_hps.assert_no_excessive_var_usage(
              '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format(
                  d1, d1.variables, d2, d2.variables),
              max_permissible=1):  # No validation => 1 convert per var.
            kl = d1.kl_divergence(d2)
        wrt_vars = list(d1.variables) + list(d2.variables)
        grads = tape.gradient(kl, wrt_vars)
        for grad, var in zip(grads, wrt_vars):
          if grad is None and dist_name not in NO_KL_PARAM_GRADS:
            raise AssertionError('Missing KL({} || {}) -> {} grad:\n'
                                 '{} vars: {}\n{} vars: {}'.format(
                                     d1, d2, var, d1, d1.variables, d2,
                                     d2.variables))
    except NotImplementedError:
      pass

    # Test that log_prob produces non-None gradients, except for distributions
    # on the NO_LOG_PROB_PARAM_GRADS blacklist.
    if dist_name not in NO_LOG_PROB_PARAM_GRADS:
      with tf.GradientTape() as tape:
        lp = dist.log_prob(tf.stop_gradient(sample))
      grads = tape.gradient(lp, dist.variables)
      for grad, var in zip(grads, dist.variables):
        if grad is None:
          raise AssertionError(
              'Missing log_prob -> {} grad for distribution {}'.format(
                  var, dist_name))

    # Test that all forms of probability evaluation avoid reading distribution
    # parameters more than once.
    for evaluative in sorted(data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ])),
            min_size=3,
            max_size=3))):
      hp.note('Testing excessive var usage in {}.{}'.format(
          dist_name, evaluative))
      try:
        # No validation => 1 convert. But for TD we allow 2:
        # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp)
        max_permissible = 2 + extra_tensor_conversions_allowed(dist)
        with tfp_hps.assert_no_excessive_var_usage(
            'evaluative `{}` of `{}`'.format(evaluative, dist),
            max_permissible=max_permissible):
          getattr(dist, evaluative)(sample)
      except NotImplementedError:
        pass
コード例 #20
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        event_dim = data.draw(hps.integers(min_value=2, max_value=6))
        bijector = data.draw(
            bijectors(bijector_name=bijector_name,
                      event_dim=event_dim,
                      enable_vars=True))
        self.evaluate(tf.group(*[v.initializer for v in bijector.variables]))

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        # TODO(axch): Would be nice to get rid of all this shape inference logic and
        # just rely on a notion of batch and event shape for bijectors, so we can
        # pass those through `domain_tensors` and `codomain_tensors` and use
        # `tensors_in_support`.  However, `RationalQuadraticSpline` behaves weirdly
        # somehow and I got confused.
        codomain_event_shape = [event_dim] * bijector.inverse_min_event_ndims
        codomain_event_shape = constrain_inverse_shape(bijector,
                                                       codomain_event_shape)
        shp = bijector.inverse_event_shape(codomain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.forward_min_event_ndims])),
            shp[shp.ndims - bijector.forward_min_event_ndims:])
        xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)),
                         name='xs')
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_forward_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_inverse_log_det_jacobian') else 4)
            elif is_transform_diagonal(bijector):
                max_permitted = (2 if hasattr(bijector.diag_bijector,
                                              '_forward_log_det_jacobian') else
                                 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        domain_event_shape = [event_dim] * bijector.forward_min_event_ndims
        domain_event_shape = constrain_forward_shape(bijector,
                                                     domain_event_shape)
        shp = bijector.forward_event_shape(domain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.inverse_min_event_ndims])),
            shp[shp.ndims - bijector.inverse_min_event_ndims:])
        ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)),
                         name='ys')
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_inverse_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_forward_log_det_jacobian') else 4)
            elif is_transform_diagonal(bijector):
                max_permitted = (2 if hasattr(bijector.diag_bijector,
                                              '_inverse_log_det_jacobian') else
                                 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)
コード例 #21
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')

        bijector, event_dim = self._draw_bijector(bijector_name, data)

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        # TODO(b/148459057): Except, don't verify Softfloor on Guitar because
        # of numerical problem.
        def exception(bijector):
            if not tfp_hps.running_under_guitar():
                return False
            if isinstance(bijector, tfb.Softfloor):
                return True
            if is_invert(bijector):
                return exception(bijector.bijector)
            return False

        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0
                and not exception(bijector)):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        ys = self._draw_codomain_tensor(bijector, data, event_dim)
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)

        # Verify that `_is_permutation` implies constant zero Jacobian.
        if bijector._is_permutation:
            self.assertTrue(bijector._is_constant_jacobian)
            self.assertAllEqual(ldj, 0.)

        # Verify correctness of batch shape.
        xs_batch_shapes = tf.nest.map_structure(
            lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs,
            bijector.inverse_event_ndims(event_ndims))
        empirical_batch_shape = functools.reduce(
            ps.broadcast_shape,
            nest.flatten_up_to(bijector.forward_min_event_ndims,
                               xs_batch_shapes))
        batch_shape = bijector.experimental_batch_shape(
            y_event_ndims=event_ndims)
        if tensorshape_util.is_fully_defined(batch_shape):
            self.assertAllEqual(empirical_batch_shape, batch_shape)
        self.assertAllEqual(
            empirical_batch_shape,
            bijector.experimental_batch_shape_tensor(
                y_event_ndims=event_ndims))

        # Check that the outputs of forward_dtype and inverse_dtype match the dtypes
        # of the outputs of forward and inverse.
        self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
        self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
コード例 #22
0
  def testDistribution(self, dist_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    tf1.set_random_seed(
        data.draw(
            hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0)))
    dist, batch_shape = data.draw(
        distributions(dist_name=dist_name, enable_vars=True))
    batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape))
    dist2, _ = data.draw(
        distributions(
            dist_name=dist_name,
            batch_shape=batch_shape2,
            event_dim=get_event_dim(dist),
            enable_vars=True))
    del batch_shape
    logging.info(
        'distribution: %s; parameters used: %s', dist,
        [k for k, v in six.iteritems(dist.parameters) if v is not None])
    self.evaluate([var.initializer for var in dist.variables])
    for k, v in six.iteritems(dist.parameters):
      if not tensor_util.is_mutable(v):
        continue
      try:
        self.assertIs(getattr(dist, k), v)
      except AssertionError as e:
        raise AssertionError(
            'No attr found for parameter {} of distribution {}: \n{}'.format(
                k, dist_name, e))

    for stat in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, stat)
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'statistic `{}` of `{}`'.format(stat, dist)):
          getattr(dist, stat)()

      except NotImplementedError:
        pass

    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `sample` of `{}`'.format(dist)):
        sample = dist.sample()
    if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
      grads = tape.gradient(sample, dist.variables)
      for grad, var in zip(grads, dist.variables):
        var_name = var.name.rstrip('_0123456789:')
        if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
          continue
        if grad is None:
          raise AssertionError(
              'Missing sample -> {} grad for distribution {}'.format(
                  var_name, dist_name))

    # Turn off validations, since log_prob can choke on dist's own samples.
    # Also, to relax conversion counts for KL (might do >2 w/ validate_args).
    dist = dist.copy(validate_args=False)
    dist2 = dist2.copy(validate_args=False)

    try:
      for d1, d2 in (dist, dist2), (dist2, dist):
        with tf.GradientTape() as tape:
          with tfp_hps.assert_no_excessive_var_usage(
              '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format(
                  d1, d1.variables, d2, d2.variables),
              max_permissible=1):  # No validation => 1 convert per var.
            kl = d1.kl_divergence(d2)
        wrt_vars = list(d1.variables) + list(d2.variables)
        grads = tape.gradient(kl, wrt_vars)
        for grad, var in zip(grads, wrt_vars):
          if grad is None and dist_name not in NO_KL_PARAM_GRADS:
            raise AssertionError('Missing KL({} || {}) -> {} grad:\n'
                                 '{} vars: {}\n{} vars: {}'.format(
                                     d1, d2, var, d1, d1.variables, d2,
                                     d2.variables))
    except NotImplementedError:
      pass

    if dist_name not in NO_LOG_PROB_PARAM_GRADS:
      with tf.GradientTape() as tape:
        lp = dist.log_prob(tf.stop_gradient(sample))
      grads = tape.gradient(lp, dist.variables)
      for grad, var in zip(grads, dist.variables):
        if grad is None:
          raise AssertionError(
              'Missing log_prob -> {} grad for distribution {}'.format(
                  var, dist_name))

    for evaluative in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, evaluative)
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'evaluative `{}` of `{}`'.format(evaluative, dist),
            max_permissible=1):  # No validation => 1 convert
          getattr(dist, evaluative)(sample)
      except NotImplementedError:
        pass