def testTheoreticalFldj(self):
        nbatch = 5
        channels = 10
        x = np.random.uniform(size=[nbatch, channels]).astype(np.float32)

        bijector = tfb.BatchNormalization(training=False)
        bijector.batchnorm.build(x.shape)
        self.evaluate([v.initializer for v in bijector.variables])

        y = self.evaluate(bijector.forward(x))
        bijector_test_util.assert_bijective_and_finite(bijector,
                                                       x,
                                                       y,
                                                       eval_func=self.evaluate,
                                                       event_ndims=1,
                                                       inverse_event_ndims=1,
                                                       rtol=1e-5)
        fldj = bijector.forward_log_det_jacobian(x, event_ndims=1)
        # The jacobian is not yet broadcast, since it is constant.
        fldj = fldj + tf.zeros(tf.shape(x)[:-1], dtype=x.dtype)
        fldj_theoretical = bijector_test_util.get_fldj_theoretical(
            bijector, x, event_ndims=1)
        self.assertAllClose(self.evaluate(fldj_theoretical),
                            self.evaluate(fldj),
                            atol=1e-5,
                            rtol=1e-5)
 def testLogProb(self, event_shape, event_dims, training, layer_cls):
     training = tf.compat.v1.placeholder_with_default(
         training, (), "training")
     layer = layer_cls(axis=event_dims, epsilon=0.)
     batch_norm = tfb.BatchNormalization(batchnorm_layer=layer,
                                         training=training)
     base_dist = distributions.MultivariateNormalDiag(
         loc=np.zeros(np.prod(event_shape), dtype=np.float32))
     # Reshape the events.
     if isinstance(event_shape, int):
         event_shape = [event_shape]
     base_dist = distributions.TransformedDistribution(
         distribution=base_dist,
         bijector=tfb.Reshape(event_shape_out=event_shape))
     dist = distributions.TransformedDistribution(distribution=base_dist,
                                                  bijector=batch_norm,
                                                  validate_args=True)
     samples = dist.sample(int(1e5))
     # No volume distortion since training=False, bijector is initialized
     # to the identity transformation.
     base_log_prob = base_dist.log_prob(samples)
     dist_log_prob = dist.log_prob(samples)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     base_log_prob_, dist_log_prob_ = self.evaluate(
         [base_log_prob, dist_log_prob])
     self.assertAllClose(base_log_prob_, dist_log_prob_)
 def testMaximumLikelihoodTraining(self):
     # Test Maximum Likelihood training with default bijector.
     training = tf.placeholder_with_default(True, (), "training")
     base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
     batch_norm = tfb.BatchNormalization(training=training)
     dist = distributions.TransformedDistribution(distribution=base_dist,
                                                  bijector=batch_norm)
     target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.])
     target_samples = target_dist.sample(200)
     dist_samples = dist.sample(3000)
     loss = -tf.reduce_mean(dist.log_prob(target_samples))
     with tf.control_dependencies(batch_norm.batchnorm.updates):
         train_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
         moving_mean = tf.identity(batch_norm.batchnorm.moving_mean)
         moving_var = tf.identity(batch_norm.batchnorm.moving_variance)
     self.evaluate(tf.global_variables_initializer())
     for _ in range(3000):
         self.evaluate(train_op)
     [dist_samples_, moving_mean_,
      moving_var_] = self.evaluate([dist_samples, moving_mean, moving_var])
     self.assertAllClose([1., 2.],
                         np.mean(dist_samples_, axis=0),
                         atol=5e-2)
     self.assertAllClose([1., 2.], moving_mean_, atol=5e-2)
     self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
