Beispiel #1
0
    def testGradientsThroughSample(self, process_name, data):
        tfp_hps.guitar_skip_if_matches('VariationalGaussianProcess',
                                       process_name, 'b/147770193')
        process = data.draw(
            stochastic_processes(process_name=process_name, enable_vars=True))
        self.evaluate([var.initializer for var in process.variables])

        with tf.GradientTape() as tape:
            sample = process.sample()
        if process.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, process.variables)
            for grad, var in zip(grads, process.variables):
                self.assertIsNotNone(
                    grad, 'Grad of sample was `None` for var: {}.'.format(var))
Beispiel #2
0
    def testGradientsThroughLogProb(self, process_name, data):
        tfp_hps.guitar_skip_if_matches('VariationalGaussianProcess',
                                       process_name, 'b/147770193')
        process = data.draw(
            stochastic_processes(process_name=process_name, enable_vars=True))
        self.evaluate([var.initializer for var in process.variables])

        # Test that log_prob produces non-None gradients.
        sample = process.sample()
        with tf.GradientTape() as tape:
            lp = process.log_prob(sample)
        grads = tape.gradient(lp, process.variables)
        for grad, var in zip(grads, process.variables):
            self.assertIsNotNone(
                grad, 'Grad of log_prob was `None` for var: {}.'.format(var))
Beispiel #3
0
  def testExcessiveConcretizationInLogProb(self, process_name, data):
    # Check that log_prob computations avoid reading process parameters
    # more than once.
    tfp_hps.guitar_skip_if_matches(
        'VariationalGaussianProcess', process_name, 'b/147770193')
    process = data.draw(stochastic_processes(
        process_name=process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    hp.note('Testing excessive var usage in {}.log_prob'.format(process_name))
    sample = process.sample()
    try:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `log_prob` of `{}`'.format(process),
          max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)):
        process.log_prob(sample)
    except NotImplementedError:
      pass
Beispiel #4
0
  def testExcessiveConcretizationInZeroArgPublicMethods(
      self, process_name, data):
    tfp_hps.guitar_skip_if_matches(
        'VariationalGaussianProcess', process_name, 'b/147770193')
    # Check that standard statistics do not concretize variables/deferred
    # tensors more than the allowed amount.
    process = data.draw(stochastic_processes(process_name, enable_vars=True))
    self.evaluate([var.initializer for var in process.variables])

    for stat in ['mean', 'covariance', 'stddev', 'variance', 'sample']:
      hp.note('Testing excessive concretization in {}.{}'.format(process_name,
                                                                 stat))
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'method `{}` of `{}`'.format(stat, process),
            max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)):
          getattr(process, stat)()

      except NotImplementedError:
        pass
Beispiel #5
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')

        bijector, event_dim = self._draw_bijector(bijector_name, data)

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        # TODO(b/148459057): Except, don't verify Softfloor on Guitar because
        # of numerical problem.
        def exception(bijector):
            if not tfp_hps.running_under_guitar():
                return False
            if isinstance(bijector, tfb.Softfloor):
                return True
            if is_invert(bijector):
                return exception(bijector.bijector)
            return False

        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0
                and not exception(bijector)):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        ys = self._draw_codomain_tensor(bijector, data, event_dim)
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)

        # Verify that `_is_permutation` implies constant zero Jacobian.
        if bijector._is_permutation:
            self.assertTrue(bijector._is_constant_jacobian)
            self.assertAllEqual(ldj, 0.)

        # Verify correctness of batch shape.
        xs_batch_shapes = tf.nest.map_structure(
            lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs,
            bijector.inverse_event_ndims(event_ndims))
        empirical_batch_shape = functools.reduce(
            ps.broadcast_shape,
            nest.flatten_up_to(bijector.forward_min_event_ndims,
                               xs_batch_shapes))
        batch_shape = bijector.experimental_batch_shape(
            y_event_ndims=event_ndims)
        if tensorshape_util.is_fully_defined(batch_shape):
            self.assertAllEqual(empirical_batch_shape, batch_shape)
        self.assertAllEqual(
            empirical_batch_shape,
            bijector.experimental_batch_shape_tensor(
                y_event_ndims=event_ndims))

        # Check that the outputs of forward_dtype and inverse_dtype match the dtypes
        # of the outputs of forward and inverse.
        self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
        self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
Beispiel #6
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        event_dim = data.draw(hps.integers(min_value=2, max_value=6))
        bijector = data.draw(
            bijectors(bijector_name=bijector_name,
                      event_dim=event_dim,
                      enable_vars=True))
        self.evaluate(tf.group(*[v.initializer for v in bijector.variables]))

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        # TODO(axch): Would be nice to get rid of all this shape inference logic and
        # just rely on a notion of batch and event shape for bijectors, so we can
        # pass those through `domain_tensors` and `codomain_tensors` and use
        # `tensors_in_support`.  However, `RationalQuadraticSpline` behaves weirdly
        # somehow and I got confused.
        codomain_event_shape = [event_dim] * bijector.inverse_min_event_ndims
        codomain_event_shape = constrain_inverse_shape(bijector,
                                                       codomain_event_shape)
        shp = bijector.inverse_event_shape(codomain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.forward_min_event_ndims])),
            shp[shp.ndims - bijector.forward_min_event_ndims:])
        xs = tf.identity(data.draw(domain_tensors(bijector, shape=shp)),
                         name='xs')
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_forward_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_inverse_log_det_jacobian') else 4)
            elif is_transform_diagonal(bijector):
                max_permitted = (2 if hasattr(bijector.diag_bijector,
                                              '_forward_log_det_jacobian') else
                                 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        domain_event_shape = [event_dim] * bijector.forward_min_event_ndims
        domain_event_shape = constrain_forward_shape(bijector,
                                                     domain_event_shape)
        shp = bijector.forward_event_shape(domain_event_shape)
        shp = tensorshape_util.concatenate(
            data.draw(
                tfp_hps.broadcast_compatible_shape(
                    shp[:shp.ndims - bijector.inverse_min_event_ndims])),
            shp[shp.ndims - bijector.inverse_min_event_ndims:])
        ys = tf.identity(data.draw(codomain_tensors(bijector, shape=shp)),
                         name='ys')
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = 2 if hasattr(bijector,
                                         '_inverse_log_det_jacobian') else 4
            if is_invert(bijector):
                max_permitted = (2 if hasattr(
                    bijector.bijector, '_forward_log_det_jacobian') else 4)
            elif is_transform_diagonal(bijector):
                max_permitted = (2 if hasattr(bijector.diag_bijector,
                                              '_inverse_log_det_jacobian') else
                                 4)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)