Exemplo n.º 1
0
 def __init__(self):
     self.opt = TrainOptions().parse()
     self.model = ModelsFactory.get_by_name(self.opt)
     self.tb_visualizer = TBVisualizer(self.opt)
     if self.get_training_data():
         self.setup_train_test_sets()
         self.train()
Exemplo n.º 2
0
    def __init__(self):
        self._opt = TrainOptions().parse()

        self.train_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.train_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.train_list_w_g_list,
                                            transform=self.train_transform),
                                            batch_size=self._opt.batch_size, 
                                            num_workers = self._opt.n_threads_train,
                                            shuffle=True, drop_last=True)

        self._dataset_test = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.test_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.test_list_w_g_list,
                                            transform=self.test_transform),
                                            batch_size=self._opt.batch_size,
                                            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
Exemplo n.º 3
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        self.data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        self.data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
        self._dataset_train = self.data_loader_train.load_data()
        self._dataset_test = self.data_loader_test.load_data()

        self._dataset_train_size = len(self.data_loader_train)
        self._dataset_test_size = len(self.data_loader_test)
        print('#train images = %s' % self._dataset_train_size)
        print('#test images = %s' % self._dataset_test_size)
        
        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        
        self._train()
Exemplo n.º 4
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        # data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        # data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
        data_loader_train = get_dataloader(self._opt.data_dir, img_size=self._opt.image_size, selected_attrs=self._opt.selected_attrs, mode='train', batch_size=self._opt.batch_size)
        data_loader_test = get_dataloader(self._opt.data_dir, img_size=self._opt.image_size, selected_attrs=self._opt.selected_attrs, mode='val', batch_size=self._opt.batch_size)

        self._dataset_train = data_loader_train
        self._dataset_test = data_loader_test

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train image batches = %d' % self._dataset_train_size)  # dataloader size, not dataset size
        print('#test image batches = %d' % self._dataset_test_size)    # dataloader size, not dataset size

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
Exemplo n.º 5
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train video clips = %d' % self._dataset_train_size)
        print('#test video clips = %d' % self._dataset_test_size)

        self._model = Impersonator(self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
Exemplo n.º 6
0
    def __init__(self):
        self._opt = TrainOptions().parse()

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        data_loader_train = CustomDatasetDataLoader(self._opt, mode='train')
        data_loader_val = CustomDatasetDataLoader(self._opt, mode='val')
        #data_loader_train = CustomDatasetDataLoader(self._opt, mode='test')
        #data_loader_val = CustomDatasetDataLoader(self._opt, mode='test')

        self._dataset_train = data_loader_train.load_data()
        self._dataset_val = data_loader_val.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_val_size = len(data_loader_val)
        print('#train images = %d' % self._dataset_train_size)
        print('#val images = %d' % self._dataset_val_size)

        self._train()
Exemplo n.º 7
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)
        print('TRAIN IMAGES FOLDER = %s' %
              data_loader_train._dataset._imgs_dir)
        print('TEST IMAGES FOLDER = %s' % data_loader_test._dataset._imgs_dir)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        self._writer = SummaryWriter()

        self._input_imgs = torch.empty(0, 3, self._opt.image_size,
                                       self._opt.image_size)
        self._fake_imgs = torch.empty(0, 3, self._opt.image_size,
                                      self._opt.image_size)
        self._rec_real_imgs = torch.empty(0, 3, self._opt.image_size,
                                          self._opt.image_size)
        self._fake_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._fake_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                           self._opt.image_size)
        self._rec_real_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._cyc_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                              self._opt.image_size)
        self._real_conds = list()
        self._desired_conds = list()

        self._train()
