예제 #1
0
def D_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, _ = G(xa, b_ - a_)

    # discriminate
    xa_logit_gan, xa_logit_att = D(xa)
    xb_logit_gan, xb_logit_att = D(xb)

    # discriminator losses
    xa_loss_gan, xb_loss_gan = d_loss_fn(xa_logit_gan, xb_logit_gan)
    gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb,
                                 args.gradient_penalty_mode,
                                 args.gradient_penalty_sample_mode)
    xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
    reg_loss = tf.reduce_sum(D.func.reg_losses)

    loss = (xa_loss_gan + xb_loss_gan + gp * args.d_gradient_penalty_weight +
            xa_loss_att * args.d_attribute_loss_weight + reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=D.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = [
            tl.summary_v2(
                {
                    'loss_gan': xa_loss_gan + xb_loss_gan,
                    'gp': gp,
                    'xa_loss_att': xa_loss_att,
                    'reg_loss': reg_loss
                },
                step=step_cnt,
                name='D'),
            tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate')
        ]

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
예제 #2
0
    xb__loss_att: 'xb__loss_att',
    xa__loss_rec: 'xa__loss_rec',
}, scope='G')

d_summary = tf.summary.merge([d_summary, lr_summary])

# sample
x_sample = Gdec(Genc(xa_sample, is_training=False), _b_sample, is_training=False)


# ==============================================================================
# =                                    train                                   =
# ==============================================================================

# iteration counter
it_cnt, update_cnt = tl.counter()

# saver
saver = tf.train.Saver(max_to_keep=1)

# summary writer
summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph)

# initialization
ckpt_dir = './output/%s/checkpoints' % experiment_name
pylib.mkdir(ckpt_dir)
try:
    tl.load_checkpoint(ckpt_dir, sess)
except:
    sess.run(tf.global_variables_initializer())
