Esempio n. 1
0
    def testReverseMask(self, num_masked, fraction_masked, batch_shape):
        input_depth = 8
        x_ = np.random.normal(0., 1.,
                              batch_shape + (input_depth, )).astype(np.float32)
        flip_nvp = tfb.RealNVP(
            num_masked=num_masked,
            fraction_masked=fraction_masked,
            validate_args=True,
            **self._real_nvp_kwargs,
        )
        x = tf.constant(x_)

        forward_x = flip_nvp.forward(x)

        expected_num_masked = (num_masked if num_masked is not None else
                               np.floor(input_depth * fraction_masked))

        self.assertEqual(flip_nvp._masked_size, expected_num_masked)

        _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)],
                          axis=-1)  # pylint: disable=unbalanced-tuple-unpacking

        # Check latter half is the same after passing thru reversed mask RealNVP.
        _, forward_x2 = tf.split(forward_x, [
            input_depth - abs(flip_nvp._masked_size),
            abs(flip_nvp._masked_size)
        ],
                                 axis=-1)
        self.evaluate(tf1.global_variables_initializer())
        forward_x2_ = self.evaluate(forward_x2)

        self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
Esempio n. 2
0
  def testBijectorWithTrivialTransform(self):
    flat_x_ = np.random.normal(0., 1., 8).astype(np.float32)
    batched_x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32)
    for x_ in [flat_x_, batched_x_]:
      nvp = tfb.RealNVP(
          num_masked=4,
          validate_args=True,
          shift_and_log_scale_fn=lambda x, _: (x, x),
          is_constant_jacobian=False)
      x = tf.constant(x_)
      forward_x = nvp.forward(x)
      # Use identity to invalidate cache.
      inverse_y = nvp.inverse(tf.identity(forward_x))
      forward_inverse_y = nvp.forward(inverse_y)
      fldj = nvp.forward_log_det_jacobian(x, event_ndims=1)
      # Use identity to invalidate cache.
      ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1)
      forward_x_ = self.evaluate(forward_x)
      inverse_y_ = self.evaluate(inverse_y)
      forward_inverse_y_ = self.evaluate(forward_inverse_y)
      ildj_ = self.evaluate(ildj)
      fldj_ = self.evaluate(fldj)

      self.assertEqual("real_nvp", nvp.name)
      self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.)
      self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.)
      self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