Example #4
0
 def make_layer(i):
     fn = ShiftAndLogScale(n_units - n_masked)
     chain = [
         tfb.RealNVP(
             num_masked=n_masked,
             shift_and_log_scale_fn=fn,
         ),
         tfb.BatchNormalization(),
     ]
     if i % 2 == 0:
         perm = lambda: tfb.Permute(permutation=[1, 0])
         chain = [perm(), *chain, perm()]
     return tfb.Chain(chain)
 def testLogProb(self):
   with self.test_session() as sess:
     layer = normalization.BatchNormalization(epsilon=0.)
     batch_norm = tfb.BatchNormalization(batchnorm_layer=layer, training=False)
     base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
     dist = transformed_distribution_lib.TransformedDistribution(
         distribution=base_dist,
         bijector=batch_norm,
         validate_args=True)
     samples = dist.sample(int(1e5))
     # No volume distortion since training=False, bijector is initialized
     # to the identity transformation.
     base_log_prob = base_dist.log_prob(samples)
     dist_log_prob = dist.log_prob(samples)
     tf.global_variables_initializer().run()
     base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob])
     self.assertAllClose(base_log_prob_, dist_log_prob_)
 def testInvertMutuallyConsistent(self, layer_cls):
     # BatchNorm bijector is only mutually consistent when training=False.
     dims = 4
     training = tf.compat.v1.placeholder_with_default(False, (), "training")
     layer = layer_cls(epsilon=0.)
     batch_norm = tfb.Invert(
         tfb.BatchNormalization(batchnorm_layer=layer, training=training))
     dist = distributions.TransformedDistribution(
         distribution=distributions.Normal(loc=0., scale=1.),
         bijector=batch_norm,
         event_shape=[dims],
         validate_args=True)
     self.run_test_sample_consistent_log_prob(sess_run_fn=self.evaluate,
                                              dist=dist,
                                              num_samples=int(1e5),
                                              radius=2.,
                                              center=0.,
                                              rtol=0.02)
 def testMutuallyConsistent(self):
   # BatchNorm bijector is only mutually consistent when training=False.
   dims = 4
   with self.test_session() as sess:
     layer = normalization.BatchNormalization(epsilon=0.)
     batch_norm = tfb.BatchNormalization(batchnorm_layer=layer, training=False)
     dist = transformed_distribution_lib.TransformedDistribution(
         distribution=tf.distributions.Normal(loc=0., scale=1.),
         bijector=batch_norm,
         event_shape=[dims],
         validate_args=True)
     self.run_test_sample_consistent_log_prob(
         sess_run_fn=sess.run,
         dist=dist,
         num_samples=int(1e5),
         radius=2.,
         center=0.,
         rtol=0.02)
    def testWithKeras(self):
        # NOTE: Keras throws an error below if we use
        # tf.compat.v1.layers.BatchNormalization() here.
        layer = None

        dist = distributions.TransformedDistribution(
            distribution=distributions.Normal(loc=0., scale=1.),
            bijector=tfb.BatchNormalization(batchnorm_layer=layer),
            event_shape=[1],
            validate_args=True)

        x_ = tf.keras.Input(shape=(1, ))
        log_prob_ = dist.log_prob(x_)
        model = tf.keras.Model(x_, log_prob_)

        model.compile(optimizer="adam", loss=lambda _, log_prob: -log_prob)

        model.fit(x=np.random.normal(size=(32, 1)).astype(np.float32),
                  y=np.zeros((32, 0)),
                  batch_size=16,
                  epochs=1,
                  steps_per_epoch=2)
    def testForwardInverse(self, input_shape, event_dims, training):
        """Tests forward and backward passes with different event shapes and axes.

    Args:
      input_shape: Tuple of shapes for input tensor.
      event_dims: Tuple of dimension indices that will be normalized.
      training: Boolean of whether bijector runs in training or inference mode.
    """
        x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape)
        x = tf.compat.v1.placeholder_with_default(
            x_, input_shape if 0 in event_dims else (None, ) + input_shape[1:])
        # When training, memorize the exact mean of the last
        # minibatch that it normalized (instead of moving average assignment).
        layer = tf.keras.layers.BatchNormalization(axis=event_dims,
                                                   momentum=0.,
                                                   epsilon=0.)
        batch_norm = tfb.BatchNormalization(batchnorm_layer=layer,
                                            training=training)
        # Minibatch statistics are saved only after norm_x has been computed.
        norm_x = batch_norm.inverse(x)
        with tf.control_dependencies(batch_norm.batchnorm.updates):
            moving_mean = tf.identity(batch_norm.batchnorm.moving_mean)
            moving_var = tf.identity(batch_norm.batchnorm.moving_variance)
            denorm_x = batch_norm.forward(tf.identity(norm_x))
            fldj = batch_norm.forward_log_det_jacobian(
                x, event_ndims=len(event_dims))
            # Use identity to invalidate cache.
            ildj = batch_norm.inverse_log_det_jacobian(
                tf.identity(denorm_x), event_ndims=len(event_dims))
        self.evaluate(tf.compat.v1.global_variables_initializer())
        # Update variables.
        norm_x_ = self.evaluate(norm_x)
        [
            norm_x_,
            moving_mean_,
            moving_var_,
            denorm_x_,
            ildj_,
            fldj_,
        ] = self.evaluate([
            norm_x,
            moving_mean,
            moving_var,
            denorm_x,
            ildj,
            fldj,
        ])
        self.assertStartsWith(batch_norm.name, "batch_normalization")

        reduction_axes = self._reduction_axes(input_shape, event_dims)
        keepdims = len(event_dims) > 1

        expected_batch_mean = np.mean(x_,
                                      axis=reduction_axes,
                                      keepdims=keepdims)
        expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims)

        if training:
            # When training=True, values become normalized across batch dim and
            # original values are recovered after de-normalizing.
            zeros = np.zeros_like(norm_x_)
            self.assertAllClose(np.mean(zeros, axis=reduction_axes),
                                np.mean(norm_x_, axis=reduction_axes))

            self.assertAllClose(expected_batch_mean, moving_mean_)
            self.assertAllClose(expected_batch_var, moving_var_)
            self.assertAllClose(x_, denorm_x_, atol=1e-5)
            # Since moving statistics are set to batch statistics after
            # normalization, ildj and -fldj should match.
            self.assertAllClose(ildj_, -fldj_)
            # ildj is computed with minibatch statistics.
            expected_ildj = np.sum(
                np.log(1.) -
                .5 * np.log(expected_batch_var + batch_norm.batchnorm.epsilon))
            self.assertAllClose(expected_ildj, np.squeeze(ildj_))
        else:
            # When training=False, moving_mean, moving_var remain at their
            # initialized values (0., 1.), resulting in no scale/shift (a small
            # shift occurs if epsilon > 0.)
            self.assertAllClose(x_, norm_x_)
            self.assertAllClose(x_, denorm_x_, atol=1e-5)
            # ildj is computed with saved statistics.
            expected_ildj = np.sum(
                np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon))
            self.assertAllClose(expected_ildj, np.squeeze(ildj_))