def testMinEventNdimsWithPartiallyDependentJointMap(self): dependent = tfb.Chain([tfb.Split(2), tfb.Invert(tfb.Split(2))]) wrap_in_list = tfb.Restructure(input_structure=[0, 1], output_structure=[[0, 1]]) dependent_as_chain = tfb.Chain([ tfb.Invert(wrap_in_list), tfb.JointMap([dependent]), wrap_in_list ]) self.assertAllEqualNested(dependent.forward_min_event_ndims, dependent_as_chain.forward_min_event_ndims) self.assertAllEqualNested(dependent.inverse_min_event_ndims, dependent_as_chain.inverse_min_event_ndims) self.assertAllEqualNested(dependent._parts_interact, dependent_as_chain._parts_interact)
def test_nested_transform(self): target_dist = tfd.Normal(loc=0., scale=1.) b1 = tfb.Scale(0.5) b2 = tfb.Exp() chain = tfb.Chain([b2, b1 ]) # applies bijectors right to left (b1 then b2). inner_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_dist.log_prob, num_leapfrog_steps=27, step_size=10), bijector=b1) outer_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=inner_kernel, bijector=b2) chain_kernel = tfp.mcmc.TransformedTransitionKernel( inner_kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_dist.log_prob, num_leapfrog_steps=27, step_size=10), bijector=chain) outer_pkr_one, outer_pkr_two = self.evaluate([ outer_kernel.bootstrap_results(2.), outer_kernel.bootstrap_results(9.), ]) # the outermost kernel only applies the outermost bijector self.assertNear(np.log(2.), outer_pkr_one.transformed_state, err=1e-6) self.assertNear(np.log(9.), outer_pkr_two.transformed_state, err=1e-6) chain_pkr_one, chain_pkr_two = self.evaluate([ chain_kernel.bootstrap_results(2.), chain_kernel.bootstrap_results(9.), ]) # all bijectors are applied to the inner kernel, from innermost to outermost # this behavior is completely analogous to a bijector Chain self.assertNear(chain_pkr_one.transformed_state, outer_pkr_one.inner_results.transformed_state, err=1e-6) self.assertEqual( chain_pkr_one.inner_results.accepted_results, outer_pkr_one.inner_results.inner_results.accepted_results) self.assertNear(chain_pkr_two.transformed_state, outer_pkr_two.inner_results.transformed_state, err=1e-6) self.assertEqual( chain_pkr_two.inner_results.accepted_results, outer_pkr_two.inner_results.inner_results.accepted_results) seed = test_util.test_seed(sampler_type='stateless') outer_results_one, outer_results_two = self.evaluate([ outer_kernel.one_step(2., outer_pkr_one, seed=seed), outer_kernel.one_step(9., outer_pkr_two, seed=seed) ]) chain_results_one, chain_results_two = self.evaluate([ chain_kernel.one_step(2., chain_pkr_one, seed=seed), chain_kernel.one_step(9., chain_pkr_two, seed=seed) ]) self.assertNear(chain_results_one[0], outer_results_one[0], err=1e-6) self.assertNear(chain_results_two[0], outer_results_two[0], err=1e-6)
def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ tfb.Shift(tf.cast(-1.0, dtype=tf.float64)), tfb.Scale(tf.cast(2.0, dtype=tf.float64)), tfb.Sigmoid(), tfb.Scale(tf.cast(2.0, dtype=tf.float64)) ]) x = np.linspace(-3.0, 3.0, 100) y = np.tanh(x) self.assertAllClose(self.evaluate(direct_bj.forward(x)), self.evaluate(indirect_bj.forward(x))) self.assertAllClose(self.evaluate(direct_bj.inverse(y)), self.evaluate(indirect_bj.inverse(y))) self.assertAllClose( self.evaluate(direct_bj.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate( indirect_bj.inverse_log_det_jacobian(y, event_ndims=0))) self.assertAllClose( self.evaluate(direct_bj.forward_log_det_jacobian(x, event_ndims=0)), self.evaluate( indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
def testBijector(self): low = np.array([[-3.], [0.], [5.]]).astype(np.float32) high = 12. bijector = tfb.Sigmoid(low=low, high=high, validate_args=True) equivalent_bijector = tfb.Chain( [tfb.Shift(shift=low), tfb.Scale(scale=high - low), tfb.Sigmoid()]) x = [[[1., 2., -5., -0.3]]] y = self.evaluate(equivalent_bijector.forward(x)) self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y)[..., :1, :]), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.inverse_log_det_jacobian(y, event_ndims=1)), self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=1)), rtol=1e-5) self.assertAllClose( self.evaluate( equivalent_bijector.forward_log_det_jacobian(x, event_ndims=1)), self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
def test_batch_broadcast_vector_to_parts(self): batch_shape = [4, 2] true_split_sizes = [1, 3, 2] base_event_size = sum(true_split_sizes) # Base dist with no batch shape (will require broadcasting). base_dist = tfd.MultivariateNormalDiag( loc=tf.random.normal([base_event_size], seed=test_util.test_seed()), scale_diag=tf.exp(tf.random.normal([base_event_size], seed=test_util.test_seed()))) # Bijector with batch shape in one part. bijector = tfb.Chain([tfb.JointMap([tfb.Identity(), tfb.Identity(), tfb.Shift( tf.ones(batch_shape + [true_split_sizes[-1]]))]), tfb.Split(true_split_sizes, axis=-1)]) split_dist = tfd.TransformedDistribution(base_dist, bijector) self.assertAllEqual(split_dist.batch_shape, batch_shape) # Because one branch of the split has batch shape, TD should feed batches # of base samples *into* the split, so the batch shape propagates to all # branches. xs = split_dist.sample(seed=test_util.test_seed()) self.assertAllEqualNested( tf.nest.map_structure(lambda x: x.shape, xs), [batch_shape + [d] for d in true_split_sizes])
def testNameScopeRefersToInitialScope(self): if tf.executing_eagerly(): self.skipTest('Eager mode.') outer_bijector = tfb.Exp(name='Exponential') self.assertStartsWith(outer_bijector.name, 'Exponential') with tf.name_scope('inside'): inner_bijector = tfb.Exp(name='Exponential') self.assertStartsWith(inner_bijector.name, 'Exponential') self.assertStartsWith(inner_bijector.forward(0., name='x').name, 'inside/Exponential/x') self.assertStartsWith(outer_bijector.forward(0., name='x').name, 'inside/Exponential_CONSTRUCTED_AT_top_level/x') meta_bijector = tfb.Chain([inner_bijector], name='meta_bijector') # Check for spurious `_CONSTRUCTED_AT_`. self.assertStartsWith( meta_bijector.forward(0., name='x').name, 'inside/meta_bijector/x/Exponential/forward') # Outside the scope. self.assertStartsWith(inner_bijector.forward(0., name='x').name, 'Exponential_CONSTRUCTED_AT_inside/x') self.assertStartsWith(outer_bijector.forward(0., name='x').name, 'Exponential/x') # Check that init scope is annotated only for the toplevel bijector. self.assertStartsWith( meta_bijector.forward(0., name='x').name, 'meta_bijector_CONSTRUCTED_AT_inside/x/Exponential/forward')
def _make_reshaped_bijector(b, s): return tfb.Chain([ tfb.Reshape(event_shape_in=s, event_shape_out=[ps.reduce_prod(s)]), b, tfb.Reshape(event_shape_out=b.inverse_event_shape(s)), ])
def testChainIldjWithPlaceholder(self): chain = tfb.Chain((tfb.Exp(), tfb.Exp())) samples = tf.placeholder(dtype=np.float32, shape=[None, 10], name="samples") ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) self.assertTrue(ildj is not None) with self.cached_session(): ildj.eval({samples: np.zeros([2, 10], np.float32)})
def testScalarCongruency(self): with self.test_session(): chain = tfb.Chain((tfb.Exp(), tfb.Softplus())) assert_scalar_congruency(chain, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testMinEventNdimsShapeChangingAddRemoveDims(self): chain = tfb.Chain( [ShapeChanging(2, 1), ShapeChanging(3, 0), ShapeChanging(1, 2)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims)
def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()
def testCompositeTensor(self): exp = tfb.Exp() sp = tfb.Softplus() aff = tfb.Scale(scale=2.) chain = tfb.Chain(bijectors=[exp, sp, aff]) self.assertIsInstance(chain, tf.__internal__.CompositeTensor) # Bijector may be flattened into `Tensor` components and rebuilt. flat = tf.nest.flatten(chain, expand_composites=True) unflat = tf.nest.pack_sequence_as(chain, flat, expand_composites=True) self.assertIsInstance(unflat, tfb.Chain) # Bijector may be input to a `tf.function`-decorated callable. @tf.function def call_forward(bij, x): return bij.forward(x) x = tf.ones([2, 3], dtype=tf.float32) self.assertAllClose(call_forward(unflat, x), chain.forward(x)) # TypeSpec can be encoded/decoded. struct_coder = tf.__internal__.saved_model.StructureCoder() enc = struct_coder.encode_structure(chain._type_spec) dec = struct_coder.decode_proto(enc) self.assertEqual(chain._type_spec, dec)
def testMinEventNdimsShapeChangingRemoveDims(self): chain = tfb.Chain([ShapeChanging(3, 0)]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain([ShapeChanging(3, 0), tfb.Affine()]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain([tfb.Affine(), ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) self.assertEqual(6, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims)
def testNonCompositeTensor(self): class NonCompositeScale(tfb.Bijector): """Bijector that is not a `CompositeTensor`.""" def __init__(self, scale): parameters = dict(locals()) self.scale = scale super(NonCompositeScale, self).__init__( validate_args=True, forward_min_event_ndims=0., parameters=parameters, name="non_composite_scale") def _forward(self, x): return x * self.scale def _inverse(self, y): return y / self.scale exp = tfb.Exp() scale = NonCompositeScale(scale=tf.constant(3.)) chain = tfb.Chain(bijectors=[exp, scale]) self.assertNotIsInstance(chain, tf.__internal__.CompositeTensor) self.assertAllClose(chain.forward([1.]), exp.forward(scale.forward([1.])))
def _build_inference_bijector(parameter): """Return a scaling-and-support bijector for inference. By default, this is just `param.bijector`, which transforms a real-value input to the parameter's support. For scale parameters (heuristically detected as any param with a Softplus support bijector), we also rescale by the prior stddev. This is approximately equivalent to performing inference on a standardized input `observed_time_series/stddev(observed_time_series)`, because: a) rescaling all the scale parameters is equivalent (gives equivalent forecasts, etc) to rescaling the `observed_time_series`. b) the default scale priors in STS components have stddev proportional to `stddev(observed_time_series)`. Args: parameter: `sts.Parameter` named tuple instance. Returns: bijector: a `tfb.Bijector` instance to use in inference. """ if isinstance(parameter.bijector, tfb.Softplus): try: # Use mean + stddev, rather than just stddev, to ensure a reasonable # init if the user passes a crazy custom prior like N(100000, 0.001). prior_scale = tf.abs( parameter.prior.mean()) + parameter.prior.stddev() return tfb.Chain( [tfb.AffineScalar(scale=prior_scale), parameter.bijector]) except NotImplementedError: # Custom prior with no mean and/or stddev. pass return parameter.bijector
def testScalarCongruency(self): chain = tfb.Chain((tfb.Exp(), tfb.Softplus())) bijector_test_util.assert_scalar_congruency(chain, lower_x=1e-3, upper_x=1.5, rtol=0.05, eval_func=self.evaluate)
def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ tfb.AffineScalar(shift=tf.to_double(-1.0), scale=tf.to_double(2.0)), tfb.Sigmoid(), tfb.AffineScalar(scale=tf.to_double(2.0)) ]) x = np.linspace(-3.0, 3.0, 100) y = np.tanh(x) self.assertAllClose(self.evaluate(direct_bj.forward(x)), self.evaluate(indirect_bj.forward(x))) self.assertAllClose(self.evaluate(direct_bj.inverse(y)), self.evaluate(indirect_bj.inverse(y))) self.assertAllClose( self.evaluate(direct_bj.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate( indirect_bj.inverse_log_det_jacobian(y, event_ndims=0))) self.assertAllClose( self.evaluate(direct_bj.forward_log_det_jacobian(x, event_ndims=0)), self.evaluate( indirect_bj.forward_log_det_jacobian(x, event_ndims=0)))
def testChainIldjWithPlaceholder(self): chain = tfb.Chain((tfb.Exp(), tfb.Exp())) samples = tf1.placeholder_with_default( np.zeros([2, 10], np.float32), shape=None) ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) self.assertIsNotNone(ildj) self.evaluate(ildj)
def testNonCompositeTensor(self): exp = tfb.Exp() scale = test_util.NonCompositeTensorScale(scale=tf.constant(3.)) chain = tfb.Chain(bijectors=[exp, scale]) self.assertNotIsInstance(chain, tf.__internal__.CompositeTensor) self.assertAllClose(chain.forward([1.]), exp.forward(scale.forward([1.])))
def testInvalidChainNdimsRaisesError(self): with self.assertRaisesRegexp( ValueError, "Differences between `event_ndims` and `min_event_ndims must be equal" ): tfb.Chain( [ShapeChanging([1, 1], [1, 1]), ShapeChanging([1, 1], [2, 1])])
def test_composition_str_and_repr_match_expected_dynamic_shape(self): bij = tfb.Chain([ tfb.Exp(), tfb.Shift(self._tensor([1., 2.])), tfb.SoftmaxCentered() ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 ' 'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'), '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered', '>]>' ], repr(bij)) bij = tfb.Chain([ tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.])) }), tfb.Restructure({ 'a': 0, 'b': 1 }, [0, 1]), tfb.Split(2), tfb.Invert(tfb.SoftmaxCentered()), ]) self.assertContainsInOrder([ 'tfp.bijectors.Chain(', ('forward_min_event_ndims=1, ' 'inverse_min_event_ndims={a: 1, b: 1}, ' 'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), ' 'Restructure, Split, Invert(SoftmaxCentered)])') ], str(bij)) self.assertContainsInOrder([ '<tfp.bijectors.Chain ', ('batch_shape=? forward_min_event_ndims=1 ' "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 " "dtype_y={'a': ?, 'b': float32} " "bijectors=[<tfp.bijectors.JointMap "), '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split', '>, <tfp.bijectors.Invert', '>]>' ], repr(bij))
def testNestedDtype(self): chain = tfb.Chain([ tfb.Identity(), tfb.Scale(tf.constant(2., tf.float64)), tfb.Identity() ]) self.assertAllClose(tf.constant([2, 4, 6], tf.float64), self.evaluate(chain.forward([1, 2, 3])))
def default_bijector(cls, dtype: Any = None, **kwargs) -> tfb.Bijector: """ Affine bijection between $[[0, 1], [0, 1]] <--> [[-2.5, 2.5], [-1.0, 2.0]]$ """ if dtype is None: dtype = default_float() scale = tfb.Scale(tf.convert_to_tensor([5.0, 3.0], dtype=dtype)) shift = tfb.Shift(tf.convert_to_tensor([-0.5, -1 / 3], dtype=dtype)) return tfb.Chain([scale, shift])
def testMinEventNdimsShapeChangingRemoveDims(self): chain = tfb.Chain([ShapeChanging(3, 0)]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain( [ShapeChanging(3, 0), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain( [tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) self.assertEqual(6, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims)
def testCopyExtraArgs(self): # Note: we cannot easily test all bijectors since each requires # different initialization arguments. We therefore spot test a few. sigmoid = tfb.Sigmoid(low=-1., high=2., validate_args=True) self.assertEqual(sigmoid.parameters, sigmoid.copy().parameters) chain = tfb.Chain( [ tfb.Softplus(hinge_softness=[1., 2.], validate_args=True), tfb.MatrixInverseTriL(validate_args=True) ], validate_args=True) self.assertEqual(chain.parameters, chain.copy().parameters)
def test_slice_single_param_bijector_composition(self): sliced = slicing._slice_single_param( tfb.JointMap({'a': tfb.Chain([ tfb.Invert(tfb.Scale(tf.ones([4, 3, 1]))) ])}), param_event_ndims={'a': 1}, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual( list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape), sliced.experimental_batch_shape_tensor(x_event_ndims={'a': 1}))
def testBijectorIdentity(self): chain = tfb.Chain() self.assertStartsWith(chain.name, "identity") x = np.asarray([[[1., 2.], [2., 3.]]]) self.assertAllClose(x, self.evaluate(chain.forward(x))) self.assertAllClose(x, self.evaluate(chain.inverse(x))) self.assertAllClose( 0., self.evaluate(chain.inverse_log_det_jacobian(x, event_ndims=1))) self.assertAllClose( 0., self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))
def testMinEventNdimsChain(self): chain = tfb.Chain([tfb.Exp(), tfb.Exp(), tfb.Exp()]) self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) chain = tfb.Chain([tfb.Affine(), tfb.Affine(), tfb.Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([tfb.Exp(), tfb.Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([tfb.Affine(), tfb.Exp()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) chain = tfb.Chain([tfb.Affine(), tfb.Exp(), tfb.Softplus(), tfb.Affine()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims)
def __init__(self, loc, chol_precision_tril, name=None): super(MVNCholPrecisionTriL, self).__init__( distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc), scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ tfb.Shift(shift=loc), tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril, adjoint=True)), ]), name=name)
def testEventNdimsIsOptional(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 9.] self.assertAllClose( np.log(6, dtype=np.float32) + np.sum(x), self.evaluate(chain.forward_log_det_jacobian(x))) self.assertAllClose( -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y)))