Example #1
0
    def testMinEventNdimsWithPartiallyDependentJointMap(self):

        dependent = tfb.Chain([tfb.Split(2), tfb.Invert(tfb.Split(2))])
        wrap_in_list = tfb.Restructure(input_structure=[0, 1],
                                       output_structure=[[0, 1]])
        dependent_as_chain = tfb.Chain([
            tfb.Invert(wrap_in_list),
            tfb.JointMap([dependent]), wrap_in_list
        ])
        self.assertAllEqualNested(dependent.forward_min_event_ndims,
                                  dependent_as_chain.forward_min_event_ndims)
        self.assertAllEqualNested(dependent.inverse_min_event_ndims,
                                  dependent_as_chain.inverse_min_event_ndims)
        self.assertAllEqualNested(dependent._parts_interact,
                                  dependent_as_chain._parts_interact)
Example #2
0
    def test_nested_transform(self):
        target_dist = tfd.Normal(loc=0., scale=1.)
        b1 = tfb.Scale(0.5)
        b2 = tfb.Exp()
        chain = tfb.Chain([b2, b1
                           ])  # applies bijectors right to left (b1 then b2).
        inner_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_dist.log_prob,
                num_leapfrog_steps=27,
                step_size=10),
            bijector=b1)
        outer_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=inner_kernel, bijector=b2)
        chain_kernel = tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_dist.log_prob,
                num_leapfrog_steps=27,
                step_size=10),
            bijector=chain)
        outer_pkr_one, outer_pkr_two = self.evaluate([
            outer_kernel.bootstrap_results(2.),
            outer_kernel.bootstrap_results(9.),
        ])

        # the outermost kernel only applies the outermost bijector
        self.assertNear(np.log(2.), outer_pkr_one.transformed_state, err=1e-6)
        self.assertNear(np.log(9.), outer_pkr_two.transformed_state, err=1e-6)

        chain_pkr_one, chain_pkr_two = self.evaluate([
            chain_kernel.bootstrap_results(2.),
            chain_kernel.bootstrap_results(9.),
        ])

        # all bijectors are applied to the inner kernel, from innermost to outermost
        # this behavior is completely analogous to a bijector Chain
        self.assertNear(chain_pkr_one.transformed_state,
                        outer_pkr_one.inner_results.transformed_state,
                        err=1e-6)
        self.assertEqual(
            chain_pkr_one.inner_results.accepted_results,
            outer_pkr_one.inner_results.inner_results.accepted_results)
        self.assertNear(chain_pkr_two.transformed_state,
                        outer_pkr_two.inner_results.transformed_state,
                        err=1e-6)
        self.assertEqual(
            chain_pkr_two.inner_results.accepted_results,
            outer_pkr_two.inner_results.inner_results.accepted_results)

        seed = test_util.test_seed(sampler_type='stateless')
        outer_results_one, outer_results_two = self.evaluate([
            outer_kernel.one_step(2., outer_pkr_one, seed=seed),
            outer_kernel.one_step(9., outer_pkr_two, seed=seed)
        ])
        chain_results_one, chain_results_two = self.evaluate([
            chain_kernel.one_step(2., chain_pkr_one, seed=seed),
            chain_kernel.one_step(9., chain_pkr_two, seed=seed)
        ])
        self.assertNear(chain_results_one[0], outer_results_one[0], err=1e-6)
        self.assertNear(chain_results_two[0], outer_results_two[0], err=1e-6)