Exemplo n.º 8
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
Exemplo n.º 9
0
    def __init__(self):

        # TO GET THEM:
        # clusters_pose_map, clusters_rot_map, clusters_root_rot = self.get_rot_map(self._model.clusters_tensor, torch.zeros((25, 3)).cuda())
        #for i in range(25):
        #    import matplotlib.pyplot
        #    from mpl_toolkits.mplot3d import Axes3D
        #    ax = matplotlib.pyplot.figure().add_subplot(111, projection='3d')
        #    #i = 0
        #    add_group_meshs(ax, cluster_verts[i].cpu().data.numpy(), hand_faces, c='b')
        #    cam_equal_aspect_3d(ax, cluster_verts[i].cpu().data.numpy())
        #    print(i)
        #    matplotlib.pyplot.pause(1)
        #    matplotlib.pyplot.close()

        # FINGER LIMIT ANGLE:
        #self.limit_bigfinger = torch.FloatTensor([1.0222, 0.0996, 0.7302]) # 36:39
        #self.limit_bigfinger = torch.FloatTensor([1.2030, 0.12, 0.25]) # 36:39
        #self.limit_bigfinger = torch.FloatTensor([1.2, -0.4, 0.25]) # 36:39
        self.limit_bigfinger = torch.FloatTensor([1.2, -0.6, 0.25])  # 36:39
        self.limit_index = torch.FloatTensor([-0.0827, -0.4389, 1.5193])  # 0:3
        self.limit_middlefinger = torch.FloatTensor(
            [-2.9802e-08, -7.4506e-09, 1.4932e+00])  # 9:12
        self.limit_fourth = torch.FloatTensor([0.1505, 0.3769,
                                               1.5090])  # 27:30
        self.limit_small = torch.FloatTensor([-0.6235, 0.0275,
                                              1.0519])  # 18:21
        if torch.cuda.is_available():
            self.limit_bigfinger = self.limit_bigfinger.cuda()
            self.limit_index = self.limit_index.cuda()
            self.limit_middlefinger = self.limit_middlefinger.cuda()
            self.limit_fourth = self.limit_fourth.cuda()
            self.limit_small = self.limit_small.cuda()

        self._bigfinger_vertices = [
            697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709,
            710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722,
            723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735,
            736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748,
            749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761,
            762, 763, 764, 765, 766, 767, 768
        ]

        self._indexfinger_vertices = [
            46, 47, 48, 49, 56, 57, 58, 59, 86, 87, 133, 134, 155, 156, 164,
            165, 166, 167, 174, 175, 189, 194, 195, 212, 213, 221, 222, 223,
            224, 225, 226, 237, 238, 272, 273, 280, 281, 282, 283, 294, 295,
            296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308,
            309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321,
            322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334,
            335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347,
            348, 349, 350, 351, 352, 353, 354, 355
        ]

        self._middlefinger_vertices = [
            356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 372,
            373, 374, 375, 376, 377, 381, 382, 385, 386, 387, 388, 389, 390,
            391, 392, 393, 394, 395, 396, 397, 398, 400, 401, 402, 403, 404,
            405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417,
            418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430,
            431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
            444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456,
            457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467
        ]

        self._fourthfinger_vertices = [
            468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 482,
            483, 484, 485, 486, 487, 491, 492, 495, 496, 497, 498, 499, 500,
            501, 502, 503, 504, 505, 506, 507, 508, 511, 512, 513, 514, 515,
            516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528,
            529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541,
            542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554,
            555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567,
            568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578
        ]

        self._smallfinger_vertices = [
            580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 598,
            599, 600, 601, 602, 603, 609, 610, 613, 614, 615, 616, 617, 618,
            619, 620, 621, 622, 623, 624, 625, 626, 628, 629, 630, 631, 632,
            633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645,
            646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658,
            659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671,
            672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684,
            685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695
        ]

        self._opt = TestOptions().parse()
        #assert self._opt.load_epoch > 0, 'Use command --load_epoch to indicate the epoch you want to load - and choose a trained model'

        # Let's set batch size at 2 since we're only getting one image so far
        self._opt.batch_size = 1

        self._opt.n_threads_train = self._opt.n_threads_test
        data_loader_test = CustomDatasetDataLoader(self._opt, mode='test')
        self._dataset_test = data_loader_test.load_data()
        self._dataset_test_size = len(data_loader_test)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._total_steps = self._dataset_test_size
        self._display_visualizer_test(20, self._total_steps)
Exemplo n.º 10
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(self._opt.load_epoch + 1, self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' % (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                  (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay, time_epoch,
                   time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time() - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time() - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch+1) % self._opt.train_G_every_n_iterations == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals, train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch, do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' % (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch, self._iters_per_epoch, errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_test):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.iteritems():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.iterkeys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors, total_steps, is_train=False)
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=False)

        # set model back to train
        self._model.set_train()
