def testRaisesBadBijectors(self):
        with self.assertRaisesRegexp(NotImplementedError,
                                     'Only scalar and vector event-shape'):
            tfb.Blockwise(bijectors=[tfb.Reshape(event_shape_out=[1, 1])])

        with self.assertRaisesRegexp(NotImplementedError,
                                     'Only scalar and vector event-shape'):
            tfb.Blockwise(bijectors=[
                tfb.Reshape(event_shape_out=[1], event_shape_in=[])
            ])
    def testRaisesBadBlocksDynamic(self):
        if tf.executing_eagerly(): return
        with self.assertRaises(tf.errors.InvalidArgumentError):
            block_sizes = tf1.placeholder_with_default([1, 2], shape=None)
            blockwise = tfb.Blockwise(bijectors=[tfb.Exp()],
                                      block_sizes=block_sizes,
                                      validate_args=True)
            self.evaluate(blockwise.block_sizes)

        with self.assertRaises(tf.errors.InvalidArgumentError):
            block_sizes = tf1.placeholder_with_default([[1]], shape=None)
            blockwise = tfb.Blockwise(bijectors=[tfb.Exp()],
                                      block_sizes=block_sizes,
                                      validate_args=True)
            self.evaluate(blockwise.block_sizes)
 def testRaisesBadBlocks(self):
     with self.assertRaisesRegexp(
             ValueError,
             r'`block_sizes` must be `None`, or a vector of the same length as '
             r'`bijectors`. Got a `Tensor` with shape \(2L?,\) and `bijectors` of '
             r'length 1'):
         tfb.Blockwise(bijectors=[tfb.Exp()], block_sizes=[1, 2])
    def testCompositeTensor(self):
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.Scale(scale=2.)
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff])
        self.assertIsInstance(blockwise, tf.__internal__.CompositeTensor)

        # Bijector may be flattened into `Tensor` components and rebuilt.
        flat = tf.nest.flatten(blockwise, expand_composites=True)
        unflat = tf.nest.pack_sequence_as(blockwise,
                                          flat,
                                          expand_composites=True)
        self.assertIsInstance(unflat, tfb.Blockwise)

        # Bijector may be input to a `tf.function`-decorated callable.
        @tf.function
        def call_forward(bij, x):
            return bij.forward(x)

        x = tf.ones([2, 3], dtype=tf.float32)
        self.assertAllClose(call_forward(unflat, x), blockwise.forward(x))

        # Type spec can be encoded/decoded.
        enc = tf.__internal__.saved_model.encode_structure(
            blockwise._type_spec)
        dec = tf.__internal__.saved_model.decode_proto(enc)
        self.assertEqual(blockwise._type_spec, dec)
 def testName(self):
     exp = tfb.Exp()
     sp = tfb.Softplus()
     aff = tfb.Affine(scale_diag=[2., 3., 4.])
     blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                               block_sizes=[2, 1, 3])
     self.assertStartsWith(blockwise.name,
                           'blockwise_of_exp_and_softplus_and_affine')
 def testNonCompositeTensor(self):
     exp = tfb.Exp()
     scale = test_util.NonCompositeTensorScale(scale=tf.constant(3.))
     blockwise = tfb.Blockwise(bijectors=[exp, scale])
     self.assertNotIsInstance(blockwise, tf.__internal__.CompositeTensor)
     self.assertAllClose(
         blockwise.forward([1., 1.]),
         tf.convert_to_tensor([exp.forward(1.),
                               scale.forward(1.)]))
    def testBijectiveAndFinite(self):
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.Affine(scale_diag=[2., 3., 4.])
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                                  block_sizes=[2, 1, 3])

        x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32)
        x = tf1.placeholder_with_default(x, shape=x.shape)
        # Identity to break the caching.
        blockwise_y = tf.identity(blockwise.forward(x))

        bijector_test_util.assert_bijective_and_finite(
            blockwise,
            x=self.evaluate(x),
            y=self.evaluate(blockwise_y),
            eval_func=self.evaluate,
            event_ndims=1)
    def testKwargs(self):
        zeros = tf.zeros(1)

        bijectors = [
            tfb.Inline(  # pylint: disable=g-complex-comprehension
                forward_fn=mock.Mock(return_value=zeros),
                inverse_fn=mock.Mock(return_value=zeros),
                forward_log_det_jacobian_fn=mock.Mock(return_value=zeros),
                inverse_log_det_jacobian_fn=mock.Mock(return_value=zeros),
                forward_min_event_ndims=0,
                name='inner{}'.format(i)) for i in range(2)
        ]

        blockwise = tfb.Blockwise(bijectors)

        x = [1, 2]
        blockwise.forward(x, inner0={'arg': 1}, inner1={'arg': 2})
        blockwise.inverse(x, inner0={'arg': 3}, inner1={'arg': 4})
        blockwise.forward_log_det_jacobian(x,
                                           event_ndims=1,
                                           inner0={'arg': 5},
                                           inner1={'arg': 6})
        blockwise.inverse_log_det_jacobian(x,
                                           event_ndims=1,
                                           inner0={'arg': 7},
                                           inner1={'arg': 8})

        bijectors[0]._forward.assert_any_call(mock.ANY, arg=1)
        bijectors[1]._forward.assert_any_call(mock.ANY, arg=2)
        bijectors[0]._inverse.assert_any_call(mock.ANY, arg=3)
        bijectors[1]._inverse.assert_any_call(mock.ANY, arg=4)
        bijectors[0]._forward_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=5)
        bijectors[1]._forward_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=6)
        bijectors[0]._inverse_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=7)
        bijectors[1]._inverse_log_det_jacobian.assert_called_with(mock.ANY,
                                                                  arg=8)
