示例#1
0
  def testBijector(self):
    """Do a basic sanity check of forward, inverse, jacobian."""
    expected_x = np.random.randn(4, 3, 2)
    expected_y = np.reshape(expected_x, [4, 6])

    with self.cached_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)
      (x_,
       y_,
       fldj_,
       ildj_) = sess.run((
           bijector.inverse(expected_y),
           bijector.forward(expected_x),
           bijector.forward_log_det_jacobian(expected_x, event_ndims=2),
           bijector.inverse_log_det_jacobian(expected_y, event_ndims=2),
       ), feed_dict=feed_dict)
      self.assertEqual("reshape", bijector.name)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
      self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
      self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
  def testBijector(self):
    """Do a basic sanity check of forward, inverse, jacobian."""
    expected_x = np.random.randn(4, 3, 2)
    expected_y = np.reshape(expected_x, [4, 6])

    with self.test_session() as sess:
      shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
      bijector = Reshape(
          event_shape_out=shape_out,
          event_shape_in=shape_in,
          validate_args=True)
      (x_,
       y_,
       fldj_,
       ildj_) = sess.run((
           bijector.inverse(expected_y),
           bijector.forward(expected_x),
           bijector.forward_log_det_jacobian(expected_x),
           bijector.inverse_log_det_jacobian(expected_y),
       ), feed_dict=feed_dict)
      self.assertEqual("reshape", bijector.name)
      self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
      self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
      self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
      self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
示例#3
0
    def testRaisesOpError(self):
        x1 = np.random.randn(4, 2, 3)
        x2 = np.random.randn(4, 3, 2)
        x3 = np.random.randn(4, 5, 1, 1)

        with self.test_session() as sess:
            shape_in_ph = array_ops.placeholder(shape=[
                2,
            ],
                                                dtype=dtypes.int32)
            shape_out_ph = array_ops.placeholder(shape=[
                3,
            ],
                                                 dtype=dtypes.int32)
            bijector = Reshape(event_shape_out=shape_out_ph,
                               event_shape_in=shape_in_ph,
                               validate_args=True)

            with self.assertRaisesOpError(
                    "Input `event_shape` does not match `event_shape_in`."):
                sess.run(bijector.forward(x2),
                         feed_dict={
                             shape_out_ph: [1, 6, 1],
                             shape_in_ph: [2, 3]
                         })

            with self.assertRaisesOpError(
                    "event_shape_out entries must be positive."):
                sess.run(bijector.forward(x1),
                         feed_dict={
                             shape_out_ph: [-1, -1, 6],
                             shape_in_ph: [2, 3]
                         })

            # test that *all* methods check basic assertions
            fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]}
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.inverse(x3), feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.inverse_log_det_jacobian(x3),
                         feed_dict=fd_mismatched)
            with self.assertRaisesOpError(
                    "Input/output `event_size`s do not match."):
                sess.run(bijector.forward_log_det_jacobian(x1),
                         feed_dict=fd_mismatched)