Exemplo n.º 11
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        #data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        #self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        #self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        #print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)  #!!!!!

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)  #!!!!!!

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                #self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)
Exemplo n.º 12
0
    def __init__(self):

        self._opt = TrainOptions().parse()
        self.face_train_list = self._opt.face_train_list
        self.face_test_list = self._opt.face_test_list

        self.face_img_root = self._opt.face_img_root
        self.face_parsing_root = self._opt.face_parsing_root
        self.face_landmark_train = self._opt.face_landmark_train
        self.face_landmark_test = self._opt.face_landmark_test

        self.img_size = self._opt.img_size
        self.scale_factor = self._opt.scale_factor
        self.img_size = self._opt.img_size
        self.heatmap_size = self._opt.heatmap_size

        self.scale_factor = self._opt.scale_factor

        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_train,
                                    self.face_train_list,
                                    transform=self.train_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='train',
                                    upsample=self._opt.upsample),
            batch_size=self._opt.batch_size,
            num_workers=self._opt.n_threads_train,
            shuffle=True,
            drop_last=True)

        self._dataset_test = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_test,
                                    self.face_test_list,
                                    transform=self.test_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='val',
                                    upsample=self._opt.upsample),
            batch_size=10,
            num_workers=self._opt.n_threads_test,
            shuffle=False,
            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._save_path = os.path.join(self._opt.checkpoints_dir,
                                       self._opt.name)
        self._val_path = os.path.join(self._save_path, 'val_log.txt')

        self._train()
Exemplo n.º 13
0
class Train:
    def __init__(self):

        self._opt = TrainOptions().parse()
        self.face_train_list = self._opt.face_train_list
        self.face_test_list = self._opt.face_test_list

        self.face_img_root = self._opt.face_img_root
        self.face_parsing_root = self._opt.face_parsing_root
        self.face_landmark_train = self._opt.face_landmark_train
        self.face_landmark_test = self._opt.face_landmark_test

        self.img_size = self._opt.img_size
        self.scale_factor = self._opt.scale_factor
        self.img_size = self._opt.img_size
        self.heatmap_size = self._opt.heatmap_size

        self.scale_factor = self._opt.scale_factor

        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_train,
                                    self.face_train_list,
                                    transform=self.train_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='train',
                                    upsample=self._opt.upsample),
            batch_size=self._opt.batch_size,
            num_workers=self._opt.n_threads_train,
            shuffle=True,
            drop_last=True)

        self._dataset_test = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_test,
                                    self.face_test_list,
                                    transform=self.test_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='val',
                                    upsample=self._opt.upsample),
            batch_size=10,
            num_workers=self._opt.n_threads_test,
            shuffle=False,
            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._save_path = os.path.join(self._opt.checkpoints_dir,
                                       self._opt.name)
        self._val_path = os.path.join(self._save_path, 'val_log.txt')

        self._train()

    def _train(self):

        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):

            if i_epoch > 1:
                epoch_val_start_time = time.time()
                self._val_epoch(i_epoch)
                epoch_val_end_time = time.time()
                time_epoch = epoch_val_end_time - epoch_val_start_time
                print(
                    'End of epoch %d / %d \t Val Time Taken: %d sec (%d min or %d h)'
                    % (i_epoch,
                       self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                       time_epoch, time_epoch / 60.0, time_epoch / 3600.0))

            epoch_start_time = time.time()
            self._train_epoch(i_epoch)
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)
            time_epoch = time.time() - epoch_start_time

            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60.0, time_epoch / 3600.0))

            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _val_epoch(self, i_epoch):

        self._model.set_eval()

        for i_val_batch, val_batch in enumerate(self._dataset_test):

            point = val_batch['point'].numpy()

            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=True)
            visuals = self._model.get_current_visuals()

            img_sr = visuals['batch_img_fine'].transpose((1, 2, 0))
            img_gt = visuals['batch_img_SR'].transpose((1, 2, 0))

            self._tb_visualizer.display_current_results(visuals,
                                                        i_epoch,
                                                        i_val_batch,
                                                        is_train=False,
                                                        save_visuals=True)
        self._model.set_train()

    def _train_epoch(self, i_epoch):

        epoch_iter = 0
        self._model.set_train()

        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            do_visuals = (self._last_display_time is None) or \
                         (time.time() - self._last_display_time > self._opt.display_freq_s)
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(train_generator,
                                            keep_data_for_visuals=do_visuals)

            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            if do_visuals:
                self._display_visualizer_train(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            if self._last_save_latest_time is None or \
                    time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):

        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, i_epoch, total_steps):

        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            i_epoch,
            total_steps,
            is_train=True,
            save_visuals=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)
