Exemplo n.º 1
0
    def testMatchWithAffineTransform(self):
        direct_bj = tfb.Tanh()
        indirect_bj = tfb.Chain([
            tfb.Shift(tf.cast(-1.0, dtype=tf.float64)),
            tfb.Scale(tf.cast(2.0, dtype=tf.float64)),
            tfb.Sigmoid(),
            tfb.Scale(tf.cast(2.0, dtype=tf.float64))
        ])

        x = np.linspace(-3.0, 3.0, 100)
        y = np.tanh(x)
        self.assertAllClose(self.evaluate(direct_bj.forward(x)),
                            self.evaluate(indirect_bj.forward(x)))
        self.assertAllClose(self.evaluate(direct_bj.inverse(y)),
                            self.evaluate(indirect_bj.inverse(y)))
        self.assertAllClose(
            self.evaluate(direct_bj.inverse_log_det_jacobian(y,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.inverse_log_det_jacobian(y, event_ndims=0)))
        self.assertAllClose(
            self.evaluate(direct_bj.forward_log_det_jacobian(x,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
 def testTransformedKLDifferentBijectorFails(self):
     d1 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=2.),
                      validate_args=True)
     d2 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=3.),
                      validate_args=True)
     with self.assertRaisesRegex(NotImplementedError,
                                 r'their bijectors are not equal'):
         tfd.kl_divergence(d1, d2)
Exemplo n.º 3
0
  def testMixedDtypeLogDetJacobian(self):
    bij = tfb.JointMap({
        'a': tfb.Scale(tf.constant(1, dtype=tf.float16)),
        'b': tfb.Scale(tf.constant(2, dtype=tf.float32)),
        'c': tfb.Scale(tf.constant(3, dtype=tf.float64))
    })

    fldj = bij.forward_log_det_jacobian(
        x={'a': 4, 'b': 5, 'c': 6},
        event_ndims=dict.fromkeys('abc', 0))
    self.assertDTypeEqual(fldj, np.float64)
    self.assertAllClose(np.log(1) + np.log(2) + np.log(3), self.evaluate(fldj))
Exemplo n.º 4
0
 def testCdfDescendingChained(self):
   bij1 = tfb.Shift(shift=1.)(tfb.Scale(scale=[1., -2.]))
   bij2 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[3.], [-5.]]))
   bij3 = tfb.Shift(shift=1.)(tfb.Scale(scale=[[[7.]], [[-11.]]]))
   for chain in bij2(bij1), bij3(bij2(bij1)):
     td = self._cls()(
         distribution=tfd.Normal(loc=0., scale=tf.ones([2, 2, 2])),
         bijector=chain,
         validate_args=True)
     nd = tfd.Normal(loc=1., scale=3., validate_args=True)
     self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool),
                         td.cdf(nd.quantile(.4)) < td.cdf(nd.quantile(.6)),
                         msg=chain.name)
Exemplo n.º 5
0
 def testTinyScale(self, dtype):
   log_scale = tf.cast(-2000., dtype)
   x = tf.cast(1., dtype)
   scale = tf.math.exp(log_scale)
   fldj_linear = tfb.Scale(scale=scale).forward_log_det_jacobian(
       x, event_ndims=0)
   fldj_log = tfb.Scale(log_scale=log_scale).forward_log_det_jacobian(
       x, event_ndims=0)
   fldj_linear_, fldj_log_ = self.evaluate([fldj_linear, fldj_log])
   # Using the linear scale will saturate to 0, and produce bad log-det
   # Jacobians.
   self.assertNotEqual(fldj_linear_, fldj_log_)
   self.assertAllClose(-2000., fldj_log_)