예제 #3
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    xb, _, ms, ms_multi = G(xa, b_ - a_)

    # discriminate
    xb_logit_gan, xb_logit_att = D(xb)

    # generator losses
    xb_loss_gan = g_loss_fn(xb_logit_gan)
    xb_loss_att = tf.losses.sigmoid_cross_entropy(b, xb_logit_att)
    spasity_loss = tf.reduce_sum([
        tf.reduce_mean(m) * w
        for m, w in zip(ms, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    ])
    full_overlap_mask_pair_loss, non_overlap_mask_pair_loss = module.overlap_loss_fn(
        ms_multi, args.att_names)
    reg_loss = tf.reduce_sum(G.func.reg_losses)

    loss = (
        xb_loss_gan + xb_loss_att * args.g_attribute_loss_weight +
        spasity_loss * args.g_spasity_loss_weight +
        full_overlap_mask_pair_loss * args.g_full_overlap_mask_pair_loss_weight
        +
        non_overlap_mask_pair_loss * args.g_non_overlap_mask_pair_loss_weight +
        reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(
        loss, global_step=step_cnt, var_list=G.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2(
            {
                'xb_loss_gan': xb_loss_gan,
                'xb_loss_att': xb_loss_att,
                'spasity_loss': spasity_loss,
                'full_overlap_mask_pair_loss': full_overlap_mask_pair_loss,
                'non_overlap_mask_pair_loss': non_overlap_mask_pair_loss,
                'reg_loss': reg_loss
            },
            step=step_cnt,
            name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(G.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' %
          (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
예제 #4
0
def G_train_graph():
    # ======================================
    # =                 graph              =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])

    xa, a = train_iter.get_next()
    b = tf.random_shuffle(a)
    a_ = a * 2 - 1
    b_ = b * 2 - 1

    # generate
    z = Genc(xa)
    xa_ = Gdec(z, a_)
    xb_ = Gdec(z, b_)

    # discriminate
    xb__logit_gan, xb__logit_att = D(xb_)

    # generator losses
    xb__loss_gan = g_loss_fn(xb__logit_gan)
    xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
    xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
    reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses)

    loss = (xb__loss_gan +
            xb__loss_att * args.g_attribute_loss_weight +
            xa__loss_rec * args.g_reconstruction_loss_weight +
            reg_loss)

    # optim
    step_cnt, _ = tl.counter()
    step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables)

    # summary
    with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
            tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
        summary = tl.summary_v2({
            'xb__loss_gan': xb__loss_gan,
            'xb__loss_att': xb__loss_att,
            'xa__loss_rec': xa__loss_rec,
            'reg_loss': reg_loss
        }, step=step_cnt, name='G')

    # ======================================
    # =           generator size           =
    # ======================================

    n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables)
    print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run
예제 #5
0
    def train(self):

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               self.sess.graph)

        # initialization
        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        try:
            assert clear == False
            tl.load_checkpoint(ckpt_dir, self.sess)
        except:
            print('NOTE: Initializing all parameters...')
            self.sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = self.val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(self.data_loader) // (self.config["batchSize"] *
                                                     (n_d + 1))
            max_it = epoch * it_per_epoch

            for it in range(self.sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    self.sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = self.sess.run(
                            [self.d_summary, self.d_step],
                            feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = self.sess.run(
                        [self.g_summary, self.g_step],
                        feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                            (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                self.sess.run(self.x_sample,
                                              feed_dict={
                                                  self.xa_sample:
                                                  xa_sample_ipt,
                                                  self._b_sample:
                                                  _b_sample_ipt,
                                                  self.raw_b_sample:
                                                  raw_b_sample_ipt
                                              }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            self.sess.close()
예제 #6
0
    def train(self):

        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        n_att = len(self.config["selectedAttrs"])

        if self.config["threads"] >= 0:
            cpu_config = tf.ConfigProto(
                intra_op_parallelism_threads=self.config["threads"] // 2,
                inter_op_parallelism_threads=self.config["threads"] // 2,
                device_count={'CPU': self.config["threads"]})
            cpu_config.gpu_options.allow_growth = True
            sess = tf.Session(config=cpu_config)
        else:
            sess = tl.session()

        data_loader = Celeba(self.config["dataset_path"],
                             self.config["selectedAttrs"],
                             self.config["imsize"],
                             self.config["batchSize"],
                             part='train',
                             sess=sess,
                             crop=(self.config["imCropSize"] > 0))

        val_loader = Celeba(self.config["dataset_path"],
                            self.config["selectedAttrs"],
                            self.config["imsize"],
                            self.config["sampleNum"],
                            part='val',
                            shuffle=False,
                            sess=sess,
                            crop=(self.config["imCropSize"] > 0))

        package = __import__("components." + self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        DClass = getattr(package, 'D')
        GP = getattr(package, "gradient_penalty")

        package = __import__("components.STU." + self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["skipNum"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state='stu')

        D = partial(DClass,
                    n_att=n_att,
                    dim=self.config["DConvDim"],
                    fc_dim=self.config["DFcDim"],
                    n_layers=self.config["DLayerNum"])

        # inputs

        xa = data_loader.batch_op[0]
        a = data_loader.batch_op[1]
        b = tf.random_shuffle(a)
        _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"]
        _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"]

        xa_sample = tf.placeholder(
            tf.float32,
            shape=[None, self.config["imsize"], self.config["imsize"], 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        lr = tf.placeholder(tf.float32, shape=[])

        # generate
        z = Genc(xa)
        zb = Gstu(z, _b - _a)
        xb_ = Gdec(zb, _b - _a)
        with tf.control_dependencies([xb_]):
            za = Gstu(z, _a - _a)
            xa_ = Gdec(za, _a - _a)

        # discriminate
        xa_logit_gan, xa_logit_att = D(xa)
        xb__logit_gan, xb__logit_att = D(xb_)

        wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
        d_loss_gan = -wd
        gp = GP(D, xa, xb_)
        xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
        d_loss = d_loss_gan + gp * 10.0 + xa_loss_att

        xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
        xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
        xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
        g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[
            "recWeight"]

        d_var = tl.trainable_variables('D')
        d_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var)
        g_var = tl.trainable_variables('G')
        g_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var)

        d_summary = tl.summary(
            {
                d_loss_gan: 'd_loss_gan',
                gp: 'gp',
                xa_loss_att: 'xa_loss_att',
            },
            scope='D')

        lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

        g_summary = tl.summary(
            {
                xb__loss_gan: 'xb__loss_gan',
                xb__loss_att: 'xb__loss_att',
                xa__loss_rec: 'xa__loss_rec',
            },
            scope='G')

        d_summary = tf.summary.merge([d_summary, lr_summary])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=self.config["max2Keep"])

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               sess.graph)

        # initialization
        if self.config["mode"] == "finetune":
            print("Continute train the model")
            tl.load_checkpoint(ckpt_dir, sess)
            print("Load previous model successfully!")
        else:
            print('Initializing all parameters...')
            sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(data_loader) // (self.config["batchSize"] *
                                                (n_d + 1))
            max_it = epoch * it_per_epoch

            print("Start to train the graph!")
            for it in range(sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = sess.run([d_summary, d_step],
                                                    feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = sess.run([g_summary, g_step],
                                                feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                sess.run(x_sample,
                                         feed_dict={
                                             xa_sample: xa_sample_ipt,
                                             _b_sample: _b_sample_ipt,
                                             raw_b_sample: raw_b_sample_ipt
                                         }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            sess.close()
예제 #7
0
def train_on_fake(dataset, c_dim, result_dir, gpu_id, use_real=0, epoch_=200):
    """ param """
    batch_size = 64
    batch_size_fake = batch_size
    lr = 0.0002

    ''' data '''
    if use_real == 1:
        print('======Using real data======')
        batch_size_real = batch_size // 2
        batch_size_fake = batch_size - batch_size_real
        train_tfrecord_path_real = './tfrecords/celeba_tfrecord_train'
        train_data_pool_real = tl.TfrecordData(train_tfrecord_path_real, batch_size_real, shuffle=True)
    train_tfrecord_path_fake = os.path.join(result_dir, 'synthetic_tfrecord')
    train_data_pool_fake = tl.TfrecordData(train_tfrecord_path_fake, batch_size_fake, shuffle=True)
    if dataset == 'CelebA':
        test_tfrecord_path = './tfrecords/celeba_tfrecord_test'
    elif dataset == 'RaFD':
        test_tfrecord_path = './tfrecords/rafd_test'
    test_data_pool = tl.TfrecordData(test_tfrecord_path, 120)
    att_dim = c_dim

    """ graphs """
    with tf.device('/gpu:{}'.format(gpu_id)):

        ''' models '''
        classifier = models.classifier

        ''' graph '''
        # inputs
        x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        x = x_255 / 127.5 - 1
        if dataset == 'CelebA':
            att = tf.placeholder(tf.int64, shape=[None, att_dim])
        elif dataset == 'RaFD':
            att = tf.placeholder(tf.float32, shape=[None, att_dim])

        # classify
        logits = classifier(x, att_dim=att_dim, reuse=False)

        # loss
        reg_loss = tf.losses.get_regularization_loss()
        if dataset == 'CelebA':
            loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_multi_binary_label_with_logits(att, logits)
        elif dataset == 'RaFD':
            loss = tf.losses.softmax_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_one_hot_label_with_logits(att, logits)

        lr_ = tf.placeholder(tf.float32, shape=[])

        # optim
        #with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE):
        step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss)

        # test
        test_logits = classifier(x, att_dim=att_dim, training=False)
        if dataset == 'CelebA':
            test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits)
        elif dataset == 'RaFD':
            test_acc = mean_accuracy_one_hot_label_with_logits(att, test_logits)
        mean_acc = tf.placeholder(tf.float32, shape=())

    # summary
    summary = tl.summary({loss: 'loss', acc: 'acc'})
    test_summary = tl.summary({mean_acc: 'test_acc'})

    """ train """
    ''' init '''
    # session
    sess = tf.Session()
    # iteration counter
    it_cnt, update_cnt = tl.counter()
    # saver
    saver = tf.train.Saver(max_to_keep=None)
    # summary writer
    sum_dir = os.path.join(result_dir, 'summaries_train_on_fake')
    if use_real == 1:
        sum_dir += '_real'
    summary_writer = tf.summary.FileWriter(sum_dir, sess.graph)

    ''' initialization '''
    ckpt_dir = os.path.join(result_dir, 'checkpoints_train_on_fake')
    if use_real == 1:
        ckpt_dir += '_real'
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir + '/')
    if not tl.load_checkpoint(ckpt_dir, sess):
        sess.run(tf.global_variables_initializer())

    ''' train '''
    try:
        batch_epoch = len(train_data_pool_fake) // batch_size
        max_it = epoch_ * batch_epoch
        for it in range(sess.run(it_cnt), max_it):
            bth = it//batch_epoch - 8
            lr__ = lr*(1-max(bth, 0)/epoch_)**0.75
            if it % batch_epoch == 0:
                print('======learning rate:', lr__, '======')
            sess.run(update_cnt)

            # which epoch
            epoch = it // batch_epoch
            it_epoch = it % batch_epoch + 1

            x_255_ipt, att_ipt = train_data_pool_fake.batch(['img', 'attr'])
            if dataset == 'RaFD':
                att_ipt = ToOnehot(att_ipt, att_dim)
            if use_real == 1:
                x_255_ipt_real, att_ipt_real = train_data_pool_real.batch(['img', 'class'])
                x_255_ipt = np.concatenate([x_255_ipt, x_255_ipt_real])
                att_ipt = np.concatenate([att_ipt, att_ipt_real])
            summary_opt, _ = sess.run([summary, step], feed_dict={x_255: x_255_ipt, att: att_ipt, lr_:lr__})
            summary_writer.add_summary(summary_opt, it)

            # display
            if (it + 1) % batch_epoch == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (epoch, it_epoch, batch_epoch))

            # save
            if (it + 1) % (batch_epoch * 50) == 0:
                save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch))
                print('Model saved in file: % s' % save_path)

            # sample
            if it % 100 == 0:
                test_it = 100 if dataset == 'CelebA' else 7
                test_acc_opt_list = []
                for i in range(test_it):
                    key = 'class' if dataset == 'CelebA' else 'attr'
                    x_255_ipt, att_ipt = test_data_pool.batch(['img', key])
                    if dataset == 'RaFD':
                        att_ipt = ToOnehot(att_ipt, att_dim)

                    test_acc_opt = sess.run(test_acc, feed_dict={x_255: x_255_ipt, att: att_ipt})
                    test_acc_opt_list.append(test_acc_opt)
                test_summary_opt = sess.run(test_summary, feed_dict={mean_acc: np.mean(test_acc_opt_list)})
                summary_writer.add_summary(test_summary_opt, it)

    except Exception:
        traceback.print_exc()
    finally:
        print(" [*] Close main session!")
        sess.close()
예제 #8
0
    {l: 'g_f_tree_loss_%d' % i
     for i, l in enumerate(g_f_tree_losses)},
    scope='G_Tree')
g_summary = tf.summary.merge([g_summary, g_tree_summary])

# sample
z_sample = tf.placeholder(tf.float32, [None, z_dim])
c_sample = tf.placeholder(tf.float32, [None, c_dim])
f_sample = G(z_sample, c_sample, is_training=False)

# ==============================================================================
# =                                    train                                   =
# ==============================================================================

# epoch counter
ep_cnt, update_cnt = tl.counter(start=1)

# session
sess = tl.session()

# saver
saver = tf.train.Saver(max_to_keep=1)

# summary writer
summary_writer = tf.summary.FileWriter(
    './output/%s/summaries' % experiment_name, sess.graph)

# initialization
ckpt_dir = './output/%s/checkpoints' % experiment_name
pylib.mkdir(ckpt_dir)
try:
예제 #9
0
    scope='D')

lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

generator_summary = tl.summary(
    {
        discriminator_label_loss: 'discriminator_label_loss',
        label_decoder_loss: 'label_decoder_loss',
        image_decoder_loss: 'image_decode_loss',
    },
    scope='G')

discriminator_summary = tf.summary.merge([discriminator_summary, lr_summary])

# Iteration counter
iteration_counter, update_counter = tl.counter()

# Saver for model
saver = tf.train.Saver(max_to_keep=1)

# Logging information
summary_writer = tf.summary.FileWriter(
    './output/%s/summaries' % experiment_name, sess.graph)

# Check if training was already begun
checkpoint_dir = './output/%s/checkpoints' % experiment_name
pylib.mkdir(checkpoint_dir)
try:
    tl.load_checkpoint(checkpoint_dir, sess)
except:
    sess.run(tf.global_variables_initializer())
예제 #10
0
def G_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])
    zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims]
    eps = tf.random.normal([args.batch_size, args.eps_dim])

    # counter
    step_cnt, _ = tl.counter()

    # optimizer
    optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1)

    def graph_per_gpu(zs, eps):
        # generate
        x_f = G(zs, eps)

        # discriminate
        x_f_logit = D(x_f)

        # loss
        x_f_loss = g_loss_fn(x_f_logit)
        orth_loss = tf.reduce_sum(
            tl.tensors_filter(G.func.reg_losses, 'orthogonal_regularizer'))
        reg_loss = tf.reduce_sum(
            tl.tensors_filter(G.func.reg_losses, 'l2_regularizer'))

        loss = (x_f_loss * args.g_loss_weight_x_gan +
                orth_loss * args.g_loss_weight_orth_loss +
                reg_loss * args.weight_decay)

        # optim
        grads = optimizer.compute_gradients(
            loss, var_list=G.func.trainable_variables)

        return grads, x_f_loss, orth_loss, reg_loss

    split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(
        *tl.parellel_run(tl.gpus(), graph_per_gpu,
                         tl.split_nest((zs, eps), len(tl.gpus()))))
    # split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((zs, eps), 1)))
    grads = tl.average_gradients(split_grads)
    x_f_loss, orth_loss, reg_loss = [
        tf.reduce_mean(t)
        for t in [split_x_f_loss, split_orth_loss, split_reg_loss]
    ]

    step = optimizer.apply_gradients(grads, global_step=step_cnt)

    # moving average
    with tf.control_dependencies([step]):
        step = G_ema.apply(G.func.trainable_variables)

    # summary
    summary_dict = {
        'x_f_loss': x_f_loss,
        'orth_loss': orth_loss,
        'reg_loss': reg_loss
    }
    summary_dict.update({
        'L_%d' % i: t
        for i, t in enumerate(tl.tensors_filter(G.func.variables, 'L'))
    })
    summary_loss = tl.create_summary_statistic_v2(summary_dict,
                                                  './output/%s/summaries/G' %
                                                  args.experiment_name,
                                                  step=step_cnt,
                                                  n_steps_per_record=10,
                                                  name='G_loss')

    summary_image = tl.create_summary_image_v2(
        {
            'orth_U_%d' % i: t[None, :, :, None]
            for i, t in enumerate(tf.get_collection('orth', G.func.scope +
                                                    '/'))
        },
        './output/%s/summaries/G' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='G_image')

    # ======================================
    # =             model size             =
    # ======================================

    n_params, n_bytes = tl.count_parameters(G.func.variables)
    print('Model Size: n_parameters = %d = %.2fMB' %
          (n_params, n_bytes / 1024 / 1024))

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        sess.run([step, summary_loss, summary_image],
                 feed_dict={lr: pl_ipts['lr']})

    return run