Esempio n. 3
0
 def testBatchedBijectorWithMLPTransform(self):
   x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32)
   with self.cached_session() as sess:
     nvp = tfb.RealNVP(
         num_masked=4, validate_args=True, **self._real_nvp_kwargs)
     x = tf.constant(x_)
     forward_x = nvp.forward(x)
     # Use identity to invalidate cache.
     inverse_y = nvp.inverse(tf.identity(forward_x))
     forward_inverse_y = nvp.forward(inverse_y)
     fldj = nvp.forward_log_det_jacobian(x, event_ndims=1)
     # Use identity to invalidate cache.
     ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1)
     tf.global_variables_initializer().run()
     [
         forward_x_,
         inverse_y_,
         forward_inverse_y_,
         ildj_,
         fldj_,
     ] = sess.run([
         forward_x,
         inverse_y,
         forward_inverse_y,
         ildj,
         fldj,
     ])
     self.assertEqual("real_nvp", nvp.name)
     self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.)
     self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.)
     self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
 def testNonBatchedBijectorWithMLPTransform(self):
     x_ = np.random.normal(0., 1., (8, )).astype(np.float32)
     nvp = tfb.RealNVP(num_masked=4,
                       validate_args=True,
                       **self._real_nvp_kwargs)
     x = tf.constant(x_)
     forward_x = nvp.forward(x)
     # Use identity to invalidate cache.
     inverse_y = nvp.inverse(tf.identity(forward_x))
     forward_inverse_y = nvp.forward(inverse_y)
     fldj = nvp.forward_log_det_jacobian(x, event_ndims=1)
     # Use identity to invalidate cache.
     ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x),
                                         event_ndims=1)
     self.evaluate(tf1.global_variables_initializer())
     [
         forward_x_,
         inverse_y_,
         forward_inverse_y_,
         ildj_,
         fldj_,
     ] = self.evaluate([
         forward_x,
         inverse_y,
         forward_inverse_y,
         ildj,
         fldj,
     ])
     self.assertStartsWith(nvp.name, 'real_nvp')
     self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.)
     self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.)
     self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
    def testBijectorConditionKwargs(self):
        batch_size = 3
        x_ = np.linspace(-1.0, 1.0,
                         (batch_size * 4 * 2)).astype(np.float32).reshape(
                             (batch_size, 4 * 2))

        conditions = {
            'a': tf.random.normal((batch_size, 4), dtype=tf.float32, seed=584),
            'b': tf.random.normal((batch_size, 2), dtype=tf.float32,
                                  seed=9817),
        }

        def _condition_shift_and_log_scale_fn(x0, output_units, a, b):
            x = tf.concat((x0, a, b), axis=-1)
            out = tf1.layers.dense(inputs=x, units=2 * output_units)
            shift, log_scale = tf.split(out, 2, axis=-1)
            return shift, log_scale

        condition_shift_and_log_scale_fn = tf1.make_template(
            'real_nvp_condition_template', _condition_shift_and_log_scale_fn)

        nvp = tfb.RealNVP(
            num_masked=4,
            validate_args=True,
            is_constant_jacobian=False,
            shift_and_log_scale_fn=condition_shift_and_log_scale_fn)

        x = tf.constant(x_)

        forward_x = nvp.forward(x, **conditions)
        # Use identity to invalidate cache.
        inverse_y = nvp.inverse(tf.identity(forward_x), **conditions)
        forward_inverse_y = nvp.forward(inverse_y, **conditions)
        fldj = nvp.forward_log_det_jacobian(x, event_ndims=1, **conditions)
        # Use identity to invalidate cache.
        ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x),
                                            event_ndims=1,
                                            **conditions)
        self.evaluate(tf1.global_variables_initializer())
        [
            forward_x_,
            inverse_y_,
            forward_inverse_y_,
            ildj_,
            fldj_,
        ] = self.evaluate([
            forward_x,
            inverse_y,
            forward_inverse_y,
            ildj,
            fldj,
        ])
        self.assertStartsWith(nvp.name, 'real_nvp')
        self.assertAllClose(forward_x_,
                            forward_inverse_y_,
                            rtol=1e-5,
                            atol=1e-5)
        self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=1e-5)
        self.assertAllClose(ildj_, -fldj_, rtol=1e-5, atol=1e-5)
 def testBadNumMaskRaises(self, num_masked):
     with self.assertRaisesRegexp(
             ValueError,
             'Number of masked units {} must be smaller than the event size 1'
             .format(num_masked)):
         rnvp = tfb.RealNVP(num_masked=num_masked,
                            shift_and_log_scale_fn=lambda x, _: (x, x))
         rnvp.forward(np.zeros(1))
Esempio n. 7
0
  def testRankChangingBijectorRaises(self):
    with self.assertRaisesRegexp(
        ValueError, 'Bijectors which alter `event_ndims` are not supported.'):

      def bijector_fn(*args, **kwargs):
        del args, kwargs
        return tfb.Inline(forward_min_event_ndims=1, inverse_min_event_ndims=0)
      rnvp = tfb.RealNVP(1, bijector_fn=bijector_fn, validate_args=True)
      rnvp.forward([1., 2.])
Esempio n. 8
0
  def testMatrixBijectorRaises(self):
    with self.assertRaisesRegexp(
        ValueError,
        'Bijectors with `forward_min_event_ndims` > 1 are not supported'):

      def bijector_fn(*args, **kwargs):
        del args, kwargs
        return tfb.Inline(forward_min_event_ndims=2)
      rnvp = tfb.RealNVP(1, bijector_fn=bijector_fn, validate_args=True)
      rnvp.forward([1., 2.])
