Beispiel #1
0
    def model(self, latent, depth, scales):
        """

        :param latent:
        :param depth:
        :param scales:
        :return:
        """
        print('self.nclass:', self.nclass)
        # [b, 32, 32, 1]
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        # [b, 10]
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        # [?, 4, 4, 16]
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')
        # [b, 32, 32, 1] => [b, 4, 4, 16]
        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        # [b, 4, 4, 16] => [b, 32, 32, 1]
        decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder')
        # [b, 4, 4, 16] => [b, 32, 32, 1], auto-reuse
        ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder')
        #
        loss = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(loss, 'loss')
        # utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse')

        # we only use encode to acquire representation and wont use classification to backprop encoder
        # hence we will stop_gradient(encoder)
        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        # record classification loss on latent
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            # since xloss is isolated from loss, here we simply write two optimizers as one optimizer
            train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss + xloss, tf.train.get_global_step())

        ops = train.AEOps(x, h, l, encode, decode, ae, train_op, classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(ops, interpolation=n_interpolations, height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Beispiel #2
0
    def model(self, latent, depth, scales):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder')
        ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder')
        loss = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(loss, 'loss')
        utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(FLAGS.lr)
            train_op = train_op.minimize(loss + xloss,
                                         tf.train.get_global_step())
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Beispiel #3
0
    def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        def discriminator(h):
            with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
                h = tf.layers.flatten(h)
                for size in [int(s) for s in disc_layer_sizes.split(',')]:
                    h = tf.layers.dense(h, size, tf.nn.leaky_relu)
                return tf.layers.dense(h, 1)

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype)
        adversary_logit_latent = discriminator(encode)
        adversary_logit_prior = discriminator(prior_samples)
        adversary_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.zeros_like(adversary_logit_latent)))
        adversary_loss_prior = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_prior,
                labels=tf.ones_like(adversary_logit_prior)))
        autoencoder_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.ones_like(adversary_logit_latent)))

        def _accuracy(logits, label):
            labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool))
            correct = tf.equal(tf.greater(logits, 0), labels)
            return tf.reduce_mean(tf.to_float(correct))

        latent_accuracy = _accuracy(adversary_logit_latent, False)
        prior_accuracy = _accuracy(adversary_logit_prior, True)
        adversary_accuracy = (latent_accuracy + prior_accuracy) / 2

        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent')
        utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior')
        utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent')
        utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + autoencoder_loss_latents, var_list=ae_vars)
            train_disc = tf.train.AdamOptimizer(adversary_lr).minimize(
                adversary_loss_prior + adversary_loss_latents,
                var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_disc, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #4
0
    def model(self, latent, depth, scales, advweight, advdepth, reg, advnoise,
              advfake, wgt_mmd):
        ## define inputs
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            # return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3])
            y = layers.encoder(x, scales, depth, latent, 'disc')
            return y

        encode = encoder(x)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        decode = decoder(h)

        ## impose regularization on latent space
        encode_flat = tf.reshape(encode, [tf.shape(encode)[0], -1])
        h_flat = tf.reshape(h, [tf.shape(h)[0], -1])
        loss_mmd = tf.nn.relu(mmd2(encode_flat, h_flat))

        ## impose regularization on latent space
        alpha_mix = tf.random_uniform(tf.shape(encode), 0, 1)
        alpha_mix = 0.5 - tf.abs(alpha_mix - 0.5)  # Make interval [0, 0.5]
        encode_mix = alpha_mix * encode + (1 - alpha_mix) * encode[::-1]
        decode_mix = decoder(encode_mix)

        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix) - alpha_mix))
        loss_ae_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix)))

        alpha_noise = tf.random_uniform(tf.shape(encode), 0, 1)
        encode_mix_noise = alpha_noise * encode + (1 - alpha_noise) * h
        decode_mix_noise = decoder(encode_mix_noise)

        loss_disc_noise = tf.reduce_mean(
            tf.square(disc(decode_mix_noise) - alpha_noise))
        loss_ae_disc_noise = tf.reduce_mean(tf.square(disc(decode_mix_noise)))

        alpha_fake = 0.5  # I think here we can have another try.
        loss_disc_fake = tf.reduce_mean(tf.square(disc(decode) - alpha_fake))
        loss_ae_disc_fake = tf.reduce_mean(tf.square(disc(decode)))

        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')
        utils.HookReport.log_tensor(loss_disc_mix, 'loss_disc_mix')
        utils.HookReport.log_tensor(loss_ae_disc_mix, 'loss_ae_disc_mix')
        utils.HookReport.log_tensor(loss_disc_noise, 'loss_disc_noise')
        utils.HookReport.log_tensor(loss_ae_disc_noise, 'loss_ae_disc_noise')
        utils.HookReport.log_tensor(loss_disc_fake, 'loss_disc_fake')
        utils.HookReport.log_tensor(loss_ae_disc_fake, 'loss_ae_disc_fake')
        utils.HookReport.log_tensor(loss_mmd, 'loss_mmd')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')

        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc_mix +
                advnoise * loss_ae_disc_noise + advfake * loss_ae_disc_fake +
                wgt_mmd * loss_mmd,
                var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(
                FLAGS.lr).minimize(loss_disc_real + loss_disc_mix +
                                   loss_disc_noise + loss_disc_fake,
                                   var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss,
                global_step=tf.train.get_global_step(),
                var_list=xl_vars)

        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          train_xl,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #5
