Exemplo n.º 1
0
    def model(self, smoothing):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label_onehot')

        ops = classifiers.single_layer_classifier(x,
                                                  l,
                                                  self.nclass,
                                                  smoothing=smoothing)
        ops.x = x
        ops.label = l
        loss = tf.reduce_mean(ops.loss)
        halfway = ((FLAGS.total_kimg << 10) // FLAGS.batch) // 2
        lr = tf.train.exponential_decay(FLAGS.lr,
                                        tf.train.get_global_step(),
                                        decay_steps=halfway,
                                        decay_rate=0.1)

        utils.HookReport.log_tensor(loss, 'xe')
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = tf.train.AdamOptimizer(lr)
            ops.train_op = opt.minimize(loss, tf.train.get_global_step())

        return ops
Exemplo n.º 2
0
    def model(self, latent, depth, scales):
        """

        :param latent:
        :param depth:
        :param scales:
        :return:
        """
        print('self.nclass:', self.nclass)
        # [b, 32, 32, 1]
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        # [b, 10]
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        # [?, 4, 4, 16]
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')
        # [b, 32, 32, 1] => [b, 4, 4, 16]
        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        # [b, 4, 4, 16] => [b, 32, 32, 1]
        decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder')
        # [b, 4, 4, 16] => [b, 32, 32, 1], auto-reuse
        ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder')
        #
        loss = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(loss, 'loss')
        # utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse')

        # we only use encode to acquire representation and wont use classification to backprop encoder
        # hence we will stop_gradient(encoder)
        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        # record classification loss on latent
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            # since xloss is isolated from loss, here we simply write two optimizers as one optimizer
            train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss + xloss, tf.train.get_global_step())

        ops = train.AEOps(x, h, l, encode, decode, ae, train_op, classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(ops, interpolation=n_interpolations, height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 3
0
    def model(self, latent, depth, scales):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder')
        ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder')
        loss = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(loss, 'loss')
        utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(FLAGS.lr)
            train_op = train_op.minimize(loss + xloss,
                                         tf.train.get_global_step())
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 4
0
    def model(self, latent, depth, scales):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = layers.decoder(h, scales, depth, self.colors, 'ae_decoder')
        ae = layers.decoder(encode, scales, depth, self.colors, 'ae_decoder')
        loss = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(loss, 'loss')
        utils.HookReport.log_tensor(tf.sqrt(loss) * 127.5, 'rmse')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(encode), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(FLAGS.lr)
            train_op = train_op.minimize(loss + xloss,
                                         tf.train.get_global_step())
        ops = train.AEOps(x, h, l, encode, decode, ae, train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(
            gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 5
0
    def model(self, smoothing):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label_onehot')

        ops = classifiers.single_layer_classifier(x, l, self.nclass,
                                                  smoothing=smoothing)
        ops.x = x
        ops.label = l
        loss = tf.reduce_mean(ops.loss)
        halfway = ((FLAGS.total_kimg << 10) // FLAGS.batch) // 2
        lr = tf.train.exponential_decay(FLAGS.lr, tf.train.get_global_step(),
                                        decay_steps=halfway,
                                        decay_rate=0.1)

        utils.HookReport.log_tensor(loss, 'xe')
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = tf.train.AdamOptimizer(lr)
            ops.train_op = opt.minimize(loss, tf.train.get_global_step())

        return ops
Exemplo n.º 6
0
        def record_test_ops(vars_k, k):
            """
            record intermediate test status for update_num pretrain
            :param vars_k:
            :param k: just for log_tensor
            :return:
            """
            # add 0 updated representation
            encoder_op = self.forward_encoder(test_qry_x, vars_k[:self.encoder_var_num])
            decoder_op = self.forward_decoder(h, vars_k[self.encoder_var_num:])  # reuse
            ae_op = self.forward_ae(test_qry_x, vars_k)
            encoder_ops.append(encoder_op)
            decoder_ops.append(decoder_op)
            ae_ops.append(ae_op)

            # this op only optimize classifier, hence stop_gradient after encoder_op
            # classify_op is not a single op, including prediction and loss
            classify_op = classifiers.single_layer_classifier(tf.stop_gradient(encoder_op), test_qry_y, self.nclass,
                                                              scope='classifier_%d'%k, reuse=False)
            classify_ops.append(classify_op)
            # record classification loss on latent
            # utils.HookReport.log_tensor(tf.reduce_mean(classify_op.loss), 'test_classify_h_loss_update_%d'%k)

            return
Exemplo n.º 7
0
    def model(self, latent, depth, scales, beta):
        """

        :param latent: hidden/latent channel number
        :param depth: channel number for factor
        :param scales: factor
        :param beta: beta for KL divergence
        :return:
        """
        # x is rescaled to [-1, 1] in data argumentation phase
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        # [32>>3, 32>>3, latent_depth]
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'vae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'vae_dec')

        # [b, 4, 4, 16]
        encode = encoder(x)

        with tf.variable_scope('vae_u_std'):
            encode_shape = tf.shape(encode)
            # [b, 16*16]
            encode_flat = tf.layers.flatten(encode)
            # not run-time shape, 16*16
            latent_dim = encode_flat.get_shape()[-1]
            # dense:[16*16, 16*16]
            # mean
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            # dense: [16*16, 16*16]
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)

        # [b, 16*16], log square
        # variance
        # => [b, 4*4*16]
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))

        # N(u, std^2)
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        # [b, 4*4*16] => [b, 4, 4, 16]
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)
        # [b, 32, 32, 1]
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        # [b, 32, 32, 1]
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)

        # for VAE, h stands for sampled value with Guassian(u, std^2)
        # -1~1
        ae = 2*tf.nn.sigmoid(p_x_given_z_logits) - 1
        decode = 2*tf.nn.sigmoid(decoder(h)) - 1

        # compute kl divergence
        # there is a closed form of KL between two Guassian distributions
        # please refer to here:
        # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
        loss_kl = 0.5*tf.reduce_sum(-log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl/tf.to_float(tf.shape(x)[0])

        # rescale to [0, 1], convenient for Bernoulli distribution
        x_bernoulli = 0.5*(x + 1)
        # can use reconstruction or use density estimation
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll/tf.to_float(tf.shape(x)[0])

        #
        elbo = loss_ll - beta*loss_kl

        utils.HookReport.log_tensor(loss_kl, 'kl_divergence')
        utils.HookReport.log_tensor(loss_ll, 'log_likelihood')
        utils.HookReport.log_tensor(elbo, 'elbo')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('vae_enc') + tf.global_variables('vae_dec') + tf.global_variables('vae_u_std')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(- elbo, var_list=ae_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars)

        ops = train.AEOps(x, h, l, q_z_sample_reshaped, decode, ae, tf.group(train_ae, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save( ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func( gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 8
0
    def model(self, latent, depth, scales, advweight, advdepth, reg,
              int_hidden_layers, int_hidden_units, disc_hidden_layers, beta,
              wgt_mmd, use_ema, n_scale, wgt_noise, wgt_fake):
        """
            :param latent: The number of channels in latent space.
            :param depth: The number of channels in the first conv operation.
            :param scales: The number of scales.
            :param advweight: The weight for disc.
            :param advdepth: The number of channels in the first conv operation in disc.
            :param reg: The ratio for combine reconstruction and the original images.
            :param int_hidden_layers: The number of layers in the interpolation module.
            :param int_hidden_units: The number of units in the hidden layer of interpolation module.
            :param disc_hidden_layers: The number of layers in the int_disc module.
            :param beta: The momentum of ema.
            :param wgt_mmd: weight for mmd regularization.
            :param use_ema: option to toggle using ema.
            :param n_scale: the scale of training interpolation.
            :param wgt_noise: weight for noise interpolation.
            :param wgt_fake: weight for fake interpolation.
            :return: The operation for manipulating the model.
        """

        dim = latent * FLAGS.latent_width**2
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')
        h_b = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h_b')

        self.use_ema = True if use_ema == 1 else False
        self.n_scale = n_scale

        alpha_h = tf.placeholder(tf.float32, [None, 1, 1, 1], 'alpha_h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def interpolate(h_a, h_b, alpha):
            alpha_dim = 1

            alpha_reshape = tf.reshape(alpha,
                                       shape=[tf.shape(alpha)[0], alpha_dim])

            h_reshape_a = tf.concat((tf.reshape(
                h_a, shape=[tf.shape(h_a)[0], dim]), alpha_reshape),
                                    axis=1)
            h_reshape_b = tf.concat((tf.reshape(
                h_b, shape=[tf.shape(h_b)[0], dim]), 1 - alpha_reshape),
                                    axis=1)
            enc_layer_list = [1200 for i in range(int_hidden_layers)
                              ] + [int_hidden_units]

            h_encode = layers.fully_connected(h_reshape_a,
                                              'int_enc',
                                              hidden_units=enc_layer_list)
            h_encode2 = layers.fully_connected(h_reshape_b,
                                               'int_enc',
                                               hidden_units=enc_layer_list)
            h_mix = h_encode + h_encode2

            alpha_layer_list = [1200 for i in range(int_hidden_layers)
                                ] + [alpha_dim]

            alpha_encode = layers.fully_connected(
                alpha_reshape, 'int_alpha', hidden_units=alpha_layer_list)
            alpha_encode2 = layers.fully_connected(
                1 - alpha_reshape, 'int_alpha', hidden_units=alpha_layer_list)

            alpha_encode = tf.reshape(alpha_encode,
                                      [tf.shape(alpha_encode)[0], 1, 1, 1])
            alpha_encode2 = tf.reshape(alpha_encode2,
                                       [tf.shape(alpha_encode2)[0], 1, 1, 1])

            dec_layer_list = [1200 for i in range(int_hidden_layers)] + [dim]
            h_bias = layers.fully_connected(h_mix,
                                            'int_dec',
                                            hidden_units=dec_layer_list)

            h_bias = tf.reshape(h_bias, [
                tf.shape(h_bias)[0], FLAGS.latent_width, FLAGS.latent_width,
                latent
            ])

            return (alpha + alpha * (1 - alpha) * alpha_encode) * h_a + \
                   ((1 - alpha) + alpha * (1 - alpha) * alpha_encode2) * h_b + \
                   alpha * (1 - alpha) * h_bias

        def disc(x):
            return tf.reduce_mean(layers.encoder(x, scales, depth, latent,
                                                 'disc_img'),
                                  axis=[1, 2, 3])

        def disc_interpolate(z):
            z = tf.reshape(z, shape=[tf.shape(z)[0], dim])
            disc_layer_list = [1200 for i in range(disc_hidden_layers)] + [dim]
            predicted_alpha = layers.fully_connected(
                z, 'disc_int', hidden_units=disc_layer_list)

            predicted_alpha = tf.reduce_mean(predicted_alpha, axis=[1])
            return predicted_alpha

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_rec = tf.losses.mean_squared_error(x, ae)

        encode_mix = interpolate(encode * self.n_scale,
                                 encode[::-1] * self.n_scale, alpha_h)
        encode_mix = encode_mix / self.n_scale
        my_encode_mix = interpolate(h * self.n_scale, h_b * self.n_scale,
                                    alpha_h)
        my_encode_mix = my_encode_mix / self.n_scale
        h_encode_mix = interpolate(h, h[::-1], alpha_h)
        decode_mix = decoder(encode_mix)

        loss_disc = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae)))) + \
                    tf.reduce_mean(tf.square(disc(decode_mix) - alpha_h))

        alpha_noise = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        encode_mix_noise = interpolate(h * self.n_scale, encode * self.n_scale,
                                       alpha_noise)
        encode_mix_noise = encode_mix_noise / self.n_scale
        decode_mix_noise = decoder(encode_mix_noise)

        loss_disc_noise = tf.reduce_mean(
            tf.square(disc(decode_mix_noise) - alpha_noise))
        loss_ae_disc_noise = tf.reduce_mean(tf.square(disc(decode_mix_noise)))

        alpha_fake = 0.5
        loss_disc_fake = tf.reduce_mean(tf.square(disc(decode) - alpha_fake))
        loss_ae_disc_fake = tf.reduce_mean(tf.square(disc(decode)))

        # use h as anchor.
        loss_disc_interpolate = tf.reduce_mean(tf.square(disc_interpolate(h))) + \
                                tf.reduce_mean(tf.square(disc_interpolate(h_encode_mix) - alpha_h))
        loss_int_disc = tf.reduce_mean(
            tf.square(disc_interpolate(h_encode_mix)))

        loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))

        encode_flat = tf.reshape(encode, [tf.shape(encode)[0], -1])
        h_flat = tf.reshape(h, [tf.shape(h)[0], -1])
        loss_mmd = tf.nn.relu(mmd2(encode_flat, h_flat))

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)

        utils.HookReport.log_tensor(loss_mmd, 'loss_mmd')
        utils.HookReport.log_tensor(loss_rec, 'loss_rec')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_disc_interpolate,
                                    'loss_disc_interpolate')
        utils.HookReport.log_tensor(loss_int_disc, 'loss_int_disc')
        utils.HookReport.log_tensor(loss_disc_noise, 'loss_disc_noise')
        utils.HookReport.log_tensor(loss_ae_disc_noise, 'loss_ae_disc_noise')
        utils.HookReport.log_tensor(loss_disc_fake, 'loss_disc_fake')
        utils.HookReport.log_tensor(loss_ae_disc_fake, 'loss_ae_disc_fake')
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_img_vars = tf.global_variables('disc_img')
        disc_int_vars = tf.global_variables('disc_int')
        int_vars = tf.global_variables('int')
        xl_vars = tf.global_variables('single_layer_classifier')

        ema = EMA(beta=beta)
        ema.apply(int_vars)

        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_rec + advweight * loss_ae_disc + wgt_mmd * loss_mmd +
                wgt_noise * loss_ae_disc_noise + wgt_fake * loss_ae_disc_fake,
                var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc + loss_disc_noise + loss_disc_fake,
                var_list=disc_img_vars)
            train_d_int = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc_interpolate, var_list=disc_int_vars)
            train_int = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_int_disc, tf.train.get_global_step(), var_list=int_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)

        ops = AEOps(x,
                    h,
                    h_b,
                    l,
                    encode,
                    decode,
                    ae,
                    tf.group(train_ae, train_d, train_xl),
                    train_xl,
                    alpha_h,
                    my_encode_mix,
                    tf.group(train_d_int, train_int),
                    ema,
                    classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, free_int, samples = tf.py_func(gen_images, [],
                                                     [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('free_int', tf.expand_dims(free_int, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 9
0
    def model(self, latent, depth, scales, beta):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        encode = encoder(x)
        with tf.variable_scope('ae_latent'):
            encode_shape = tf.shape(encode)
            encode_flat = tf.layers.flatten(encode)
            latent_dim = encode_flat.get_shape()[-1]
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)
        ae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1
        decode = 2 * tf.nn.sigmoid(decoder(h)) - 1
        loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 +
                                      tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl / tf.to_float(tf.shape(x)[0])
        x_bernoulli = 0.5 * (x + 1)
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll / tf.to_float(tf.shape(x)[0])
        elbo = loss_ll - beta * loss_kl

        utils.HookReport.log_tensor(loss_kl, 'loss_kl')
        utils.HookReport.log_tensor(loss_ll, 'loss_ll')
        utils.HookReport.log_tensor(elbo, 'elbo')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                -elbo, var_list=ae_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          q_z_sample_reshaped,
                          decode,
                          ae,
                          tf.group(train_ae, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 10
0
    def model(self, latent, depth, scales, advweight, advdepth, reg, advnoise,
              advfake, wgt_mmd):
        ## define inputs
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            # return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'disc'), axis=[1, 2, 3])
            y = layers.encoder(x, scales, depth, latent, 'disc')
            return y

        encode = encoder(x)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        decode = decoder(h)

        ## impose regularization on latent space
        encode_flat = tf.reshape(encode, [tf.shape(encode)[0], -1])
        h_flat = tf.reshape(h, [tf.shape(h)[0], -1])
        loss_mmd = tf.nn.relu(mmd2(encode_flat, h_flat))

        ## impose regularization on latent space
        alpha_mix = tf.random_uniform(tf.shape(encode), 0, 1)
        alpha_mix = 0.5 - tf.abs(alpha_mix - 0.5)  # Make interval [0, 0.5]
        encode_mix = alpha_mix * encode + (1 - alpha_mix) * encode[::-1]
        decode_mix = decoder(encode_mix)

        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix) - alpha_mix))
        loss_ae_disc_mix = tf.reduce_mean(tf.square(disc(decode_mix)))

        alpha_noise = tf.random_uniform(tf.shape(encode), 0, 1)
        encode_mix_noise = alpha_noise * encode + (1 - alpha_noise) * h
        decode_mix_noise = decoder(encode_mix_noise)

        loss_disc_noise = tf.reduce_mean(
            tf.square(disc(decode_mix_noise) - alpha_noise))
        loss_ae_disc_noise = tf.reduce_mean(tf.square(disc(decode_mix_noise)))

        alpha_fake = 0.5  # I think here we can have another try.
        loss_disc_fake = tf.reduce_mean(tf.square(disc(decode) - alpha_fake))
        loss_ae_disc_fake = tf.reduce_mean(tf.square(disc(decode)))

        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')
        utils.HookReport.log_tensor(loss_disc_mix, 'loss_disc_mix')
        utils.HookReport.log_tensor(loss_ae_disc_mix, 'loss_ae_disc_mix')
        utils.HookReport.log_tensor(loss_disc_noise, 'loss_disc_noise')
        utils.HookReport.log_tensor(loss_ae_disc_noise, 'loss_ae_disc_noise')
        utils.HookReport.log_tensor(loss_disc_fake, 'loss_disc_fake')
        utils.HookReport.log_tensor(loss_ae_disc_fake, 'loss_ae_disc_fake')
        utils.HookReport.log_tensor(loss_mmd, 'loss_mmd')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')

        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc_mix +
                advnoise * loss_ae_disc_noise + advfake * loss_ae_disc_fake +
                wgt_mmd * loss_mmd,
                var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(
                FLAGS.lr).minimize(loss_disc_real + loss_disc_mix +
                                   loss_disc_noise + loss_disc_fake,
                                   var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss,
                global_step=tf.train.get_global_step(),
                var_list=xl_vars)

        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          train_xl,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 11
0
    def model(self, latent, depth, scales, advweight, advdepth, reg):
        x = tf.placeholder(tf.float32, [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(tf.float32, [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'acai_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'acai_dec')
            return v

        def disc(x):
            # [b, 32 ,32, 1] => [b, 4, 4, adv_c] => [b]
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent, 'acai_disc'), axis=[1, 2, 3])

        # [b, 4, 4, 16]
        encode = encoder(x)
        # [b, 32, 32, 1]
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        # [b, 1, 1, 1] ~ uniform dist(0~1)
        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # Make interval [0, 0.5]
        # a * [b, 4, 4, 16] + (1-a)*[reversed(b), 4, 4, 16]
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        # [b, 32, 32, 1] => [b]
        decode_mix = decoder(encode_mix)

        loss_disc = tf.reduce_mean(tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))

        # utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l, self.nclass, scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('acai_enc') + tf.global_variables('acai_dec')
        disc_vars = tf.global_variables('acai_disc')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x, h, l, encode, decode, ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(ops, interpolation=n_interpolations,
                                                    height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [], [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched), [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 12
0
    def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        def discriminator(h):
            with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
                h = tf.layers.flatten(h)
                for size in [int(s) for s in disc_layer_sizes.split(',')]:
                    h = tf.layers.dense(h, size, tf.nn.leaky_relu)
                return tf.layers.dense(h, 1)

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype)
        adversary_logit_latent = discriminator(encode)
        adversary_logit_prior = discriminator(prior_samples)
        adversary_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.zeros_like(adversary_logit_latent)))
        adversary_loss_prior = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_prior,
                labels=tf.ones_like(adversary_logit_prior)))
        autoencoder_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.ones_like(adversary_logit_latent)))

        def _accuracy(logits, label):
            labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool))
            correct = tf.equal(tf.greater(logits, 0), labels)
            return tf.reduce_mean(tf.to_float(correct))
        latent_accuracy = _accuracy(adversary_logit_latent, False)
        prior_accuracy = _accuracy(adversary_logit_prior, True)
        adversary_accuracy = (latent_accuracy + prior_accuracy)/2

        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent')
        utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior')
        utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent')
        utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(encode), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + autoencoder_loss_latents, var_list=ae_vars)
            train_disc = tf.train.AdamOptimizer(adversary_lr).minimize(
                adversary_loss_prior + adversary_loss_latents,
                var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x, h, l, encode, decode, ae,
                          tf.group(train_ae, train_disc, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(
            gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(
                tf.reshape(inter, batched), [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 13
0
    def model(self, latent, depth, scales, advweight, advdepth, reg):
        # scale: the num of downscaled(avgpool) convolution layer
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        # place holder for decoder(w' x h' x latent)
        # we need this because of the interpolated code

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            # similar shape to encoder
            # last input is scalar(reduce mean of hidden layer) in order to map alpha
            # FIXME: why dont sigmoid output
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # Make interval [0, 0.5]
        # FIXME: why dont alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 0.5)
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        # mix latent codes symmetrically
        # e.g. l1 (+) l3 / l2 (+) l2 / l3 (+) l1
        decode_mix = decoder(encode_mix)

        loss_disc = tf.reduce_mean(
            tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(ae + reg * (x - ae))))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 14
0
    def model(self, latent, depth, scales, beta):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        encode = encoder(x)
        with tf.variable_scope('ae_latent'):
            encode_shape = tf.shape(encode)
            encode_flat = tf.layers.flatten(encode)
            latent_dim = encode_flat.get_shape()[-1]
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)
        ae = 2*tf.nn.sigmoid(p_x_given_z_logits) - 1
        decode = 2*tf.nn.sigmoid(decoder(h)) - 1
        loss_kl = 0.5*tf.reduce_sum(
            -log_q_sigma_sq - 1 + tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl/tf.to_float(tf.shape(x)[0])
        x_bernoulli = 0.5*(x + 1)
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll/tf.to_float(tf.shape(x)[0])
        elbo = loss_ll - beta*loss_kl

        utils.HookReport.log_tensor(loss_kl, 'loss_kl')
        utils.HookReport.log_tensor(loss_ll, 'loss_ll')
        utils.HookReport.log_tensor(elbo, 'elbo')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(encode), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                -elbo, var_list=ae_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x, h, l, q_z_sample_reshaped, decode, ae,
                          tf.group(train_ae, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(
            gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 15
0
    def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes):

        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'aae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'aae_dec')

        def discriminator(h):
            """
            Construct 2 layer MLP: [b, 4, 4, 16]=>MLP(100, 100)=>[b, 1]
            :param h:
            :return:
            """
            with tf.variable_scope('aae_disc', reuse=tf.AUTO_REUSE):
                # [b, 4, 4, 16] => [b, 16*16]
                h = tf.layers.flatten(h)
                for size in [int(s) for s in disc_layer_sizes.split(',')]:
                    # Dense(16*16, 100)
                    # Dense(100, 100)
                    h = tf.layers.dense(h, size, tf.nn.leaky_relu)
                # [b, 100] => [b, 1]
                return tf.layers.dense(h, 1)

        # [b, 4, 4, 16]
        encode = encoder(x)
        # [b, 32, 32, 1]
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        # assume the prior dist of h is normal
        prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype)
        # D(h), justify the generate latent is close to prior_samples or not
        adversary_logit_latent = discriminator(encode)
        # D(p(h))
        adversary_logit_prior = discriminator(prior_samples)
        # loss on fake h
        adversary_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.zeros_like(adversary_logit_latent)))
        # loss on real prior h
        adversary_loss_prior = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_prior,
                labels=tf.ones_like(adversary_logit_prior)))
        # loss on auto-encoder to fool discriminator
        autoencoder_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.ones_like(adversary_logit_latent)))

        #
        def _accuracy(logits, label):
            labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool))
            correct = tf.equal(tf.greater(logits, 0), labels)
            return tf.reduce_mean(tf.to_float(correct))

        latent_accuracy = _accuracy(adversary_logit_latent, False)
        prior_accuracy = _accuracy(adversary_logit_prior, True)
        adversary_accuracy = (latent_accuracy + prior_accuracy) / 2

        # reconstruction loss
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        # discriminator should treat all h as fake
        utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent')
        # discriminator should treat all prior h as real
        utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior')
        # h generated by encoder should fool discrimator
        utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent')
        # average accuracy on justify enc(x) from p(h)
        utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode),
                                                   l,
                                                   self.nclass,
                                                   scope='classifier')
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_loss_on_h')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('aae_enc') + tf.global_variables(
            'aae_dec')
        disc_vars = tf.global_variables('aae_disc')
        xl_vars = tf.global_variables('classifier')
        with tf.control_dependencies(update_ops):
            # train auto-encoder and G/encoder
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + autoencoder_loss_latents, var_list=ae_vars)
            # train discriminator
            train_disc = tf.train.AdamOptimizer(adversary_lr).minimize(
                adversary_loss_prior + adversary_loss_latents,
                var_list=disc_vars)
            # train MLP classifier
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_disc, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 16
0
    def model(self, latent, depth, scales, beta, advweight, advdepth, reg):
        """
        Args:
            latent: number of channels output by the encoder.
            depth: depth (number of channels before applying the first convolution)
                for the encoder
            scales: input width/height to latent width/height ratio, on log base 2 scale
                (how many times the encoder should downsample)
            beta: scale hyperparam >= 1 for the KL term in the ELBO
                (value of 1 equivalent to vanilla VAE)
            advweight: how much the VAE should care about fooling the discriminator
                (value of 0 equivalent to training a VAE alone)
            advdepth: depth for the discriminator
            reg: gamma in the paper
        """
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            """ Outputs latent codes (not mean vectors) """
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            """ Outputs Bernoulli logits """
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        def disc(x):
            """ Outputs predicted mixing coefficient alpha """
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        # ENCODE
        encode = encoder(x)
        # get mean and var from the latent code
        with tf.variable_scope('ae_latent'):
            encode_shape = tf.shape(encode)
            encode_flat = tf.layers.flatten(encode)
            latent_dim = encode_flat.get_shape()[-1]
            q_mu = tf.layers.dense(encode_flat, latent_dim)
            log_q_sigma_sq = tf.layers.dense(encode_flat, latent_dim)
        # sample
        q_sigma = tf.sqrt(tf.exp(log_q_sigma_sq))
        q_z = tf.distributions.Normal(loc=q_mu, scale=q_sigma)
        q_z_sample = q_z.sample()
        q_z_sample_reshaped = tf.reshape(q_z_sample, encode_shape)

        # DECODE
        p_x_given_z_logits = decoder(q_z_sample_reshaped)
        vae = 2 * tf.nn.sigmoid(p_x_given_z_logits) - 1  # [0, 1] -> [-1, 1]
        decode = 2 * tf.nn.sigmoid(decoder(h)) - 1

        # COMPUTE VAE LOSS
        p_x_given_z = tf.distributions.Bernoulli(logits=p_x_given_z_logits)
        loss_kl = 0.5 * tf.reduce_sum(-log_q_sigma_sq - 1 +
                                      tf.exp(log_q_sigma_sq) + q_mu**2)
        loss_kl = loss_kl / tf.to_float(tf.shape(x)[0])
        x_bernoulli = 0.5 * (x + 1)  # [-1, 1] -> [0, 1]
        loss_ll = tf.reduce_sum(p_x_given_z.log_prob(x_bernoulli))
        loss_ll = loss_ll / tf.to_float(tf.shape(x)[0])
        elbo = loss_ll - beta * loss_kl
        loss_vae = -elbo
        utils.HookReport.log_tensor(loss_vae, 'neg elbo')

        # COMPUTE DISCRIMINATOR LOSS
        # interpolate in latent space with a randomly-chosen alpha
        alpha = tf.random_uniform([tf.shape(encode)[0], 1, 1, 1], 0, 1)
        alpha = 0.5 - tf.abs(alpha - 0.5)  # [0, 1] -> [0, 0.5]
        encode_mix = alpha * encode + (1 - alpha) * encode[::-1]
        decode_mix = 2 * tf.nn.sigmoid(decoder(encode_mix)) - 1

        loss_disc = tf.reduce_mean(
            tf.square(disc(decode_mix) - alpha[:, 0, 0, 0]))
        loss_disc_real = tf.reduce_mean(tf.square(disc(vae + reg * (x - vae))))
        # vae wants disc to predict 0
        loss_vae_disc = tf.reduce_mean(tf.square(disc(decode_mix)))
        utils.HookReport.log_tensor(loss_disc_real, 'loss_disc_real')

        # CLASSIFY (determine "usefulness" of latent codes)
        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_vae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_vae + advweight * loss_vae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc + loss_disc_real, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          vae,
                          tf.group(train_vae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 17
0
    def model(self, latent, depth, scales, z_log_size, beta, num_latents):
        tf.set_random_seed(123)
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def decode_fn(h):
            with tf.variable_scope('vqvae', reuse=tf.AUTO_REUSE):
                h2 = tf.expand_dims(tf.layers.flatten(h), axis=1)
                h2 = tf.layers.dense(h2,
                                     self.hparams.hidden_size * num_latents)
                d = bneck.discrete_bottleneck(h2)
                y = layers.decoder(tf.reshape(d['dense'], tf.shape(h)), scales,
                                   depth, self.colors, 'ae_decoder')
                return y, d

        self.hparams.hidden_size = ((self.height >> scales) *
                                    (self.width >> scales) * latent)
        self.hparams.z_size = z_log_size
        self.hparams.num_residuals = 1
        self.hparams.num_blocks = 1
        self.hparams.beta = beta
        self.hparams.ema = True
        bneck = DiscreteBottleneck(self.hparams)
        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = decode_fn(h)[0]
        ae, d = decode_fn(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(d['loss'], 'vqvae_loss')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(d['dense']), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops + [d['discrete']]):
            train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + xloss + d['loss'], tf.train.get_global_step())
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops
Exemplo n.º 18
0
    def model(self, latent, depth, scales, advweight, advdepth):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            v = layers.decoder(h, scales, depth, self.colors, 'ae_dec')
            return v

        def disc(x):
            return tf.reduce_mean(layers.encoder(x, scales, advdepth, latent,
                                                 'disc'),
                                  axis=[1, 2, 3])

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        loss_disc = tf.reduce_mean(
            tf.square(disc(x)) + tf.square(disc(ae) - 1))
        loss_ae_disc = tf.reduce_mean(tf.square(disc(ae)))

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(loss_disc, 'loss_disc')
        utils.HookReport.log_tensor(loss_ae_disc, 'loss_ae_disc')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + advweight * loss_ae_disc, var_list=ae_vars)
            train_d = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_disc, var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_d, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 19
