Esempio n. 1
0
    def testPartialStaticPermEventShapes(self):
        if tf.executing_eagerly():
            return  # this test is not interesting in eager.
        perm = tf.convert_to_tensor(value=[
            tf.constant(2),
            tf1.placeholder_with_default(0, []),
            tf1.placeholder_with_default(1, [])
        ])
        self.assertAllEqual([2, None, None],
                            tf.get_static_value(perm, partial=True))
        b = tfb.Transpose(perm)
        self.assertAllEqual([8, 5, None, None],
                            b.forward_event_shape([8, 7, 6, 5]).as_list())
        self.assertAllEqual([8, None, None, 7],
                            b.inverse_event_shape([8, 7, 6, 5]).as_list())

        # Process of elimination should allow us to deduce one non-static perm idx.
        perm = tf.convert_to_tensor(value=[
            tf.constant(2),
            tf1.placeholder_with_default(0, []),
            tf.constant(1)
        ])
        self.assertAllEqual([2, None, 1],
                            tf.get_static_value(perm, partial=True))
        b = tfb.Transpose(perm)
        self.assertAllEqual([8, 5, 7, 6], b.forward_event_shape([8, 7, 6, 5]))
        self.assertAllEqual([8, 6, 5, 7], b.inverse_event_shape([8, 7, 6, 5]))
Esempio n. 2
0
 def testTransformedDist(self):
     d = tfd.Independent(tfd.Normal(tf.zeros([4, 3, 2]), 1), 3)
     dt = tfb.Transpose([1, 0])(d)
     self.assertEqual((4, 3, 2), d.event_shape)
     self.assertEqual((4, 2, 3), dt.event_shape)
     dt = tfb.Invert(tfb.Transpose([1, 0, 2]))(d)
     self.assertEqual((4, 3, 2), d.event_shape)
     self.assertEqual((3, 4, 2), dt.event_shape)
Esempio n. 3
0
 def testInvalidPermException(self):
     msg = '`perm` must be a valid permutation vector.'
     if self.is_static or tf.executing_eagerly():
         with self.assertRaisesRegexp(ValueError, msg):
             bijector = tfb.Transpose(perm=[1, 2], validate_args=True)
     else:
         with self.assertRaisesOpError(msg):
             bijector = tfb.Transpose(perm=tf1.placeholder_with_default(
                 [1, 2], shape=[2]),
                                      validate_args=True)
             self.evaluate(bijector.forward([[0, 1]]))
Esempio n. 4
0
    def testTransposeFromPerm(self):
        perm_ = [2, 0, 1]
        actual_x_ = np.array([
            [[1, 2], [3, 4]],
            [[5, 6], [7, 8]],
        ],
                             dtype=np.float32)
        actual_y_ = np.array([
            [[1, 3], [5, 7]],
            [[2, 4], [6, 8]],
        ],
                             dtype=np.float32)
        if self.is_static:
            actual_x = tf.constant(actual_x_)
            actual_y = tf.constant(actual_y_)
            perm = tf.constant(perm_)
        else:
            actual_x = tf1.placeholder_with_default(actual_x_, shape=None)
            actual_y = tf1.placeholder_with_default(actual_y_, shape=None)
            perm = tf1.placeholder_with_default(perm_, shape=[3])

        bijector = tfb.Transpose(perm=perm, validate_args=True)
        y = bijector.forward(actual_x)
        x = bijector.inverse(actual_y)
        fldj = bijector.forward_log_det_jacobian(x, event_ndims=3)
        ildj = bijector.inverse_log_det_jacobian(y, event_ndims=3)

        [y_, x_, ildj_, fldj_] = self.evaluate([y, x, ildj, fldj])

        self.assertStartsWith(bijector.name, 'transpose')
        self.assertAllEqual(actual_y, y_)
        self.assertAllEqual(actual_x, x_)
        self.assertAllEqual(0., ildj_)
        self.assertAllEqual(0., fldj_)