Exemplo n.º 6
0
    def test_end_to_end_works_correctly(self):
        true_mean = self.dtype([0, 0])
        true_cov = self.dtype([[1, 0.5], [0.5, 1]])
        num_results = 500

        def target_log_prob(x, y):
            # Corresponds to unnormalized MVN.
            # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
            z = tf.stack([x, y], axis=-1) - true_mean
            z = tf.squeeze(tf.linalg.triangular_solve(
                np.linalg.cholesky(true_cov), z[..., tf.newaxis]),
                           axis=-1)
            return -0.5 * tf.reduce_sum(z**2., axis=-1)

        transformed_hmc = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=tf.function(target_log_prob,
                                               autograph=False),
                # Affine scaling means we have to change the step_size
                # in order to get 60% acceptance, as was done in mcmc/hmc_test.py.
                step_size=[1.23 / 0.75, 1.23 / 0.5],
                num_leapfrog_steps=2),
            bijector=[
                tfb.Scale(scale=0.75),
                tfb.Scale(scale=0.5),
            ])
        # Recall, tfp.mcmc.sample_chain calls
        # transformed_hmc.bootstrap_results too.
        states, kernel_results = tfp.mcmc.sample_chain(
            num_results=num_results,
            # The initial state is used by inner_kernel.bootstrap_results.
            # Note the input is *after* `bijector.forward`.
            current_state=[self.dtype(-2), self.dtype(2)],
            kernel=transformed_hmc,
            num_burnin_steps=200,
            num_steps_between_results=1,
            seed=test_util.test_seed())
        states = tf.stack(states, axis=-1)
        self.assertEqual(num_results,
                         tf.compat.dimension_value(states.shape[0]))
        sample_mean = tf.reduce_mean(states, axis=0)
        x = states - sample_mean
        sample_cov = tf.matmul(x, x,
                               transpose_a=True) / self.dtype(num_results)
        [sample_mean_, sample_cov_, is_accepted_] = self.evaluate([
            sample_mean, sample_cov, kernel_results.inner_results.is_accepted
        ])
        self.assertAllClose(0.6, is_accepted_.mean(), atol=0.15, rtol=0.)
        self.assertAllClose(sample_mean_, true_mean, atol=0.2, rtol=0.)
        self.assertAllClose(sample_cov_, true_cov, atol=0., rtol=0.4)
 def test_dist_fn_takes_kwargs(self):
   dist = tfd.JointDistributionNamed(
       {'positive': tfd.Exponential(rate=1.),
        'negative': tfb.Scale(-1.)(tfd.Exponential(rate=1.)),
        'b': lambda **kwargs: tfd.Normal(loc=kwargs['negative'],  # pylint: disable=g-long-lambda
                                         scale=kwargs['positive'],
                                         validate_args=True),
        'a': lambda **kwargs: tfb.Scale(kwargs['b'])(  # pylint: disable=g-long-lambda
            tfd.Gamma(concentration=-kwargs['negative'],
                      rate=kwargs['positive'],
                      validate_args=True))
        }, validate_args=True)
   lp = dist.log_prob(dist.sample(5, seed=test_util.test_seed()))
   self.assertAllEqual(lp.shape, [5])
Exemplo n.º 8
0
    def test_nested_transform(self):
        target_dist = tfd.Normal(loc=0., scale=1.)
        b1 = tfb.Scale(0.5)
        b2 = tfb.Exp()
        chain = tfb.Chain([b2, b1
                           ])  # applies bijectors right to left (b1 then b2).
        inner_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_dist.log_prob,
                num_leapfrog_steps=27,
                step_size=10),
            bijector=b1)
        outer_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=inner_kernel, bijector=b2)
        chain_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_dist.log_prob,
                num_leapfrog_steps=27,
                step_size=10),
            bijector=chain)
        outer_pkr_one, outer_pkr_two = self.evaluate([
            outer_kernel.bootstrap_results(2.),
            outer_kernel.bootstrap_results(9.),
        ])

        # the outermost kernel only applies the outermost bijector
        self.assertNear(np.log(2.), outer_pkr_one.transformed_state, err=1e-6)
        self.assertNear(np.log(9.), outer_pkr_two.transformed_state, err=1e-6)

        chain_pkr_one, chain_pkr_two = self.evaluate([
            chain_kernel.bootstrap_results(2.),
            chain_kernel.bootstrap_results(9.),
        ])

        # all bijectors are applied to the inner kernel, from innermost to outermost
        # this behavior is completely analogous to a bijector Chain
        self.assertNear(chain_pkr_one.transformed_state,
                        outer_pkr_one.inner_results.transformed_state,
                        err=1e-6)
        self.assertEqual(
            chain_pkr_one.inner_results.accepted_results,
            outer_pkr_one.inner_results.inner_results.accepted_results)
        self.assertNear(chain_pkr_two.transformed_state,
                        outer_pkr_two.inner_results.transformed_state,
                        err=1e-6)
        self.assertEqual(
            chain_pkr_two.inner_results.accepted_results,
            outer_pkr_two.inner_results.inner_results.accepted_results)

        seed = test_util.test_seed(sampler_type='stateless')
        outer_results_one, outer_results_two = self.evaluate([
            outer_kernel.one_step(2., outer_pkr_one, seed=seed),
            outer_kernel.one_step(9., outer_pkr_two, seed=seed)
        ])
        chain_results_one, chain_results_two = self.evaluate([
            chain_kernel.one_step(2., chain_pkr_one, seed=seed),
            chain_kernel.one_step(9., chain_pkr_two, seed=seed)
        ])
        self.assertNear(chain_results_one[0], outer_results_one[0], err=1e-6)
        self.assertNear(chain_results_two[0], outer_results_two[0], err=1e-6)
