Exemplo n.º 1
0
    def testPowerNegativeInputRaisesError(self):
        with self.assertRaisesOpError('must be non-negative'):
            b = tfb.Power(power=2.5, validate_args=True)
            self.evaluate(b.inverse(-1.))

        b = tfb.Power(power=3., validate_args=True)
        self.evaluate(b.inverse(-1.))
Exemplo n.º 2
0
 def testScalarCongruency(self):
     bijector = tfb.Power(power=2.6, validate_args=True)
     bijector_test_util.assert_scalar_congruency(bijector,
                                                 lower_x=1e-3,
                                                 upper_x=1.5,
                                                 eval_func=self.evaluate,
                                                 n=1e5,
                                                 rtol=0.08)
Exemplo n.º 3
0
    def testSpecialCases(self, power, cls):
        b = tfb.Power(power=power)
        b_other = cls()
        x = [[[1., 5], [2, 1]], [[np.sqrt(2.), 3], [np.sqrt(8.), 1]]]
        y, y_other = self.evaluate((b.forward(x), b_other.forward(x)))
        self.assertAllClose(y, y_other)

        x, x_other = self.evaluate((b.inverse(y), b_other.inverse(y)))
        self.assertAllClose(x, x_other)

        ildj, ildj_other = self.evaluate(
            (b.inverse_log_det_jacobian(y, event_ndims=0),
             b_other.inverse_log_det_jacobian(y, event_ndims=0)))
        self.assertAllClose(ildj, ildj_other)
Exemplo n.º 4
0
 def testBijectorScalar(self):
     power = np.array([2.6, 0.3, -1.1])
     bijector = tfb.Power(power=power, validate_args=True)
     self.assertStartsWith(bijector.name, 'power')
     x = np.array([[[1., 5., 3.], [2., 1., 7.]],
                   [[np.sqrt(2.), 3., 1.], [np.sqrt(8.), 1., 0.4]]])
     y = np.power(x, power)
     ildj = -np.log(np.abs(power)) - (power - 1.) * np.log(x)
     self.assertAllClose(y, self.evaluate(bijector.forward(x)))
     self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
     self.assertAllClose(
         ildj,
         self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=0)),
         atol=0.,
         rtol=1e-6)
     self.assertAllClose(
         self.evaluate(
             -bijector.inverse_log_det_jacobian(y, event_ndims=0)),
         self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=0)),
         atol=0.,
         rtol=1e-7)
Exemplo n.º 5
0
 def testPowerOddInteger(self):
     power = np.array([3., -5., 5., -7.]).reshape((4, 1))
     bijector = tfb.Power(power=power, validate_args=True)
     self.assertStartsWith(bijector.name, 'power')
     x = np.linspace(-10., 10., 20)
     y = np.power(x, power)
     ildj = -np.log(np.abs(power)) - (power - 1.) * np.log(np.abs(x))
     self.assertAllClose(y, self.evaluate(bijector.forward(x)))
     self.assertAllClose(x * np.ones((4, 1)),
                         self.evaluate(bijector.inverse(y)))
     self.assertAllClose(
         ildj,
         self.evaluate(bijector.inverse_log_det_jacobian(y, event_ndims=0)),
         atol=0.,
         rtol=1e-6)
     self.assertAllClose(
         self.evaluate(
             -bijector.inverse_log_det_jacobian(y, event_ndims=0)),
         self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=0)),
         atol=0.,
         rtol=1e-7)
Exemplo n.º 6
0
 def testZeroPowerRaisesError(self):
     with self.assertRaisesRegexp(Exception, 'must be non-zero'):
         b = tfb.Power(power=0., validate_args=True)
         b.forward(1.)
Exemplo n.º 7
0
class BijectorBatchShapesTest(test_util.TestCase):

  @parameterized.named_parameters(
      ('exp', tfb.Exp, None),
      ('scale',
       lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid',
       lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None),
      ('scale_matvec',
       lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None),
      ('invert',
       lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None),
      ('reshape',
       lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('restructure_with_ragged_event_ndims',
       lambda: tfb.Restructure(input_structure=[0, 1],  # pylint: disable=g-long-lambda
                               output_structure={'a': 0, 'b': 1}),
       [0, 1]))
  def test_batch_shape_matches_output_shapes(self,
                                             bijector_fn,
                                             override_x_event_ndims=None):
    bijector = bijector_fn()
    if override_x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
      y_event_ndims = bijector.inverse_min_event_ndims
    else:
      x_event_ndims = override_x_event_ndims
      y_event_ndims = bijector.forward_event_ndims(x_event_ndims)

    # All ways of calculating the batch shape should yield the same result.
    batch_shape_x = bijector.experimental_batch_shape(
        x_event_ndims=x_event_ndims)
    batch_shape_y = bijector.experimental_batch_shape(
        y_event_ndims=y_event_ndims)
    self.assertEqual(batch_shape_x, batch_shape_y)

    batch_shape_tensor_x = bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    batch_shape_tensor_y = bijector.experimental_batch_shape_tensor(
        y_event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_x)

    # Check that we're robust to integer type.
    batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor(
        x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims))
    batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor(
        y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims))
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64)
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x)

    # Pushing a value through the bijector should return a Tensor(s) with
    # the expected batch shape...
    xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims)
    ys = bijector.forward(xs)
    for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)):
      part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd]
      self.assertAllEqual(batch_shape_y,
                          ps.broadcast_shape(batch_shape_y, part_batch_shape))

    # ... which should also be the shape of the fldj.
    fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims)
    self.assertAllEqual(batch_shape_y, ps.shape(fldj))

    # Also check the inverse case.
    xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys))
    for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)):
      part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd]
      self.assertAllEqual(batch_shape_x,
                          ps.broadcast_shape(batch_shape_x, part_batch_shape))
    ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_x, ps.shape(ildj))

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale([3.14159])),
      ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159]))))
  def test_ndims_specification(self, bijector_fn):
    bijector = bijector_fn()

    # If no `event_ndims` is passed, should assume min_event_ndims.
    self.assertAllEqual(bijector.experimental_batch_shape(), [1])
    self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1])

    with self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0)

    with  self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0)

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])),
       None),
      ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))),
       None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('nested_jointmap',
       lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]),  # pylint: disable=g-long-lambda
                                           'b': tfb.Exp()}),
                             tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]),
       [{'a': 0, 'b': 0}, [2, 2]]))
  def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None):
    bijector = bijector_fn()
    if x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
    batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims)
    param_batch_shapes = batch_shape_lib.batch_shape_parts(
        bijector, bijector_x_event_ndims=x_event_ndims)

    new_batch_shape = [4, 2, 1, 1, 1]
    broadcast_bijector = bijector._broadcast_parameters_with_batch_shape(
        new_batch_shape, x_event_ndims)
    broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    self.assertAllEqual(broadcast_batch_shape,
                        ps.broadcast_shape(batch_shape, new_batch_shape))

    # Check that all params have the expected batch shape.
    broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts(
        broadcast_bijector, bijector_x_event_ndims=x_event_ndims)

    def _maybe_broadcast_param_batch_shape(p, s):
      if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims():
        return s  # Can't broadcast a bijector that doesn't itself have params.
      return ps.broadcast_shape(s, new_batch_shape)
    expected_broadcast_param_batch_shapes = tf.nest.map_structure(
        _maybe_broadcast_param_batch_shape,
        {param: getattr(bijector, param) for param in param_batch_shapes},
        param_batch_shapes)
    self.assertAllEqualNested(broadcast_param_batch_shapes,
                              expected_broadcast_param_batch_shapes)