def testPowerNegativeInputRaisesError(self): with self.assertRaisesOpError('must be non-negative'): b = tfb.Power(power=2.5, validate_args=True) self.evaluate(b.inverse(-1.)) b = tfb.Power(power=3., validate_args=True) self.evaluate(b.inverse(-1.))
def testScalarCongruency(self): bijector = tfb.Power(power=2.6, validate_args=True) bijector_test_util.assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, eval_func=self.evaluate, n=1e5, rtol=0.08)
def testSpecialCases(self, power, cls): b = tfb.Power(power=power) b_other = cls() x = [[[1., 5], [2, 1]], [[np.sqrt(2.), 3], [np.sqrt(8.), 1]]] y, y_other = self.evaluate((b.forward(x), b_other.forward(x))) self.assertAllClose(y, y_other) x, x_other = self.evaluate((b.inverse(y), b_other.inverse(y))) self.assertAllClose(x, x_other) ildj, ildj_other = self.evaluate( (b.inverse_log_det_jacobian(y, event_ndims=0), b_other.inverse_log_det_jacobian(y, event_ndims=0))) self.assertAllClose(ildj, ildj_other)
def testBijectorScalar(self): power = np.array([2.6, 0.3, -1.1]) bijector = tfb.Power(power=power, validate_args=True) self.assertStartsWith(bijector.name, 'power') x = np.array([[[1., 5., 3.], [2., 1., 7.]], [[np.sqrt(2.), 3., 1.], [np.sqrt(8.), 1., 0.4]]]) y = np.power(x, power) ildj = -np.log(np.abs(power)) - (power - 1.) * np.log(x) self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) self.assertAllClose( ildj, self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=0)), atol=0., rtol=1e-6) self.assertAllClose( self.evaluate( -bijector.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=0)), atol=0., rtol=1e-7)
def testPowerOddInteger(self): power = np.array([3., -5., 5., -7.]).reshape((4, 1)) bijector = tfb.Power(power=power, validate_args=True) self.assertStartsWith(bijector.name, 'power') x = np.linspace(-10., 10., 20) y = np.power(x, power) ildj = -np.log(np.abs(power)) - (power - 1.) * np.log(np.abs(x)) self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x * np.ones((4, 1)), self.evaluate(bijector.inverse(y))) self.assertAllClose( ildj, self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=0)), atol=0., rtol=1e-6) self.assertAllClose( self.evaluate( -bijector.inverse_log_det_jacobian(y, event_ndims=0)), self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=0)), atol=0., rtol=1e-7)
def testZeroPowerRaisesError(self): with self.assertRaisesRegexp(Exception, 'must be non-zero'): b = tfb.Power(power=0., validate_args=True) b.forward(1.)
class BijectorBatchShapesTest(test_util.TestCase): @parameterized.named_parameters( ('exp', tfb.Exp, None), ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('scale_matvec', lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('reshape', lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('restructure_with_ragged_event_ndims', lambda: tfb.Restructure(input_structure=[0, 1], # pylint: disable=g-long-lambda output_structure={'a': 0, 'b': 1}), [0, 1])) def test_batch_shape_matches_output_shapes(self, bijector_fn, override_x_event_ndims=None): bijector = bijector_fn() if override_x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims y_event_ndims = bijector.inverse_min_event_ndims else: x_event_ndims = override_x_event_ndims y_event_ndims = bijector.forward_event_ndims(x_event_ndims) # All ways of calculating the batch shape should yield the same result. batch_shape_x = bijector.experimental_batch_shape( x_event_ndims=x_event_ndims) batch_shape_y = bijector.experimental_batch_shape( y_event_ndims=y_event_ndims) self.assertEqual(batch_shape_x, batch_shape_y) batch_shape_tensor_x = bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) batch_shape_tensor_y = bijector.experimental_batch_shape_tensor( y_event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y) self.assertAllEqual(batch_shape_tensor_x, batch_shape_x) # Check that we're robust to integer type. batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor( x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims)) batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor( y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims)) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64) self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x) # Pushing a value through the bijector should return a Tensor(s) with # the expected batch shape... xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims) ys = bijector.forward(xs) for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)): part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd] self.assertAllEqual(batch_shape_y, ps.broadcast_shape(batch_shape_y, part_batch_shape)) # ... which should also be the shape of the fldj. fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims) self.assertAllEqual(batch_shape_y, ps.shape(fldj)) # Also check the inverse case. xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys)) for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)): part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd] self.assertAllEqual(batch_shape_x, ps.broadcast_shape(batch_shape_x, part_batch_shape)) ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims) self.assertAllEqual(batch_shape_x, ps.shape(ildj)) @parameterized.named_parameters( ('scale', lambda: tfb.Scale([3.14159])), ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159])))) def test_ndims_specification(self, bijector_fn): bijector = bijector_fn() # If no `event_ndims` is passed, should assume min_event_ndims. self.assertAllEqual(bijector.experimental_batch_shape(), [1]) self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1]) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0) with self.assertRaisesRegex( ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'): bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0) @parameterized.named_parameters( ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None), ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None), ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None), ('chain', lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]), # pylint: disable=g-long-lambda tfb.Invert(tfb.Split(2))]), None), ('jointmap_01', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]), ('jointmap_11', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]), ('jointmap_20', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]), ('jointmap_22', lambda: tfb.JointMap( # pylint: disable=g-long-lambda [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]), ('nested_jointmap', lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]), # pylint: disable=g-long-lambda 'b': tfb.Exp()}), tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]), [{'a': 0, 'b': 0}, [2, 2]])) def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None): bijector = bijector_fn() if x_event_ndims is None: x_event_ndims = bijector.forward_min_event_ndims batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims) param_batch_shapes = batch_shape_lib.batch_shape_parts( bijector, bijector_x_event_ndims=x_event_ndims) new_batch_shape = [4, 2, 1, 1, 1] broadcast_bijector = bijector._broadcast_parameters_with_batch_shape( new_batch_shape, x_event_ndims) broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor( x_event_ndims=x_event_ndims) self.assertAllEqual(broadcast_batch_shape, ps.broadcast_shape(batch_shape, new_batch_shape)) # Check that all params have the expected batch shape. broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts( broadcast_bijector, bijector_x_event_ndims=x_event_ndims) def _maybe_broadcast_param_batch_shape(p, s): if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims(): return s # Can't broadcast a bijector that doesn't itself have params. return ps.broadcast_shape(s, new_batch_shape) expected_broadcast_param_batch_shapes = tf.nest.map_structure( _maybe_broadcast_param_batch_shape, {param: getattr(bijector, param) for param in param_batch_shapes}, param_batch_shapes) self.assertAllEqualNested(broadcast_param_batch_shapes, expected_broadcast_param_batch_shapes)