Exemplo n.º 9
0
  def testCompositeTensor(self):
    exp = tfb.Exp()
    sp = tfb.Softplus()
    aff = tfb.Scale(scale=2.)
    bij = tfb.JointMap(bijectors=[exp, sp, aff])
    self.assertIsInstance(bij, tf.__internal__.CompositeTensor)

    # Bijector may be flattened into `Tensor` components and rebuilt.
    flat = tf.nest.flatten(bij, expand_composites=True)
    unflat = tf.nest.pack_sequence_as(bij, flat, expand_composites=True)
    self.assertIsInstance(unflat, tfb.JointMap)

    # Bijector may be input to a `tf.function`-decorated callable.
    @tf.function
    def call_forward(bij, x):
      return bij.forward(x)

    x = [1., 2., 3.]
    self.assertAllClose(call_forward(unflat, x), bij.forward(x))

    # Type spec can be encoded/decoded.
    struct_coder = tf.__internal__.saved_model.StructureCoder()
    enc = struct_coder.encode_structure(bij._type_spec)
    dec = struct_coder.decode_proto(enc)
    self.assertEqual(bij._type_spec, dec)
Exemplo n.º 10
0
  def test_single_part_str_repr_match_expected(self):
    bij = tfb.Exp()
    self.assertContainsInOrder(
        ['tfp.bijectors.Exp("exp", batch_shape=[], min_event_ndims=0)'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Exp 'exp' batch_shape=[] forward_min_event_ndims=0 "
         "inverse_min_event_ndims=0 dtype_x=? dtype_y=?>"],
        repr(bij))

    bij = tfb.Scale([1., 1.], name='myscale')
    self.assertContainsInOrder(
        ['tfp.bijectors.Scale("myscale", batch_shape=[2], min_event_ndims=0, '
         'dtype=float32)'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Scale 'myscale' batch_shape=[2] "
         "forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=float32 "
         "dtype_y=float32>"],
        repr(bij))

    bij = tfb.Split([3, 4, 2], name='s_p_l_i_t')
    self.assertContainsInOrder(
        ['tfp.bijectors.Split("s_p_l_i_t", batch_shape=[], '
         'forward_min_event_ndims=1, inverse_min_event_ndims=[1, 1, 1])'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Split 's_p_l_i_t' batch_shape=[] "
         "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] "
         "dtype_x=? dtype_y=[?, ?, ?]>"
         ], repr(bij))
Exemplo n.º 11
0
 def testNoBatchScale(self, is_static, dtype):
     bijector = tfb.Scale(scale=dtype(2.))
     x = self.maybe_static(np.array([1., 2, 3], dtype))
     self.assertAllClose([2., 4, 6], bijector.forward(x))
     self.assertAllClose([.5, 1, 1.5], bijector.inverse(x))
     self.assertAllClose(
         -np.log(2.), bijector.inverse_log_det_jacobian(x, event_ndims=0))
