Exemple #1
0
xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
xb__loss_wgan = -tf.reduce_mean(xb__logit_wgan)
g_loss = xb__loss_wgan + xb__loss_att * 10.0 + xa__loss_rec * 100.0

# optim
d_var = tl.trainable_variables('D')
d_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss, var_list=d_var)

g_var = tl.trainable_variables('G')
g_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_loss, var_list=g_var)

# summary
d_summary = tl.summary({
    x_gp: 'x_gp',
    x_wd: 'x_wd',
    xa_loss_att: 'xa_loss_att',
}, scope='D')

lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

g_summary = tl.summary({
    xb__loss_wgan: 'xb__loss_wgan',
    xb__loss_att: 'xb__loss_att',
    xa__loss_rec: 'xa__loss_rec',
}, scope='G')

d_summary = tf.summary.merge([d_summary, lr_summary])

# sample
x_sample = Gdec(Genc(xa_sample, is_training=False), _b_sample, is_training=False)
Exemple #2
0
for var in tf.trainable_variables():
    if var.name.startswith('Enc'):
        enc_vars.append(var)
    elif var.name.startswith('Dec'):
        dec_vars.append(var)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
enc_gvs = optimizer.compute_gradients(enc_loss, enc_vars)
dec_gvs = optimizer.compute_gradients(dec_loss, dec_vars)
train_op = optimizer.apply_gradients(enc_gvs + dec_gvs,
                                     global_step=global_step)

# summary
summary = tl.summary({
    img_rec_loss: 'img_rec_loss',
    zn_rec_loss: 'zn_rec_loss',
    zh_rec_loss: 'zh_rec_loss',
    nll_enc_loss: 'nll_enc_loss',
    nll_dec_loss: 'nll_dec_loss'
})

# sample
# TODO: compute running averages for different input batches respectively
z_intp_sample, img_rec_sample = enc_dec(img, is_training=True)
img_sample = Dec(normal_dist.sample([100]), is_training=True)
fid_sample = Dec(normal_dist.sample([1000]), is_training=True)
if dataset_name == 'mnist':
    fid_sample = tf.image.grayscale_to_rgb(fid_sample)

z_intp_split, img_split = tf.split(z_intp_sample, 2), tf.split(img, 2)
img_intp_sample = [
    Dec((1 - i) * z_intp_split[0] + i * z_intp_split[1], is_training=True)
Exemple #3
0
''' graph '''
# inputs
x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
x = x_255 / 127.5 - 1
att = tf.placeholder(tf.int64, shape=[None, att_dim])

# classify
logits = classifier(x, reuse=False)

# loss
reg_loss = tf.losses.get_regularization_loss()
loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss
acc = mean_accuracy_multi_binary_label_with_logits(att, logits)

# summary
summary = tl.summary({loss: 'loss', acc: 'acc'})

lr_ = tf.placeholder(tf.float32, shape=[])

# optim
#with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE):
step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss)

# test
test_logits = classifier(x, training=False)
test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits)
mean_acc = tf.placeholder(tf.float32, shape=())
test_summary = tl.summary({mean_acc: 'test_acc'})
""" train """
''' init '''
# session
Exemple #4
0
z_sample = tf.placeholder(tf.float32, [None, z_dim])

# encode & decode
z_mu, z_log_sigma_sq, img_rec = enc_dec(img)

# loss
rec_loss = tf.losses.mean_squared_error(img, img_rec)
kld_loss = -tf.reduce_mean(
    0.5 * (1 + z_log_sigma_sq - z_mu**2 - tf.exp(z_log_sigma_sq)))
loss = rec_loss + kld_loss * beta

# otpim
step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(loss)

# summary
summary = tl.summary({rec_loss: 'rec_loss', kld_loss: 'kld_loss'})

# sample
_, _, img_rec_sample = enc_dec(img, is_training=False)
img_sample = Dec(z_sample, is_training=False)

# ==============================================================================
# =                                    train                                   =
# ==============================================================================

# session
sess = tl.session()

# saver
saver = tf.train.Saver(max_to_keep=1)
# optim
d_var = tl.trainable_variables('D')
d_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss, var_list=d_var)

g_var = tl.trainable_variables('G')
g_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_loss, var_list=g_var)

