Exemplo n.º 1
0
class Solver(object):
    """docstring for Solver"""
    def __init__(self):
        super(Solver, self).__init__()

    def initialize(self, opt):
        self.opt = opt
        self.visual = Visualizer()
        self.visual.initialize(self.opt)

    def run_solver(self):
        if self.opt.mode == "train":
            self.train_networks()
        else:
            self.test_networks(self.opt)

    def train_networks(self):
        # init train setting
        self.init_train_setting()

        # for every epoch
        for epoch in range(self.opt.epoch_count, self.epoch_len + 1):
            # train network
            self.train_epoch(epoch)
            # update learning rate
            self.cur_lr = self.train_model.update_learning_rate()
            # save checkpoint if needed
            if epoch % self.opt.save_epoch_freq == 0:
                self.train_model.save_ckpt(epoch)

        # save the last epoch
        self.train_model.save_ckpt(self.epoch_len)

    def init_train_setting(self):
        self.train_dataset = create_dataloader(self.opt)
        self.train_model = create_model(self.opt)

        self.train_total_steps = 0
        self.epoch_len = self.opt.niter + self.opt.niter_decay
        self.cur_lr = self.opt.lr

    def train_epoch(self, epoch):
        epoch_start_time = time.time()
        epoch_steps = 0

        last_print_step_t = time.time()
        for idx, batch in enumerate(self.train_dataset):

            self.train_total_steps += self.opt.batch_size
            epoch_steps += self.opt.batch_size
            # train network
            self.train_model.feed_batch(batch)
            self.train_model.optimize_paras(
                train_gen=(idx % self.opt.train_gen_iter == 0))
            # print losses
            if self.train_total_steps % self.opt.print_losses_freq == 0:
                cur_losses = self.train_model.get_latest_losses()
                avg_step_t = (time.time() -
                              last_print_step_t) / self.opt.print_losses_freq
                last_print_step_t = time.time()
                # print loss info to command line
                info_dict = {
                    'epoch': epoch,
                    'epoch_len': self.epoch_len,
                    'epoch_steps': idx * self.opt.batch_size,
                    'epoch_steps_len': len(self.train_dataset),
                    'step_time': avg_step_t,
                    'cur_lr': self.cur_lr,
                    'log_path': os.path.join(self.opt.ckpt_dir,
                                             self.opt.log_file),
                    'losses': cur_losses
                }
                self.visual.print_losses_info(info_dict)

            # plot loss map to visdom
            if self.train_total_steps % self.opt.plot_losses_freq == 0 and self.visual.display_id > 0:
                cur_losses = self.train_model.get_latest_losses()
                epoch_steps = idx * self.opt.batch_size
                self.visual.display_current_losses(
                    epoch - 1, epoch_steps / len(self.train_dataset),
                    cur_losses)

            # display image on visdom
            if self.train_total_steps % self.opt.sample_img_freq == 0 and self.visual.display_id > 0:
                cur_vis = self.train_model.get_latest_visuals()
                self.visual.display_online_results(cur_vis, epoch)
                # latest_aus = model.get_latest_aus()
                # visual.log_aus(epoch, epoch_steps, latest_aus, opt.ckpt_dir)

    def test_networks(self, opt):
        self.init_test_setting(opt)
        self.test_ops()

    def init_test_setting(self, opt):
        self.test_dataset = create_dataloader(opt)
        self.test_model = create_model(opt)

    def test_ops(self):
        for batch_idx, batch in enumerate(self.test_dataset):
            with torch.no_grad():
                # interpolate several times
                faces_list = [batch['src_img'].float().numpy()]
                paths_list = [batch['src_path'], batch['tar_path']]
                for idx in range(self.opt.interpolate_len):
                    cur_alpha = (idx + 1.) / float(self.opt.interpolate_len)
                    cur_tar_aus = cur_alpha * batch['tar_aus'] + (
                        1 - cur_alpha) * batch['src_aus']
                    # print(batch['src_aus'])
                    # print(cur_tar_aus)
                    test_batch = {
                        'src_img': batch['src_img'],
                        'tar_aus': cur_tar_aus,
                        'src_aus': batch['src_aus'],
                        'tar_img': batch['tar_img']
                    }

                    self.test_model.feed_batch(test_batch)
                    self.test_model.forward()

                    cur_gen_faces = self.test_model.fake_img.cpu().float(
                    ).numpy()
                    faces_list.append(cur_gen_faces)
                faces_list.append(batch['tar_img'].float().numpy())
            self.test_save_imgs(faces_list, paths_list)

    def test_save_imgs(self, faces_list, paths_list):
        for idx in range(len(paths_list[0])):
            src_name = os.path.splitext(os.path.basename(
                paths_list[0][idx]))[0]
            tar_name = os.path.splitext(os.path.basename(
                paths_list[1][idx]))[0]

            if self.opt.save_test_gif:
                import imageio
                imgs_numpy_list = []
                for face_idx in range(len(faces_list) -
                                      1):  # remove target image
                    cur_numpy = np.array(
                        self.visual.numpy2im(faces_list[face_idx][idx]))
                    imgs_numpy_list.extend([cur_numpy for _ in range(3)])
                saved_path = os.path.join(self.opt.results,
                                          "%s_%s.gif" % (src_name, tar_name))
                imageio.mimsave(saved_path, imgs_numpy_list)
            else:
                # concate src, inters, tar faces
                concate_img = np.array(self.visual.numpy2im(
                    faces_list[0][idx]))
                for face_idx in range(1, len(faces_list)):
                    concate_img = np.concatenate(
                        (concate_img,
                         np.array(
                             self.visual.numpy2im(faces_list[face_idx][idx]))),
                        axis=1)
                concate_img = Image.fromarray(concate_img)
                # save image
                saved_path = os.path.join(self.opt.results,
                                          "%s_%s.jpg" % (src_name, tar_name))
                concate_img.save(saved_path)

            print("[Success] Saved images to %s" % saved_path)
