Beispiel #1
0
pylib.mkdir('./output/%s' % experiment_name)
with open('./output/%s/setting.txt' % experiment_name, 'w') as f:
    f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))

# ==============================================================================
# =                                   graphs                                   =
# ==============================================================================

# data
if threads >= 0:
    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=threads // 2,
                                inter_op_parallelism_threads=threads // 2,
                                device_count={'CPU': threads})
    sess = tf.Session(config=cpu_config)
else:
    sess = tl.session()
crop_ = not use_cropped_img
tr_data = data.Celeba(dataroot,
                      atts,
                      img_size,
                      batch_size,
                      part='train',
                      sess=sess,
                      crop=crop_)
val_data = data.Celeba(dataroot,
                       atts,
                       img_size,
                       n_sample,
                       part='val',
                       shuffle=False,
                       sess=sess,
Beispiel #2
0
Datei: mnist.py Projekt: zj10/PGA
                 sess=None):
        imgs, lbls, _ = mnist_load(data_dir, split)
        imgs.shape = imgs.shape + (1, )

        imgs_pl = tf.placeholder(tf.float32, imgs.shape)
        lbls_pl = tf.placeholder(tf.int64, lbls.shape)

        memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl}

        self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls}
        super(Mnist,
              self).__init__(memory_data_dict, batch_size, prefetch_batch,
                             drop_remainder, filter, map_func, num_threads,
                             shuffle, buffer_size, repeat, sess)

    def reset(self):
        super(Mnist, self).reset(self.feed_dict)


if __name__ == '__main__':
    import imlib as im
    from tflib import session
    sess = session()
    mnist = Mnist('/tmp', 5000, repeat=1, sess=sess)
    print(len(mnist))
    for batch in mnist:
        print(batch['lbl'][-1])
        im.imshow(batch['img'][-1].squeeze())
        im.show()
    sess.close()