g_dec_var = tl.trainable_variables('Gdec')
g_dec_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_b_att_loss,
                                                            var_list=g_dec_var)

# summary
d_summary = tl.summary(
    {
        d_loss_gan: 'd_loss_gan',
        gp: 'gp',
        xa_loss_att: 'xa_loss_att',
    },
    scope='D')

lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

g_summary = tl.summary(
    {
        xb__loss_gan: 'xb__loss_gan',
        xb__loss_att: 'xb__loss_att',
        xa__loss_rec: 'xa__loss_rec',
    },
    scope='G')

g_dec_summary = tl.summary({
    def build_graph(self):
        package = __import__("components." + self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        GstuClass = getattr(package, 'Gstu')
        DClass = getattr(package, 'D')

        self.Genc = partial(GencClass,
                            dim=self.config["GConvDim"],
                            n_layers=self.config["GLayerNum"],
                            multi_inputs=1)

        self.Gdec = partial(GdecClass,
                            dim=self.config["DConvDim"],
                            n_layers=self.config["DLayerNum"],
                            shortcut_layers=self.config["skipNum"],
                            inject_layers=self.config["injectLayers"],
                            one_more_conv=self.config["oneMoreConv"])

        self.Gstu = partial(GstuClass,
                            dim=self.config["stuDim"],
                            n_layers=self.config["skipNum"],
                            inject_layers=self.config["skipNum"],
                            kernel_size=3,
                            pass_state='stu')

        n_att = len(self.config["selectedAttrs"])
        self.D = partial(network_models.D,
                         n_att=n_att,
                         dim=self.config["DConvDim"],
                         fc_dim=self.config["DFcDim"],
                         n_layers=self.config["DLayerNum"])

        # inputs

        xa = self.data_loader.batch_op[0]
        a = self.data_loader.batch_op[1]
        b = tf.random_shuffle(a)
        _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"]
        _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"]

        self.xa_sample = tf.compat.v1.placeholder(
            tf.float32,
            shape=[None, self.config["imsize"], self.config["imsize"], 3])
        self._b_sample = tf.compat.v1.placeholder(tf.float32,
                                                  shape=[None, n_att])
        self.raw_b_sample = tf.compat.v1.placeholder(tf.float32,
                                                     shape=[None, n_att])
        self.lr = tf.compat.v1.placeholder(dtype=tf.float32, shape=[])

        # generate
        z = self.Genc(xa)
        zb = self.Gstu(z, _b - _a)
        xb_ = self.Gdec(zb, _b - _a)
        with tf.control_dependencies([xb_]):
            za = self.Gstu(z, _a - _a)
            xa_ = self.Gdec(za, _a - _a)

        # discriminate
        xa_logit_gan, xa_logit_att = self.D(xa)
        xb__logit_gan, xb__logit_att = self.D(xb_)

        wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
        d_loss_gan = -wd
        gp = network_models.gradient_penalty(self.D, xa, xb_)
        xa_loss_att = tf.compat.v1.losses.sigmoid_cross_entropy(
            a, xa_logit_att)
        d_loss = d_loss_gan + gp * 10.0 + xa_loss_att

        xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
        xb__loss_att = tf.compat.v1.losses.sigmoid_cross_entropy(
            b, xb__logit_att)
        xa__loss_rec = tf.compat.v1.losses.absolute_difference(xa, xa_)
        g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[
            "recWeight"]

        d_var = tl.trainable_variables('D')
        self.d_step = tf.compat.v1.train.AdamOptimizer(
            self.lr, beta1=0.5).minimize(d_loss, var_list=d_var)
        g_var = tl.trainable_variables('G')
        self.g_step = tf.compat.v1.train.AdamOptimizer(
            self.lr, beta1=0.5).minimize(g_loss, var_list=g_var)

        d_summary = tl.summary(
            {
                d_loss_gan: 'd_loss_gan',
                gp: 'gp',
                xa_loss_att: 'xa_loss_att',
            },
            scope='D')

        lr_summary = tl.summary({self.lr: 'lr'}, scope='Learning_Rate')

        self.g_summary = tl.summary(
            {
                xb__loss_gan: 'xb__loss_gan',
                xb__loss_att: 'xb__loss_att',
                xa__loss_rec: 'xa__loss_rec',
            },
            scope='G')

        self.d_summary = tf.summary.merge([d_summary, lr_summary])

        # sample
        test_label = self._b_sample - self.raw_b_sample
        self.x_sample = self.Gdec(self.Gstu(self.Genc(self.xa_sample,
                                                      is_training=False),
                                            test_label,
                                            is_training=False),
                                  test_label,
                                  is_training=False)
Exemple #7
0
    def train(self):

        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        n_att = len(self.config["selectedAttrs"])

        if self.config["threads"] >= 0:
            cpu_config = tf.ConfigProto(
                intra_op_parallelism_threads=self.config["threads"] // 2,
                inter_op_parallelism_threads=self.config["threads"] // 2,
                device_count={'CPU': self.config["threads"]})
            cpu_config.gpu_options.allow_growth = True
            sess = tf.Session(config=cpu_config)
        else:
            sess = tl.session()

        data_loader = Celeba(self.config["dataset_path"],
                             self.config["selectedAttrs"],
                             self.config["imsize"],
                             self.config["batchSize"],
                             part='train',
                             sess=sess,
                             crop=(self.config["imCropSize"] > 0))

        val_loader = Celeba(self.config["dataset_path"],
                            self.config["selectedAttrs"],
                            self.config["imsize"],
                            self.config["sampleNum"],
                            part='val',
                            shuffle=False,
                            sess=sess,
                            crop=(self.config["imCropSize"] > 0))

        package = __import__("components." + self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        DClass = getattr(package, 'D')
        GP = getattr(package, "gradient_penalty")

        package = __import__("components.STU." + self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["skipNum"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state='stu')

        D = partial(DClass,
                    n_att=n_att,
                    dim=self.config["DConvDim"],
                    fc_dim=self.config["DFcDim"],
                    n_layers=self.config["DLayerNum"])

        # inputs

        xa = data_loader.batch_op[0]
        a = data_loader.batch_op[1]
        b = tf.random_shuffle(a)
        _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"]
        _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"]

        xa_sample = tf.placeholder(
            tf.float32,
            shape=[None, self.config["imsize"], self.config["imsize"], 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        lr = tf.placeholder(tf.float32, shape=[])

        # generate
        z = Genc(xa)
        zb = Gstu(z, _b - _a)
        xb_ = Gdec(zb, _b - _a)
        with tf.control_dependencies([xb_]):
            za = Gstu(z, _a - _a)
            xa_ = Gdec(za, _a - _a)

        # discriminate
        xa_logit_gan, xa_logit_att = D(xa)
        xb__logit_gan, xb__logit_att = D(xb_)

        wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
        d_loss_gan = -wd
        gp = GP(D, xa, xb_)
        xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
        d_loss = d_loss_gan + gp * 10.0 + xa_loss_att

        xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
        xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
        xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
        g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[
            "recWeight"]

        d_var = tl.trainable_variables('D')
        d_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var)
        g_var = tl.trainable_variables('G')
        g_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var)

        d_summary = tl.summary(
            {
                d_loss_gan: 'd_loss_gan',
                gp: 'gp',
                xa_loss_att: 'xa_loss_att',
            },
            scope='D')

        lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

        g_summary = tl.summary(
            {
                xb__loss_gan: 'xb__loss_gan',
                xb__loss_att: 'xb__loss_att',
                xa__loss_rec: 'xa__loss_rec',
            },
            scope='G')

        d_summary = tf.summary.merge([d_summary, lr_summary])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=self.config["max2Keep"])

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               sess.graph)

        # initialization
        if self.config["mode"] == "finetune":
            print("Continute train the model")
            tl.load_checkpoint(ckpt_dir, sess)
            print("Load previous model successfully!")
        else:
            print('Initializing all parameters...')
            sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(data_loader) // (self.config["batchSize"] *
                                                (n_d + 1))
            max_it = epoch * it_per_epoch

            print("Start to train the graph!")
            for it in range(sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = sess.run([d_summary, d_step],
                                                    feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = sess.run([g_summary, g_step],
                                                feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                sess.run(x_sample,
                                         feed_dict={
                                             xa_sample: xa_sample_ipt,
                                             _b_sample: _b_sample_ipt,
                                             raw_b_sample: raw_b_sample_ipt
                                         }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            sess.close()
D_learning_rate = tf.train.exponential_decay(lr, D_global_step, decay_steps, 
                                             decay_rate, staircase=True)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

G_opt = tf.train.AdamOptimizer(learning_rate=G_learning_rate, beta1=0.5)
D_opt = tf.train.AdamOptimizer(learning_rate=D_learning_rate, beta1=0.5)
with tf.control_dependencies(update_ops):
    G_step = G_opt.minimize(G_loss, var_list=g_vars, global_step=G_global_step)
    D_step = D_opt.minimize(D_loss, var_list=d_vars, global_step=D_global_step)

# summary
G_summary = tl.summary({rec_loss_a: 'rec_loss_a',
                      rec_loss_b: 'rec_loss_b',
                      rec_loss: 'rec_loss',
                      kld_loss_a: 'kld_loss_a',
                      kld_loss_b: 'kld_loss_b',
                      kld_loss: 'kld_loss',
                      adv_loss: 'adv_loss',
                      G_loss: 'G_loss'})
D_summary = tl.summary({wd_loss: 'wd_loss',
                      gp_loss: 'gp_loss',
                      D_loss: 'D_loss'})


# ==============================================================================
# =                                    train                                   =
# ==============================================================================

# number of points to plot during training    
n_s = np.min([len(source_train_data),10000]) 
n_t = np.min([len(target_train_data),10000])    
Exemple #9
0
# otpim
enc_vars = []
dec_vars = []
for var in tf.trainable_variables():
    if var.name.startswith('Enc'):
        enc_vars.append(var)
    elif var.name.startswith('Dec'):
        dec_vars.append(var)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)
enc_gvs = optimizer.compute_gradients(enc_loss, enc_vars)
dec_gvs = optimizer.compute_gradients(dec_loss, dec_vars)
train_op = optimizer.apply_gradients(enc_gvs + dec_gvs, global_step=global_step)

# summary
summary = tl.summary({img_rec_loss: 'img_rec_loss',
                      zn_rec_loss: 'zn_rec_loss', zh_rec_loss: 'zh_rec_loss',
                      vrec_loss: 'vrec_loss', vkld_loss: 'vkld_loss', nll_enc_loss: 'nll_enc_loss'})

# sample
# TODO: compute running averages for different input batches respectively
z_intp_sample, _, img_rec_sample = enc_dec(img, is_training=True)
img_sample = Dec(normal_dist.sample([100]), is_training=True)
fid_sample = Dec(normal_dist.sample([1000]), is_training=True)
if dataset_name == 'mnist':
    fid_sample = tf.image.grayscale_to_rgb(fid_sample)

z_intp_split, img_split = tf.split(z_intp_sample, 2), tf.split(img, 2)
img_intp_sample = [Dec((1 - i) * z_intp_split[0] + i * z_intp_split[1], is_training=True) for i in np.linspace(0, 1, 9)]
img_intp_sample = [img_split[0]] + img_intp_sample + [img_split[1]]
img_intp_sample = tf.concat(img_intp_sample, 2)
Exemple #10
0
def train_on_fake(dataset, c_dim, result_dir, gpu_id, use_real=0, epoch_=200):
    """ param """
    batch_size = 64
    batch_size_fake = batch_size
    lr = 0.0002

    ''' data '''
    if use_real == 1:
        print('======Using real data======')
        batch_size_real = batch_size // 2
        batch_size_fake = batch_size - batch_size_real
        train_tfrecord_path_real = './tfrecords/celeba_tfrecord_train'
        train_data_pool_real = tl.TfrecordData(train_tfrecord_path_real, batch_size_real, shuffle=True)
    train_tfrecord_path_fake = os.path.join(result_dir, 'synthetic_tfrecord')
    train_data_pool_fake = tl.TfrecordData(train_tfrecord_path_fake, batch_size_fake, shuffle=True)
    if dataset == 'CelebA':
        test_tfrecord_path = './tfrecords/celeba_tfrecord_test'
    elif dataset == 'RaFD':
        test_tfrecord_path = './tfrecords/rafd_test'
    test_data_pool = tl.TfrecordData(test_tfrecord_path, 120)
    att_dim = c_dim

    """ graphs """
    with tf.device('/gpu:{}'.format(gpu_id)):

        ''' models '''
        classifier = models.classifier

        ''' graph '''
        # inputs
        x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        x = x_255 / 127.5 - 1
        if dataset == 'CelebA':
            att = tf.placeholder(tf.int64, shape=[None, att_dim])
        elif dataset == 'RaFD':
            att = tf.placeholder(tf.float32, shape=[None, att_dim])

        # classify
        logits = classifier(x, att_dim=att_dim, reuse=False)

        # loss
        reg_loss = tf.losses.get_regularization_loss()
        if dataset == 'CelebA':
            loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_multi_binary_label_with_logits(att, logits)
        elif dataset == 'RaFD':
            loss = tf.losses.softmax_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_one_hot_label_with_logits(att, logits)

        lr_ = tf.placeholder(tf.float32, shape=[])

        # optim
        #with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE):
        step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss)

        # test
        test_logits = classifier(x, att_dim=att_dim, training=False)
        if dataset == 'CelebA':
            test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits)
        elif dataset == 'RaFD':
            test_acc = mean_accuracy_one_hot_label_with_logits(att, test_logits)
        mean_acc = tf.placeholder(tf.float32, shape=())

    # summary
    summary = tl.summary({loss: 'loss', acc: 'acc'})
    test_summary = tl.summary({mean_acc: 'test_acc'})

    """ train """
    ''' init '''
    # session
    sess = tf.Session()
    # iteration counter
    it_cnt, update_cnt = tl.counter()
    # saver
    saver = tf.train.Saver(max_to_keep=None)
    # summary writer
    sum_dir = os.path.join(result_dir, 'summaries_train_on_fake')
    if use_real == 1:
        sum_dir += '_real'
    summary_writer = tf.summary.FileWriter(sum_dir, sess.graph)

    ''' initialization '''
    ckpt_dir = os.path.join(result_dir, 'checkpoints_train_on_fake')
    if use_real == 1:
        ckpt_dir += '_real'
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir + '/')
    if not tl.load_checkpoint(ckpt_dir, sess):
        sess.run(tf.global_variables_initializer())

    ''' train '''
    try:
        batch_epoch = len(train_data_pool_fake) // batch_size
        max_it = epoch_ * batch_epoch
        for it in range(sess.run(it_cnt), max_it):
            bth = it//batch_epoch - 8
            lr__ = lr*(1-max(bth, 0)/epoch_)**0.75
            if it % batch_epoch == 0:
                print('======learning rate:', lr__, '======')
            sess.run(update_cnt)

            # which epoch
            epoch = it // batch_epoch
            it_epoch = it % batch_epoch + 1

            x_255_ipt, att_ipt = train_data_pool_fake.batch(['img', 'attr'])
            if dataset == 'RaFD':
                att_ipt = ToOnehot(att_ipt, att_dim)
            if use_real == 1:
                x_255_ipt_real, att_ipt_real = train_data_pool_real.batch(['img', 'class'])
                x_255_ipt = np.concatenate([x_255_ipt, x_255_ipt_real])
                att_ipt = np.concatenate([att_ipt, att_ipt_real])
            summary_opt, _ = sess.run([summary, step], feed_dict={x_255: x_255_ipt, att: att_ipt, lr_:lr__})
            summary_writer.add_summary(summary_opt, it)

            # display
            if (it + 1) % batch_epoch == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (epoch, it_epoch, batch_epoch))

            # save
            if (it + 1) % (batch_epoch * 50) == 0:
                save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch))
                print('Model saved in file: % s' % save_path)

            # sample
            if it % 100 == 0:
                test_it = 100 if dataset == 'CelebA' else 7
                test_acc_opt_list = []
                for i in range(test_it):
                    key = 'class' if dataset == 'CelebA' else 'attr'
                    x_255_ipt, att_ipt = test_data_pool.batch(['img', key])
                    if dataset == 'RaFD':
                        att_ipt = ToOnehot(att_ipt, att_dim)

                    test_acc_opt = sess.run(test_acc, feed_dict={x_255: x_255_ipt, att: att_ipt})
                    test_acc_opt_list.append(test_acc_opt)
                test_summary_opt = sess.run(test_summary, feed_dict={mean_acc: np.mean(test_acc_opt_list)})
                summary_writer.add_summary(test_summary_opt, it)

    except Exception:
        traceback.print_exc()
    finally:
        print(" [*] Close main session!")
        sess.close()