Example #3
0
    def testMatchWithAffineTransform(self):
        direct_bj = tfb.Tanh()
        indirect_bj = tfb.Chain([
            tfb.Shift(tf.cast(-1.0, dtype=tf.float64)),
            tfb.Scale(tf.cast(2.0, dtype=tf.float64)),
            tfb.Sigmoid(),
            tfb.Scale(tf.cast(2.0, dtype=tf.float64))
        ])

        x = np.linspace(-3.0, 3.0, 100)
        y = np.tanh(x)
        self.assertAllClose(self.evaluate(direct_bj.forward(x)),
                            self.evaluate(indirect_bj.forward(x)))
        self.assertAllClose(self.evaluate(direct_bj.inverse(y)),
                            self.evaluate(indirect_bj.inverse(y)))
        self.assertAllClose(
            self.evaluate(direct_bj.inverse_log_det_jacobian(y,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.inverse_log_det_jacobian(y, event_ndims=0)))
        self.assertAllClose(
            self.evaluate(direct_bj.forward_log_det_jacobian(x,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
Example #4
0
    def testBijector(self):
        low = np.array([[-3.], [0.], [5.]]).astype(np.float32)
        high = 12.

        bijector = tfb.Sigmoid(low=low, high=high, validate_args=True)

        equivalent_bijector = tfb.Chain(
            [tfb.Shift(shift=low),
             tfb.Scale(scale=high - low),
             tfb.Sigmoid()])

        x = [[[1., 2., -5., -0.3]]]
        y = self.evaluate(equivalent_bijector.forward(x))
        self.assertAllClose(y, self.evaluate(bijector.forward(x)))
        self.assertAllClose(x,
                            self.evaluate(bijector.inverse(y)[..., :1, :]),
                            rtol=1e-5)
        self.assertAllClose(
            self.evaluate(
                equivalent_bijector.inverse_log_det_jacobian(y,
                                                             event_ndims=1)),
            self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)),
            rtol=1e-5)
        self.assertAllClose(
            self.evaluate(
                equivalent_bijector.forward_log_det_jacobian(x,
                                                             event_ndims=1)),
            self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
  def test_batch_broadcast_vector_to_parts(self):
    batch_shape = [4, 2]
    true_split_sizes = [1, 3, 2]

    base_event_size = sum(true_split_sizes)
    # Base dist with no batch shape (will require broadcasting).
    base_dist = tfd.MultivariateNormalDiag(
        loc=tf.random.normal([base_event_size], seed=test_util.test_seed()),
        scale_diag=tf.exp(tf.random.normal([base_event_size],
                                           seed=test_util.test_seed())))

    # Bijector with batch shape in one part.
    bijector = tfb.Chain([tfb.JointMap([tfb.Identity(),
                                        tfb.Identity(),
                                        tfb.Shift(
                                            tf.ones(batch_shape +
                                                    [true_split_sizes[-1]]))]),
                          tfb.Split(true_split_sizes, axis=-1)])
    split_dist = tfd.TransformedDistribution(base_dist, bijector)
    self.assertAllEqual(split_dist.batch_shape, batch_shape)

    # Because one branch of the split has batch shape, TD should feed batches
    # of base samples *into* the split, so the batch shape propagates to all
    # branches.
    xs = split_dist.sample(seed=test_util.test_seed())
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape, xs),
        [batch_shape + [d] for d in true_split_sizes])
Example #6
0
  def testNameScopeRefersToInitialScope(self):
    if tf.executing_eagerly():
      self.skipTest('Eager mode.')

    outer_bijector = tfb.Exp(name='Exponential')
    self.assertStartsWith(outer_bijector.name, 'Exponential')

    with tf.name_scope('inside'):
      inner_bijector = tfb.Exp(name='Exponential')
      self.assertStartsWith(inner_bijector.name, 'Exponential')

      self.assertStartsWith(inner_bijector.forward(0., name='x').name,
                            'inside/Exponential/x')
      self.assertStartsWith(outer_bijector.forward(0., name='x').name,
                            'inside/Exponential_CONSTRUCTED_AT_top_level/x')

      meta_bijector = tfb.Chain([inner_bijector], name='meta_bijector')
      # Check for spurious `_CONSTRUCTED_AT_`.
      self.assertStartsWith(
          meta_bijector.forward(0., name='x').name,
          'inside/meta_bijector/x/Exponential/forward')

    # Outside the scope.
    self.assertStartsWith(inner_bijector.forward(0., name='x').name,
                          'Exponential_CONSTRUCTED_AT_inside/x')
    self.assertStartsWith(outer_bijector.forward(0., name='x').name,
                          'Exponential/x')
    # Check that init scope is annotated only for the toplevel bijector.
    self.assertStartsWith(
        meta_bijector.forward(0., name='x').name,
        'meta_bijector_CONSTRUCTED_AT_inside/x/Exponential/forward')
Example #7
0
 def _make_reshaped_bijector(b, s):
     return tfb.Chain([
         tfb.Reshape(event_shape_in=s,
                     event_shape_out=[ps.reduce_prod(s)]),
         b,
         tfb.Reshape(event_shape_out=b.inverse_event_shape(s)),
     ])
Example #8
0
 def testChainIldjWithPlaceholder(self):
   chain = tfb.Chain((tfb.Exp(), tfb.Exp()))
   samples = tf.placeholder(dtype=np.float32, shape=[None, 10], name="samples")
   ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
   self.assertTrue(ildj is not None)
   with self.cached_session():
     ildj.eval({samples: np.zeros([2, 10], np.float32)})
Example #9
0
 def testScalarCongruency(self):
     with self.test_session():
         chain = tfb.Chain((tfb.Exp(), tfb.Softplus()))
         assert_scalar_congruency(chain,
                                  lower_x=1e-3,
                                  upper_x=1.5,
                                  rtol=0.05)
Example #10
0
 def testMinEventNdimsShapeChangingAddRemoveDims(self):
     chain = tfb.Chain(
         [ShapeChanging(2, 1),
          ShapeChanging(3, 0),
          ShapeChanging(1, 2)])
     self.assertEqual(4, chain.forward_min_event_ndims)
     self.assertEqual(1, chain.inverse_min_event_ndims)
    def testStddev(self):
        base_stddev = 2.
        shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
        scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
        expected_stddev = tf.abs(base_stddev * scale)
        normal = self._cls()(
            distribution=tfd.Normal(loc=tf.zeros_like(shift),
                                    scale=base_stddev * tf.ones_like(scale),
                                    validate_args=True),
            bijector=tfb.Chain(
                [tfb.Shift(shift=shift),
                 tfb.Scale(scale=scale)],
                validate_args=True),
            validate_args=True)
        self.assertAllClose(expected_stddev, normal.stddev())
        self.assertAllClose(expected_stddev**2, normal.variance())

        split_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                   bijector=tfb.Split(3),
                                   validate_args=True)
        self.assertAllCloseNested(
            tf.split(expected_stddev, num_or_size_splits=3, axis=-1),
            split_normal.stddev())

        scaled_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                    bijector=tfb.ScaleMatvecTriL([[1., 0.],
                                                                  [-1., 2.]]),
                                    validate_args=True)
        with self.assertRaisesRegex(NotImplementedError,
                                    'is a multivariate transformation'):
            scaled_normal.stddev()