Esempio n. 9
0
    def testBijectorConditionKwargs(self):
        batch_size = 3
        x_ = np.linspace(-1.0, 1.0,
                         (batch_size * 4 * 2)).astype(np.float32).reshape(
                             (batch_size, 4 * 2))

        conditions = {
            'a': np.random.normal(size=(batch_size, 4)).astype(np.float32),
            'b': np.random.normal(size=(batch_size, 4)).astype(np.float32),
        }

        def _condition_shift_and_log_scale_fn(x0, output_units, a, b):
            del output_units
            return x0 + a, x0 + b

        nvp = tfb.RealNVP(
            num_masked=4,
            validate_args=True,
            is_constant_jacobian=False,
            shift_and_log_scale_fn=_condition_shift_and_log_scale_fn)

        x = tf.constant(x_)

        forward_x = nvp.forward(x, **conditions)
        # Use identity to invalidate cache.
        inverse_y = nvp.inverse(tf.identity(forward_x), **conditions)
        forward_inverse_y = nvp.forward(inverse_y, **conditions)
        fldj = nvp.forward_log_det_jacobian(x, event_ndims=1, **conditions)
        # Use identity to invalidate cache.
        ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x),
                                            event_ndims=1,
                                            **conditions)
        [
            forward_x_,
            inverse_y_,
            forward_inverse_y_,
            ildj_,
            fldj_,
        ] = self.evaluate([
            forward_x,
            inverse_y,
            forward_inverse_y,
            ildj,
            fldj,
        ])
        self.assertStartsWith(nvp.name, 'real_nvp')
        self.assertAllClose(forward_x_,
                            forward_inverse_y_,
                            rtol=1e-5,
                            atol=1e-5)
        self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=1e-5)
        self.assertAllClose(ildj_, -fldj_, rtol=1e-5, atol=1e-5)
Esempio n. 10
0
 def make_layer(i):
     fn = ShiftAndLogScale(n_units - n_masked)
     chain = [
         tfb.RealNVP(
             num_masked=n_masked,
             shift_and_log_scale_fn=fn,
         ),
         tfb.BatchNormalization(),
     ]
     if i % 2 == 0:
         perm = lambda: tfb.Permute(permutation=[1, 0])
         chain = [perm(), *chain, perm()]
     return tfb.Chain(chain)
 def testMutuallyConsistent(self):
     dims = 4
     nvp = tfb.RealNVP(num_masked=3,
                       validate_args=True,
                       **self._real_nvp_kwargs)
     dist = tfd.TransformedDistribution(distribution=tfd.Sample(
         tfd.Normal(0., 1.), [dims]),
                                        bijector=nvp,
                                        validate_args=True)
     self.run_test_sample_consistent_log_prob(sess_run_fn=self.evaluate,
                                              dist=dist,
                                              num_samples=int(1e6),
                                              seed=54819,
                                              radius=1.,
                                              center=0.,
                                              rtol=0.1)
Esempio n. 12
0
 def testMutuallyConsistent(self):
   dims = 4
   nvp = tfb.RealNVP(
       num_masked=3, validate_args=True, **self._real_nvp_kwargs)
   dist = tfd.TransformedDistribution(
       distribution=tfd.Normal(loc=0., scale=1.),
       bijector=nvp,
       event_shape=[dims],
       validate_args=True)
   self.run_test_sample_consistent_log_prob(
       sess_run_fn=self.evaluate,
       dist=dist,
       num_samples=int(2e5),
       radius=1.,
       center=0.,
       rtol=0.02)