0
    def model(self, latent, depth, scales, adversary_lr, disc_layer_sizes):
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def encoder(x):
            return layers.encoder(x, scales, depth, latent, 'ae_enc')

        def decoder(h):
            return layers.decoder(h, scales, depth, self.colors, 'ae_dec')

        def discriminator(h):
            with tf.variable_scope('disc', reuse=tf.AUTO_REUSE):
                h = tf.layers.flatten(h)
                for size in [int(s) for s in disc_layer_sizes.split(',')]:
                    h = tf.layers.dense(h, size, tf.nn.leaky_relu)
                return tf.layers.dense(h, 1)

        encode = encoder(x)
        decode = decoder(h)
        ae = decoder(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        prior_samples = tf.random_normal(tf.shape(encode), dtype=encode.dtype)
        adversary_logit_latent = discriminator(encode)
        adversary_logit_prior = discriminator(prior_samples)
        adversary_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.zeros_like(adversary_logit_latent)))
        adversary_loss_prior = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_prior,
                labels=tf.ones_like(adversary_logit_prior)))
        autoencoder_loss_latents = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=adversary_logit_latent,
                labels=tf.ones_like(adversary_logit_latent)))

        def _accuracy(logits, label):
            labels = tf.logical_and(label, tf.ones_like(logits, dtype=bool))
            correct = tf.equal(tf.greater(logits, 0), labels)
            return tf.reduce_mean(tf.to_float(correct))

        latent_accuracy = _accuracy(adversary_logit_latent, False)
        prior_accuracy = _accuracy(adversary_logit_prior, True)
        adversary_accuracy = (latent_accuracy + prior_accuracy) / 2

        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(adversary_loss_latents, 'loss_adv_latent')
        utils.HookReport.log_tensor(adversary_loss_prior, 'loss_adv_prior')
        utils.HookReport.log_tensor(autoencoder_loss_latents, 'loss_ae_latent')
        utils.HookReport.log_tensor(adversary_accuracy, 'adversary_accuracy')

        xops = classifiers.single_layer_classifier(tf.stop_gradient(encode), l,
                                                   self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        ae_vars = tf.global_variables('ae_')
        disc_vars = tf.global_variables('disc')
        xl_vars = tf.global_variables('single_layer_classifier')
        with tf.control_dependencies(update_ops):
            train_ae = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + autoencoder_loss_latents, var_list=ae_vars)
            train_disc = tf.train.AdamOptimizer(adversary_lr).minimize(
                adversary_loss_prior + adversary_loss_latents,
                var_list=disc_vars)
            train_xl = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                xloss, tf.train.get_global_step(), var_list=xl_vars)
        ops = train.AEOps(x,
                          h,
                          l,
                          encode,
                          decode,
                          ae,
                          tf.group(train_ae, train_disc, train_xl),
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops,
                interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(gen_images, [],
                                                  [tf.float32] * 4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        if FLAGS.dataset == 'lines32':
            batched = (n_interpolations, 32, n_images_per_interpolation, 32, 1)
            batched_interp = tf.transpose(tf.reshape(inter, batched),
                                          [0, 2, 1, 3, 4])
            mean_distance, mean_smoothness = tf.py_func(
                eval.line_eval, [batched_interp], [tf.float32, tf.float32])
            tf.summary.scalar('mean_distance', mean_distance)
            tf.summary.scalar('mean_smoothness', mean_smoothness)

        return ops
Exemplo n.º 20
0
    def model(self, latent, depth, scales, z_log_size, beta, num_latents):
        tf.set_random_seed(123)
        x = tf.placeholder(tf.float32,
                           [None, self.height, self.width, self.colors], 'x')
        l = tf.placeholder(tf.float32, [None, self.nclass], 'label')
        h = tf.placeholder(
            tf.float32,
            [None, self.height >> scales, self.width >> scales, latent], 'h')

        def decode_fn(h):
            with tf.variable_scope('vqvae', reuse=tf.AUTO_REUSE):
                h2 = tf.expand_dims(tf.layers.flatten(h), axis=1)
                h2 = tf.layers.dense(h2, self.hparams.hidden_size * num_latents)
                d = bneck.discrete_bottleneck(h2)
                y = layers.decoder(tf.reshape(d['dense'], tf.shape(h)),
                                   scales, depth, self.colors, 'ae_decoder')
                return y, d

        self.hparams.hidden_size = (
                    (self.height >> scales) * (self.width >> scales) * latent)
        self.hparams.z_size = z_log_size
        self.hparams.num_residuals = 1
        self.hparams.num_blocks = 1
        self.hparams.beta = beta
        self.hparams.ema = True
        bneck = DiscreteBottleneck(self.hparams)
        encode = layers.encoder(x, scales, depth, latent, 'ae_encoder')
        decode = decode_fn(h)[0]
        ae, d = decode_fn(encode)
        loss_ae = tf.losses.mean_squared_error(x, ae)

        utils.HookReport.log_tensor(tf.sqrt(loss_ae) * 127.5, 'rmse')
        utils.HookReport.log_tensor(loss_ae, 'loss_ae')
        utils.HookReport.log_tensor(d['loss'], 'vqvae_loss')

        xops = classifiers.single_layer_classifier(
            tf.stop_gradient(d['dense']), l, self.nclass)
        xloss = tf.reduce_mean(xops.loss)
        utils.HookReport.log_tensor(xloss, 'classify_latent')

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops + [d['discrete']]):
            train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(
                loss_ae + xloss + d['loss'], tf.train.get_global_step())
        ops = train.AEOps(x, h, l, encode, decode, ae, train_op,
                          classify_latent=xops.output)

        n_interpolations = 16
        n_images_per_interpolation = 16

        def gen_images():
            return self.make_sample_grid_and_save(
                ops, interpolation=n_interpolations,
                height=n_images_per_interpolation)

        recon, inter, slerp, samples = tf.py_func(
            gen_images, [], [tf.float32]*4)
        tf.summary.image('reconstruction', tf.expand_dims(recon, 0))
        tf.summary.image('interpolation', tf.expand_dims(inter, 0))
        tf.summary.image('slerp', tf.expand_dims(slerp, 0))
        tf.summary.image('samples', tf.expand_dims(samples, 0))

        return ops