예제 #1
0
xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3])
_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

# sample
x_sample = Gdec(Genc(xa_sample, is_training=False),
                _b_sample,
                is_training=False)

# ==============================================================================
# =                                    test                                    =
# ==============================================================================

# initialization
ckpt_dir = './output/%s/checkpoints' % experiment_name
try:
    tl.load_checkpoint(ckpt_dir, sess)
except:
    raise Exception(' [*] No checkpoint!')

# sample
try:
    # print(te_data)
    # for idx, batch in enumerate(te_data):
    #     print(idx)
    #     print(batch)
    for idx, batch in enumerate(te_data):
        xa_sample_ipt = batch[0]
        a_sample_ipt = batch[1]
        b_sample_ipt_list = [a_sample_ipt]  # the first is for reconstruction
        for i in range(len(atts)):
            tmp = np.array(a_sample_ipt, copy=True)
예제 #2
0
# with tf.device('/gpu:%d' % gpu_id):
''' models '''
classifier = models.classifier
''' graph '''
# inputs
x = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])

# classify
logits = classifier(x, reuse=False, training=False)
pred = tf.cast(tf.round(tf.nn.sigmoid(logits)), tf.int64)
""" train """
''' init '''
# session
sess = tl.session()
''' initialization '''
tl.load_checkpoint(ckpt_file, sess)
''' train '''
try:
    img_paths = glob(os.path.join(img_dir, '*.jpg'))
    img_paths.sort()

    cnt = np.zeros([len(att_id)])
    err_cnt = np.zeros([len(att_id)])
    err_each_cnt = np.zeros([len(att_id), len(att_id)])
    for img_path in img_paths:
        imgs = im.imread(img_path)
        # imgs = im.resize(imgs, (128, 128))
        print(imgs.shape)
        # imgs = np.concatenate([imgs[:, :img_size, :], imgs[:, img_size+img_size//10:, :]], axis=1)
        imgs = np.expand_dims(imgs, axis=0)
        # imgs = np.concatenate(np.split(imgs, 15, axis=2))
예제 #3
0
test_summary = tl.summary({mean_acc: 'test_acc'})
""" train """
''' init '''
# session
sess = tf.Session()
# iteration counter
it_cnt, update_cnt = tl.counter()
# saver
saver = tf.train.Saver(max_to_keep=None)
# summary writer
summary_writer = tf.summary.FileWriter('./summaries', sess.graph)
''' initialization '''
ckpt_dir = './checkpoints'
if not os.path.exists(ckpt_dir):
    os.mkdir(ckpt_dir + '/')
if not tl.load_checkpoint(ckpt_dir, sess):
    sess.run(tf.global_variables_initializer())
''' train '''
try:
    batch_epoch = len(train_data_pool) // batch_size
    max_it = epoch_ * batch_epoch
    for it in range(sess.run(it_cnt), max_it):
        bth = it // batch_epoch - 8
        lr__ = lr * (1 - max(bth, 0) / epoch_)**0.75
        if it % batch_epoch == 0:
            print('======learning rate:', lr__, '======')
        sess.run(update_cnt)

        # which epoch
        epoch = it // batch_epoch
        it_epoch = it % batch_epoch + 1