Beispiel #3
0
    def attr_cls(self):
        """Computes the GAN-test attribute classification accuracy."""
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        data_loader = self.labelled_loader

        attr_list = []
        if self.dataset == 'CelebA':
            ckpt_file = './checkpoints_train_on_real/CelebA/Epoch_(127)_(2543of2543).ckpt'
            attr_list = self.selected_attrs
            n_print = 2000
        elif self.dataset == 'RaFD':
            ckpt_file = './checkpoints_train_on_real/RaFD/Epoch_(199)_(112of112).ckpt'
            attr_list = self.selected_emots
            n_print = 200
        classifier = models.classifier

        # Classifier graph
        x = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        logits = classifier(x,
                            att_dim=len(attr_list),
                            reuse=False,
                            training=False)
        if self.dataset == 'CelebA':
            pred_s = tf.cast(tf.nn.sigmoid(logits), tf.float64)
        elif self.dataset == 'RaFD':
            pred_s = tf.cast(tf.nn.softmax(logits), tf.float64)

        cnt_pos = np.zeros([self.c_dim]).astype(np.int64)
        cnt_neg = np.zeros([self.c_dim]).astype(np.int64)
        cnt_rec = np.zeros([self.c_dim]).astype(np.int64)
        c_pos = np.zeros([self.c_dim])
        c_neg = np.zeros([self.c_dim])
        c_rec = np.zeros([self.c_dim])
        ca_req = np.zeros([self.c_dim]).astype(np.int64)
        cr_req = np.zeros([self.c_dim]).astype(np.int64)
        co_req = np.zeros([self.c_dim]).astype(np.int64)

        with torch.no_grad():
            with tl.session() as sess:
                tl.load_checkpoint(ckpt_file, sess)
                attr_list = ['Reconstruction'] + attr_list
                total_count = 0
                for i, (x_real, c_org) in enumerate(data_loader):

                    if self.dataset == 'RaFD':
                        c_org = self.label2onehot(c_org, self.c_dim)

                    # Prepare input images and target domain labels.
                    x_real = x_real.to(self.device)
                    c_trg_list = self.create_labels(c_org, self.c_dim,
                                                    self.dataset,
                                                    self.selected_attrs)
                    c_trg_batch = torch.cat(
                        [c.unsqueeze(1) for c in c_trg_list],
                        dim=1).cpu().numpy()
                    c_trg_list = [None] + [c_org.to(self.device)] + c_trg_list
                    att_gt_batch = c_org.numpy()

                    # Classify translate images.
                    pred_score_list = []
                    preds_list = []
                    for j, c_trg in enumerate(c_trg_list):
                        if j == 0:
                            feed = np.transpose(x_real.cpu().numpy(),
                                                [0, 2, 3, 1])
                        else:
                            x_fake = self.G(x_real, c_trg)
                            feed = np.transpose(x_fake.cpu().numpy(),
                                                [0, 2, 3, 1])
                        pred_score = sess.run(pred_s, feed_dict={x: feed})
                        pred_score_list.append(
                            np.expand_dims(pred_score, axis=1))
                        if self.dataset == 'CelebA':
                            preds = np.round(pred_score).astype(int)
                        elif self.dataset == 'RaFD':
                            max_id = np.argmax(pred_score, axis=1)
                            preds = np.zeros_like(pred_score).astype(int)
                            preds[np.arange(pred_score.shape[0]), max_id] = 1
                        preds_list.append(np.expand_dims(preds, axis=1))
                    pred_score_batch = np.concatenate(pred_score_list, axis=1)
                    preds_opt_batch = np.concatenate(preds_list, axis=1)

                    # Calculate accuracy.
                    for pred_score, preds_opt, att_gt, c_trg in zip(
                            pred_score_batch, preds_opt_batch, att_gt_batch,
                            c_trg_batch):
                        for k in range(2, len(preds_opt)):
                            if c_trg[k - 2, k - 2] == 1 - att_gt[k - 2]:
                                if att_gt[k - 2] == 0:
                                    ca_req[k - 2] += 1
                                elif att_gt[k - 2] == 1:
                                    cr_req[k - 2] += 1

                                if preds_opt[k, k - 2] == 1 - att_gt[k - 2]:
                                    if preds_opt[k, k - 2] == 1:
                                        cnt_pos[k - 2] += 1
                                        c_pos[k - 2] += pred_score[k, k - 2]
                                    elif preds_opt[k, k - 2] == 0:
                                        cnt_neg[k - 2] += 1
                                        c_neg[k -
                                              2] += 1 - pred_score[k, k - 2]
                            else:
                                co_req[k - 2] += 1
                                if preds_opt[k, k - 2] == att_gt[k - 2]:
                                    cnt_rec[k - 2] += 1
                                    if preds_opt[k, k - 2] == 1:
                                        c_rec[k - 2] += pred_score[k, k - 2]
                                    elif preds_opt[k, k - 2] == 0:
                                        c_rec[k -
                                              2] += 1 - pred_score[k, k - 2]

                    total_count += x_real.shape[0]
                    if total_count % n_print == 0:
                        print('{} images classified.'.format(total_count))
                        print('\tAcc. Addition')
                        print('\t', cnt_pos / ca_req)
                        print('\t', np.mean(cnt_pos / ca_req))

                attr_cls_path = os.path.join(self.result_dir, 'GAN-test.txt')
                with open(attr_cls_path, 'w') as f:
                    f.write('Overall accuracy,{},average,{}\n'.format(
                        arr_2_str((cnt_pos + cnt_neg + cnt_rec) /
                                  (ca_req + cr_req + co_req)),
                        arr_2_str(
                            np.mean((cnt_pos + cnt_neg + cnt_rec) /
                                    (ca_req + cr_req + co_req)))))
        print('GAN-test accuracy: {}'.format(
            arr_2_str(
                np.mean((cnt_pos + cnt_neg + cnt_rec) /
                        (ca_req + cr_req + co_req)))))