Example #12
0
  def testCompositeTensor(self):
    exp = tfb.Exp()
    sp = tfb.Softplus()
    aff = tfb.Scale(scale=2.)
    chain = tfb.Chain(bijectors=[exp, sp, aff])
    self.assertIsInstance(chain, tf.__internal__.CompositeTensor)

    # Bijector may be flattened into `Tensor` components and rebuilt.
    flat = tf.nest.flatten(chain, expand_composites=True)
    unflat = tf.nest.pack_sequence_as(chain, flat, expand_composites=True)
    self.assertIsInstance(unflat, tfb.Chain)

    # Bijector may be input to a `tf.function`-decorated callable.
    @tf.function
    def call_forward(bij, x):
      return bij.forward(x)

    x = tf.ones([2, 3], dtype=tf.float32)
    self.assertAllClose(call_forward(unflat, x), chain.forward(x))

    # TypeSpec can be encoded/decoded.
    struct_coder = tf.__internal__.saved_model.StructureCoder()
    enc = struct_coder.encode_structure(chain._type_spec)
    dec = struct_coder.decode_proto(enc)
    self.assertEqual(chain._type_spec, dec)
Example #13
0
    def testMinEventNdimsShapeChangingRemoveDims(self):
        chain = tfb.Chain([ShapeChanging(3, 0)])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain([ShapeChanging(3, 0), tfb.Affine()])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain([tfb.Affine(), ShapeChanging(3, 0)])
        self.assertEqual(4, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)])
        self.assertEqual(6, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)
Example #14
0
  def testNonCompositeTensor(self):

    class NonCompositeScale(tfb.Bijector):
      """Bijector that is not a `CompositeTensor`."""

      def __init__(self, scale):
        parameters = dict(locals())
        self.scale = scale
        super(NonCompositeScale, self).__init__(
            validate_args=True,
            forward_min_event_ndims=0.,
            parameters=parameters,
            name="non_composite_scale")

      def _forward(self, x):
        return x * self.scale

      def _inverse(self, y):
        return y / self.scale

    exp = tfb.Exp()
    scale = NonCompositeScale(scale=tf.constant(3.))
    chain = tfb.Chain(bijectors=[exp, scale])
    self.assertNotIsInstance(chain, tf.__internal__.CompositeTensor)
    self.assertAllClose(chain.forward([1.]), exp.forward(scale.forward([1.])))
