예제 #1
0
파일: solver.py 프로젝트: afrozalm/dtn_ext
class Solver(object):
    def __init__(self,
                 model,
                 batch_size=100,
                 pretrain_iter=20000,
                 train_iter=2000,
                 sample_iter=100,
                 real_dir='real-face',
                 caric_dir='caricature-face',
                 combined_dir='class-combined.pkl',
                 log_dir='logs',
                 n_classes=200,
                 sample_save_path='sample',
                 model_save_path='model',
                 pretrained_model='model/pre_model-4000',
                 test_model='model/dtn_ext-400',
                 src_disc_rep=1,
                 src_gen_rep=1,
                 trg_disc_rep=1,
                 trg_gen_rep=1):

        self.loader = DataLoader(batch_size)
        self.model = model
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.pretrain_iter = pretrain_iter
        self.combined_dir = combined_dir
        self.train_iter = train_iter
        self.sample_iter = sample_iter
        self.real_dir = real_dir
        self.caric_dir = caric_dir
        self.log_dir = log_dir
        self.sample_save_path = sample_save_path
        self.model_save_path = model_save_path
        self.pretrained_model = pretrained_model
        self.test_model = test_model
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.src_disc_rep = src_disc_rep
        self.src_gen_rep = src_gen_rep
        self.trg_disc_rep = trg_disc_rep
        self.trg_gen_rep = trg_gen_rep

    def load_real(self, image_dir, split='train'):
        print('loading real faces..')
        image_file = 'train.pkl' if split == 'train' else 'test.pkl'
        image_dir = os.path.join(image_dir, image_file)
        with open(image_dir, 'rb') as f:
            real_faces = pickle.load(f)
        mean = np.mean(real_faces['X'])
        images = real_faces['X'] / mean - 1
        labels = real_faces['y']
        print('finished loading real faces..!')
        return images, labels

    def load_caric(self, image_dir, split='train'):
        print('loading caricature faces..')
        image_file = 'train.pkl' if split == 'train' else 'test.pkl'
        image_dir = os.path.join(image_dir, image_file)
        with open(image_dir, 'rb') as f:
            caric = pickle.load(f)
        mean = np.mean(caric['X'])
        images = caric['X'] / mean - 1
        labels = caric['y']
        print('finished loading caricature faces..!')
        return images, labels

    def load_combined(self):
        print('loading combined images ...')
        with open(self.combined_dir, 'rb') as f:
            combined_imgs = pickle.load(f)
        for lbl in combined_imgs:
            mean_r = np.mean(combined_imgs[lbl]['real'])
            mean_c = np.mean(combined_imgs[lbl]['caric'])
            combined_imgs[lbl][
                'real'] = combined_imgs[lbl]['real'] / mean_r - 1
            combined_imgs[lbl][
                'caric'] = combined_imgs[lbl]['caric'] / mean_c - 1
        print('finished loading combined_imgs')
        return combined_imgs

    def merge_images(self, sources, targets, k=10):
        _, h, w, _ = sources.shape
        row = int(np.sqrt(self.batch_size))
        merged = np.zeros([row * h, row * w * 2, 3])

        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h, :] = s
            merged[i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h, :] = t
        return merged

    def get_pairs(self, combined_images, label_set, set_type='positive'):
        def get_pos_pair(label):
            toss = np.random.uniform()
            if toss < 0.5:
                real_img = sample(combined_images[label]['real'], 1)[0]
                caric_img = sample(combined_images[label]['caric'], 1)[0]
                return [real_img, caric_img]
            elif toss > 0.75:
                try:
                    return sample(combined_images[label]['caric'], 2)
                except:
                    return sample(combined_images[label]['real'], 2)
            else:
                try:
                    return sample(combined_images[label]['real'], 2)
                except:
                    real_img = sample(combined_images[label]['real'], 1)[0]
                    caric_img = sample(combined_images[label]['caric'], 1)[0]
                    return [real_img, caric_img]

        def get_neg_pair(label):
            toss = np.random.uniform()
            neg_lbl = sample(label_set - set([label]), 1)[0]
            if toss < 0.5:
                neg_img = sample(combined_images[neg_lbl]['real'], 1)[0]
                img = sample(combined_images[label]['caric'], 1)[0]
            elif toss > 0.75:
                neg_img = sample(combined_images[neg_lbl]['caric'], 1)[0]
                img = sample(combined_images[label]['caric'], 1)[0]
            else:
                neg_img = sample(combined_images[neg_lbl]['caric'], 1)[0]
                img = sample(combined_images[label]['caric'], 1)[0]
            return [img, neg_img]

        # some labels for positive pairs
        perm = sample(label_set, self.batch_size)
        if set_type == 'positive':
            pos_ones, pos_twos = zip(*map(get_pos_pair, perm))
            return np.asarray(pos_ones), np.asarray(pos_twos)
        else:
            neg_ones, neg_twos = zip(*map(get_neg_pair, perm))
            return np.asarray(neg_ones), np.asarray(neg_twos)

    def pretrain(self):

        # make directory if not exists
        if tf.gfile.Exists(self.log_dir):
            tf.gfile.DeleteRecursively(self.log_dir)
        tf.gfile.MakeDirs(self.log_dir)

        # load real faces
        train_images_r, train_labels_r = self.load_real(self.real_dir,
                                                        split='train')
        test_images_r, test_labels_r = self.load_real(self.real_dir,
                                                      split='test')
        train_images_c, train_labels_c = self.load_caric(self.caric_dir,
                                                         split='train')
        test_images_c, test_labels_c = self.load_caric(self.caric_dir,
                                                       split='test')
        train_images = np.vstack((train_images_r, train_images_c))
        train_labels = np.hstack((train_labels_r, train_labels_c))
        test_images = np.vstack((test_images_r, test_images_c))
        test_labels = np.hstack((test_labels_r, test_labels_c))
        combined_images = self.load_combined()

        label_set = set(np.hstack((train_labels, test_labels)))

        self.loader.add_dataset('caric_faces', train_images_c)
        self.loader.add_dataset('train_images', train_images)
        self.loader.add_dataset('test_images', test_images)
        self.loader.add_dataset('train_labels', train_labels)
        self.loader.add_dataset('test_labels', test_labels)
        self.loader.link_datasets(group_name='train',
                                  members=['train_labels', 'train_images'])
        self.loader.link_datasets(group_name='test',
                                  members=['test_labels', 'test_images'])

        # build a graph
        model = self.model
        model.build_model()

        with tf.Session(config=self.config) as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver()
            if self.pretrained_model != '':
                print('loading pretrained model ..')
                saver.restore(sess, self.pretrained_model)
            summary_writer = tf.summary.FileWriter(
                logdir=self.log_dir, graph=tf.get_default_graph())

            for step in range(self.pretrain_iter + 1):
                batch_labels, batch_images = \
                    self.loader.next_group_batch('train')
                caric_batch = self.loader.next_batch('caric_faces')
                pos_ones, pos_twos = self.get_pairs(combined_images,
                                                    label_set,
                                                    set_type='positive')
                neg_ones, neg_twos = self.get_pairs(combined_images,
                                                    label_set,
                                                    set_type='negative')
                feed_dict = {
                    model.images: batch_images,
                    model.labels: batch_labels,
                    model.caric_images: caric_batch,
                    model.pos_ones: pos_ones,
                    model.pos_twos: pos_twos,
                    model.neg_ones: neg_ones,
                    model.neg_twos: neg_twos
                }
                sess.run(model.train_op, feed_dict)

                if (step + 1) % 10 == 0:
                    summary, l, acc = sess.run(
                        [model.summary_op, model.loss, model.accuracy],
                        feed_dict)
                    # rand_idxs = np.random.permutation(test_images.shape[0])[:self.batch_size]
                    batch_labels, batch_images = \
                        self.loader.next_group_batch('test')
                    caric_batch = self.loader.next_batch('caric_faces')
                    pos_ones, pos_twos = self.get_pairs(combined_images,
                                                        label_set,
                                                        set_type='positive')
                    neg_ones, neg_twos = self.get_pairs(combined_images,
                                                        label_set,
                                                        set_type='negative')
                    test_acc, _ = \
                        sess.run(fetches=[model.accuracy, model.loss],
                                 feed_dict={model.images: batch_images,
                                            model.labels: batch_labels,
                                            model.caric_images: caric_batch,
                                            model.pos_ones: pos_ones,
                                            model.pos_twos: pos_twos,
                                            model.neg_ones: neg_ones,
                                            model.neg_twos: neg_twos})
                    summary_writer.add_summary(summary, step)
                    print(
                        'Step: [%d/%d] loss: [%.6f] train acc: [%.2f] test acc [%.2f]'
                        % (step + 1, self.pretrain_iter, l, acc, test_acc))

                if (step + 1) % 1000 == 0:
                    saver.save(sess,
                               os.path.join(self.model_save_path, 'pre_model'),
                               global_step=step + 1)
                    print('pre_model-%d saved..!' % (step + 1))

    def train(self):
        # load faces
        real_images, real_labels = self.load_real(self.real_dir, split='train')
        caric_images, caric_labels = self.load_caric(self.caric_dir,
                                                     split='train')
        combined_images = self.load_combined()
        label_set = set(np.hstack((real_labels, caric_labels)))

        self.loader.add_dataset('real_images', real_images)
        self.loader.add_dataset('caric_images', caric_images)
        self.loader.add_dataset('real_labels', real_labels)
        self.loader.add_dataset('caric_labels', caric_labels)

        # build a graph
        model = self.model
        model.build_model()

        # make directory if not exists
        if tf.gfile.Exists(self.log_dir):
            tf.gfile.DeleteRecursively(self.log_dir)
        tf.gfile.MakeDirs(self.log_dir)

        with tf.Session(config=self.config) as sess:
            # initialize G and D
            tf.global_variables_initializer().run()
            # restore variables of F and G
            print('loading pretrained model F..')
            f_variables_to_restore = \
                slim.get_model_variables(scope='content_extractor')
            f_restorer = tf.train.Saver(f_variables_to_restore)
            f_restorer.restore(sess, self.pretrained_model)

            print('loading pretrained model G..')
            g_variables_to_restore = \
                slim.get_model_variables(scope='generator')
            g_restorer = tf.train.Saver(g_variables_to_restore)
            g_restorer.restore(sess, self.pretrained_model)

            summary_writer = tf.summary.FileWriter(
                logdir=self.log_dir, graph=tf.get_default_graph())
            saver = tf.train.Saver()

            print('start training..!')
            f_interval = 15
            for step in range(self.train_iter + 1):

                i = step % int(real_images.shape[0] / self.batch_size)
                src_images = self.loader.next_batch('real_images')
                pos_ones, pos_twos = self.get_pairs(combined_images,
                                                    label_set,
                                                    set_type='positive')
                neg_ones, neg_twos = self.get_pairs(combined_images,
                                                    label_set,
                                                    set_type='negative')

                feed_dict = {
                    model.src_images: src_images,
                    model.pos_ones: pos_ones,
                    model.pos_twos: pos_twos,
                    model.neg_ones: neg_ones,
                    model.neg_twos: neg_twos
                }

                for _ in xrange(self.src_disc_rep):
                    sess.run(model.d_train_op_src, feed_dict)
                for _ in xrange(self.src_gen_rep):
                    sess.run([model.g_train_op_src], feed_dict)

                if step > 1600:
                    f_interval = 30

                if i % f_interval == 0:
                    sess.run(model.f_train_op_src, feed_dict)

                if (step + 1) % 10 == 0:
                    summary, dl, gl, fl = sess.run([
                        model.summary_op_src, model.d_loss_src,
                        model.g_loss_src, model.f_loss_src
                    ], feed_dict)
                    summary_writer.add_summary(summary, step)
                    print(
                        '[Source] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f] f_loss: [%.6f]'
                        % (step + 1, self.train_iter, dl, gl, fl))

                # train the model for target domain T
                trg_images = self.loader.next_batch('caric_images')
                feed_dict = {
                    model.src_images: src_images,
                    model.trg_images: trg_images,
                    model.pos_ones: pos_ones,
                    model.pos_twos: pos_twos,
                    model.neg_ones: neg_ones,
                    model.neg_twos: neg_twos
                }
                for _ in xrange(self.trg_disc_rep):
                    sess.run(model.d_train_op_trg, feed_dict)
                for _ in xrange(self.trg_gen_rep):
                    sess.run(model.g_train_op_trg, feed_dict)

                if (step + 1) % 10 == 0:
                    summary, dl, gl = sess.run([
                        model.summary_op_trg, model.d_loss_trg,
                        model.g_loss_trg
                    ], feed_dict)
                    summary_writer.add_summary(summary, step)
                    print(
                        '[Target] step: [%d/%d] d_loss: [%.6f] g_loss: [%.6f]'
                        % (step + 1, self.train_iter, dl, gl))

                if (step + 1) % 200 == 0:
                    saver.save(sess,
                               os.path.join(self.model_save_path, 'dtn_ext'),
                               global_step=step + 1)
                    print('model/dtn_ext-%d saved' % (step + 1))

                if (step + 1) % 1000 == 0:
                    for i in range(self.sample_iter):
                        # train model for source domain S
                        batch_images = self.loader.next_batch('real_images')
                        feed_dict = {model.src_images: batch_images}
                        sampled_batch_images = sess.run(
                            model.fake_images, feed_dict)

                        # merge and save source images and sampled target image
                        merged = self.merge_images(batch_images,
                                                   sampled_batch_images)
                        path = os.path.join(
                            self.sample_save_path, 'sample-%d-to-%d.png' %
                            (i * self.batch_size, (i + 1) * self.batch_size))
                        scipy.misc.imsave(path, merged)
                        print('saved %s' % path)

    def eval(self):
        # build model
        model = self.model
        model.build_model()

        # load real faces
        real_images, _ = self.load_real(self.real_dir)
        self.loader.add_dataset(name='real_images', data_ptr=real_images)

        with tf.Session(config=self.config) as sess:
            # initialize G and D
            tf.global_variables_initializer().run()
            # restore variables of F and G
            print('loading pretrained model F..')
            f_variables_to_restore = \
                slim.get_model_variables(scope='content_extractor')
            f_restorer = tf.train.Saver(f_variables_to_restore)
            f_restorer.restore(sess, self.test_model)

            print('loading pretrained model G..')
            g_variables_to_restore = \
                slim.get_model_variables(scope='generator')
            g_restorer = tf.train.Saver(g_variables_to_restore)
            g_restorer.restore(sess, self.test_model)

            print('start sampling..!')
            for i in range(self.sample_iter):
                # train model for source domain S
                batch_images = self.loader.next_batch('real_images')
                feed_dict = {model.images: batch_images}
                sampled_batch_images = sess.run(model.sampled_images,
                                                feed_dict)

                # merge and save source images and sampled target images
                merged = self.merge_images(batch_images, sampled_batch_images)
                path = os.path.join(
                    self.sample_save_path,
                    'sample-%d-to-%d.png' % (i * self.batch_size,
                                             (i + 1) * self.batch_size))
                scipy.misc.imsave(path, merged)
                print('saved %s' % path)