Beispiel #4
0
def test_train_on_fake(dataset, c_dim, result_dir, gpu_id, epoch_=200):
    img_size = 128
    ''' data '''
    if dataset == 'CelebA':
        ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(2513of2513).ckpt'.format(
            epoch_ - 1)
        test_tfrecord_path = './tfrecords_test/celeba_tfrecord_test'
        test_data_pool = tl.TfrecordData(test_tfrecord_path, 18, shuffle=False)
    elif dataset == 'RaFD':
        ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(112of112).ckpt'.format(
            epoch_ - 1)
        test_tfrecord_path = './tfrecords_test/rafd_test'
        test_data_pool = tl.TfrecordData(test_tfrecord_path,
                                         120,
                                         shuffle=False)
    ckpt_file = os.path.join(result_dir, ckpt_file)
    """ graphs """
    with tf.device('/gpu:{}'.format(gpu_id)):
        ''' models '''
        classifier = models.classifier
        ''' graph '''
        # inputs
        x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
        x = x_255 / 127.5 - 1
        if dataset == 'CelebA':
            label = tf.placeholder(tf.int64, shape=[None, c_dim])
        elif dataset == 'RaFD':
            label = tf.placeholder(tf.float32, shape=[None, c_dim])

        # classify
        logits = classifier(x, att_dim=c_dim, reuse=False, training=False)

        if dataset == 'CelebA':
            accuracy = mean_accuracy_multi_binary_label_with_logits(
                label, logits)
        elif dataset == 'RaFD':
            accuracy = mean_accuracy_one_hot_label_with_logits(label, logits)
    """ train """
    ''' init '''
    # session
    sess = tl.session()
    ''' initialization '''
    tl.load_checkpoint(ckpt_file, sess)
    ''' train '''
    try:
        all_accuracies = []
        denom = 18 if dataset == 'CelebA' else 120
        key = 'class' if dataset == 'CelebA' else 'attr'
        test_iter = len(test_data_pool) // denom
        for iter in range(test_iter):
            img, label_gt = test_data_pool.batch(['img', key])
            if dataset == 'RaFD':
                label_gt = ToOnehot(label_gt, c_dim)
            print('Test batch {}'.format(iter), end='\r')
            batch_accuracy = sess.run(accuracy,
                                      feed_dict={
                                          x_255: img,
                                          label: label_gt
                                      })
            all_accuracies.append(batch_accuracy)

        if dataset == 'CelebA':
            mean_accuracies = np.mean(np.concatenate(all_accuracies), axis=0)
            mean_accuracy = np.mean(mean_accuracies)
            print('\nIndividual accuracies: {} Average: {:.4f}'.format(
                mean_accuracies, mean_accuracy))
            with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f:
                for attr, acc in zip([
                        'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male',
                        'Young'
                ], mean_accuracies):
                    f.write('{}: {}\n'.format(attr, acc))
                f.write('Average: {}'.format(mean_accuracy))
        elif dataset == 'RaFD':
            mean_accuracy = np.mean(all_accuracies)
            print('\nAverage accuracies: {:.4f}'.format(mean_accuracy))
            with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f:
                f.write('Average accuracy: {}'.format(mean_accuracy))

    except Exception:
        traceback.print_exc()
    finally:
        print(" [*] Close main session!")
        sess.close()