예제 #4
0
    def train(self):

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               self.sess.graph)

        # initialization
        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        try:
            assert clear == False
            tl.load_checkpoint(ckpt_dir, self.sess)
        except:
            print('NOTE: Initializing all parameters...')
            self.sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = self.val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(self.data_loader) // (self.config["batchSize"] *
                                                     (n_d + 1))
            max_it = epoch * it_per_epoch

            for it in range(self.sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    self.sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = self.sess.run(
                            [self.d_summary, self.d_step],
                            feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = self.sess.run(
                        [self.g_summary, self.g_step],
                        feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                            (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                self.sess.run(self.x_sample,
                                              feed_dict={
                                                  self.xa_sample:
                                                  xa_sample_ipt,
                                                  self._b_sample:
                                                  _b_sample_ipt,
                                                  self.raw_b_sample:
                                                  raw_b_sample_ipt
                                              }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            self.sess.close()
예제 #5
0
    def train(self):

        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        n_att = len(self.config["selectedAttrs"])

        if self.config["threads"] >= 0:
            cpu_config = tf.ConfigProto(
                intra_op_parallelism_threads=self.config["threads"] // 2,
                inter_op_parallelism_threads=self.config["threads"] // 2,
                device_count={'CPU': self.config["threads"]})
            cpu_config.gpu_options.allow_growth = True
            sess = tf.Session(config=cpu_config)
        else:
            sess = tl.session()

        data_loader = Celeba(self.config["dataset_path"],
                             self.config["selectedAttrs"],
                             self.config["imsize"],
                             self.config["batchSize"],
                             part='train',
                             sess=sess,
                             crop=(self.config["imCropSize"] > 0))

        val_loader = Celeba(self.config["dataset_path"],
                            self.config["selectedAttrs"],
                            self.config["imsize"],
                            self.config["sampleNum"],
                            part='val',
                            shuffle=False,
                            sess=sess,
                            crop=(self.config["imCropSize"] > 0))

        package = __import__("components." + self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        DClass = getattr(package, 'D')
        GP = getattr(package, "gradient_penalty")

        package = __import__("components.STU." + self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["skipNum"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state='stu')

        D = partial(DClass,
                    n_att=n_att,
                    dim=self.config["DConvDim"],
                    fc_dim=self.config["DFcDim"],
                    n_layers=self.config["DLayerNum"])

        # inputs

        xa = data_loader.batch_op[0]
        a = data_loader.batch_op[1]
        b = tf.random_shuffle(a)
        _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"]
        _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"]

        xa_sample = tf.placeholder(
            tf.float32,
            shape=[None, self.config["imsize"], self.config["imsize"], 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        lr = tf.placeholder(tf.float32, shape=[])

        # generate
        z = Genc(xa)
        zb = Gstu(z, _b - _a)
        xb_ = Gdec(zb, _b - _a)
        with tf.control_dependencies([xb_]):
            za = Gstu(z, _a - _a)
            xa_ = Gdec(za, _a - _a)

        # discriminate
        xa_logit_gan, xa_logit_att = D(xa)
        xb__logit_gan, xb__logit_att = D(xb_)

        wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
        d_loss_gan = -wd
        gp = GP(D, xa, xb_)
        xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
        d_loss = d_loss_gan + gp * 10.0 + xa_loss_att

        xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
        xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
        xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
        g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[
            "recWeight"]

        d_var = tl.trainable_variables('D')
        d_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var)
        g_var = tl.trainable_variables('G')
        g_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var)

        d_summary = tl.summary(
            {
                d_loss_gan: 'd_loss_gan',
                gp: 'gp',
                xa_loss_att: 'xa_loss_att',
            },
            scope='D')

        lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

        g_summary = tl.summary(
            {
                xb__loss_gan: 'xb__loss_gan',
                xb__loss_att: 'xb__loss_att',
                xa__loss_rec: 'xa__loss_rec',
            },
            scope='G')

        d_summary = tf.summary.merge([d_summary, lr_summary])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=self.config["max2Keep"])

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               sess.graph)

        # initialization
        if self.config["mode"] == "finetune":
            print("Continute train the model")
            tl.load_checkpoint(ckpt_dir, sess)
            print("Load previous model successfully!")
        else:
            print('Initializing all parameters...')
            sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(data_loader) // (self.config["batchSize"] *
                                                (n_d + 1))
            max_it = epoch * it_per_epoch

            print("Start to train the graph!")
            for it in range(sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = sess.run([d_summary, d_step],
                                                    feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = sess.run([g_summary, g_step],
                                                feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                sess.run(x_sample,
                                         feed_dict={
                                             xa_sample: xa_sample_ipt,
                                             _b_sample: _b_sample_ipt,
                                             raw_b_sample: raw_b_sample_ipt
                                         }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            sess.close()
예제 #6
0
def model_initialization():
    """Model creation and weight load.
    
    Load of several parameters found in the pretrained STGAN model: 
    https://drive.google.com/open?id=1329IbLE6877DcDUut1reKxckijBJye7N.

    Returns:
        sess (TF Session): Current session for inference.
        x_sample (tfTensor): Tensor of shape (n_img, 128, 128, 3).
        xa_sample (tfTensor): Input tensor of shape (n_img, 128, 128, 3).
        _b_sample (tfTensor): Label tensor of shape (n_img, 13).
        raw_b_sample (tfTensor): Label tensor of shape (n_img, 13).

    """
    with open('./model/setting.txt') as f:
        args = json.load(f)

    atts = args['atts']
    n_atts = len(atts)
    img_size = args['img_size']
    shortcut_layers = args['shortcut_layers']
    inject_layers = args['inject_layers']
    enc_dim = args['enc_dim']
    dec_dim = args['dec_dim']
    dis_dim = args['dis_dim']
    dis_fc_dim = args['dis_fc_dim']
    enc_layers = args['enc_layers']
    dec_layers = args['dec_layers']
    dis_layers = args['dis_layers']

    label = args['label']
    use_stu = args['use_stu']
    stu_dim = args['stu_dim']
    stu_layers = args['stu_layers']
    stu_inject_layers = args['stu_inject_layers']
    stu_kernel_size = args['stu_kernel_size']
    stu_norm = args['stu_norm']
    stu_state = args['stu_state']
    multi_inputs = args['multi_inputs']
    rec_loss_weight = args['rec_loss_weight']
    one_more_conv = args['one_more_conv']

    sess = tl.session()
    # Models
    Genc = partial(models.Genc,
                   dim=enc_dim,
                   n_layers=enc_layers,
                   multi_inputs=multi_inputs)
    Gdec = partial(models.Gdec,
                   dim=dec_dim,
                   n_layers=dec_layers,
                   shortcut_layers=shortcut_layers,
                   inject_layers=inject_layers,
                   one_more_conv=one_more_conv)
    Gstu = partial(models.Gstu,
                   dim=stu_dim,
                   n_layers=stu_layers,
                   inject_layers=stu_inject_layers,
                   kernel_size=stu_kernel_size,
                   norm=stu_norm,
                   pass_state=stu_state)

    # Inputs
    xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3])
    _b_sample = tf.placeholder(tf.float32, shape=[None, n_atts])
    raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_atts])

    # Sample
    test_label = _b_sample - raw_b_sample if label == 'diff' else _b_sample
    if use_stu:
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)
    else:
        x_sample = Gdec(Genc(xa_sample, is_training=False),
                        test_label,
                        is_training=False)

    # Initialization
    ckpt_dir = './model/checkpoints'
    tl.load_checkpoint(ckpt_dir, sess)

    return sess, x_sample, xa_sample, _b_sample, raw_b_sample
예제 #7
0
def train_on_fake(dataset, c_dim, result_dir, gpu_id, use_real=0, epoch_=200):
    """ param """
    batch_size = 64
    batch_size_fake = batch_size
    lr = 0.0002

    ''' data '''
    if use_real == 1:
        print('======Using real data======')
        batch_size_real = batch_size // 2
        batch_size_fake = batch_size - batch_size_real
        train_tfrecord_path_real = './tfrecords/celeba_tfrecord_train'
        train_data_pool_real = tl.TfrecordData(train_tfrecord_path_real, batch_size_real, shuffle=True)
    train_tfrecord_path_fake = os.path.join(result_dir, 'synthetic_tfrecord')
    train_data_pool_fake = tl.TfrecordData(train_tfrecord_path_fake, batch_size_fake, shuffle=True)
    if dataset == 'CelebA':
        test_tfrecord_path = './tfrecords/celeba_tfrecord_test'
    elif dataset == 'RaFD':
        test_tfrecord_path = './tfrecords/rafd_test'
    test_data_pool = tl.TfrecordData(test_tfrecord_path, 120)
    att_dim = c_dim

    """ graphs """
    with tf.device('/gpu:{}'.format(gpu_id)):

        ''' models '''
        classifier = models.classifier

        ''' graph '''
        # inputs
        x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        x = x_255 / 127.5 - 1
        if dataset == 'CelebA':
            att = tf.placeholder(tf.int64, shape=[None, att_dim])
        elif dataset == 'RaFD':
            att = tf.placeholder(tf.float32, shape=[None, att_dim])

        # classify
        logits = classifier(x, att_dim=att_dim, reuse=False)

        # loss
        reg_loss = tf.losses.get_regularization_loss()
        if dataset == 'CelebA':
            loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_multi_binary_label_with_logits(att, logits)
        elif dataset == 'RaFD':
            loss = tf.losses.softmax_cross_entropy(att, logits) + reg_loss
            acc = mean_accuracy_one_hot_label_with_logits(att, logits)

        lr_ = tf.placeholder(tf.float32, shape=[])

        # optim
        #with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE):
        step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss)

        # test
        test_logits = classifier(x, att_dim=att_dim, training=False)
        if dataset == 'CelebA':
            test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits)
        elif dataset == 'RaFD':
            test_acc = mean_accuracy_one_hot_label_with_logits(att, test_logits)
        mean_acc = tf.placeholder(tf.float32, shape=())

    # summary
    summary = tl.summary({loss: 'loss', acc: 'acc'})
    test_summary = tl.summary({mean_acc: 'test_acc'})

    """ train """
    ''' init '''
    # session
    sess = tf.Session()
    # iteration counter
    it_cnt, update_cnt = tl.counter()
    # saver
    saver = tf.train.Saver(max_to_keep=None)
    # summary writer
    sum_dir = os.path.join(result_dir, 'summaries_train_on_fake')
    if use_real == 1:
        sum_dir += '_real'
    summary_writer = tf.summary.FileWriter(sum_dir, sess.graph)

    ''' initialization '''
    ckpt_dir = os.path.join(result_dir, 'checkpoints_train_on_fake')
    if use_real == 1:
        ckpt_dir += '_real'
    if not os.path.exists(ckpt_dir):
        os.mkdir(ckpt_dir + '/')
    if not tl.load_checkpoint(ckpt_dir, sess):
        sess.run(tf.global_variables_initializer())

    ''' train '''
    try:
        batch_epoch = len(train_data_pool_fake) // batch_size
        max_it = epoch_ * batch_epoch
        for it in range(sess.run(it_cnt), max_it):
            bth = it//batch_epoch - 8
            lr__ = lr*(1-max(bth, 0)/epoch_)**0.75
            if it % batch_epoch == 0:
                print('======learning rate:', lr__, '======')
            sess.run(update_cnt)

            # which epoch
            epoch = it // batch_epoch
            it_epoch = it % batch_epoch + 1

            x_255_ipt, att_ipt = train_data_pool_fake.batch(['img', 'attr'])
            if dataset == 'RaFD':
                att_ipt = ToOnehot(att_ipt, att_dim)
            if use_real == 1:
                x_255_ipt_real, att_ipt_real = train_data_pool_real.batch(['img', 'class'])
                x_255_ipt = np.concatenate([x_255_ipt, x_255_ipt_real])
                att_ipt = np.concatenate([att_ipt, att_ipt_real])
            summary_opt, _ = sess.run([summary, step], feed_dict={x_255: x_255_ipt, att: att_ipt, lr_:lr__})
            summary_writer.add_summary(summary_opt, it)

            # display
            if (it + 1) % batch_epoch == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (epoch, it_epoch, batch_epoch))

            # save
            if (it + 1) % (batch_epoch * 50) == 0:
                save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch))
                print('Model saved in file: % s' % save_path)

            # sample
            if it % 100 == 0:
                test_it = 100 if dataset == 'CelebA' else 7
                test_acc_opt_list = []
                for i in range(test_it):
                    key = 'class' if dataset == 'CelebA' else 'attr'
                    x_255_ipt, att_ipt = test_data_pool.batch(['img', key])
                    if dataset == 'RaFD':
                        att_ipt = ToOnehot(att_ipt, att_dim)

                    test_acc_opt = sess.run(test_acc, feed_dict={x_255: x_255_ipt, att: att_ipt})
                    test_acc_opt_list.append(test_acc_opt)
                test_summary_opt = sess.run(test_summary, feed_dict={mean_acc: np.mean(test_acc_opt_list)})
                summary_writer.add_summary(test_summary_opt, it)

    except Exception:
        traceback.print_exc()
    finally:
        print(" [*] Close main session!")
        sess.close()