Example #15
0
def _build_inference_bijector(parameter):
    """Return a scaling-and-support bijector for inference.

  By default, this is just `param.bijector`, which transforms a real-value input
  to the parameter's support.

  For scale parameters (heuristically detected as any param with a Softplus
  support bijector), we also rescale by the prior stddev. This is
  approximately equivalent to performing inference on a standardized input
  `observed_time_series/stddev(observed_time_series)`, because:
   a) rescaling all the scale parameters is equivalent (gives equivalent
      forecasts, etc) to rescaling the `observed_time_series`.
   b) the default scale priors in STS components have stddev proportional to
     `stddev(observed_time_series)`.

  Args:
    parameter: `sts.Parameter` named tuple instance.
  Returns:
    bijector: a `tfb.Bijector` instance to use in inference.
  """
    if isinstance(parameter.bijector, tfb.Softplus):
        try:
            # Use mean + stddev, rather than just stddev, to ensure a reasonable
            # init if the user passes a crazy custom prior like N(100000, 0.001).
            prior_scale = tf.abs(
                parameter.prior.mean()) + parameter.prior.stddev()
            return tfb.Chain(
                [tfb.AffineScalar(scale=prior_scale), parameter.bijector])
        except NotImplementedError:  # Custom prior with no mean and/or stddev.
            pass
    return parameter.bijector
Example #16
0
 def testScalarCongruency(self):
     chain = tfb.Chain((tfb.Exp(), tfb.Softplus()))
     bijector_test_util.assert_scalar_congruency(chain,
                                                 lower_x=1e-3,
                                                 upper_x=1.5,
                                                 rtol=0.05,
                                                 eval_func=self.evaluate)
Example #17
0
    def testMatchWithAffineTransform(self):
        direct_bj = tfb.Tanh()
        indirect_bj = tfb.Chain([
            tfb.AffineScalar(shift=tf.to_double(-1.0),
                             scale=tf.to_double(2.0)),
            tfb.Sigmoid(),
            tfb.AffineScalar(scale=tf.to_double(2.0))
        ])

        x = np.linspace(-3.0, 3.0, 100)
        y = np.tanh(x)
        self.assertAllClose(self.evaluate(direct_bj.forward(x)),
                            self.evaluate(indirect_bj.forward(x)))
        self.assertAllClose(self.evaluate(direct_bj.inverse(y)),
                            self.evaluate(indirect_bj.inverse(y)))
        self.assertAllClose(
            self.evaluate(direct_bj.inverse_log_det_jacobian(y,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.inverse_log_det_jacobian(y, event_ndims=0)))
        self.assertAllClose(
            self.evaluate(direct_bj.forward_log_det_jacobian(x,
                                                             event_ndims=0)),
            self.evaluate(
                indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
Example #18
0
 def testChainIldjWithPlaceholder(self):
   chain = tfb.Chain((tfb.Exp(), tfb.Exp()))
   samples = tf1.placeholder_with_default(
       np.zeros([2, 10], np.float32), shape=None)
   ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
   self.assertIsNotNone(ildj)
   self.evaluate(ildj)
Example #19
0
 def testNonCompositeTensor(self):
     exp = tfb.Exp()
     scale = test_util.NonCompositeTensorScale(scale=tf.constant(3.))
     chain = tfb.Chain(bijectors=[exp, scale])
     self.assertNotIsInstance(chain, tf.__internal__.CompositeTensor)
     self.assertAllClose(chain.forward([1.]),
                         exp.forward(scale.forward([1.])))
Example #20
0
 def testInvalidChainNdimsRaisesError(self):
     with self.assertRaisesRegexp(
             ValueError,
             "Differences between `event_ndims` and `min_event_ndims must be equal"
     ):
         tfb.Chain(
             [ShapeChanging([1, 1], [1, 1]),
              ShapeChanging([1, 1], [2, 1])])
Example #21
0
    def test_composition_str_and_repr_match_expected_dynamic_shape(self):
        bij = tfb.Chain([
            tfb.Exp(),
            tfb.Shift(self._tensor([1., 2.])),
            tfb.SoftmaxCentered()
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 '
             'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'),
            '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered',
            '>]>'
        ], repr(bij))

        bij = tfb.Chain([
            tfb.JointMap({
                'a': tfb.Exp(),
                'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.]))
            }),
            tfb.Restructure({
                'a': 0,
                'b': 1
            }, [0, 1]),
            tfb.Split(2),
            tfb.Invert(tfb.SoftmaxCentered()),
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('forward_min_event_ndims=1, '
             'inverse_min_event_ndims={a: 1, b: 1}, '
             'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), '
             'Restructure, Split, Invert(SoftmaxCentered)])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 '
             "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 "
             "dtype_y={'a': ?, 'b': float32} "
             "bijectors=[<tfp.bijectors.JointMap "),
            '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split',
            '>, <tfp.bijectors.Invert', '>]>'
        ], repr(bij))
Example #22
0
  def testNestedDtype(self):
    chain = tfb.Chain([
        tfb.Identity(),
        tfb.Scale(tf.constant(2., tf.float64)),
        tfb.Identity()
    ])

    self.assertAllClose(tf.constant([2, 4, 6], tf.float64),
                        self.evaluate(chain.forward([1, 2, 3])))
Example #23
0
 def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector:
     """
 Affine bijection between $[[0, 1], [0, 1]] <--> [[-2.5, 2.5], [-1.0, 2.0]]$
 """
     if dtype is None:
         dtype = default_float()
     scale = tfb.Scale(tf.convert_to_tensor([5.0, 3.0], dtype=dtype))
     shift = tfb.Shift(tf.convert_to_tensor([-0.5, -1 / 3], dtype=dtype))
     return tfb.Chain([scale, shift])
Example #24
0
    def testMinEventNdimsShapeChangingRemoveDims(self):
        chain = tfb.Chain([ShapeChanging(3, 0)])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [ShapeChanging(3, 0),
             tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
        self.assertEqual(3, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)

        chain = tfb.Chain(
            [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
             ShapeChanging(3, 0)])
        self.assertEqual(4, chain.forward_min_event_ndims)
        self.assertEqual(1, chain.inverse_min_event_ndims)

        chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)])
        self.assertEqual(6, chain.forward_min_event_ndims)
        self.assertEqual(0, chain.inverse_min_event_ndims)