def runModel(image_url, file_name, test_att, n_slide, image_labels,
             model_type):
    # ==============================================================================
    # =                                    param                                   =
    # ==============================================================================

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_name',
                        dest='experiment_name',
                        default="384_shortcut1_inject1_none_hd",
                        help='experiment_name')
    parser.add_argument('--test_att', dest='test_att', help='test_att')
    parser.add_argument('--test_int_min',
                        dest='test_int_min',
                        type=float,
                        default=-1.0,
                        help='test_int_min')
    parser.add_argument('--test_int_max',
                        dest='test_int_max',
                        type=float,
                        default=1.0,
                        help='test_int_max')
    args_ = parser.parse_args()

    if model_type == 0:
        experiment_name = args_.experiment_name
    else:
        experiment_name = "128_custom"

    print("EXPERIMENT NAME WORKING:" + experiment_name)

    with open('./output/%s/setting.txt' % experiment_name) as f:
        args = json.load(f)

    # model
    atts = args['atts']
    n_att = len(atts)
    img_size = args['img_size']
    shortcut_layers = args['shortcut_layers']
    inject_layers = args['inject_layers']
    enc_dim = args['enc_dim']
    dec_dim = args['dec_dim']
    dis_dim = args['dis_dim']
    dis_fc_dim = args['dis_fc_dim']
    enc_layers = args['enc_layers']
    dec_layers = args['dec_layers']
    dis_layers = args['dis_layers']
    # testing
    thres_int = args['thres_int']
    test_int_min = args_.test_int_min
    test_int_max = args_.test_int_max
    # others
    use_cropped_img = args['use_cropped_img']
    n_slide = int(n_slide)

    assert test_att is not None, 'test_att should be chosen in %s' % (
        str(atts))

    # ==============================================================================
    # =                                   graphs                                   =
    # ==============================================================================

    # data
    sess = tl.session()

    # get image
    print(image_url)
    if experiment_name == "128_custom":
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_align_celeba "
            + image_url)
    else:
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_crop_celeba "
            + image_url)

    print("Working")

    # pass image with labels to dataset
    te_data = data.Celeba('./data',
                          atts,
                          img_size,
                          1,
                          part='val',
                          sess=sess,
                          crop=not use_cropped_img,
                          image_labels=image_labels,
                          file_name=file_name)

    sample = None

    # models
    Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
    Gdec = partial(models.Gdec,
                   dim=dec_dim,
                   n_layers=dec_layers,
                   shortcut_layers=shortcut_layers,
                   inject_layers=inject_layers)

    # inputs
    xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3])
    _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

    # sample
    x_sample = Gdec(Genc(xa_sample, is_training=False),
                    _b_sample,
                    is_training=False)

    # ==============================================================================
    # =                                    test                                    =
    # ==============================================================================

    # initialization
    ckpt_dir = './output/%s/checkpoints' % experiment_name
    print("CHECKPOINT DIR: " + ckpt_dir)
    try:
        tl.load_checkpoint(ckpt_dir, sess)
    except:
        raise Exception(' [*] No checkpoint!')

    save_location = ""
    # sample
    try:
        for idx, batch in enumerate(te_data):
            xa_sample_ipt = batch[0]
            b_sample_ipt = batch[1]

            x_sample_opt_list = []

            for i in range(n_slide - 1, n_slide):
                test_int = (test_int_max -
                            test_int_min) / (n_slide - 1) * i + test_int_min
                _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                _b_sample_ipt[..., atts.index(test_att)] = test_int
                x_sample_opt_list.append(
                    sess.run(x_sample,
                             feed_dict={
                                 xa_sample: xa_sample_ipt,
                                 _b_sample: _b_sample_ipt
                             }))

            sample = np.concatenate(x_sample_opt_list, 2)
            save_location = '/output/%s/sample_testing_slide_%s/' % (
                experiment_name, test_att)
            save_dir = './output/%s/sample_testing_slide_%s' % (
                experiment_name, test_att)
            pylib.mkdir(save_dir)
            im.imwrite(sample.squeeze(0), '%s/%s' % (save_dir, file_name))

            print('%d.png done!' % (idx + 0))

            if (idx + 1 == te_data._img_num):
                break
    except:
        traceback.print_exc()
    finally:
        sess.close()

    if experiment_name == "128_custom":
        os.system("rm ./data/img_align_celeba/" + file_name)
    else:
        os.system("rm ./data/img_crop_celeba/" + file_name)

    return "http://129.32.22.10:7001" + save_location + file_name