Exemplo n.º 12
0
 def testModifiedVariableScaleAssertion(self):
     v = tf.Variable(1.)
     self.evaluate(v.initializer)
     b = tfb.Scale(scale=v, validate_args=True)
     with self.assertRaisesOpError('Argument `scale` must be non-zero'):
         with tf.control_dependencies([v.assign(0.)]):
             _ = self.evaluate(b.forward(1.))
  def test_unknown_event_rank(self):
    if tf.executing_eagerly():
      self.skipTest('Eager execution.')
    unknown_rank_dist = tfd.Independent(
        tfd.Normal(loc=tf.ones([2, 1, 3]), scale=2.),
        reinterpreted_batch_ndims=tf1.placeholder_with_default(1, shape=[]))
    td = tfd.TransformedDistribution(
        distribution=unknown_rank_dist,
        bijector=tfb.Scale(1.),
        validate_args=True)
    self.assertEqual(td.batch_shape, tf.TensorShape(None))
    self.assertEqual(td.event_shape, tf.TensorShape(None))
    self.assertAllEqual(td.batch_shape_tensor(), [2, 1])
    self.assertAllEqual(td.event_shape_tensor(), [3])

    joint_td = tfd.TransformedDistribution(
        distribution=tfd.JointDistributionSequentialAutoBatched(
            [unknown_rank_dist, unknown_rank_dist]),
        bijector=tfb.Invert(tfb.Split(2)),
        validate_args=True)
    # Note that the current behavior is conservative; we could also correctly
    # return a batch shape of `[]` in this case.
    self.assertEqual(joint_td.batch_shape, tf.TensorShape(None))
    self.assertEqual(joint_td.event_shape, tf.TensorShape(None))
    self.assertAllEqual(joint_td.batch_shape_tensor(), [])
    self.assertAllEqual(joint_td.event_shape_tensor(), [2, 1, 6])
Exemplo n.º 14
0
 def testScalarCongruency(self, dtype):
   bijector = tfb.Scale(scale=dtype(0.42))
   bijector_test_util.assert_scalar_congruency(
       bijector,
       lower_x=dtype(-2.),
       upper_x=dtype(2.),
       eval_func=self.evaluate)
    def testStddev(self):
        base_stddev = 2.
        shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
        scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
        expected_stddev = tf.abs(base_stddev * scale)
        normal = self._cls()(
            distribution=tfd.Normal(loc=tf.zeros_like(shift),
                                    scale=base_stddev * tf.ones_like(scale),
                                    validate_args=True),
            bijector=tfb.Chain(
                [tfb.Shift(shift=shift),
                 tfb.Scale(scale=scale)],
                validate_args=True),
            validate_args=True)
        self.assertAllClose(expected_stddev, normal.stddev())
        self.assertAllClose(expected_stddev**2, normal.variance())

        split_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                   bijector=tfb.Split(3),
                                   validate_args=True)
        self.assertAllCloseNested(
            tf.split(expected_stddev, num_or_size_splits=3, axis=-1),
            split_normal.stddev())

        scaled_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                    bijector=tfb.ScaleMatvecTriL([[1., 0.],
                                                                  [-1., 2.]]),
                                    validate_args=True)
        with self.assertRaisesRegex(NotImplementedError,
                                    'is a multivariate transformation'):
            scaled_normal.stddev()
Exemplo n.º 16
0
 def test_bijector_constant_underlying_ildj(self):
     d = tfb.Scale([2., 3.])(tfd.Normal([0., 0.], 1.))
     bij = tfd.Sample(d, [3]).experimental_default_event_space_bijector()
     ildj = bij.inverse_log_det_jacobian(tf.zeros([2, 3]), event_ndims=1)
     self.assertAllClose(-np.log([2., 3.]) * 3, ildj)
     ildj = bij.inverse_log_det_jacobian(tf.zeros([2, 3]), event_ndims=2)
     self.assertAllClose(-np.log([2., 3.]).sum() * 3, ildj)