Exemplo n.º 14
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)
        print('TRAIN IMAGES FOLDER = %s' %
              data_loader_train._dataset._imgs_dir)
        print('TEST IMAGES FOLDER = %s' % data_loader_test._dataset._imgs_dir)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        self._writer = SummaryWriter()

        self._input_imgs = torch.empty(0, 3, self._opt.image_size,
                                       self._opt.image_size)
        self._fake_imgs = torch.empty(0, 3, self._opt.image_size,
                                      self._opt.image_size)
        self._rec_real_imgs = torch.empty(0, 3, self._opt.image_size,
                                          self._opt.image_size)
        self._fake_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._fake_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                           self._opt.image_size)
        self._rec_real_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._cyc_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                              self._opt.image_size)
        self._real_conds = list()
        self._desired_conds = list()

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

#self._writer.add_embedding(self._fake_imgs, metadata=self._desired_conds, label_img=self._input_imgs, tag='desired_conds_fake')
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._real_conds, label_img=self._fake_imgs, tag='real_conds_rec_real')
#self._writer.close()
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._desired_conds, label_img=self._input_imgs, tag='desired_conds_rec_real')
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._real_conds, label_img=self._input_imgs, tag='reconstruction_with_real_conds')

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            #do_visuals = False
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        '''for key in errors.keys():
            self._writer.add_scalar('data/%s' % key, errors[key], i_epoch*self._opt.batch_size + i_train_batch)
        self._writer.add_scalars('data/errors', errors, i_epoch*self._opt.batch_size + i_train_batch)'''

        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)
        '''for name, param in self._model._G.state_dict().items():
            if param.grad == None:
                continue
            print('Generator params: ', name)
            self._writer.add_histogram(name, param.grad.clone().cpu().data.numpy(), total_steps)
        for name, param in self._model._D.state_dict().items():
            if param.grad==None:
                continue
            print('Discriminator params: ', name)
            self._writer.add_histogram(name, param.grad.clone().cpu().data.numpy(), total_steps)'''

    def _display_visualizer_train(self, total_steps):
        visuals = self._model.get_current_visuals()

        tmp = np.transpose(visuals['1_input_img'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._input_imgs, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['2_fake_img'], (2, 0, 1)).astype(np.float32)
        torch.cat((self._fake_imgs, torch.from_numpy(tmp).unsqueeze(0)), dim=0)

        tmp = np.transpose(visuals['3_rec_real_img'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._rec_real_imgs, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['4_fake_img_unmasked'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._fake_imgs_unmasked, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = np.transpose(visuals['5_fake_img_mask'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._fake_imgs_mask, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['6_rec_real_img_mask'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._rec_real_imgs_mask, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = np.transpose(visuals['7_cyc_img_unmasked'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._cyc_imgs_unmasked, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = visuals['8_real_cond']
        self._real_conds.append(tmp.tolist())

        tmp = visuals['9_desired_cond']
        self._desired_conds.append(tmp.tolist())

        #self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=True)
        #self._tb_visualizer.plot_scalars(self._model.get_current_errors(), total_steps, is_train=True)
        #self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_test):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.iteritems():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.iterkeys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(
            i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors,
                                         total_steps,
                                         is_train=False)
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(), total_steps, is_train=False)

        # set model back to train
        self._model.set_train()
Exemplo n.º 15
0
class Train:
    def __init__(self):
        self.opt = TrainOptions().parse()
        self.model = ModelsFactory.get_by_name(self.opt)
        self.tb_visualizer = TBVisualizer(self.opt)
        if self.get_training_data():
            self.setup_train_test_sets()
            self.train()

    def get_training_data(self):
        if not os.path.isdir("data/train_data/") or len(
                os.listdir("data/train_data/")) == 0 or not os.path.isdir(
                    "data/meshes/") or len(os.listdir("data/meshes/")) == 1:
            while True:
                download_data = input(
                    "Training data does not exist. Want to download it (y/n)? "
                )
                if download_data == "y":
                    dir = os.path.dirname(
                        os.path.realpath(__file__)) + "/data/"
                    print(
                        "Downloading training data. This will take some time so just sit back and relax."
                    )
                    subprocess.Popen([dir + 'download_train_data.sh %s' % dir],
                                     shell=True).wait()
                    print(
                        "Done downloading training data. Continuing with training."
                    )
                    return True
                elif download_data == "n":
                    print(
                        "You chose to not download the data. Terminating training"
                    )
                    return False
        else:
            print("Training data exists. Proceeding with training.")
            return True

    def setup_train_test_sets(self):
        data_loader_train = CustomDatasetDataLoader(self.opt, mode='train')
        self.dataset_train = data_loader_train.load_data()
        self.dataset_train_size = len(data_loader_train)
        print('#train images = %d' % self.dataset_train_size)
        data_loader_val = CustomDatasetDataLoader(self.opt, mode='test')
        self.dataset_val = data_loader_val.load_data()
        self.dataset_val_size = len(data_loader_val)
        print('#val images = %d' % self.dataset_val_size)

    def train(self):
        # Here we set the start epoch. It is nonzero only if we continue train or test as the epoch saved for the network
        # we load is used
        start_epoch = self.model.get_epoch()
        self.total_steps = start_epoch * self.dataset_train_size
        self.iters_per_epoch = self.dataset_train_size / self.opt.batch_size
        self.last_display_time = None
        self.last_display_time_val = None
        self.last_save_latest_time = None
        self.last_print_time = time.time()
        self.visuals_per_batch = self.iters_per_epoch // 2
        for i_epoch in range(
                start_epoch,
                self.opt.nepochs_no_decay + self.opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self.train_epoch(i_epoch)
            # print epoch info
            self.print_epoch_info(time.time() - epoch_start_time, i_epoch)
            # update learning rate
            self.update_learning_rate(i_epoch)
            # save model
            self.model.save("latest", i_epoch + 1)
            self.display_visualizer_train(i_epoch)
            if (i_epoch) % 5 == 0:  # Only test the network every fifth epoch
                self.test(i_epoch, self.total_steps)

    def print_epoch_info(self, time_epoch, epoch_num):
        print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
              (epoch_num, self.opt.nepochs_no_decay + self.opt.nepochs_decay,
               time_epoch, time_epoch / 60, time_epoch / 3600))

    def update_learning_rate(self, epoch_num):
        if epoch_num > self.opt.nepochs_no_decay:
            self.model.update_learning_rate()

    def train_epoch(self, i_epoch):
        self.model.set_train()
        self.epoch_losses_G = []
        self.epoch_losses_D = []
        self.epoch_scalars = []
        self.epoch_visuals = []
        for i_train_batch, train_batch in enumerate(self.dataset_train):
            iter_start_time = time.time()

            self.model.set_input(train_batch)
            train_generator = self.train_generator(i_train_batch)

            self.model.optimize_parameters(train_generator=train_generator)

            self.total_steps += self.opt.batch_size

            self.bookkeep_epoch_data(train_generator)

            if ((i_train_batch + 1) % self.visuals_per_batch == 0):
                self.bookkeep_epoch_visualizations()
                self.display_terminal(iter_start_time, i_epoch, i_train_batch,
                                      True)

    def train_generator(self, batch_num):
        return ((batch_num + 1) % self.opt.train_G_every_n_iterations) == 0

    def bookkeep_epoch_visualizations(self):
        self.epoch_visuals.append(self.model.get_current_visuals())

    def bookkeep_epoch_data(self, train_generator):
        if train_generator:
            self.epoch_losses_G.append(self.model.get_current_errors_G())
            self.epoch_scalars.append(self.model.get_current_scalars())
        self.epoch_losses_D.append(self.model.get_current_errors_D())

    def display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                         visuals_flag):
        errors = self.model.get_current_errors()
        t = (time.time() - iter_start_time) / self.opt.batch_size
        self.tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                      self.iters_per_epoch,
                                                      errors, t, visuals_flag)

    def display_visualizer_train(self, total_steps):
        self.tb_visualizer.display_current_results(util.concatenate_dictionary(
            self.epoch_visuals),
                                                   total_steps,
                                                   is_train=True,
                                                   save_visuals=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_losses_G),
                                        total_steps,
                                        is_train=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_losses_D),
                                        total_steps,
                                        is_train=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_scalars),
                                        total_steps,
                                        is_train=True)

    def display_visualizer_test(self, test_epoch_visuals, epoch_num,
                                average_test_results, test_time, total_steps):
        self.tb_visualizer.print_current_validate_errors(
            epoch_num, average_test_results, test_time)
        self.tb_visualizer.plot_scalars(average_test_results,
                                        epoch_num,
                                        is_train=False)
        self.tb_visualizer.display_current_results(
            util.concatenate_dictionary(test_epoch_visuals),
            total_steps,
            is_train=False,
            save_visuals=True)

    def test(self, i_epoch, total_steps):
        val_start_time = time.time()

        self.model.set_eval()
        test_epoch_visuals = []

        iters_per_epoch_val = self.dataset_val_size / self.opt.batch_size
        visuals_per_val_epoch = max(1, round(iters_per_epoch_val // 2))
        errors = []
        with torch.no_grad():
            for i_val_batch, val_batch in enumerate(self.dataset_val):
                self.model.set_input(val_batch)
                self.model.forward_G(train=True)
                errors.append(self.model.get_current_errors_G())
                if (i_val_batch + 1) % visuals_per_val_epoch == 0:
                    test_epoch_visuals.append(self.model.get_current_visuals())

        average_test_results = util.average_dictionary(errors)
        test_time = (time.time() - val_start_time)
        self.display_visualizer_test(test_epoch_visuals, i_epoch,
                                     average_test_results, test_time,
                                     total_steps)
        self.model.set_train()
Exemplo n.º 16
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        data_loader_train = CustomDatasetDataLoader(self._opt, mode='train')
        data_loader_val = CustomDatasetDataLoader(self._opt, mode='val')
        #data_loader_train = CustomDatasetDataLoader(self._opt, mode='test')
        #data_loader_val = CustomDatasetDataLoader(self._opt, mode='test')

        self._dataset_train = data_loader_train.load_data()
        self._dataset_val = data_loader_val.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_val_size = len(data_loader_val)
        print('#train images = %d' % self._dataset_train_size)
        print('#val images = %d' % self._dataset_val_size)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._model.set_epoch(i_epoch)
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals
            # NOTE: visuals is false

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations == 0)
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            total_steps,
            is_train=True,
            save_visuals=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_val):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.items():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.keys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(
            i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors,
                                         total_steps,
                                         is_train=False)
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            total_steps,
            is_train=False,
            save_visuals=True)

        # set model back to train
        self._model.set_train()
Exemplo n.º 17
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()

        self.train_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.train_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.train_list_w_g_list,
                                            transform=self.train_transform),
                                            batch_size=self._opt.batch_size, 
                                            num_workers = self._opt.n_threads_train,
                                            shuffle=True, drop_last=True)

        self._dataset_test = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.test_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.test_list_w_g_list,
                                            transform=self.test_transform),
                                            batch_size=self._opt.batch_size,
                                            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):

        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size 
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(self._opt.load_epoch + 1, 
                                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            self._train_epoch(i_epoch)

            print('saving the model at the end of epoch %d, iters %d' % (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            time_epoch = time.time() - epoch_start_time
            print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                  (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay, time_epoch,
                   time_epoch / 60, time_epoch / 3600))

            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):

            iter_start_time = time.time()

            do_visuals = \
                self._last_display_time is None or \
                    time.time() - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = \
                do_visuals or time.time() - self._last_print_time > self._opt.print_freq_s

            train_generator = \
                do_visuals or ((i_train_batch+1) % self._opt.train_G_every_n_iterations == 0)

            train_batch_split = {}

            if (i_train_batch + 1) % self._opt.fake_real_rate == 0:
                train_batch_split['none_occ_img'] = train_batch['face_w_syn_occ_img_GT']
                train_batch_split['none_occ_attr'] = train_batch['face_w_syn_occ_attr']


                train_batch_split['occ_img'] = train_batch['face_w_occ_img']
                train_batch_split['occ_attr'] = train_batch['face_w_occ_attr']


                train_batch_split['none_occ_img_adv'] = train_batch['face_wo_occ_img_adv']  
                train_batch_split['none_occ_attr_adv'] = train_batch['face_wo_occ_attr_adv']

                has_GT_flag = False
                has_attr_flag = True        
            else:
                train_batch_split['none_occ_img'] = train_batch['face_w_syn_occ_img_GT']
                train_batch_split['none_occ_attr'] = train_batch['face_w_syn_occ_attr']

                train_batch_split['occ_img'] = train_batch['face_w_syn_occ_img']
                train_batch_split['occ_attr'] = train_batch['face_w_syn_occ_attr']

                train_batch_split['none_occ_img_adv'] = train_batch['face_wo_occ_img_adv']  
                train_batch_split['none_occ_attr_adv'] = train_batch['face_wo_occ_attr_adv']

                has_GT_flag = True
                has_attr_flag = True

            # print(has_GT_flag, has_attr_flag)
            self._model.set_input(train_batch_split)
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals, 
                                                train_generator=train_generator, 
                                                has_GT=has_GT_flag, has_attr=has_attr_flag)   
        
            if do_print_terminal:

                self._display_terminal(iter_start_time, i_epoch, 
                                            i_train_batch, do_visuals, 
                                            has_GT=has_GT_flag, has_attr=has_attr_flag)

                self._last_print_time = time.time()

            if do_visuals:
                self._display_visualizer_train(i_epoch,
                                                self._total_steps, 
                                                has_GT=has_GT_flag, 
                                                has_attr=has_attr_flag)
                self._display_visualizer_val(i_epoch, 
                                                self._total_steps, 
                                                has_GT=False, 
                                                has_attr=True)
                self._last_display_time = time.time()


            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            if self._last_save_latest_time is None or \
                time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                
                print('saving the latest model (epoch %d, total_steps %d)' % (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()


    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag, has_GT, has_attr):
        
        errors = self._model.get_current_errors(has_GT, has_attr)
        t = (time.time() - iter_start_time) / self._opt.batch_size
        
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch, 
                                                        self._iters_per_epoch, 
                                                        errors, t, visuals_flag)

    def _display_visualizer_train(self, i_epoch, total_steps, has_GT, has_attr):
        
        if has_GT == True and has_attr == True:
            flag = '_w_GT_w_attr'
        if has_GT == False and has_attr == True:
            flag = '_wo_GT_w_attr'

        

        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), 
                                                        i_epoch, total_steps, is_train=True, 
                                                        save_visuals=True, flag=flag)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(has_GT, has_attr), 
                                            total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), 
                                            total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps, has_GT, has_attr):
        
        if has_GT == True and has_attr == True:
            flag = '_w_GT_w_attr'
        if has_GT == False and has_attr == True:
            flag = '_wo_GT_w_attr'
        
        # print(has_GT, has_attr, flag)

        self._model.set_eval()

        for i_val_batch, val_batch in enumerate(self._dataset_test):

            if i_val_batch == self._opt.num_iters_validate:
                break
            
            val_batch_split = {}

            val_batch_split['none_occ_img'] = val_batch['face_w_syn_occ_img_GT']
            val_batch_split['none_occ_attr'] = val_batch['face_w_syn_occ_attr']

            val_batch_split['occ_img'] = val_batch['face_w_occ_img']
            val_batch_split['occ_attr'] = val_batch['face_w_occ_attr']

            val_batch_split['none_occ_img_adv'] = val_batch['face_wo_occ_img_adv']  
            val_batch_split['none_occ_attr_adv'] = val_batch['face_wo_occ_attr_adv']

            self._model.set_input(val_batch_split)
            self._model.forward(keep_data_for_visuals=True)

        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), 
                                                    i_epoch, total_steps, is_train=False, 
                                                    save_visuals=True, flag=flag)

        self._model.set_train()