Exemple #11
0
    g_f_tree_losses[i] * lambdas[i] * layer_mask[i]
    for i in range(len(lambdas))
])
g_loss = g_f_loss + g_tree_loss

# optims
d_step = optim(learning_rate=lr_d).minimize(
    d_loss, var_list=tl.trainable_variables(includes='D'))
g_step = optim(learning_rate=lr_g).minimize(
    g_loss, var_list=tl.trainable_variables(includes='G'))

# summaries
d_summary = tl.summary(
    {
        d_r_loss: 'd_r_loss',
        d_f_loss: 'd_f_loss',
        d_r_loss + d_f_loss: 'd_loss',
        gp: 'gp'
    },
    scope='D')
tmp = {l: 'd_f_tree_loss_%d' % i for i, l in enumerate(d_f_tree_losses)}
if att != '':
    tmp.update({d_r_tree_losses[0]: 'd_r_tree_loss_0'})
d_tree_summary = tl.summary(tmp, scope='D_Tree')
d_summary = tf.summary.merge([d_summary, d_tree_summary])

g_summary = tl.summary({g_f_loss: 'g_f_loss'}, scope='G')
g_tree_summary = tl.summary(
    {l: 'g_f_tree_loss_%d' % i
     for i, l in enumerate(g_f_tree_losses)},
    scope='G_Tree')