예제 #8
0
    def attr_cls(self):
        """Computes the GAN-test attribute classification accuracy."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        data_loader = self.labelled_loader

        attr_list = []
        if self.dataset == 'CelebA':
            ckpt_file = './checkpoints_train_on_real/CelebA/Epoch_(127)_(2543of2543).ckpt'
            attr_list = self.selected_attrs
            n_print = 2000
        elif self.dataset == 'RaFD':
            ckpt_file = './checkpoints_train_on_real/RaFD/Epoch_(199)_(112of112).ckpt'
            attr_list = self.selected_emots
            n_print = 200
        classifier = models.classifier

        # Classifier graph
        x = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        logits = classifier(x,
                            att_dim=len(attr_list),
                            reuse=False,
                            training=False)
        if self.dataset == 'CelebA':
            pred_s = tf.cast(tf.nn.sigmoid(logits), tf.float64)
        elif self.dataset == 'RaFD':
            pred_s = tf.cast(tf.nn.softmax(logits), tf.float64)

        cnt_pos = np.zeros([self.c_dim]).astype(np.int64)
        cnt_neg = np.zeros([self.c_dim]).astype(np.int64)
        cnt_rec = np.zeros([self.c_dim]).astype(np.int64)
        c_pos = np.zeros([self.c_dim])
        c_neg = np.zeros([self.c_dim])
        c_rec = np.zeros([self.c_dim])
        ca_req = np.zeros([self.c_dim]).astype(np.int64)
        cr_req = np.zeros([self.c_dim]).astype(np.int64)
        co_req = np.zeros([self.c_dim]).astype(np.int64)

        with torch.no_grad():
            with tl.session() as sess:
                tl.load_checkpoint(ckpt_file, sess)
                attr_list = ['Reconstruction'] + attr_list
                total_count = 0
                for i, (x_real, c_org) in enumerate(data_loader):

                    if self.dataset == 'RaFD':
                        c_org = self.label2onehot(c_org, self.c_dim)

                    # Prepare input images and target domain labels.
                    x_real = x_real.to(self.device)
                    c_trg_list = self.create_labels(c_org, self.c_dim,
                                                    self.dataset,
                                                    self.selected_attrs)
                    c_trg_batch = torch.cat(
                        [c.unsqueeze(1) for c in c_trg_list],
                        dim=1).cpu().numpy()
                    c_trg_list = [None] + [c_org.to(self.device)] + c_trg_list
                    att_gt_batch = c_org.numpy()

                    # Classify translate images.
                    pred_score_list = []
                    preds_list = []
                    for j, c_trg in enumerate(c_trg_list):
                        if j == 0:
                            feed = np.transpose(x_real.cpu().numpy(),
                                                [0, 2, 3, 1])
                        else:
                            x_fake = self.G(x_real, c_trg)
                            feed = np.transpose(x_fake.cpu().numpy(),
                                                [0, 2, 3, 1])
                        pred_score = sess.run(pred_s, feed_dict={x: feed})
                        pred_score_list.append(
                            np.expand_dims(pred_score, axis=1))
                        if self.dataset == 'CelebA':
                            preds = np.round(pred_score).astype(int)
                        elif self.dataset == 'RaFD':
                            max_id = np.argmax(pred_score, axis=1)
                            preds = np.zeros_like(pred_score).astype(int)
                            preds[np.arange(pred_score.shape[0]), max_id] = 1
                        preds_list.append(np.expand_dims(preds, axis=1))
                    pred_score_batch = np.concatenate(pred_score_list, axis=1)
                    preds_opt_batch = np.concatenate(preds_list, axis=1)

                    # Calculate accuracy.
                    for pred_score, preds_opt, att_gt, c_trg in zip(
                            pred_score_batch, preds_opt_batch, att_gt_batch,
                            c_trg_batch):
                        for k in range(2, len(preds_opt)):
                            if c_trg[k - 2, k - 2] == 1 - att_gt[k - 2]:
                                if att_gt[k - 2] == 0:
                                    ca_req[k - 2] += 1
                                elif att_gt[k - 2] == 1:
                                    cr_req[k - 2] += 1

                                if preds_opt[k, k - 2] == 1 - att_gt[k - 2]:
                                    if preds_opt[k, k - 2] == 1:
                                        cnt_pos[k - 2] += 1
                                        c_pos[k - 2] += pred_score[k, k - 2]
                                    elif preds_opt[k, k - 2] == 0:
                                        cnt_neg[k - 2] += 1
                                        c_neg[k -
                                              2] += 1 - pred_score[k, k - 2]
                            else:
                                co_req[k - 2] += 1
                                if preds_opt[k, k - 2] == att_gt[k - 2]:
                                    cnt_rec[k - 2] += 1
                                    if preds_opt[k, k - 2] == 1:
                                        c_rec[k - 2] += pred_score[k, k - 2]
                                    elif preds_opt[k, k - 2] == 0:
                                        c_rec[k -
                                              2] += 1 - pred_score[k, k - 2]

                    total_count += x_real.shape[0]
                    if total_count % n_print == 0:
                        print('{} images classified.'.format(total_count))
                        print('\tAcc. Addition')
                        print('\t', cnt_pos / ca_req)
                        print('\t', np.mean(cnt_pos / ca_req))

                attr_cls_path = os.path.join(self.result_dir, 'GAN-test.txt')
                with open(attr_cls_path, 'w') as f:
                    f.write('Overall accuracy,{},average,{}\n'.format(
                        arr_2_str((cnt_pos + cnt_neg + cnt_rec) /
                                  (ca_req + cr_req + co_req)),
                        arr_2_str(
                            np.mean((cnt_pos + cnt_neg + cnt_rec) /
                                    (ca_req + cr_req + co_req)))))
        print('GAN-test accuracy: {}'.format(
            arr_2_str(
                np.mean((cnt_pos + cnt_neg + cnt_rec) /
                        (ca_req + cr_req + co_req)))))
예제 #9
0
파일: gan_test.py 프로젝트: yyht/MatchGAN
def test_train_on_fake(dataset, c_dim, result_dir, gpu_id, epoch_=200):
    img_size = 128
    ''' data '''
    if dataset == 'CelebA':
        ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(2513of2513).ckpt'.format(
            epoch_ - 1)
        test_tfrecord_path = './tfrecords_test/celeba_tfrecord_test'
        test_data_pool = tl.TfrecordData(test_tfrecord_path, 18, shuffle=False)
    elif dataset == 'RaFD':
        ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(112of112).ckpt'.format(
            epoch_ - 1)
        test_tfrecord_path = './tfrecords_test/rafd_test'
        test_data_pool = tl.TfrecordData(test_tfrecord_path,
                                         120,
                                         shuffle=False)
    ckpt_file = os.path.join(result_dir, ckpt_file)
    """ graphs """
    with tf.device('/gpu:{}'.format(gpu_id)):
        ''' models '''
        classifier = models.classifier
        ''' graph '''
        # inputs
        x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        x = x_255 / 127.5 - 1
        if dataset == 'CelebA':
            label = tf.placeholder(tf.int64, shape=[None, c_dim])
        elif dataset == 'RaFD':
            label = tf.placeholder(tf.float32, shape=[None, c_dim])

        # classify
        logits = classifier(x, att_dim=c_dim, reuse=False, training=False)

        if dataset == 'CelebA':
            accuracy = mean_accuracy_multi_binary_label_with_logits(
                label, logits)
        elif dataset == 'RaFD':
            accuracy = mean_accuracy_one_hot_label_with_logits(label, logits)
    """ train """
    ''' init '''
    # session
    sess = tl.session()
    ''' initialization '''
    tl.load_checkpoint(ckpt_file, sess)
    ''' train '''
    try:
        all_accuracies = []
        denom = 18 if dataset == 'CelebA' else 120
        key = 'class' if dataset == 'CelebA' else 'attr'
        test_iter = len(test_data_pool) // denom
        for iter in range(test_iter):
            img, label_gt = test_data_pool.batch(['img', key])
            if dataset == 'RaFD':
                label_gt = ToOnehot(label_gt, c_dim)
            print('Test batch {}'.format(iter), end='\r')
            batch_accuracy = sess.run(accuracy,
                                      feed_dict={
                                          x_255: img,
                                          label: label_gt
                                      })
            all_accuracies.append(batch_accuracy)

        if dataset == 'CelebA':
            mean_accuracies = np.mean(np.concatenate(all_accuracies), axis=0)
            mean_accuracy = np.mean(mean_accuracies)
            print('\nIndividual accuracies: {} Average: {:.4f}'.format(
                mean_accuracies, mean_accuracy))
            with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f:
                for attr, acc in zip([
                        'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male',
                        'Young'
                ], mean_accuracies):
                    f.write('{}: {}\n'.format(attr, acc))
                f.write('Average: {}'.format(mean_accuracy))
        elif dataset == 'RaFD':
            mean_accuracy = np.mean(all_accuracies)
            print('\nAverage accuracies: {:.4f}'.format(mean_accuracy))
            with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f:
                f.write('Average accuracy: {}'.format(mean_accuracy))

    except Exception:
        traceback.print_exc()
    finally:
        print(" [*] Close main session!")
        sess.close()
예제 #10
0
def runModel(image_url, file_name, test_att, n_slide, image_labels,
             model_type):
    # ==============================================================================
    # =                                    param                                   =
    # ==============================================================================

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_name',
                        dest='experiment_name',
                        default="384_shortcut1_inject1_none_hd",
                        help='experiment_name')
    parser.add_argument('--test_att', dest='test_att', help='test_att')
    parser.add_argument('--test_int_min',
                        dest='test_int_min',
                        type=float,
                        default=-1.0,
                        help='test_int_min')
    parser.add_argument('--test_int_max',
                        dest='test_int_max',
                        type=float,
                        default=1.0,
                        help='test_int_max')
    args_ = parser.parse_args()

    if model_type == 0:
        experiment_name = args_.experiment_name
    else:
        experiment_name = "128_custom"

    print("EXPERIMENT NAME WORKING:" + experiment_name)

    with open('./output/%s/setting.txt' % experiment_name) as f:
        args = json.load(f)

    # model
    atts = args['atts']
    n_att = len(atts)
    img_size = args['img_size']
    shortcut_layers = args['shortcut_layers']
    inject_layers = args['inject_layers']
    enc_dim = args['enc_dim']
    dec_dim = args['dec_dim']
    dis_dim = args['dis_dim']
    dis_fc_dim = args['dis_fc_dim']
    enc_layers = args['enc_layers']
    dec_layers = args['dec_layers']
    dis_layers = args['dis_layers']
    # testing
    thres_int = args['thres_int']
    test_int_min = args_.test_int_min
    test_int_max = args_.test_int_max
    # others
    use_cropped_img = args['use_cropped_img']
    n_slide = int(n_slide)

    assert test_att is not None, 'test_att should be chosen in %s' % (
        str(atts))

    # ==============================================================================
    # =                                   graphs                                   =
    # ==============================================================================

    # data
    sess = tl.session()

    # get image
    print(image_url)
    if experiment_name == "128_custom":
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_align_celeba "
            + image_url)
    else:
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_crop_celeba "
            + image_url)

    print("Working")

    # pass image with labels to dataset
    te_data = data.Celeba('./data',
                          atts,
                          img_size,
                          1,
                          part='val',
                          sess=sess,
                          crop=not use_cropped_img,
                          image_labels=image_labels,
                          file_name=file_name)

    sample = None

    # models
    Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
    Gdec = partial(models.Gdec,
                   dim=dec_dim,
                   n_layers=dec_layers,
                   shortcut_layers=shortcut_layers,
                   inject_layers=inject_layers)

    # inputs
    xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3])
    _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

    # sample
    x_sample = Gdec(Genc(xa_sample, is_training=False),
                    _b_sample,
                    is_training=False)

    # ==============================================================================
    # =                                    test                                    =
    # ==============================================================================

    # initialization
    ckpt_dir = './output/%s/checkpoints' % experiment_name
    print("CHECKPOINT DIR: " + ckpt_dir)
    try:
        tl.load_checkpoint(ckpt_dir, sess)
    except:
        raise Exception(' [*] No checkpoint!')

    save_location = ""
    # sample
    try:
        for idx, batch in enumerate(te_data):
            xa_sample_ipt = batch[0]
            b_sample_ipt = batch[1]

            x_sample_opt_list = []

            for i in range(n_slide - 1, n_slide):
                test_int = (test_int_max -
                            test_int_min) / (n_slide - 1) * i + test_int_min
                _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                _b_sample_ipt[..., atts.index(test_att)] = test_int
                x_sample_opt_list.append(
                    sess.run(x_sample,
                             feed_dict={
                                 xa_sample: xa_sample_ipt,
                                 _b_sample: _b_sample_ipt
                             }))

            sample = np.concatenate(x_sample_opt_list, 2)
            save_location = '/output/%s/sample_testing_slide_%s/' % (
                experiment_name, test_att)
            save_dir = './output/%s/sample_testing_slide_%s' % (
                experiment_name, test_att)
            pylib.mkdir(save_dir)
            im.imwrite(sample.squeeze(0), '%s/%s' % (save_dir, file_name))

            print('%d.png done!' % (idx + 0))

            if (idx + 1 == te_data._img_num):
                break
    except:
        traceback.print_exc()
    finally:
        sess.close()

    if experiment_name == "128_custom":
        os.system("rm ./data/img_align_celeba/" + file_name)
    else:
        os.system("rm ./data/img_crop_celeba/" + file_name)

    return "http://129.32.22.10:7001" + save_location + file_name