Exemplo n.º 2
0
class Solver(object):
    """docstring for Solver"""
    def __init__(self):
        super(Solver, self).__init__()

    def initialize(self, opt):
        self.opt = opt
        self.visual = Visualizer()
        self.visual.initialize(self.opt)

    def run_solver(self):
        if self.opt.mode == "train":
            self.train_networks()
        else:
            self.test_networks()

    def train_networks(self):
        # init train setting
        self.init_train_setting()

        # for every epoch
        for epoch in range(self.opt.epoch_count, self.epoch_len + 1):
            # train network
            self.train_epoch(epoch)
            # update learning rate
            self.cur_lr = self.train_model.update_learning_rate()
            # save checkpoint if needed
            if epoch % self.opt.save_epoch_freq == 0:
                self.train_model.save_ckpt(epoch)

        # save the last epoch
        self.train_model.save_ckpt(self.epoch_len)

    def init_train_setting(self):
        self.train_dataset = create_dataloader(self.opt)
        self.train_model = create_model(self.opt)

        self.train_total_steps = 0
        self.epoch_len = self.opt.niter + self.opt.niter_decay
        self.cur_lr = self.opt.lr

    def train_epoch(self, epoch):
        epoch_start_time = time.time()
        epoch_steps = 0

        last_print_step_t = time.time()
        for idx, batch in enumerate(self.train_dataset):

            self.train_total_steps += self.opt.batch_size
            epoch_steps += self.opt.batch_size
            # train network
            self.train_model.feed_batch(batch)
            self.train_model.optimize_paras(train_recog=(idx % self.opt.train_recog_iter == 0), \
                                            train_dis=(idx % self.opt.train_dis_iter == 0))
            # print losses
            if self.train_total_steps % self.opt.print_losses_freq == 0:
                cur_losses = self.train_model.get_latest_losses()
                avg_step_t = (time.time() -
                              last_print_step_t) / self.opt.print_losses_freq
                last_print_step_t = time.time()
                # print loss info to command line
                info_dict = {
                    'epoch': epoch,
                    'epoch_len': self.epoch_len,
                    'epoch_steps': idx * self.opt.batch_size,
                    'epoch_steps_len': len(self.train_dataset),
                    'step_time': avg_step_t,
                    'cur_lr': self.cur_lr,
                    'log_path': os.path.join(self.opt.ckpt_dir,
                                             self.opt.log_file),
                    'losses': cur_losses
                }
                self.visual.print_losses_info(info_dict)

            # plot loss map to visdom
            if self.train_total_steps % self.opt.plot_losses_freq == 0 and self.visual.display_id > 0:
                cur_losses = self.train_model.get_latest_losses()
                epoch_steps = idx * self.opt.batch_size
                self.visual.display_current_losses(
                    epoch - 1, epoch_steps / len(self.train_dataset),
                    cur_losses)

            # display image on visdom
            if self.train_total_steps % self.opt.sample_img_freq == 0 and self.visual.display_id > 0:
                cur_vis = self.train_model.get_latest_visuals()
                self.visual.display_online_results(cur_vis, epoch)

    def test_networks(self):
        self.init_test_setting()
        self.test_ops()
        self.cal_f1_scores()

    def init_test_setting(self):
        self.test_dataset = create_dataloader(self.opt)
        self.test_model = create_model(self.opt)
        self.aus_id_list = list(
            map(lambda x: "AU%02d" % int(x), list(self.opt.aus_id.split(','))))

    def test_ops(self):
        real_aus_list = []
        pred_aus_list = []
        for batch_idx, batch in enumerate(self.test_dataset):
            with torch.no_grad():

                self.test_model.feed_batch(batch)
                self.test_model.forward()

                real_aus_list.extend(
                    list(self.test_model.img_aus.cpu().float().numpy()))
                pred_aus_list.extend(
                    list(self.test_model.gen_aus.cpu().float().numpy()))

                print(
                    ">>> %d/%d" %
                    (batch_idx * self.opt.batch_size, len(self.test_dataset)))

        self.real_aus = np.array(real_aus_list)
        self.pred_aus = np.array(pred_aus_list)

    def cal_f1_scores(self):
        f1_scores_list = []
        for idx in range(self.opt.aus_nc):
            cur_real_aus = self.real_aus[:, idx].flatten().astype(int)
            cur_pred_aus_raw = self.pred_aus[:, idx].flatten()
            cur_pred_aus = (cur_pred_aus_raw > 0.5).astype(
                int
            )  # convert to binary array, based on threadhold 0.5, sigmoid function

            print(">>> %s" % self.aus_id_list[idx])
            print("cur_real_aus", cur_real_aus.tolist())
            # print("cur_pred_aus_raw", cur_pred_aus_raw)
            print("cur_pred_aus", cur_pred_aus.tolist(), "\n")

            f1_scores_list.append(
                f1_score(cur_real_aus, cur_pred_aus, average='micro'))

        for k, v in zip(self.aus_id_list, f1_scores_list):
            print("%s: %f" % (k, v))

        print("Avg : %f" % (sum(f1_scores_list) / len(f1_scores_list)))

        # log the result to files
        with open(os.path.join(self.opt.result_dir, "results.csv"), 'a+') as f:
            f.write(
                "%s, %s\n" %
                (self.opt.ckpt_dir, ", ".join([str(x)
                                               for x in f1_scores_list])))