Esempio n. 5
0
 def make_bijector(perm=None, rightmost_transposed_ndims=None):
     if perm is not None:
         perm = tf.convert_to_tensor(value=perm)
         if not self.is_static:
             perm = tf1.placeholder_with_default(perm, shape=perm.shape)
     return tfb.Transpose(
         perm, rightmost_transposed_ndims=rightmost_transposed_ndims)
Esempio n. 6
0
 def testNonPermutationAssertion(self):
   message = '`perm` must be a valid permutation vector'
   with self.assertRaisesRegexp(Exception, message):
     permutation = np.int32([1, 0, 1])
     bijector = tfb.Transpose(perm=permutation, validate_args=True)
     x = np.random.randn(4, 2, 3)
     _ = self.evaluate(bijector.forward(x))
Esempio n. 7
0
 def testNonNegativeAssertion(self):
   message = '`rightmost_transposed_ndims` must be non-negative'
   with self.assertRaisesRegexp(Exception, message):
     ndims = np.int32(-3)
     bijector = tfb.Transpose(rightmost_transposed_ndims=ndims,
                              validate_args=True)
     x = np.random.randn(4, 2, 3)
     _ = self.evaluate(bijector.forward(x))
Esempio n. 8
0
 def testModifiedVariableNonPermutationAssertion(self):
   message = '`perm` must be a valid permutation vector'
   permutation = tf.Variable(np.int32([1, 0, 2]))
   self.evaluate(permutation.initializer)
   bijector = tfb.Transpose(perm=permutation, validate_args=True)
   with self.assertRaisesRegexp(Exception, message):
     with tf.control_dependencies([permutation.assign([1, 0, 1])]):
       x = np.random.randn(4, 2, 3)
       _ = self.evaluate(bijector.forward(x))
Esempio n. 9
0
    def testTransposeFromEventNdim(self):
        rightmost_transposed_ndims_ = np.array(2, dtype=np.int32)
        actual_x_ = np.array([
            [[1, 2], [3, 4]],
            [[5, 6], [7, 8]],
        ],
                             dtype=np.float32)
        actual_y_ = np.array([
            [[1, 3], [2, 4]],
            [[5, 7], [6, 8]],
        ],
                             dtype=np.float32)
        if self.is_static:
            actual_x = tf.constant(actual_x_)
            actual_y = tf.constant(actual_y_)
            rightmost_transposed_ndims = tf.constant(
                rightmost_transposed_ndims_)
        else:
            actual_x = tf.compat.v1.placeholder_with_default(actual_x_,
                                                             shape=None)
            actual_y = tf.compat.v1.placeholder_with_default(actual_y_,
                                                             shape=None)
            rightmost_transposed_ndims = tf.constant(
                rightmost_transposed_ndims_)

        bijector = tfb.Transpose(
            rightmost_transposed_ndims=rightmost_transposed_ndims,
            validate_args=True)
        y = bijector.forward(actual_x)
        x = bijector.inverse(actual_y)
        fldj = bijector.forward_log_det_jacobian(x, event_ndims=2)
        ildj = bijector.inverse_log_det_jacobian(y, event_ndims=2)

        [y_, x_, ildj_, fldj_] = self.evaluate([y, x, ildj, fldj])

        self.assertEqual('transpose', bijector.name)
        self.assertAllEqual(actual_y, y_)
        self.assertAllEqual(actual_x, x_)
        self.assertAllEqual(0., ildj_)
        self.assertAllEqual(0., fldj_)
Esempio n. 10
0
 def testInvalidEventNdimsException(self):
     msg = '`rightmost_transposed_ndims` must be non-negative.'
     with self.assertRaisesRegexp(ValueError, msg):
         tfb.Transpose(rightmost_transposed_ndims=-1, validate_args=True)