Example #25
0
 def testCopyExtraArgs(self):
   # Note: we cannot easily test all bijectors since each requires
   # different initialization arguments. We therefore spot test a few.
   sigmoid = tfb.Sigmoid(low=-1., high=2., validate_args=True)
   self.assertEqual(sigmoid.parameters, sigmoid.copy().parameters)
   chain = tfb.Chain(
       [
           tfb.Softplus(hinge_softness=[1., 2.], validate_args=True),
           tfb.MatrixInverseTriL(validate_args=True)
       ], validate_args=True)
   self.assertEqual(chain.parameters, chain.copy().parameters)
Example #26
0
 def test_slice_single_param_bijector_composition(self):
   sliced = slicing._slice_single_param(
       tfb.JointMap({'a': tfb.Chain([
           tfb.Invert(tfb.Scale(tf.ones([4, 3, 1])))
       ])}),
       param_event_ndims={'a': 1},
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual(
       list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape),
       sliced.experimental_batch_shape_tensor(x_event_ndims={'a': 1}))
Example #27
0
 def testBijectorIdentity(self):
   chain = tfb.Chain()
   self.assertStartsWith(chain.name, "identity")
   x = np.asarray([[[1., 2.],
                    [2., 3.]]])
   self.assertAllClose(x, self.evaluate(chain.forward(x)))
   self.assertAllClose(x, self.evaluate(chain.inverse(x)))
   self.assertAllClose(
       0., self.evaluate(chain.inverse_log_det_jacobian(x, event_ndims=1)))
   self.assertAllClose(
       0., self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))
Example #28
0
  def testMinEventNdimsChain(self):
    chain = tfb.Chain([tfb.Exp(), tfb.Exp(), tfb.Exp()])
    self.assertEqual(0, chain.forward_min_event_ndims)
    self.assertEqual(0, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.Affine(), tfb.Affine(), tfb.Affine()])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.Exp(), tfb.Affine()])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.Affine(), tfb.Exp()])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.Affine(), tfb.Exp(), tfb.Softplus(), tfb.Affine()])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)
 def __init__(self, loc, chol_precision_tril, name=None):
   super(MVNCholPrecisionTriL, self).__init__(
       distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc),
                                               scale=tf.ones_like(loc)),
                                    reinterpreted_batch_ndims=1),
       bijector=tfb.Chain([
           tfb.Shift(shift=loc),
           tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril,
                                          adjoint=True)),
       ]),
       name=name)
Example #30
0
 def testEventNdimsIsOptional(self):
   scale_diag = np.array([1., 2., 3.], dtype=np.float32)
   chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()])
   x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
   y = [1., 4., 9.]
   self.assertAllClose(
       np.log(6, dtype=np.float32) + np.sum(x),
       self.evaluate(chain.forward_log_det_jacobian(x)))
   self.assertAllClose(
       -np.log(6, dtype=np.float32) - np.sum(x),
       self.evaluate(chain.inverse_log_det_jacobian(y)))