g_summary = tf.summary.merge([g_summary, g_tree_summary])
Exemple #12
0
        d_step = tf.group(*(tf.assign(var, tf.clip_by_value(var, -0.01, 0.01))
                            for var in tl.trainable_variables(includes='D')))

# g step
g_step = optim(learning_rate=lr_g).minimize(
    g_loss, var_list=tl.trainable_variables(includes='G'))

if vgan:
    with tf.control_dependencies([g_step]):
        g_step = tf.assign(old_fake, fake)

# summaries
d_summary = tl.summary(
    {
        d_r_loss: 'd_r_loss',
        d_f_loss: 'd_f_loss',
        d_r_loss + d_f_loss: 'd_loss',
        gp: 'gp'
    },
    scope='D')
g_summary = tl.summary({
    g_f_loss: 'g_f_loss',
    vgan_loss: 'vgan_loss'
},
                       scope='G')

# sample
z_sample = tf.placeholder(tf.float32, [None, z_dim])
f_sample = G(z_sample, is_training=False)

# ==============================================================================
# =                                    train                                   =
Exemple #13
0
d_f_loss = tf.reduce_mean(conjugate_fn(activation_fn(f_output)))
d_loss = d_r_loss + d_f_loss
if tricky_G:
    g_loss = -tf.reduce_mean(activation_fn(f_output))