Exemplo n.º 17
0
    def testExcessiveConcretizationOfParamsBatchShapeOverride(self):
        # Test methods that are not implemented if event_shape is overriden.
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape))
        bij_scale = tfp_hps.defer_and_count_usage(
            tf.Variable(2.,
                        name='bij_scale',
                        dtype=tf.float32,
                        shape=self.shape))
        batch_shape = tfp_hps.defer_and_count_usage(
            tf.Variable([4, 3, 5],
                        name='input_batch_shape',
                        dtype=tf.int32,
                        shape=self.shape))
        dist = tfd.TransformedDistribution(
            distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True),
            bijector=tfb.Scale(scale=bij_scale, validate_args=True),
            batch_shape=batch_shape,
            validate_args=True)

        for method in ('log_cdf', 'cdf', 'survival_function',
                       'log_survival_function'):
            with tfp_hps.assert_no_excessive_var_usage(
                    method, max_permissible=self.max_permissible[method]):
                getattr(dist, method)(np.ones((4, 3, 2)) / 3.)

        with tfp_hps.assert_no_excessive_var_usage(
                'quantile', max_permissible=self.max_permissible['quantile']):
            dist.quantile(.1)
Exemplo n.º 18
0
    def testBijector(self):
        low = np.array([[-3.], [0.], [5.]]).astype(np.float32)
        high = 12.

        bijector = tfb.Sigmoid(low=low, high=high, validate_args=True)

        equivalent_bijector = tfb.Chain(
            [tfb.Shift(shift=low),
             tfb.Scale(scale=high - low),
             tfb.Sigmoid()])

        x = [[[1., 2., -5., -0.3]]]
        y = self.evaluate(equivalent_bijector.forward(x))
        self.assertAllClose(y, self.evaluate(bijector.forward(x)))
        self.assertAllClose(x,
                            self.evaluate(bijector.inverse(y)[..., :1, :]),
                            rtol=1e-5)
        self.assertAllClose(
            self.evaluate(
                equivalent_bijector.inverse_log_det_jacobian(y,
                                                             event_ndims=1)),
            self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)),
            rtol=1e-5)
        self.assertAllClose(
            self.evaluate(
                equivalent_bijector.forward_log_det_jacobian(x,
                                                             event_ndims=1)),
            self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
Exemplo n.º 19
0
 def testQuantileDescending(self):
     td = self._cls()(distribution=tfd.Normal(loc=0., scale=[1., 1.]),
                      bijector=tfb.Shift(shift=1.)(
                          tfb.Scale(scale=[2., -2.])),
                      validate_args=True)
     self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool),
                         td.quantile(.8) < td.quantile(.9))