예제 #2
0
class Solver(object):

    def __init__(self, model, batch_size=100, pretrain_iter=20000,
                 train_iter=2000, sample_iter=100,
                 real_dir='real-face', caric_dir='caricature-face',
                 combined_dir='class-combined.pkl',
                 log_dir='logs', n_classes=200,
                 sample_save_path='sample',
                 model_save_path='model',
                 pretrained_model='model/pre_model-4000',
                 test_model='model/cycle-400',
                 disc_rep=1,
                 gen_rep=1):

        self.loader = DataLoader(batch_size)
        self.model = model
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.pretrain_iter = pretrain_iter
        self.combined_dir = combined_dir
        self.train_iter = train_iter
        self.sample_iter = sample_iter
        self.real_dir = real_dir
        self.caric_dir = caric_dir
        self.log_dir = log_dir
        self.sample_save_path = sample_save_path
        self.model_save_path = model_save_path
        self.pretrained_model = pretrained_model
        self.test_model = test_model
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.disc_rep = disc_rep
        self.gen_rep = gen_rep

    def load_real(self, image_dir, split='train'):
        print ('loading real faces..')
        image_file = 'train.pkl' if split == 'train' else 'test.pkl'
        image_dir = os.path.join(image_dir, image_file)
        with open(image_dir, 'rb') as f:
            real_faces = pickle.load(f)
        mean = np.mean(real_faces['X'])
        images = real_faces['X'] / mean - 1
        labels = real_faces['y']
        print ('finished loading real faces..!')
        return images, labels

    def load_caric(self, image_dir, split='train'):
        print ('loading caricature faces..')
        image_file = 'train.pkl' if split == 'train' else 'test.pkl'
        image_dir = os.path.join(image_dir, image_file)
        with open(image_dir, 'rb') as f:
            caric = pickle.load(f)
        mean = np.mean(caric['X'])
        images = caric['X'] / mean - 1
        labels = caric['y']
        print ('finished loading caricature faces..!')
        return images, labels

    def load_combined(self):
        print('loading combined images ...')
        with open(self.combined_dir, 'rb') as f:
            combined_imgs = pickle.load(f)
        for lbl in combined_imgs:
            mean_r = np.mean(combined_imgs[lbl]['real'])
            mean_c = np.mean(combined_imgs[lbl]['caric'])
            combined_imgs[lbl]['real'] = \
                combined_imgs[lbl]['real'] / mean_r - 1
            combined_imgs[lbl]['caric'] = \
                combined_imgs[lbl]['caric'] / mean_c - 1
        print('finished loading combined_imgs')
        return combined_imgs

    def get_pairs(self, combined_images, label_set):
        def get_base_pos_neg(label):
            c_base, c_pos = sample(combined_images[label]['caric'], 2)
            r_base, r_pos = sample(combined_images[label]['real'], 2)
            c_neg_lbl, r_neg_lbl = sample(label_set - set([label]), 2)
            c_neg = sample(combined_images[c_neg_lbl]['caric'], 1)[0]
            r_neg = sample(combined_images[r_neg_lbl]['real'], 1)[0]
            return [c_base, c_pos, c_neg,
                    r_base, r_pos, r_neg]

        # some labels for positive pairs
        perm = sample(label_set, self.batch_size)
        all_pairs = zip(*map(get_base_pos_neg, perm))
        return map(np.asarray, all_pairs)

    def merge_images(self, sources, targets, k=10):
        _, h, w, _ = sources.shape
        row = int(np.sqrt(self.batch_size))
        merged = np.zeros([row * h, row * w * 2, 3])

        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h, :] = s
            merged[i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h, :] = t
        return merged

    def pretrain(self):

        # make directory if not exists
        if tf.gfile.Exists(self.log_dir):
            tf.gfile.DeleteRecursively(self.log_dir)
        tf.gfile.MakeDirs(self.log_dir)

        # load real faces
        train_images_r, train_labels_r = self.load_real(self.real_dir,
                                                        split='train')
        test_images_r, test_labels_r = self.load_real(self.real_dir,
                                                      split='test')
        train_images_c, train_labels_c = self.load_caric(self.caric_dir,
                                                         split='train')
        test_images_c, test_labels_c = self.load_caric(self.caric_dir,
                                                       split='test')
        train_labels = np.hstack((train_labels_r, train_labels_c))
        test_labels = np.hstack((test_labels_r, test_labels_c))
        combined_images = self.load_combined()
        label_set = set(np.hstack((train_labels, test_labels)))

        self.loader.add_dataset('caric_faces_tr', train_images_c)
        self.loader.add_dataset('caric_labels_tr', train_labels_c)
        self.loader.add_dataset('real_faces_tr', train_images_r)
        self.loader.add_dataset('real_labels_tr', train_labels_r)
        self.loader.link_datasets('caric_tr',
                                  ['caric_labels_tr', 'caric_faces_tr'])
        self.loader.link_datasets('real_tr',
                                  ['real_labels_tr', 'real_faces_tr'])
        self.loader.add_dataset('caric_faces_te', test_images_c)
        self.loader.add_dataset('caric_labels_te', test_labels_c)
        self.loader.add_dataset('real_faces_te', test_images_r)
        self.loader.add_dataset('real_labels_te', test_labels_r)
        self.loader.link_datasets('caric_te',
                                  ['caric_labels_te', 'caric_faces_te'])
        self.loader.link_datasets('real_te',
                                  ['real_labels_te', 'real_faces_te'])

        # build a graph
        model = self.model
        model.build_model()

        with tf.Session(config=self.config) as sess:
            tf.global_variables_initializer().run()
            saver = tf.train.Saver()
            if self.pretrained_model != '':
                print ('loading pretrained model ..')
                saver.restore(sess, self.pretrained_model)
            summary_writer = tf.summary.FileWriter(
                logdir=self.log_dir, graph=tf.get_default_graph())

            for step in range(self.pretrain_iter + 1):
                caric_labels, caric_images = \
                    self.loader.next_group_batch('caric_tr')
                real_labels, real_images = \
                    self.loader.next_group_batch('real_tr')
                cb, cp, cn, rb, rp, rn = self.get_pairs(combined_images,
                                                        label_set)
                feed_dict = {model.real_images: real_images,
                             model.real_labels: real_labels,
                             model.caric_images: caric_images,
                             model.caric_labels: caric_labels,
                             model.c_base: cb,
                             model.c_pos: cp,
                             model.c_neg: cn,
                             model.r_base: rb,
                             model.r_pos: rp,
                             model.r_neg: rn}
                sess.run(model.train_op, feed_dict)

                if (step + 1) % 10 == 0:
                    summary, l, acc = sess.run([model.summary_op,
                                                model.loss, model.accuracy],
                                               feed_dict)

                    caric_labels, caric_images = \
                        self.loader.next_group_batch('caric_te')
                    real_labels, real_images = \
                        self.loader.next_group_batch('real_te')
                    cb, cp, cn, rb, rp, rn = self.get_pairs(combined_images,
                                                            label_set)
                    feed_dict = {model.real_images: real_images,
                                 model.real_labels: real_labels,
                                 model.caric_images: caric_images,
                                 model.caric_labels: caric_labels,
                                 model.c_base: cb,
                                 model.c_pos: cp,
                                 model.c_neg: cn,
                                 model.r_base: rb,
                                 model.r_pos: rp,
                                 model.r_neg: rn}
                    test_acc, _ = \
                        sess.run(fetches=[model.accuracy, model.loss],
                                 feed_dict=feed_dict)
                    summary_writer.add_summary(summary, step)
                    print ('Step: [%d/%d] loss: [%.6f] train acc: [%.2f] test acc [%.2f]'
                           % (step + 1, self.pretrain_iter, l, acc, test_acc))

                if (step + 1) % 1000 == 0:
                    saver.save(sess, os.path.join(self.model_save_path,
                                                  'pre_model'),
                               global_step=step + 1)
                    print ('pre_model-%d saved..!' % (step + 1))

    def train(self):
        # make directory if not exists
        if tf.gfile.Exists(self.log_dir):
            tf.gfile.DeleteRecursively(self.log_dir)
        tf.gfile.MakeDirs(self.log_dir)

        # load faces
        real_images, real_labels = self.load_real(self.real_dir,
                                                  split='train')
        caric_images, caric_labels = self.load_caric(self.caric_dir,
                                                     split='train')

        combined_images = self.load_combined()
        labels = np.hstack((real_labels, caric_labels))
        label_set = set(labels)

        self.loader.add_dataset('real_images', real_images)
        self.loader.add_dataset('caric_images', caric_images)
        self.loader.add_dataset('real_labels', real_labels)
        self.loader.add_dataset('caric_labels', caric_labels)
        self.loader.link_datasets('real', ['real_labels', 'real_images'])
        self.loader.link_datasets('caric', ['caric_labels', 'caric_images'])

        # build a graph
        model = self.model
        model.build_model()

        with tf.Session(config=self.config) as sess:
            # initialize
            tf.global_variables_initializer().run()

            # restore variables
            if self.pretrained_model != '':
                pretrained_scopes = ['Gen_Real2Caric', 'Gen_Caric2Real',
                                     'classifier']
                print ('loading pretrained model ..')
                for scope in pretrained_scopes:
                    variables_to_restore = \
                        slim.get_model_variables(scope=scope)
                    restorer = tf.train.Saver(variables_to_restore)
                    restorer.restore(sess, self.pretrained_model)

            summary_writer = tf.summary.FileWriter(
                logdir=self.log_dir, graph=tf.get_default_graph())
            saver = tf.train.Saver()

            print ('start training..!')

            for step in range(self.train_iter + 1):

                real_labels, real_images = \
                    self.loader.next_group_batch('real')
                caric_labels, caric_images = \
                    self.loader.next_group_batch('caric')
                cb, cp, cn, rb, rp, rn = self.get_pairs(combined_images,
                                                        label_set)
                feed_dict = {model.real_images: real_images,
                             model.real_labels: real_labels,
                             model.caric_images: caric_images,
                             model.caric_labels: caric_labels,
                             model.c_base: cb,
                             model.c_pos: cp,
                             model.c_neg: cn,
                             model.r_base: rb,
                             model.r_pos: rp,
                             model.r_neg: rn}

                for _ in xrange(self.disc_rep):
                    sess.run(model.disc_op, feed_dict)
                for _ in xrange(self.gen_rep):
                    sess.run(model.gen_op, feed_dict)

                if (step + 1) % 10 == 0:
                    summary, discl, gl = \
                        sess.run([model.summary_op,
                                  model.loss_disc,
                                  model.loss_gen],
                                 feed_dict)
                    summary_writer.add_summary(summary, step)
                    print ('[Source] step: [%d/%d] disc_loss: [%.6f] gen_loss: [%.6f]'
                           % (step + 1, self.train_iter, discl, gl))

                if (step + 1) % 200 == 0:
                    saver.save(sess, os.path.join(
                        self.model_save_path, 'cycle'), global_step=step + 1)
                    print ('model/cycle-%d saved' % (step + 1))

                if (step + 1) % 2000 == 0:
                    for i in range(self.sample_iter):
                        # train model for source domain S
                        batch_images = self.loader.next_batch('real_images')
                        feed_dict = {model.real_images: batch_images}
                        sampled_batch_images = sess.run(model.fake_caric,
                                                        feed_dict)

                        # merge and save source images and sampled target image
                        merged = self.merge_images(batch_images,
                                                   sampled_batch_images)
                        path = os.path.join(self.sample_save_path,
                                            'sample-%d-to-%d.png' %
                                            (i * self.batch_size,
                                             (i + 1) * self.batch_size))
                        scipy.misc.imsave(path, merged)
                        print ('saved %s' % path)

    def eval(self):
        # build model
        model = self.model
        model.build_model()

        # load real faces
        real_images, _ = self.load_real(self.real_dir)
        self.loader.add_dataset(name='real_images',
                                data_ptr=real_images)

        with tf.Session(config=self.config) as sess:
            # load trained parameters
            print ('loading test model..')
            saver = tf.train.Saver()
            saver.restore(sess, self.test_model)

            print ('start sampling..!')
            for i in range(self.sample_iter):
                # train model for source domain S
                batch_images = self.loader.next_batch('real_images')
                feed_dict = {model.images: batch_images}
                sampled_batch_images = sess.run(model.sampled_images,
                                                feed_dict)

                # merge and save source images and sampled target images
                merged = self.merge_images(batch_images, sampled_batch_images)
                path = os.path.join(self.sample_save_path,
                                    'sample-%d-to-%d.png' %
                                    (i * self.batch_size, (i + 1) * self.batch_size))
                scipy.misc.imsave(path, merged)
                print ('saved %s' % path)