Exemple #9
0
    def testNonCompositeTensor(self):
        class NonCompositeScale(tfb.Bijector):
            """Bijector that is not a `CompositeTensor`."""
            def __init__(self, scale):
                parameters = dict(locals())
                self.scale = scale
                super(NonCompositeScale,
                      self).__init__(validate_args=True,
                                     forward_min_event_ndims=0.,
                                     parameters=parameters,
                                     name='non_composite_scale')

            def _forward(self, x):
                return x * self.scale

        exp = tfb.Exp()
        scale = NonCompositeScale(scale=tf.constant(3.))
        blockwise = tfb.Blockwise(bijectors=[exp, scale])
        self.assertNotIsInstance(blockwise, tf.__internal__.CompositeTensor)
        self.assertAllClose(
            blockwise.forward([1., 1.]),
            tf.convert_to_tensor([exp.forward(1.),
                                  scale.forward(1.)]))
    def testExplicitBlocks(self, dynamic_shape, batch_shape):
        block_sizes = tf.convert_to_tensor(value=[2, 1, 3])
        block_sizes = tf1.placeholder_with_default(
            block_sizes,
            shape=([None] * len(block_sizes.shape)
                   if dynamic_shape else block_sizes.shape))
        exp = tfb.Exp()
        sp = tfb.Softplus()
        aff = tfb.Affine(scale_diag=[2., 3., 4.])
        blockwise = tfb.Blockwise(bijectors=[exp, sp, aff],
                                  block_sizes=block_sizes,
                                  maybe_changes_size=False)

        x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32)
        for s in batch_shape:
            x = tf.expand_dims(x, 0)
            x = tf.tile(x, [s] + [1] * (tensorshape_util.rank(x.shape) - 1))
        x = tf1.placeholder_with_default(
            x, shape=None if dynamic_shape else x.shape)

        # Identity to break the caching.
        blockwise_y = tf.identity(blockwise.forward(x))
        blockwise_fldj = blockwise.forward_log_det_jacobian(x, event_ndims=1)
        blockwise_x = blockwise.inverse(blockwise_y)
        blockwise_ildj = blockwise.inverse_log_det_jacobian(blockwise_y,
                                                            event_ndims=1)

        if not dynamic_shape:
            self.assertEqual(blockwise_y.shape, batch_shape + [6])
            self.assertEqual(blockwise_fldj.shape, batch_shape + [])
            self.assertEqual(blockwise_x.shape, batch_shape + [6])
            self.assertEqual(blockwise_ildj.shape, batch_shape + [])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_y)),
                            batch_shape + [6])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_fldj)),
                            batch_shape + [])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_x)),
                            batch_shape + [6])
        self.assertAllEqual(self.evaluate(tf.shape(blockwise_ildj)),
                            batch_shape + [])

        expl_y = tf.concat([
            exp.forward(x[..., :2]),
            sp.forward(x[..., 2:3]),
            aff.forward(x[..., 3:]),
        ],
                           axis=-1)
        expl_fldj = sum([
            exp.forward_log_det_jacobian(x[..., :2], event_ndims=1),
            sp.forward_log_det_jacobian(x[..., 2:3], event_ndims=1),
            aff.forward_log_det_jacobian(x[..., 3:], event_ndims=1)
        ])
        expl_x = tf.concat([
            exp.inverse(expl_y[..., :2]),
            sp.inverse(expl_y[..., 2:3]),
            aff.inverse(expl_y[..., 3:])
        ],
                           axis=-1)
        expl_ildj = sum([
            exp.inverse_log_det_jacobian(expl_y[..., :2], event_ndims=1),
            sp.inverse_log_det_jacobian(expl_y[..., 2:3], event_ndims=1),
            aff.inverse_log_det_jacobian(expl_y[..., 3:], event_ndims=1)
        ])

        self.assertAllClose(self.evaluate(expl_y), self.evaluate(blockwise_y))
        self.assertAllClose(self.evaluate(expl_fldj),
                            self.evaluate(blockwise_fldj))
        self.assertAllClose(self.evaluate(expl_x), self.evaluate(blockwise_x))
        self.assertAllClose(self.evaluate(expl_ildj),
                            self.evaluate(blockwise_ildj))
 def testRaisesEmptyBijectors(self):
     with self.assertRaisesRegexp(ValueError,
                                  '`bijectors` must not be empty'):
         tfb.Blockwise(bijectors=[])
 def testNameOneBijector(self):
     exp = tfb.Exp()
     blockwise = tfb.Blockwise(bijectors=[exp], block_sizes=[3])
     self.assertStartsWith(blockwise.name, 'blockwise_of_exp')
 def testImplicitBlocks(self):
     exp = tfb.Exp()
     sp = tfb.Softplus()
     aff = tfb.Affine(scale_diag=[2.])
     blockwise = tfb.Blockwise(bijectors=[exp, sp, aff])
     self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1])