Exemplo n.º 20
0
    def testCompositeTensor(self):
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.Scale(scale=2.)
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff])
        self.assertIsInstance(blockwise, tf.__internal__.CompositeTensor)

        # Bijector may be flattened into `Tensor` components and rebuilt.
        flat = tf.nest.flatten(blockwise, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(blockwise,
                                          flat,
                                          expand_composites=True)
        self.assertIsInstance(unflat, tfb.Blockwise)

        # Bijector may be input to a `tf.function`-decorated callable.
        @tf.function
        def call_forward(bij, x):
            return bij.forward(x)

        x = tf.ones([2, 3], dtype=tf.float32)
        self.assertAllClose(call_forward(unflat, x), blockwise.forward(x))

        # Type spec can be encoded/decoded.
        enc = tf.__internal__.saved_model.encode_structure(
            blockwise._type_spec)
        dec = tf.__internal__.saved_model.decode_proto(enc)
        self.assertEqual(blockwise._type_spec, dec)
Exemplo n.º 21
0
  def testExcessiveConcretizationOfParams(self):
    loc = tfp_hps.defer_and_count_usage(
        tf.Variable(0., name='loc', dtype=tf.float32, shape=self.shape))
    scale = tfp_hps.defer_and_count_usage(
        tf.Variable(2., name='scale', dtype=tf.float32, shape=self.shape))
    bij_scale = tfp_hps.defer_and_count_usage(
        tf.Variable(2., name='bij_scale', dtype=tf.float32, shape=self.shape))
    event_shape = tfp_hps.defer_and_count_usage(
        tf.Variable([2, 2], name='input_event_shape', dtype=tf.int32,
                    shape=self.shape))
    batch_shape = tfp_hps.defer_and_count_usage(
        tf.Variable([4, 3, 5], name='input_batch_shape', dtype=tf.int32,
                    shape=self.shape))

    dist = tfd.TransformedDistribution(
        distribution=tfd.Normal(loc=loc, scale=scale, validate_args=True),
        bijector=tfb.Scale(scale=bij_scale, validate_args=True),
        event_shape=event_shape,
        batch_shape=batch_shape,
        validate_args=True)

    for method in ('mean', 'entropy', 'event_shape_tensor',
                   'batch_shape_tensor'):
      with tfp_hps.assert_no_excessive_var_usage(
          method, max_permissible=self.max_permissible[method]):
        getattr(dist, method)()

    with tfp_hps.assert_no_excessive_var_usage(
        'sample', max_permissible=self.max_permissible['sample']):
      dist.sample(seed=test_util.test_seed())

    for method in ('log_prob', 'prob'):
      with tfp_hps.assert_no_excessive_var_usage(
          method, max_permissible=self.max_permissible[method]):
        getattr(dist, method)(np.ones((4, 3, 5, 2, 2)) / 3.)
Exemplo n.º 22
0
    def testBijectorWithDeepStructure(self):
        bij = tfb.JointMap({
            'a': tfb.Exp(),
            'bc': tfb.JointMap([tfb.Scale(2.), tfb.Shift(3.)])
        })

        a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32)  # shape=[1, 2, 2]
        b = np.asarray([[0, 4]], dtype=np.float32)  # shape=[1, 2]
        c = np.asarray([[5, 6]], dtype=np.float32)  # shape=[1, 2]

        inputs = {
            'a': a,
            'bc': [b, c]
        }  # Could be inputs to forward or inverse.
        event_ndims = {'a': 1, 'bc': [0, 0]}

        self.assertStartsWith(bij.name, 'jointmap_of_exp_and_jointmap_of_')
        self.assertAllCloseNested({
            'a': np.exp(a),
            'bc': [b * 2., c + 3]
        }, self.evaluate(bij.forward(inputs)))
        self.assertAllCloseNested({
            'a': np.log(a),
            'bc': [b / 2., c - 3]
        }, self.evaluate(bij.inverse(inputs)))

        fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims))
        self.assertEqual((1, 2), fldj.shape)
        self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj)

        ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims))
        self.assertEqual((1, 2), ildj.shape)
        self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
Exemplo n.º 23
0
    def testBatchShapeBroadcasts(self):
        bij = tfb.JointMap({
            'a': tfb.Exp(),
            'b': tfb.Scale(10.)
        },
                           validate_args=True)
        self.assertStartsWith(bij.name, 'jointmap_of_exp_and_scale')

        a = np.asarray([[[1, 2]], [[2, 3]]],
                       dtype=np.float32)  # shape=[2, 1, 2]
        b = np.asarray([[0, 1, 2]], dtype=np.float32)  # shape=[1, 3]

        inputs = {'a': a, 'b': b}  # Could be inputs to forward or inverse.

        self.assertAllClose(
            a.sum(axis=-1) + np.log(10.),
            self.evaluate(
                bij.forward_log_det_jacobian(inputs, {
                    'a': 1,
                    'b': 0
                })))

        self.assertAllClose(
            a.sum(axis=-1) + 3 * np.log(10.),
            self.evaluate(
                bij.forward_log_det_jacobian(inputs, {
                    'a': 1,
                    'b': 1
                })))
Exemplo n.º 24
0
    def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector:
        """
    Linear bijection between $[0, 1]^{2} <--> [0, 4]^{2}$
    """
        if dtype is None:
            dtype = default_float()

        return tfb.Scale(tf.cast(4.0, dtype=dtype))
