def _real_nvp_kwargs(self): return { "shift_and_log_scale_fn": tfb.real_nvp_default_template(hidden_layers=[3], shift_only=False), "is_constant_jacobian": False, }
def _real_nvp_kwargs(self): return { 'shift_and_log_scale_fn': tfb.real_nvp_default_template(hidden_layers=[2], shift_only=True), 'is_constant_jacobian': True, }
def testBijectorWithReverseMask(self): flat_x_ = np.random.normal(0., 1., 8).astype(np.float32) batched_x_ = np.random.normal(0., 1., (3, 8)).astype(np.float32) num_masked = -5 for x_ in [flat_x_, batched_x_]: flip_nvp = tfb.RealNVP( num_masked=num_masked, validate_args=True, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[3], shift_only=False), is_constant_jacobian=False) _, x2_ = np.split(x_, [8 - abs(num_masked)], axis=-1) x = tf.constant(x_) # Check latter half is the same after passing thru reversed mask RealNVP. forward_x = flip_nvp.forward(x) _, forward_x2 = tf.split(forward_x, [8 - abs(num_masked), abs(num_masked)], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
def testBijectorWithReverseMask(self, num_masked, fraction_masked): input_depth = 8 flat_x_ = np.random.normal(0., 1., input_depth).astype(np.float32) batched_x_ = np.random.normal(0., 1., (3, input_depth)).astype(np.float32) for x_ in [flat_x_, batched_x_]: flip_nvp = tfb.RealNVP( num_masked=num_masked, fraction_masked=fraction_masked, validate_args=True, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[3], shift_only=False), is_constant_jacobian=False) x = tf.constant(x_) forward_x = flip_nvp.forward(x) expected_num_masked = (num_masked if num_masked is not None else np.floor(input_depth * fraction_masked)) self.assertEqual(flip_nvp._masked_size, expected_num_masked) _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)], axis=-1) # pylint: disable=unbalanced-tuple-unpacking # Check latter half is the same after passing thru reversed mask RealNVP. _, forward_x2 = tf.split(forward_x, [ input_depth - abs(flip_nvp._masked_size), abs(flip_nvp._masked_size) ], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)