Exemple #14
0
    def __init__(self, model):
        """Constructs the adapter.

    Args:
      model: An Inference Gym model.

    Raises:
      TypeError: If `model` has more than one unique Tensor dtype.
    """
        self._model = model
        dtypes = set(
            tf.nest.flatten(
                tf.nest.map_structure(tf.as_dtype, self._model.dtype)))
        if len(dtypes) > 1:
            raise TypeError(
                'Model must have only one Tensor dtype, saw: {}'.format(
                    self._model.dtype))
        dtype = dtypes.pop()

        # TODO(siege): Make this work with multi-part default_event_bijector.
        def _make_reshaped_bijector(b, s):
            return tfb.Chain([
                tfb.Reshape(event_shape_in=s,
                            event_shape_out=[ps.reduce_prod(s)]),
                b,
                tfb.Reshape(event_shape_out=b.inverse_event_shape(s)),
            ])

        reshaped_bijector = tf.nest.map_structure(
            _make_reshaped_bijector, self._model.default_event_space_bijector,
            self._model.event_shape)

        bijector = tfb.Blockwise(
            bijectors=tf.nest.flatten(reshaped_bijector),
            block_sizes=tf.nest.flatten(
                tf.nest.map_structure(
                    lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)),  # pylint: disable=g-long-lambda
                    self._model.default_event_space_bijector,
                    self._model.event_shape)))

        event_sizes = tf.nest.map_structure(
            lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)),
            self._model.default_event_space_bijector, self._model.event_shape)
        event_shape = tf.TensorShape([sum(tf.nest.flatten(event_sizes))])

        sample_transformations = collections.OrderedDict()

        def make_flattened_transform(transform):
            # We yank this out to avoid capturing the loop variable.
            return transform._replace(
                fn=lambda x: transform(self._split_and_reshape_event(x)))

        for key, transform in self._model.sample_transformations.items():
            sample_transformations[key] = make_flattened_transform(transform)

        super(VectorModel, self).__init__(
            default_event_space_bijector=bijector,
            event_shape=event_shape,
            dtype=dtype,
            name='vector_' + self._model.name,
            pretty_name=str(self._model),
            sample_transformations=sample_transformations,
        )
 def testNameOneBijector(self):
     exp = tfb.Exp()
     blockwise = tfb.Blockwise(bijectors=[exp], block_sizes=[3])
     self.assertEqual('blockwise_of_exp', blockwise.name)