Exemplo n.º 25
0
 def testCdfDescending(self):
   td = tfd.TransformedDistribution(
       distribution=tfd.Normal(loc=0., scale=[1., 1.]),
       bijector=tfb.Shift(shift=1.)(tfb.Scale(scale=[2., -2.])),
       validate_args=True)
   nd = tfd.Normal(loc=1., scale=2., validate_args=True)
   self.assertAllEqual(tf.ones(td.batch_shape, dtype=tf.bool),
                       td.cdf(nd.quantile(.8)) < td.cdf(nd.quantile(.9)))
Exemplo n.º 26
0
    def testLDJRatio(self):
        q = tfb.JointMap({
            'a': tfb.Exp(),
            'b': tfb.Scale(2.),
            'c': tfb.Shift(3.)
        })
        p = tfb.JointMap({
            'a': tfb.Exp(),
            'b': tfb.Scale(3.),
            'c': tfb.Shift(4.)
        })

        a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32)  # shape=[1, 2, 2]
        b = np.asarray([[0, 4]], dtype=np.float32)  # shape=[1, 2]
        c = np.asarray([[5, 6]], dtype=np.float32)  # shape=[1, 2]

        x = {'a': a, 'b': b, 'c': c}
        y = {'a': a + 1, 'b': b + 1, 'c': c + 1}
        event_ndims = {'a': 1, 'b': 0, 'c': 0}

        fldj_ratio_true = p.forward_log_det_jacobian(
            x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims)
        fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio(
            p, x, q, y, event_ndims)
        self.assertAllClose(fldj_ratio_true, fldj_ratio)

        ildj_ratio_true = p.inverse_log_det_jacobian(
            x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims)
        ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio(
            p, x, q, y, event_ndims)
        self.assertAllClose(ildj_ratio_true, ildj_ratio)

        event_ndims = {'a': 1, 'b': 2, 'c': 0}

        fldj_ratio_true = p.forward_log_det_jacobian(
            x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims)
        fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio(
            p, x, q, y, event_ndims)
        self.assertAllClose(fldj_ratio_true, fldj_ratio)

        ildj_ratio_true = p.inverse_log_det_jacobian(
            x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims)
        ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio(
            p, x, q, y, event_ndims)
        self.assertAllClose(ildj_ratio_true, ildj_ratio)
Exemplo n.º 27
0
 def testBatchScale(self, is_static, dtype):
     # Batched scale
     bijector = tfb.Scale(scale=dtype([2., 3.]))
     x = self.maybe_static(np.array([1.], dtype=dtype))
     self.assertAllClose([2., 3.], bijector.forward(x))
     self.assertAllClose([0.5, 1. / 3.], bijector.inverse(x))
     self.assertAllClose([-np.log(2.), -np.log(3.)],
                         bijector.inverse_log_det_jacobian(x,
                                                           event_ndims=0))
Exemplo n.º 28
0
  def testNestedDtype(self):
    chain = tfb.Chain([
        tfb.Identity(),
        tfb.Scale(tf.constant(2., tf.float64)),
        tfb.Identity()
    ])

    self.assertAllClose(tf.constant([2, 4, 6], tf.float64),
                        self.evaluate(chain.forward([1, 2, 3])))
Exemplo n.º 29
0
 def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector:
     """
 Affine bijection between $[[0, 1], [0, 1]] <--> [[-2.5, 2.5], [-1.0, 2.0]]$
 """
     if dtype is None:
         dtype = default_float()
     scale = tfb.Scale(tf.convert_to_tensor([5.0, 3.0], dtype=dtype))
     shift = tfb.Shift(tf.convert_to_tensor([-0.5, -1 / 3], dtype=dtype))
     return tfb.Chain([scale, shift])
Exemplo n.º 30
0
 def testScalarBatchScalarEventIdentityScale(self):
   exp2 = self._cls()(
       tfd.Exponential(rate=0.25),
       bijector=tfb.Scale(scale=2.),
       validate_args=True)
   log_prob = exp2.log_prob(1.)
   log_prob_ = self.evaluate(log_prob)
   base_log_prob = -0.5 * 0.25 + np.log(0.25)
   ildj = np.log(2.)
   self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.)