def testReverseMask(self, num_masked, fraction_masked, batch_shape): input_depth = 8 x_ = np.random.normal(0., 1., batch_shape + (input_depth, )).astype(np.float32) flip_nvp = tfb.RealNVP( num_masked=num_masked, fraction_masked=fraction_masked, validate_args=True, **self._real_nvp_kwargs, ) x = tf.constant(x_) forward_x = flip_nvp.forward(x) expected_num_masked = (num_masked if num_masked is not None else np.floor(input_depth * fraction_masked)) self.assertEqual(flip_nvp._masked_size, expected_num_masked) _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)], axis=-1) # pylint: disable=unbalanced-tuple-unpacking # Check latter half is the same after passing thru reversed mask RealNVP. _, forward_x2 = tf.split(forward_x, [ input_depth - abs(flip_nvp._masked_size), abs(flip_nvp._masked_size) ], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
def testBijectorWithTrivialTransform(self): flat_x_ = np.random.normal(0., 1., 8).astype(np.float32) batched_x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32) for x_ in [flat_x_, batched_x_]: nvp = tfb.RealNVP( num_masked=4, validate_args=True, shift_and_log_scale_fn=lambda x, _: (x, x), is_constant_jacobian=False) x = tf.constant(x_) forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(tf.identity(forward_x)) forward_inverse_y = nvp.forward(inverse_y) fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1) forward_x_ = self.evaluate(forward_x) inverse_y_ = self.evaluate(inverse_y) forward_inverse_y_ = self.evaluate(forward_inverse_y) ildj_ = self.evaluate(ildj) fldj_ = self.evaluate(fldj) self.assertEqual("real_nvp", nvp.name) self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
def testBatchedBijectorWithMLPTransform(self): x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32) with self.cached_session() as sess: nvp = tfb.RealNVP( num_masked=4, validate_args=True, **self._real_nvp_kwargs) x = tf.constant(x_) forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(tf.identity(forward_x)) forward_inverse_y = nvp.forward(inverse_y) fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1) tf.global_variables_initializer().run() [ forward_x_, inverse_y_, forward_inverse_y_, ildj_, fldj_, ] = sess.run([ forward_x, inverse_y, forward_inverse_y, ildj, fldj, ]) self.assertEqual("real_nvp", nvp.name) self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
def testNonBatchedBijectorWithMLPTransform(self): x_ = np.random.normal(0., 1., (8, )).astype(np.float32) nvp = tfb.RealNVP(num_masked=4, validate_args=True, **self._real_nvp_kwargs) x = tf.constant(x_) forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(tf.identity(forward_x)) forward_inverse_y = nvp.forward(inverse_y) fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1) self.evaluate(tf1.global_variables_initializer()) [ forward_x_, inverse_y_, forward_inverse_y_, ildj_, fldj_, ] = self.evaluate([ forward_x, inverse_y, forward_inverse_y, ildj, fldj, ]) self.assertStartsWith(nvp.name, 'real_nvp') self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(x_, inverse_y_, rtol=1e-4, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
def testBijectorConditionKwargs(self): batch_size = 3 x_ = np.linspace(-1.0, 1.0, (batch_size * 4 * 2)).astype(np.float32).reshape( (batch_size, 4 * 2)) conditions = { 'a': tf.random.normal((batch_size, 4), dtype=tf.float32, seed=584), 'b': tf.random.normal((batch_size, 2), dtype=tf.float32, seed=9817), } def _condition_shift_and_log_scale_fn(x0, output_units, a, b): x = tf.concat((x0, a, b), axis=-1) out = tf1.layers.dense(inputs=x, units=2 * output_units) shift, log_scale = tf.split(out, 2, axis=-1) return shift, log_scale condition_shift_and_log_scale_fn = tf1.make_template( 'real_nvp_condition_template', _condition_shift_and_log_scale_fn) nvp = tfb.RealNVP( num_masked=4, validate_args=True, is_constant_jacobian=False, shift_and_log_scale_fn=condition_shift_and_log_scale_fn) x = tf.constant(x_) forward_x = nvp.forward(x, **conditions) # Use identity to invalidate cache. inverse_y = nvp.inverse(tf.identity(forward_x), **conditions) forward_inverse_y = nvp.forward(inverse_y, **conditions) fldj = nvp.forward_log_det_jacobian(x, event_ndims=1, **conditions) # Use identity to invalidate cache. ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1, **conditions) self.evaluate(tf1.global_variables_initializer()) [ forward_x_, inverse_y_, forward_inverse_y_, ildj_, fldj_, ] = self.evaluate([ forward_x, inverse_y, forward_inverse_y, ildj, fldj, ]) self.assertStartsWith(nvp.name, 'real_nvp') self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-5, atol=1e-5) self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=1e-5) self.assertAllClose(ildj_, -fldj_, rtol=1e-5, atol=1e-5)
def testBadNumMaskRaises(self, num_masked): with self.assertRaisesRegexp( ValueError, 'Number of masked units {} must be smaller than the event size 1' .format(num_masked)): rnvp = tfb.RealNVP(num_masked=num_masked, shift_and_log_scale_fn=lambda x, _: (x, x)) rnvp.forward(np.zeros(1))
def testRankChangingBijectorRaises(self): with self.assertRaisesRegexp( ValueError, 'Bijectors which alter `event_ndims` are not supported.'): def bijector_fn(*args, **kwargs): del args, kwargs return tfb.Inline(forward_min_event_ndims=1, inverse_min_event_ndims=0) rnvp = tfb.RealNVP(1, bijector_fn=bijector_fn, validate_args=True) rnvp.forward([1., 2.])
def testMatrixBijectorRaises(self): with self.assertRaisesRegexp( ValueError, 'Bijectors with `forward_min_event_ndims` > 1 are not supported'): def bijector_fn(*args, **kwargs): del args, kwargs return tfb.Inline(forward_min_event_ndims=2) rnvp = tfb.RealNVP(1, bijector_fn=bijector_fn, validate_args=True) rnvp.forward([1., 2.])
def testBijectorConditionKwargs(self): batch_size = 3 x_ = np.linspace(-1.0, 1.0, (batch_size * 4 * 2)).astype(np.float32).reshape( (batch_size, 4 * 2)) conditions = { 'a': np.random.normal(size=(batch_size, 4)).astype(np.float32), 'b': np.random.normal(size=(batch_size, 4)).astype(np.float32), } def _condition_shift_and_log_scale_fn(x0, output_units, a, b): del output_units return x0 + a, x0 + b nvp = tfb.RealNVP( num_masked=4, validate_args=True, is_constant_jacobian=False, shift_and_log_scale_fn=_condition_shift_and_log_scale_fn) x = tf.constant(x_) forward_x = nvp.forward(x, **conditions) # Use identity to invalidate cache. inverse_y = nvp.inverse(tf.identity(forward_x), **conditions) forward_inverse_y = nvp.forward(inverse_y, **conditions) fldj = nvp.forward_log_det_jacobian(x, event_ndims=1, **conditions) # Use identity to invalidate cache. ildj = nvp.inverse_log_det_jacobian(tf.identity(forward_x), event_ndims=1, **conditions) [ forward_x_, inverse_y_, forward_inverse_y_, ildj_, fldj_, ] = self.evaluate([ forward_x, inverse_y, forward_inverse_y, ildj, fldj, ]) self.assertStartsWith(nvp.name, 'real_nvp') self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-5, atol=1e-5) self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=1e-5) self.assertAllClose(ildj_, -fldj_, rtol=1e-5, atol=1e-5)
def make_layer(i): fn = ShiftAndLogScale(n_units - n_masked) chain = [ tfb.RealNVP( num_masked=n_masked, shift_and_log_scale_fn=fn, ), tfb.BatchNormalization(), ] if i % 2 == 0: perm = lambda: tfb.Permute(permutation=[1, 0]) chain = [perm(), *chain, perm()] return tfb.Chain(chain)
def testMutuallyConsistent(self): dims = 4 nvp = tfb.RealNVP(num_masked=3, validate_args=True, **self._real_nvp_kwargs) dist = tfd.TransformedDistribution(distribution=tfd.Sample( tfd.Normal(0., 1.), [dims]), bijector=nvp, validate_args=True) self.run_test_sample_consistent_log_prob(sess_run_fn=self.evaluate, dist=dist, num_samples=int(1e6), seed=54819, radius=1., center=0., rtol=0.1)
def testMutuallyConsistent(self): dims = 4 nvp = tfb.RealNVP( num_masked=3, validate_args=True, **self._real_nvp_kwargs) dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), bijector=nvp, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob( sess_run_fn=self.evaluate, dist=dist, num_samples=int(2e5), radius=1., center=0., rtol=0.02)
def testInvertMutuallyConsistent(self): dims = 4 with self.test_session() as sess: nvp = tfb.Invert( tfb.RealNVP(num_masked=3, validate_args=True, **self._real_nvp_kwargs)) dist = transformed_distribution_lib.TransformedDistribution( distribution=tf.distributions.Normal(loc=0., scale=1.), bijector=nvp, event_shape=[dims], validate_args=True) self.run_test_sample_consistent_log_prob(sess_run_fn=sess.run, dist=dist, num_samples=int(1e5), radius=1., center=0., rtol=0.02)
def testBijectorWithReverseMask(self): flat_x_ = np.random.normal(0., 1., 8).astype(np.float32) batched_x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32) num_masked = -5 for x_ in [flat_x_, batched_x_]: flip_nvp = tfb.RealNVP( num_masked=num_masked, validate_args=True, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[3], shift_only=False), is_constant_jacobian=False) _, x2_ = np.split(x_, [8 - abs(num_masked)], axis=-1) x = tf.constant(x_) # Check latter half is the same after passing thru reversed mask RealNVP. forward_x = flip_nvp.forward(x) _, forward_x2 = tf.split(forward_x, [8 - abs(num_masked), abs(num_masked)], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
def testBijectorWithReverseMask(self, num_masked, fraction_masked): input_depth = 8 flat_x_ = np.random.normal(0., 1., input_depth).astype(np.float32) batched_x_ = np.random.normal(0., 1., (3, input_depth)).astype(np.float32) for x_ in [flat_x_, batched_x_]: flip_nvp = tfb.RealNVP( num_masked=num_masked, fraction_masked=fraction_masked, validate_args=True, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[3], shift_only=False), is_constant_jacobian=False) x = tf.constant(x_) forward_x = flip_nvp.forward(x) expected_num_masked = (num_masked if num_masked is not None else np.floor(input_depth * fraction_masked)) self.assertEqual(flip_nvp._masked_size, expected_num_masked) _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)], axis=-1) # pylint: disable=unbalanced-tuple-unpacking # Check latter half is the same after passing thru reversed mask RealNVP. _, forward_x2 = tf.split(forward_x, [ input_depth - abs(flip_nvp._masked_size), abs(flip_nvp._masked_size) ], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
def spline_flow(): stack = tfb.Identity() for i in range(nsplits): stack = tfb.RealNVP(5 * i, bijector_fn=splines[i])(stack) return stack
def testBadFractionRaises(self, fraction_masked): with self.assertRaisesRegexp(ValueError, '`fraction_masked` must be in'): tfb.RealNVP(fraction_masked=fraction_masked, shift_and_log_scale_fn=lambda x, _: (x, x))
def testNonFloatFractionMaskedRaises(self): with self.assertRaisesRegexp(TypeError, '`fraction_masked` must be a float'): tfb.RealNVP(fraction_masked=1, shift_and_log_scale_fn=lambda x, _: (x, x))
def testNonIntegerNumMaskedRaises(self): with self.assertRaisesRegexp(TypeError, '`num_masked` must be an integer'): tfb.RealNVP(num_masked=0.5, shift_and_log_scale_fn=lambda x, _: (x, x))