0
    def model(self, latent, depth, scales, z_log_size, beta, num_latents):
        tf.set_random_seed(123)
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def decode_fn(h):
            with tf.variable_scope('vqvae', reuse=tf.AUTO_REUSE):
                h2 = tf.expand_dims(tf.layers.flatten(h), axis=1)
                h2 = tf.layers.dense(h2,
                                     self.hparams.hidden_size * num_latents)
                d = bneck.discrete_bottleneck(h2)
                y = layers.decoder(tf.reshape(d['dense'], tf.shape(h)), scales,
                                   depth, self.colors, 'ae_decoder')
                return y, d

        self.hparams.hidden_size = ((self.height >> scales) *
                                    (self.width >> scales) * latent)
        self.hparams.z_size = z_log_size
        self.hparams.num_residuals = 1
        self.hparams.num_blocks = 1
        self.hparams.beta = beta
        self.hparams.ema = True
        bneck = DiscreteBottleneck(self.hparams)
        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = decode_fn(h)[0]
        ae, d = decode_fn(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(d['loss'], 'vqvae_loss')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(d['dense']), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops + [d['discrete']]):
            train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + xloss + d['loss'], tf.train.get_global_step())
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Beispiel #6
0
    def model(self, latent, depth, scales, beta, advweight, advdepth, reg):
        """
        Args:
            latent: number of channels output by the encoder.
            depth: depth (number of channels before applying the first convolution)
                for the encoder
            scales: input width/height to latent width/height ratio, on log base 2 scale
                (how many times the encoder should downsample)
            beta: scale hyperparam >= 1 for the KL term in the ELBO
                (value of 1 equivalent to vanilla VAE)
            advweight: how much the VAE should care about fooling the discriminator
                (value of 0 equivalent to training a VAE alone)
            advdepth: depth for the discriminator
            reg: gamma in the paper
        """
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            """ Outputs latent codes (not mean vectors) """
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            """ Outputs Bernoulli logits """
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        def disc(x):
            """ Outputs predicted mixing coefficient alpha """
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        # ENCODE
        encode = encoder(x)
        # get mean and var from the latent code
        with tf.variable_scope('ae_latent'):
            encode_shape = tf.shape(encode)
            encode_flat = tf.layers.flatten(encode)
            latent_dim = encode_flat.get_shape()[-1]
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)
        # sample
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)

        # DECODE
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        vae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1  # [0, 1] -> [-1, 1]
        decode = 2 * tf.nn.sigmoid(decoder(h)) - 1

        # COMPUTE VAE LOSS
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)
        loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 +
                                      tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl / tf.to_float(tf.shape(x)[0])
        x_bernoulli = 0.5 * (x + 1)  # [-1, 1] -> [0, 1]
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll / tf.to_float(tf.shape(x)[0])
        elbo = loss_ll - beta * loss_kl
        loss_vae = -elbo
        utils.HookReport.log_tensor(loss_vae, 'neg elbo')

        # COMPUTE DISCRIMINATOR LOSS
        # interpolate in latent space with a randomly-chosen alpha
        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # [0, 1] -> [0, 0.5]
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        decode_mix = 2 * tf.nn.sigmoid(decoder(encode_mix)) - 1

        loss_disc = tf.reduce_mean(
            tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(vae + reg * (x - vae))))
        # vae wants disc to predict 0
        loss_vae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        # CLASSIFY (determine "usefulness" of latent codes)
        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_vae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_vae + advweight * loss_vae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          vae,
                          tf.group(train_vae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #7
0
    def model(self, latent, depth, scales, beta):
        """

        :param latent: hidden/latent channel number
        :param depth: channel number for factor
        :param scales: factor
        :param beta: beta for KL divergence
        :return:
        """
        # x is rescaled to [-1, 1] in data argumentation phase
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        # [32>>3, 32>>3, latent_depth]
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'vae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'vae_dec')

        # [b, 4, 4, 16]
        encode = encoder(x)

        with tf.variable_scope('vae_u_std'):
            encode_shape = tf.shape(encode)
            # [b, 16*16]
            encode_flat = tf.layers.flatten(encode)
            # not run-time shape, 16*16
            latent_dim = encode_flat.get_shape()[-1]
            # dense:[16*16, 16*16]
            # mean
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            # dense: [16*16, 16*16]
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)

        # [b, 16*16], log square
        # variance
        # => [b, 4*4*16]
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))

        # N(u, std^2)
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        # [b, 4*4*16] => [b, 4, 4, 16]
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)
        # [b, 32, 32, 1]
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        # [b, 32, 32, 1]
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)

        # for VAE, h stands for sampled value with Guassian(u, std^2)
        # -1~1
        ae = 2*tf.nn.sigmoid(p_x_given_z_logits) - 1
        decode = 2*tf.nn.sigmoid(decoder(h)) - 1

        # compute kl divergence
        # there is a closed form of KL between two Guassian distributions
        # please refer to here:
        # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
        loss_kl = 0.5*tf.reduce_sum(-log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl/tf.to_float(tf.shape(x)[0])

        # rescale to [0, 1], convenient for Bernoulli distribution
        x_bernoulli = 0.5*(x + 1)
        # can use reconstruction or use density estimation
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll/tf.to_float(tf.shape(x)[0])

        #
        elbo = loss_ll - beta*loss_kl

        utils.HookReport.log_tensor(loss_kl, 'kl_divergence')
        utils.HookReport.log_tensor(loss_ll, 'log_likelihood')
        utils.HookReport.log_tensor(elbo, 'elbo')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('vae_enc') + tf.global_variables('vae_dec') + tf.global_variables('vae_u_std')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(- elbo, var_list=ae_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars)

        ops = train.AEOps(x, h, l, q_z_sample_reshaped, decode, ae, tf.group(train_ae, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save( ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func( gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Beispiel #8
0
    def model(self, latent, depth, scales, beta):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        encode = encoder(x)
        with tf.variable_scope('ae_latent'):
            encode_shape = tf.shape(encode)
            encode_flat = tf.layers.flatten(encode)
            latent_dim = encode_flat.get_shape()[-1]
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)
        ae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1
        decode = 2 * tf.nn.sigmoid(decoder(h)) - 1
        loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 +
                                      tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl / tf.to_float(tf.shape(x)[0])
        x_bernoulli = 0.5 * (x + 1)
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll / tf.to_float(tf.shape(x)[0])
        elbo = loss_ll - beta * loss_kl

        utils.HookReport.log_tensor(loss_kl, 'loss_kl')
        utils.HookReport.log_tensor(loss_ll, 'loss_ll')
        utils.HookReport.log_tensor(elbo, 'elbo')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                -elbo, var_list=ae_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          q_z_sample_reshaped,
                          decode,
                          ae,
                          tf.group(train_ae, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Beispiel #9
0
    def model(self, latent, depth, scales, advweight, advdepth):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        loss_disc = tf.reduce_mean(
            tf.square(disc(x)) + tf.square(disc(ae) - 1))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(ae)))

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #10
0
    def model(self, latent, depth, scales, advweight, advdepth, reg):
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'acai_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'acai_dec')
            return v

        def disc(x):
            # [b, 32 ,32, 1] => [b, 4, 4, adv_c] => [b]
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'acai_disc'), axis=[1, 2, 3])

        # [b, 4, 4, 16]
        encode = encoder(x)
        # [b, 32, 32, 1]
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        # [b, 1, 1, 1] ~ uniform dist(0~1)
        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # Make interval [0, 0.5]
        # a * [b, 4, 4, 16] + (1-a)*[reversed(b), 4, 4, 16]
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        # [b, 32, 32, 1] => [b]
        decode_mix = decoder(encode_mix)

        loss_disc = tf.reduce_mean(tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))

        # utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('acai_enc') + tf.global_variables('acai_dec')
        disc_vars = tf.global_variables('acai_disc')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x, h, l, encode, decode, ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(ops, interpolation=n_interpolations,
                                                    height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #11
0
    def model(self, latent, depth, scales, advweight, advdepth, reg):
        # scale: the num of downscaled(avgpool) convolution layer
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        # place holder for decoder(w' x h' x latent)
        # we need this because of the interpolated code

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            # similar shape to encoder
            # last input is scalar(reduce mean of hidden layer) in order to map alpha
            # FIXME: why dont sigmoid output
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # Make interval [0, 0.5]
        # FIXME: why dont alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 0.5)
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        # mix latent codes symmetrically
        # e.g. l1 (+) l3 / l2 (+) l2 / l3 (+) l1
        decode_mix = decoder(encode_mix)

        loss_disc = tf.reduce_mean(
            tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Beispiel #12
0
    def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes):

        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'aae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'aae_dec')

        def discriminator(h):
            """
            Construct 2 layer MLP: [b, 4, 4, 16]=>MLP(100, 100)=>[b, 1]
            :param h:
            :return:
            """
            with tf.variable_scope('aae_disc', reuse=tf.AUTO_REUSE):
                # [b, 4, 4, 16] => [b, 16*16]
                h = tf.layers.flatten(h)
                for size in [int(s) for s in disc_layer_sizes.split(',')]:
                    # Dense(16*16, 100)
                    # Dense(100, 100)
                    h = tf.layers.dense(h, size, tf.nn.leaky_relu)
                # [b, 100] => [b, 1]
                return tf.layers.dense(h, 1)

        # [b, 4, 4, 16]
        encode = encoder(x)
        # [b, 32, 32, 1]
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        # assume the prior dist of h is normal
        prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype)
        # D(h), justify the generate latent is close to prior_samples or not
        adversary_logit_latent = discriminator(encode)
        # D(p(h))
        adversary_logit_prior = discriminator(prior_samples)
        # loss on fake h
        adversary_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.zeros_like(adversary_logit_latent)))
        # loss on real prior h
        adversary_loss_prior = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_prior,
                labels=tf.ones_like(adversary_logit_prior)))
        # loss on auto-encoder to fool discriminator
        autoencoder_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.ones_like(adversary_logit_latent)))

        #
        def _accuracy(logits, label):
            labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool))
            correct = tf.equal(tf.greater(logits, 0), labels)
            return tf.reduce_mean(tf.to_float(correct))

        latent_accuracy = _accuracy(adversary_logit_latent, False)
        prior_accuracy = _accuracy(adversary_logit_prior, True)
        adversary_accuracy = (latent_accuracy + prior_accuracy) / 2

        # reconstruction loss
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        # discriminator should treat all h as fake
        utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent')
        # discriminator should treat all prior h as real
        utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior')
        # h generated by encoder should fool discrimator
        utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent')
        # average accuracy on justify enc(x) from p(h)
        utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode),
                                                   l,
                                                   self.nclass,
                                                   scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('aae_enc') + tf.global_variables(
            'aae_dec')
        disc_vars = tf.global_variables('aae_disc')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            # train auto-encoder and G/encoder
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + autoencoder_loss_latents, var_list=ae_vars)
            # train discriminator
            train_disc = tf.train.AdamOptimizer(adversary_lr).minimize(
                adversary_loss_prior + adversary_loss_latents,
                var_list=disc_vars)
            # train MLP classifier
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_disc, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops