def testTransformerAutoencoder(self):
    hparams = imagetransformer_latent_tiny()
    hparams.mode = tf.estimator.ModeKeys.TRAIN
    block_dim = int(hparams.hidden_size // hparams.num_blocks)
    block_v_size = 2**(hparams.bottleneck_bits /
                       (hparams.num_residuals * hparams.num_blocks))
    block_v_size = int(block_v_size)
    means = tf.get_variable(
        name="means",
        shape=[hparams.num_residuals,
               hparams.num_blocks,
               block_v_size,
               block_dim],
        initializer=tf.uniform_unit_scaling_initializer())
    hparams.bottleneck = functools.partial(
        discretization.discrete_bottleneck,
        hidden_size=hparams.hidden_size,
        z_size=hparams.bottleneck_bits,
        filter_size=hparams.filter_size,
        startup_steps=hparams.startup_steps,
        bottleneck_kind=hparams.bottleneck_kind,
        num_blocks=hparams.num_blocks,
        num_residuals=hparams.num_residuals,
        reshape_method=hparams.reshape_method,
        beta=hparams.vq_beta,
        decay=hparams.vq_decay,
        soft_em=hparams.soft_em,
        num_samples=hparams.num_samples,
        epsilon=hparams.vq_epsilon,
        ema=hparams.ema,
        means=means)

    inputs = None
    batch_size = hparams.batch_size
    targets = tf.random_uniform([batch_size,
                                 hparams.img_len,
                                 hparams.img_len,
                                 hparams.hidden_size],
                                minval=-1., maxval=1.)
    target_space_id = None

    tf.train.create_global_step()
    decoder_output, losses, cache = latent_layers.transformer_autoencoder(
        inputs, targets, target_space_id, hparams)

    self.assertEqual(set(six.iterkeys(losses)),
                     {"extra", "extra_loss", "latent_pred"})

    self.evaluate(tf.global_variables_initializer())
    decoder_output_, extra_loss_, latent_pred_ = self.evaluate(
        [decoder_output, losses["extra_loss"], losses["latent_pred"]])
    self.assertEqual(decoder_output_.shape, (batch_size,
                                             hparams.img_len,
                                             hparams.img_len,
                                             hparams.hidden_size))
    self.assertEqual(extra_loss_.shape, (batch_size,))
    self.assertEqual(latent_pred_.shape, (batch_size,))
    self.assertAllGreaterEqual(extra_loss_, 0.)
    self.assertAllGreaterEqual(latent_pred_, 0.)
    self.assertEqual(cache, None)
  def testTransformerAutoencoder(self):
    hparams = imagetransformer_latent_tiny()
    hparams.mode = tf.estimator.ModeKeys.TRAIN
    block_dim = int(hparams.hidden_size // hparams.num_blocks)
    block_v_size = 2**(hparams.bottleneck_bits /
                       (hparams.num_residuals * hparams.num_blocks))
    block_v_size = int(block_v_size)
    means = tf.get_variable(
        name="means",
        shape=[hparams.num_residuals,
               hparams.num_blocks,
               block_v_size,
               block_dim],
        initializer=tf.uniform_unit_scaling_initializer())
    hparams.bottleneck = functools.partial(
        discretization.discrete_bottleneck,
        hidden_size=hparams.hidden_size,
        z_size=hparams.bottleneck_bits,
        filter_size=hparams.filter_size,
        startup_steps=hparams.startup_steps,
        bottleneck_kind=hparams.bottleneck_kind,
        num_blocks=hparams.num_blocks,
        num_residuals=hparams.num_residuals,
        reshape_method=hparams.reshape_method,
        beta=hparams.vq_beta,
        decay=hparams.vq_decay,
        soft_em=hparams.soft_em,
        num_samples=hparams.num_samples,
        epsilon=hparams.vq_epsilon,
        ema=hparams.ema,
        means=means)

    inputs = None
    batch_size = hparams.batch_size
    targets = tf.random_uniform([batch_size,
                                 hparams.img_len,
                                 hparams.img_len,
                                 hparams.hidden_size],
                                minval=-1., maxval=1.)
    target_space_id = None

    tf.train.create_global_step()
    decoder_output, losses, cache = latent_layers.transformer_autoencoder(
        inputs, targets, target_space_id, hparams)

    self.assertEqual(set(six.iterkeys(losses)),
                     {"extra", "extra_loss", "latent_pred"})

    self.evaluate(tf.global_variables_initializer())
    decoder_output_, extra_loss_, latent_pred_ = self.evaluate(
        [decoder_output, losses["extra_loss"], losses["latent_pred"]])
    self.assertEqual(decoder_output_.shape, (batch_size,
                                             hparams.img_len,
                                             hparams.img_len,
                                             hparams.hidden_size))
    self.assertEqual(extra_loss_.shape, (batch_size,))
    self.assertEqual(latent_pred_.shape, (batch_size,))
    self.assertAllGreaterEqual(extra_loss_, 0.)
    self.assertAllGreaterEqual(latent_pred_, 0.)
    self.assertEqual(cache, None)