else:
    g_loss = -d_f_loss

# otpims
d_var = tl.trainable_variables('D')
g_var = tl.trainable_variables('G')
d_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(d_loss, var_list=d_var)
g_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(g_loss, var_list=g_var)

# summaries
d_summary = tl.summary({d_r_loss: 'd_r_loss',
                        d_f_loss: 'd_f_loss',
                        -d_loss: '%s_diverngence' % divergence}, scope='D')
g_summary = tl.summary({g_loss: 'g_loss'}, scope='G')

# sample
f_sample = G(z, is_training=False)


# ==============================================================================
# =                                    train                                   =
# ==============================================================================

# session
sess = tl.session()

# saver
Exemple #14
0
# Optimize discriminator loss
discriminator_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
    discriminator_loss, var_list=discriminator_var)

generator_var = tl.trainable_variables('G')
# Using adam optimizer to optimize loss
# Optimize generator loss
generator_step = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
    generator_loss, var_list=generator_var)

# Store all losses for each network

discriminator_summary = tl.summary(
    {
        discriminator_loss_gan: 'discriminator_loss_gan',
        gradient_loss: 'gradient_loss',
        image_attr_loss: 'image_attr_loss',
    },
    scope='D')

lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

generator_summary = tl.summary(
    {
        discriminator_label_loss: 'discriminator_label_loss',
        label_decoder_loss: 'label_decoder_loss',
        image_decoder_loss: 'image_decode_loss',
    },
    scope='G')

discriminator_summary = tf.summary.merge([discriminator_summary, lr_summary])