def testRaisesBadBijectors(self): with self.assertRaisesRegexp(NotImplementedError, 'Only scalar and vector event-shape'): tfb.Blockwise(bijectors=[tfb.Reshape(event_shape_out=[1, 1])]) with self.assertRaisesRegexp(NotImplementedError, 'Only scalar and vector event-shape'): tfb.Blockwise(bijectors=[ tfb.Reshape(event_shape_out=[1], event_shape_in=[]) ])
def testRaisesBadBlocksDynamic(self): if tf.executing_eagerly(): return with self.assertRaises(tf.errors.InvalidArgumentError): block_sizes = tf1.placeholder_with_default([1, 2], shape=None) blockwise = tfb.Blockwise(bijectors=[tfb.Exp()], block_sizes=block_sizes, validate_args=True) self.evaluate(blockwise.block_sizes) with self.assertRaises(tf.errors.InvalidArgumentError): block_sizes = tf1.placeholder_with_default([[1]], shape=None) blockwise = tfb.Blockwise(bijectors=[tfb.Exp()], block_sizes=block_sizes, validate_args=True) self.evaluate(blockwise.block_sizes)
def testRaisesBadBlocks(self): with self.assertRaisesRegexp( ValueError, r'`block_sizes` must be `None`, or a vector of the same length as ' r'`bijectors`. Got a `Tensor` with shape \(2L?,\) and `bijectors` of ' r'length 1'): tfb.Blockwise(bijectors=[tfb.Exp()], block_sizes=[1, 2])
def testCompositeTensor(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Scale(scale=2.) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertIsInstance(blockwise, tf.__internal__.CompositeTensor) # Bijector may be flattened into `Tensor` components and rebuilt. flat = tf.nest.flatten(blockwise, expand_composites=True) unflat = tf.nest.pack_sequence_as(blockwise, flat, expand_composites=True) self.assertIsInstance(unflat, tfb.Blockwise) # Bijector may be input to a `tf.function`-decorated callable. @tf.function def call_forward(bij, x): return bij.forward(x) x = tf.ones([2, 3], dtype=tf.float32) self.assertAllClose(call_forward(unflat, x), blockwise.forward(x)) # Type spec can be encoded/decoded. enc = tf.__internal__.saved_model.encode_structure( blockwise._type_spec) dec = tf.__internal__.saved_model.decode_proto(enc) self.assertEqual(blockwise._type_spec, dec)
def testName(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Affine(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) self.assertStartsWith(blockwise.name, 'blockwise_of_exp_and_softplus_and_affine')
def testNonCompositeTensor(self): exp = tfb.Exp() scale = test_util.NonCompositeTensorScale(scale=tf.constant(3.)) blockwise = tfb.Blockwise(bijectors=[exp, scale]) self.assertNotIsInstance(blockwise, tf.__internal__.CompositeTensor) self.assertAllClose( blockwise.forward([1., 1.]), tf.convert_to_tensor([exp.forward(1.), scale.forward(1.)]))
def testBijectiveAndFinite(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Affine(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) x = tf1.placeholder_with_default(x, shape=x.shape) # Identity to break the caching. blockwise_y = tf.identity(blockwise.forward(x)) bijector_test_util.assert_bijective_and_finite( blockwise, x=self.evaluate(x), y=self.evaluate(blockwise_y), eval_func=self.evaluate, event_ndims=1)
def testKwargs(self): zeros = tf.zeros(1) bijectors = [ tfb.Inline( # pylint: disable=g-complex-comprehension forward_fn=mock.Mock(return_value=zeros), inverse_fn=mock.Mock(return_value=zeros), forward_log_det_jacobian_fn=mock.Mock(return_value=zeros), inverse_log_det_jacobian_fn=mock.Mock(return_value=zeros), forward_min_event_ndims=0, name='inner{}'.format(i)) for i in range(2) ] blockwise = tfb.Blockwise(bijectors) x = [1, 2] blockwise.forward(x, inner0={'arg': 1}, inner1={'arg': 2}) blockwise.inverse(x, inner0={'arg': 3}, inner1={'arg': 4}) blockwise.forward_log_det_jacobian(x, event_ndims=1, inner0={'arg': 5}, inner1={'arg': 6}) blockwise.inverse_log_det_jacobian(x, event_ndims=1, inner0={'arg': 7}, inner1={'arg': 8}) bijectors[0]._forward.assert_any_call(mock.ANY, arg=1) bijectors[1]._forward.assert_any_call(mock.ANY, arg=2) bijectors[0]._inverse.assert_any_call(mock.ANY, arg=3) bijectors[1]._inverse.assert_any_call(mock.ANY, arg=4) bijectors[0]._forward_log_det_jacobian.assert_called_with(mock.ANY, arg=5) bijectors[1]._forward_log_det_jacobian.assert_called_with(mock.ANY, arg=6) bijectors[0]._inverse_log_det_jacobian.assert_called_with(mock.ANY, arg=7) bijectors[1]._inverse_log_det_jacobian.assert_called_with(mock.ANY, arg=8)
def testNonCompositeTensor(self): class NonCompositeScale(tfb.Bijector): """Bijector that is not a `CompositeTensor`.""" def __init__(self, scale): parameters = dict(locals()) self.scale = scale super(NonCompositeScale, self).__init__(validate_args=True, forward_min_event_ndims=0., parameters=parameters, name='non_composite_scale') def _forward(self, x): return x * self.scale exp = tfb.Exp() scale = NonCompositeScale(scale=tf.constant(3.)) blockwise = tfb.Blockwise(bijectors=[exp, scale]) self.assertNotIsInstance(blockwise, tf.__internal__.CompositeTensor) self.assertAllClose( blockwise.forward([1., 1.]), tf.convert_to_tensor([exp.forward(1.), scale.forward(1.)]))
def testExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes = tf.convert_to_tensor(value=[2, 1, 3]) block_sizes = tf1.placeholder_with_default( block_sizes, shape=([None] * len(block_sizes.shape) if dynamic_shape else block_sizes.shape)) exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Affine(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=block_sizes, maybe_changes_size=False) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) for s in batch_shape: x = tf.expand_dims(x, 0) x = tf.tile(x, [s] + [1] * (tensorshape_util.rank(x.shape) - 1)) x = tf1.placeholder_with_default( x, shape=None if dynamic_shape else x.shape) # Identity to break the caching. blockwise_y = tf.identity(blockwise.forward(x)) blockwise_fldj = blockwise.forward_log_det_jacobian(x, event_ndims=1) blockwise_x = blockwise.inverse(blockwise_y) blockwise_ildj = blockwise.inverse_log_det_jacobian(blockwise_y, event_ndims=1) if not dynamic_shape: self.assertEqual(blockwise_y.shape, batch_shape + [6]) self.assertEqual(blockwise_fldj.shape, batch_shape + []) self.assertEqual(blockwise_x.shape, batch_shape + [6]) self.assertEqual(blockwise_ildj.shape, batch_shape + []) self.assertAllEqual(self.evaluate(tf.shape(blockwise_y)), batch_shape + [6]) self.assertAllEqual(self.evaluate(tf.shape(blockwise_fldj)), batch_shape + []) self.assertAllEqual(self.evaluate(tf.shape(blockwise_x)), batch_shape + [6]) self.assertAllEqual(self.evaluate(tf.shape(blockwise_ildj)), batch_shape + []) expl_y = tf.concat([ exp.forward(x[..., :2]), sp.forward(x[..., 2:3]), aff.forward(x[..., 3:]), ], axis=-1) expl_fldj = sum([ exp.forward_log_det_jacobian(x[..., :2], event_ndims=1), sp.forward_log_det_jacobian(x[..., 2:3], event_ndims=1), aff.forward_log_det_jacobian(x[..., 3:], event_ndims=1) ]) expl_x = tf.concat([ exp.inverse(expl_y[..., :2]), sp.inverse(expl_y[..., 2:3]), aff.inverse(expl_y[..., 3:]) ], axis=-1) expl_ildj = sum([ exp.inverse_log_det_jacobian(expl_y[..., :2], event_ndims=1), sp.inverse_log_det_jacobian(expl_y[..., 2:3], event_ndims=1), aff.inverse_log_det_jacobian(expl_y[..., 3:], event_ndims=1) ]) self.assertAllClose(self.evaluate(expl_y), self.evaluate(blockwise_y)) self.assertAllClose(self.evaluate(expl_fldj), self.evaluate(blockwise_fldj)) self.assertAllClose(self.evaluate(expl_x), self.evaluate(blockwise_x)) self.assertAllClose(self.evaluate(expl_ildj), self.evaluate(blockwise_ildj))
def testRaisesEmptyBijectors(self): with self.assertRaisesRegexp(ValueError, '`bijectors` must not be empty'): tfb.Blockwise(bijectors=[])
def testNameOneBijector(self): exp = tfb.Exp() blockwise = tfb.Blockwise(bijectors=[exp], block_sizes=[3]) self.assertStartsWith(blockwise.name, 'blockwise_of_exp')
def testImplicitBlocks(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Affine(scale_diag=[2.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1])
def __init__(self, model): """Constructs the adapter. Args: model: An Inference Gym model. Raises: TypeError: If `model` has more than one unique Tensor dtype. """ self._model = model dtypes = set( tf.nest.flatten( tf.nest.map_structure(tf.as_dtype, self._model.dtype))) if len(dtypes) > 1: raise TypeError( 'Model must have only one Tensor dtype, saw: {}'.format( self._model.dtype)) dtype = dtypes.pop() # TODO(siege): Make this work with multi-part default_event_bijector. def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape(event_shape_out=b.inverse_event_shape(s)), ]) reshaped_bijector = tf.nest.map_structure( _make_reshaped_bijector, self._model.default_event_space_bijector, self._model.event_shape) bijector = tfb.Blockwise( bijectors=tf.nest.flatten(reshaped_bijector), block_sizes=tf.nest.flatten( tf.nest.map_structure( lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)), # pylint: disable=g-long-lambda self._model.default_event_space_bijector, self._model.event_shape))) event_sizes = tf.nest.map_structure( lambda b, s: ps.reduce_prod(b.inverse_event_shape(s)), self._model.default_event_space_bijector, self._model.event_shape) event_shape = tf.TensorShape([sum(tf.nest.flatten(event_sizes))]) sample_transformations = collections.OrderedDict() def make_flattened_transform(transform): # We yank this out to avoid capturing the loop variable. return transform._replace( fn=lambda x: transform(self._split_and_reshape_event(x))) for key, transform in self._model.sample_transformations.items(): sample_transformations[key] = make_flattened_transform(transform) super(VectorModel, self).__init__( default_event_space_bijector=bijector, event_shape=event_shape, dtype=dtype, name='vector_' + self._model.name, pretty_name=str(self._model), sample_transformations=sample_transformations, )
def testNameOneBijector(self): exp = tfb.Exp() blockwise = tfb.Blockwise(bijectors=[exp], block_sizes=[3]) self.assertEqual('blockwise_of_exp', blockwise.name)