Beispiel #6
0
    def test(self):

        img_size = self.config["imsize"]
        n_att = len(self.config["selectedAttrs"])
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        save_dir = self.config["projectSamples"]
        test_int = self.config["sampleThresInt"]
        # data
        sess = tl.session()

        SpecifiedImages = None
        if self.config["useSpecifiedImage"]:
            SpecifiedImages = self.config["specifiedTestImages"]
        te_data = Celeba(self.config["dataset_path"],
                         atts,
                         img_size,
                         1,
                         part='test',
                         sess=sess,
                         crop=(self.config["imCropSize"] > 0),
                         im_no=SpecifiedImages)

        # models
        package = __import__(self.config["com_base"] +
                             self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        package = __import__(self.config["com_base"] +
                             self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["stuInjectLayers"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state="stu")

        # inputs
        xa_sample = tf.placeholder(tf.float32,
                                   shape=[None, img_size, img_size, 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)
        print("Graph build success!")

        # ==============================================================================
        # =                                    test                                    =
        # ==============================================================================

        # Load pretrained model
        ckpt_dir = os.path.join(self.config["projectCheckpoints"],
                                self.config["ckpt_prefix"])
        restorer = tf.train.Saver()
        restorer.restore(sess, ckpt_dir)
        print("Load pretrained model successfully!")
        # tl.load_checkpoint(ckpt_dir, sess)

        # test
        print("Start to test!")
        test_slide = False
        multi_atts = False

        try:
            # multi_atts = test_atts is not None
            for idx, batch in enumerate(te_data):
                xa_sample_ipt = batch[0]
                a_sample_ipt = batch[1]
                b_sample_ipt_list = [
                    a_sample_ipt.copy()
                    for _ in range(n_slide if test_slide else 1)
                ]
                # if test_slide: # test_slide
                #     for i in range(n_slide):
                #         test_int = (test_int_max - test_int_min) / (n_slide - 1) * i + test_int_min
                #         b_sample_ipt_list[i] = (b_sample_ipt_list[i]*2-1) * thres_int
                #         b_sample_ipt_list[i][..., atts.index(test_att)] = test_int
                # elif multi_atts: # test_multiple_attributes
                #     for a in test_atts:
                #         i = atts.index(a)
                #         b_sample_ipt_list[-1][:, i] = 1 - b_sample_ipt_list[-1][:, i]
                #         b_sample_ipt_list[-1] = Celeba.check_attribute_conflict(b_sample_ipt_list[-1], atts[i], atts)
                # else: # test_single_attributes
                for i in range(len(atts)):
                    tmp = np.array(a_sample_ipt, copy=True)
                    tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                    tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                    b_sample_ipt_list.append(tmp)

                x_sample_opt_list = [
                    xa_sample_ipt,
                    np.full((1, img_size, img_size // 10, 3), -1.0)
                ]
                raw_a_sample_ipt = a_sample_ipt.copy()
                raw_a_sample_ipt = (raw_a_sample_ipt * 2 - 1) * thres_int
                for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                    _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                    if not test_slide:
                        # if multi_atts: # i must be 0
                        #     for t_att, t_int in zip(test_atts, test_ints):
                        #         _b_sample_ipt[..., atts.index(t_att)] = _b_sample_ipt[..., atts.index(t_att)] * t_int
                        if i > 0:  # i == 0 is for reconstruction
                            _b_sample_ipt[...,
                                          i - 1] = _b_sample_ipt[..., i -
                                                                 1] * test_int
                    x_sample_opt_list.append(
                        sess.run(x_sample,
                                 feed_dict={
                                     xa_sample: xa_sample_ipt,
                                     _b_sample: _b_sample_ipt,
                                     raw_b_sample: raw_a_sample_ipt
                                 }))
                sample = np.concatenate(x_sample_opt_list, 2)

                # if test_slide:     save_folder = 'sample_testing_slide'
                # # elif multi_atts:   save_folder = 'sample_testing_multi'
                # else:              save_folder = 'sample_testing'
                # save_dir = './output/%s/%s' % (experiment_name, save_folder)

                # pylib.mkdir(save_dir)
                im.imwrite(
                    sample.squeeze(0), '%s/%06d%s.png' %
                    (save_dir, idx + 182638 if SpecifiedImages is None else
                     SpecifiedImages[idx], '_%s' %
                     (str("jibaride")) if multi_atts else ''))

                print('%06d.png done!' %
                      (idx + 182638
                       if SpecifiedImages is None else SpecifiedImages[idx]))
        except:
            traceback.print_exc()
        finally:
            sess.close()