Esempio n. 11
0
def bijectors(draw,
              bijector_name=None,
              batch_shape=None,
              event_dim=None,
              enable_vars=False):
    """Strategy for drawing Bijectors.

  The emitted bijector may be a basic bijector or an `Invert` of a basic
  bijector, but not a compound like `Chain`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector_name: Optional Python `str`.  If given, the produced bijectors
      will all have this type.  If omitted, Hypothesis chooses one from
      the whitelist `TF2_FRIENDLY_BIJECTORS`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      bijector.  Hypothesis will pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}

  Returns:
    bijectors: A strategy for drawing bijectors with the specified `batch_shape`
      (or an arbitrary one if omitted).
  """
    if bijector_name is None:
        bijector_name = draw(hps.sampled_from(TF2_FRIENDLY_BIJECTORS))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if bijector_name == 'Invert':
        underlying_name = draw(
            hps.sampled_from(sorted(set(TF2_FRIENDLY_BIJECTORS) - {'Invert'})))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars))
        return tfb.Invert(underlying, validate_args=True)
    if bijector_name == 'TransformDiagonal':
        underlying_name = draw(
            hps.sampled_from(sorted(TRANSFORM_DIAGONAL_WHITELIST)))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=(),
                      event_dim=event_dim,
                      enable_vars=enable_vars))
        return tfb.TransformDiagonal(underlying, validate_args=True)
    if bijector_name == 'Inline':
        if enable_vars:
            scale = tf.Variable(1., name='scale')
        else:
            scale = 2.
        b = tfb.AffineScalar(scale=scale)

        inline = tfb.Inline(
            forward_fn=b.forward,
            inverse_fn=b.inverse,
            forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian(  # pylint: disable=g-long-lambda
                x,
                event_ndims=b.forward_min_event_ndims),
            forward_min_event_ndims=b.forward_min_event_ndims,
            is_constant_jacobian=b.is_constant_jacobian,
        )
        inline.b = b
        return inline
    if bijector_name == 'DiscreteCosineTransform':
        dct_type = draw(hps.integers(min_value=2, max_value=3))
        return tfb.DiscreteCosineTransform(validate_args=True,
                                           dct_type=dct_type)
    if bijector_name == 'PowerTransform':
        power = draw(hps.floats(min_value=0., max_value=10.))
        return tfb.PowerTransform(validate_args=True, power=power)
    if bijector_name == 'Permute':
        event_ndims = draw(hps.integers(min_value=1, max_value=2))
        axis = draw(hps.integers(min_value=-event_ndims, max_value=-1))
        # This is a permutation of dimensions within an axis.
        # (Contrast with `Transpose` below.)
        permutation = draw(hps.permutations(np.arange(event_dim)))
        return tfb.Permute(permutation, axis=axis)
    if bijector_name == 'Reshape':
        event_shape_out = draw(tfp_hps.shapes(min_ndims=1))
        # TODO(b/142135119): Wanted to draw general input and output shapes like the
        # following, but Hypothesis complained about filtering out too many things.
        # event_shape_in = draw(tfp_hps.shapes(min_ndims=1))
        # hp.assume(event_shape_out.num_elements() == event_shape_in.num_elements())
        event_shape_in = [event_shape_out.num_elements()]
        return tfb.Reshape(event_shape_out=event_shape_out,
                           event_shape_in=event_shape_in,
                           validate_args=True)
    if bijector_name == 'Transpose':
        event_ndims = draw(hps.integers(min_value=0, max_value=2))
        # This is a permutation of axes.
        # (Contrast with `Permute` above.)
        permutation = draw(hps.permutations(np.arange(event_ndims)))
        return tfb.Transpose(perm=permutation)

    bijector_params = draw(
        broadcasting_params(bijector_name,
                            batch_shape,
                            event_dim=event_dim,
                            enable_vars=enable_vars))
    ctor = getattr(tfb, bijector_name)
    return ctor(validate_args=True, **bijector_params)