Esempio n. 13
0
 def testInvertMutuallyConsistent(self):
     dims = 4
     with self.test_session() as sess:
         nvp = tfb.Invert(
             tfb.RealNVP(num_masked=3,
                         validate_args=True,
                         **self._real_nvp_kwargs))
         dist = transformed_distribution_lib.TransformedDistribution(
             distribution=tf.distributions.Normal(loc=0., scale=1.),
             bijector=nvp,
             event_shape=[dims],
             validate_args=True)
         self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run,
                                                  dist=dist,
                                                  num_samples=int(1e5),
                                                  radius=1.,
                                                  center=0.,
                                                  rtol=0.02)
Esempio n. 14
0
  def testBijectorWithReverseMask(self):
    flat_x_ = np.random.normal(0., 1., 8).astype(np.float32)
    batched_x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32)
    num_masked = -5
    for x_ in [flat_x_, batched_x_]:
      flip_nvp = tfb.RealNVP(
          num_masked=num_masked,
          validate_args=True,
          shift_and_log_scale_fn=tfb.real_nvp_default_template(
              hidden_layers=[3], shift_only=False),
          is_constant_jacobian=False)

      _, x2_ = np.split(x_, [8 - abs(num_masked)], axis=-1)
      x = tf.constant(x_)

      # Check latter half is the same after passing thru reversed mask RealNVP.
      forward_x = flip_nvp.forward(x)
      _, forward_x2 = tf.split(forward_x, [8 - abs(num_masked),
                                           abs(num_masked)], axis=-1)
      self.evaluate(tf1.global_variables_initializer())
      forward_x2_ = self.evaluate(forward_x2)

      self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
Esempio n. 15
0
    def testBijectorWithReverseMask(self, num_masked, fraction_masked):
        input_depth = 8
        flat_x_ = np.random.normal(0., 1., input_depth).astype(np.float32)
        batched_x_ = np.random.normal(0., 1.,
                                      (3, input_depth)).astype(np.float32)
        for x_ in [flat_x_, batched_x_]:
            flip_nvp = tfb.RealNVP(
                num_masked=num_masked,
                fraction_masked=fraction_masked,
                validate_args=True,
                shift_and_log_scale_fn=tfb.real_nvp_default_template(
                    hidden_layers=[3], shift_only=False),
                is_constant_jacobian=False)

            x = tf.constant(x_)

            forward_x = flip_nvp.forward(x)

            expected_num_masked = (num_masked if num_masked is not None else
                                   np.floor(input_depth * fraction_masked))

            self.assertEqual(flip_nvp._masked_size, expected_num_masked)

            _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)],
                              axis=-1)  # pylint: disable=unbalanced-tuple-unpacking

            # Check latter half is the same after passing thru reversed mask RealNVP.
            _, forward_x2 = tf.split(forward_x, [
                input_depth - abs(flip_nvp._masked_size),
                abs(flip_nvp._masked_size)
            ],
                                     axis=-1)
            self.evaluate(tf1.global_variables_initializer())
            forward_x2_ = self.evaluate(forward_x2)

            self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
 def spline_flow():
   stack = tfb.Identity()
   for i in range(nsplits):
     stack = tfb.RealNVP(5 * i, bijector_fn=splines[i])(stack)
   return stack
Esempio n. 17
0
 def testBadFractionRaises(self, fraction_masked):
     with self.assertRaisesRegexp(ValueError,
                                  '`fraction_masked` must be in'):
         tfb.RealNVP(fraction_masked=fraction_masked,
                     shift_and_log_scale_fn=lambda x, _: (x, x))
Esempio n. 18
0
 def testNonFloatFractionMaskedRaises(self):
     with self.assertRaisesRegexp(TypeError,
                                  '`fraction_masked` must be a float'):
         tfb.RealNVP(fraction_masked=1,
                     shift_and_log_scale_fn=lambda x, _: (x, x))
Esempio n. 19
0
 def testNonIntegerNumMaskedRaises(self):
     with self.assertRaisesRegexp(TypeError,
                                  '`num_masked` must be an integer'):
         tfb.RealNVP(num_masked=0.5,
                     shift_and_log_scale_fn=lambda x, _: (x, x))