예제 #11
0
def D_train_graph():
    # ======================================
    # =               graph                =
    # ======================================

    # placeholders & inputs
    lr = tf.placeholder(dtype=tf.float32, shape=[])
    x_r = train_iter.get_next()
    zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims]
    eps = tf.random.normal([args.batch_size, args.eps_dim])

    # counter
    step_cnt, _ = tl.counter()

    # optimizer
    optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1)

    def graph_per_gpu(x_r, zs, eps):

        # generate
        x_f = G(zs, eps)

        # discriminate
        x_r_logit = D(x_r)
        x_f_logit = D(x_f)

        # loss
        x_r_loss, x_f_loss = d_loss_fn(x_r_logit, x_f_logit)
        x_gp = tf.cond(
            tf.equal(step_cnt % args.d_lazy_reg_period, 0),
            lambda: tfprob.gradient_penalty(
                D, x_r, x_f, args.gradient_penalty_mode, args.
                gradient_penalty_sample_mode) * args.d_lazy_reg_period,
            lambda: tf.constant(0.0))
        if args.d_loss_weight_x_gp == 0:
            x_gp = tf.constant(0.0)

        reg_loss = tf.reduce_sum(D.func.reg_losses)

        loss = ((x_r_loss + x_f_loss) * args.d_loss_weight_x_gan +
                x_gp * args.d_loss_weight_x_gp + reg_loss * args.weight_decay)

        # optim
        grads = optimizer.compute_gradients(
            loss, var_list=D.func.trainable_variables)

        return grads, x_r_loss, x_f_loss, x_gp, reg_loss

    split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip(
        *tl.parellel_run(tl.gpus(), graph_per_gpu,
                         tl.split_nest((x_r, zs, eps), len(tl.gpus()))))
    # split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((x_r, zs, eps), 1)))
    grads = tl.average_gradients(split_grads)
    x_r_loss, x_f_loss, x_gp, reg_loss = [
        tf.reduce_mean(t)
        for t in [split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss]
    ]

    step = optimizer.apply_gradients(grads, global_step=step_cnt)

    # summary
    summary = tl.create_summary_statistic_v2(
        {
            'x_gan_loss': x_r_loss + x_f_loss,
            'x_gp': x_gp,
            'reg_loss': reg_loss,
            'lr': lr
        },
        './output/%s/summaries/D' % args.experiment_name,
        step=step_cnt,
        n_steps_per_record=10,
        name='D')

    # ======================================
    # =            run function            =
    # ======================================

    def run(**pl_ipts):
        for _ in range(args.n_d):
            sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})

    return run