Exemplo n.º 1
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(self._opt.load_epoch + 1, self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' % (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                  (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay, time_epoch,
                   time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time() - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time() - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch+1) % self._opt.train_G_every_n_iterations == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals, train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch, do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' % (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch, self._iters_per_epoch, errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_test):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.iteritems():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.iterkeys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors, total_steps, is_train=False)
        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=False)

        # set model back to train
        self._model.set_train()
Exemplo n.º 2
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        data_loader_train = CustomDatasetDataLoader(self._opt, mode='train')
        data_loader_val = CustomDatasetDataLoader(self._opt, mode='val')
        #data_loader_train = CustomDatasetDataLoader(self._opt, mode='test')
        #data_loader_val = CustomDatasetDataLoader(self._opt, mode='test')

        self._dataset_train = data_loader_train.load_data()
        self._dataset_val = data_loader_val.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_val_size = len(data_loader_val)
        print('#train images = %d' % self._dataset_train_size)
        print('#val images = %d' % self._dataset_val_size)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._model.set_epoch(i_epoch)
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals
            # NOTE: visuals is false

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations == 0)
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            total_steps,
            is_train=True,
            save_visuals=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_val):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.items():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.keys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(
            i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors,
                                         total_steps,
                                         is_train=False)
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            total_steps,
            is_train=False,
            save_visuals=True)

        # set model back to train
        self._model.set_train()
Exemplo n.º 3
0
class Train:
    def __init__(self):
        self.opt = TrainOptions().parse()
        self.model = ModelsFactory.get_by_name(self.opt)
        self.tb_visualizer = TBVisualizer(self.opt)
        if self.get_training_data():
            self.setup_train_test_sets()
            self.train()

    def get_training_data(self):
        if not os.path.isdir("data/train_data/") or len(
                os.listdir("data/train_data/")) == 0 or not os.path.isdir(
                    "data/meshes/") or len(os.listdir("data/meshes/")) == 1:
            while True:
                download_data = input(
                    "Training data does not exist. Want to download it (y/n)? "
                )
                if download_data == "y":
                    dir = os.path.dirname(
                        os.path.realpath(__file__)) + "/data/"
                    print(
                        "Downloading training data. This will take some time so just sit back and relax."
                    )
                    subprocess.Popen([dir + 'download_train_data.sh %s' % dir],
                                     shell=True).wait()
                    print(
                        "Done downloading training data. Continuing with training."
                    )
                    return True
                elif download_data == "n":
                    print(
                        "You chose to not download the data. Terminating training"
                    )
                    return False
        else:
            print("Training data exists. Proceeding with training.")
            return True

    def setup_train_test_sets(self):
        data_loader_train = CustomDatasetDataLoader(self.opt, mode='train')
        self.dataset_train = data_loader_train.load_data()
        self.dataset_train_size = len(data_loader_train)
        print('#train images = %d' % self.dataset_train_size)
        data_loader_val = CustomDatasetDataLoader(self.opt, mode='test')
        self.dataset_val = data_loader_val.load_data()
        self.dataset_val_size = len(data_loader_val)
        print('#val images = %d' % self.dataset_val_size)

    def train(self):
        # Here we set the start epoch. It is nonzero only if we continue train or test as the epoch saved for the network
        # we load is used
        start_epoch = self.model.get_epoch()
        self.total_steps = start_epoch * self.dataset_train_size
        self.iters_per_epoch = self.dataset_train_size / self.opt.batch_size
        self.last_display_time = None
        self.last_display_time_val = None
        self.last_save_latest_time = None
        self.last_print_time = time.time()
        self.visuals_per_batch = self.iters_per_epoch // 2
        for i_epoch in range(
                start_epoch,
                self.opt.nepochs_no_decay + self.opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self.train_epoch(i_epoch)
            # print epoch info
            self.print_epoch_info(time.time() - epoch_start_time, i_epoch)
            # update learning rate
            self.update_learning_rate(i_epoch)
            # save model
            self.model.save("latest", i_epoch + 1)
            self.display_visualizer_train(i_epoch)
            if (i_epoch) % 5 == 0:  # Only test the network every fifth epoch
                self.test(i_epoch, self.total_steps)

    def print_epoch_info(self, time_epoch, epoch_num):
        print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
              (epoch_num, self.opt.nepochs_no_decay + self.opt.nepochs_decay,
               time_epoch, time_epoch / 60, time_epoch / 3600))

    def update_learning_rate(self, epoch_num):
        if epoch_num > self.opt.nepochs_no_decay:
            self.model.update_learning_rate()

    def train_epoch(self, i_epoch):
        self.model.set_train()
        self.epoch_losses_G = []
        self.epoch_losses_D = []
        self.epoch_scalars = []
        self.epoch_visuals = []
        for i_train_batch, train_batch in enumerate(self.dataset_train):
            iter_start_time = time.time()

            self.model.set_input(train_batch)
            train_generator = self.train_generator(i_train_batch)

            self.model.optimize_parameters(train_generator=train_generator)

            self.total_steps += self.opt.batch_size

            self.bookkeep_epoch_data(train_generator)

            if ((i_train_batch + 1) % self.visuals_per_batch == 0):
                self.bookkeep_epoch_visualizations()
                self.display_terminal(iter_start_time, i_epoch, i_train_batch,
                                      True)

    def train_generator(self, batch_num):
        return ((batch_num + 1) % self.opt.train_G_every_n_iterations) == 0

    def bookkeep_epoch_visualizations(self):
        self.epoch_visuals.append(self.model.get_current_visuals())

    def bookkeep_epoch_data(self, train_generator):
        if train_generator:
            self.epoch_losses_G.append(self.model.get_current_errors_G())
            self.epoch_scalars.append(self.model.get_current_scalars())
        self.epoch_losses_D.append(self.model.get_current_errors_D())

    def display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                         visuals_flag):
        errors = self.model.get_current_errors()
        t = (time.time() - iter_start_time) / self.opt.batch_size
        self.tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                      self.iters_per_epoch,
                                                      errors, t, visuals_flag)

    def display_visualizer_train(self, total_steps):
        self.tb_visualizer.display_current_results(util.concatenate_dictionary(
            self.epoch_visuals),
                                                   total_steps,
                                                   is_train=True,
                                                   save_visuals=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_losses_G),
                                        total_steps,
                                        is_train=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_losses_D),
                                        total_steps,
                                        is_train=True)
        self.tb_visualizer.plot_scalars(util.average_dictionary(
            self.epoch_scalars),
                                        total_steps,
                                        is_train=True)

    def display_visualizer_test(self, test_epoch_visuals, epoch_num,
                                average_test_results, test_time, total_steps):
        self.tb_visualizer.print_current_validate_errors(
            epoch_num, average_test_results, test_time)
        self.tb_visualizer.plot_scalars(average_test_results,
                                        epoch_num,
                                        is_train=False)
        self.tb_visualizer.display_current_results(
            util.concatenate_dictionary(test_epoch_visuals),
            total_steps,
            is_train=False,
            save_visuals=True)

    def test(self, i_epoch, total_steps):
        val_start_time = time.time()

        self.model.set_eval()
        test_epoch_visuals = []

        iters_per_epoch_val = self.dataset_val_size / self.opt.batch_size
        visuals_per_val_epoch = max(1, round(iters_per_epoch_val // 2))
        errors = []
        with torch.no_grad():
            for i_val_batch, val_batch in enumerate(self.dataset_val):
                self.model.set_input(val_batch)
                self.model.forward_G(train=True)
                errors.append(self.model.get_current_errors_G())
                if (i_val_batch + 1) % visuals_per_val_epoch == 0:
                    test_epoch_visuals.append(self.model.get_current_visuals())

        average_test_results = util.average_dictionary(errors)
        test_time = (time.time() - val_start_time)
        self.display_visualizer_test(test_epoch_visuals, i_epoch,
                                     average_test_results, test_time,
                                     total_steps)
        self.model.set_train()
Exemplo n.º 4
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        #data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        #self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        #self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        #print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)  #!!!!!

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)  #!!!!!!

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                #self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, total_steps):
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(), total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)
Exemplo n.º 5
0
class Train:
    def __init__(self):

        self._opt = TrainOptions().parse()
        self.face_train_list = self._opt.face_train_list
        self.face_test_list = self._opt.face_test_list

        self.face_img_root = self._opt.face_img_root
        self.face_parsing_root = self._opt.face_parsing_root
        self.face_landmark_train = self._opt.face_landmark_train
        self.face_landmark_test = self._opt.face_landmark_test

        self.img_size = self._opt.img_size
        self.scale_factor = self._opt.scale_factor
        self.img_size = self._opt.img_size
        self.heatmap_size = self._opt.heatmap_size

        self.scale_factor = self._opt.scale_factor

        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.5, 0.5, 0.5],
            #                      std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_train,
                                    self.face_train_list,
                                    transform=self.train_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='train',
                                    upsample=self._opt.upsample),
            batch_size=self._opt.batch_size,
            num_workers=self._opt.n_threads_train,
            shuffle=True,
            drop_last=True)

        self._dataset_test = torch_data.DataLoader(
            dataset=FaceImageLoader(self._opt.face_img_root,
                                    self._opt.face_parsing_root,
                                    self._opt.face_landmark_test,
                                    self.face_test_list,
                                    transform=self.test_transform,
                                    scale_factor=self.scale_factor,
                                    img_size=self.img_size,
                                    heatmap_size=self.heatmap_size,
                                    mode='val',
                                    upsample=self._opt.upsample),
            batch_size=10,
            num_workers=self._opt.n_threads_test,
            shuffle=False,
            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._save_path = os.path.join(self._opt.checkpoints_dir,
                                       self._opt.name)
        self._val_path = os.path.join(self._save_path, 'val_log.txt')

        self._train()

    def _train(self):

        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):

            if i_epoch > 1:
                epoch_val_start_time = time.time()
                self._val_epoch(i_epoch)
                epoch_val_end_time = time.time()
                time_epoch = epoch_val_end_time - epoch_val_start_time
                print(
                    'End of epoch %d / %d \t Val Time Taken: %d sec (%d min or %d h)'
                    % (i_epoch,
                       self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                       time_epoch, time_epoch / 60.0, time_epoch / 3600.0))

            epoch_start_time = time.time()
            self._train_epoch(i_epoch)
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)
            time_epoch = time.time() - epoch_start_time

            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60.0, time_epoch / 3600.0))

            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _val_epoch(self, i_epoch):

        self._model.set_eval()

        for i_val_batch, val_batch in enumerate(self._dataset_test):

            point = val_batch['point'].numpy()

            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=True)
            visuals = self._model.get_current_visuals()

            img_sr = visuals['batch_img_fine'].transpose((1, 2, 0))
            img_gt = visuals['batch_img_SR'].transpose((1, 2, 0))

            self._tb_visualizer.display_current_results(visuals,
                                                        i_epoch,
                                                        i_val_batch,
                                                        is_train=False,
                                                        save_visuals=True)
        self._model.set_train()

    def _train_epoch(self, i_epoch):

        epoch_iter = 0
        self._model.set_train()

        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            do_visuals = (self._last_display_time is None) or \
                         (time.time() - self._last_display_time > self._opt.display_freq_s)
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(train_generator,
                                            keep_data_for_visuals=do_visuals)

            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            if do_visuals:
                self._display_visualizer_train(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            if self._last_save_latest_time is None or \
                    time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):

        errors = self._model.get_current_errors()
        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)

    def _display_visualizer_train(self, i_epoch, total_steps):

        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(),
            i_epoch,
            total_steps,
            is_train=True,
            save_visuals=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(),
                                         total_steps,
                                         is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(),
                                         total_steps,
                                         is_train=True)
Exemplo n.º 6
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)
        print('TRAIN IMAGES FOLDER = %s' %
              data_loader_train._dataset._imgs_dir)
        print('TEST IMAGES FOLDER = %s' % data_loader_test._dataset._imgs_dir)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        self._writer = SummaryWriter()

        self._input_imgs = torch.empty(0, 3, self._opt.image_size,
                                       self._opt.image_size)
        self._fake_imgs = torch.empty(0, 3, self._opt.image_size,
                                      self._opt.image_size)
        self._rec_real_imgs = torch.empty(0, 3, self._opt.image_size,
                                          self._opt.image_size)
        self._fake_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._fake_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                           self._opt.image_size)
        self._rec_real_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._cyc_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                              self._opt.image_size)
        self._real_conds = list()
        self._desired_conds = list()

        self._train()

    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)

            # save model
            print('saving the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

#self._writer.add_embedding(self._fake_imgs, metadata=self._desired_conds, label_img=self._input_imgs, tag='desired_conds_fake')
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._real_conds, label_img=self._fake_imgs, tag='real_conds_rec_real')
#self._writer.close()
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._desired_conds, label_img=self._input_imgs, tag='desired_conds_rec_real')
#self._writer.add_embedding(self._rec_real_imgs, metadata=self._real_conds, label_img=self._input_imgs, tag='reconstruction_with_real_conds')

    def _train_epoch(self, i_epoch):
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):
            iter_start_time = time.time()

            # display flags
            #do_visuals = False
            do_visuals = self._last_display_time is None or time.time(
            ) - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = time.time(
            ) - self._last_print_time > self._opt.print_freq_s or do_visuals

            # train model
            self._model.set_input(train_batch)
            train_generator = ((i_train_batch + 1) %
                               self._opt.train_G_every_n_iterations
                               == 0) or do_visuals
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals,
                                            train_generator=train_generator)

            # update epoch info
            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            # display terminal
            if do_print_terminal:
                self._display_terminal(iter_start_time, i_epoch, i_train_batch,
                                       do_visuals)
                self._last_print_time = time.time()

            # display visualizer
            if do_visuals:
                self._display_visualizer_train(self._total_steps)
                self._display_visualizer_val(i_epoch, self._total_steps)
                self._last_display_time = time.time()

            # save model
            if self._last_save_latest_time is None or time.time(
            ) - self._last_save_latest_time > self._opt.save_latest_freq_s:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()

    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch,
                          visuals_flag):
        errors = self._model.get_current_errors()
        '''for key in errors.keys():
            self._writer.add_scalar('data/%s' % key, errors[key], i_epoch*self._opt.batch_size + i_train_batch)
        self._writer.add_scalars('data/errors', errors, i_epoch*self._opt.batch_size + i_train_batch)'''

        t = (time.time() - iter_start_time) / self._opt.batch_size
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch,
                                                       self._iters_per_epoch,
                                                       errors, t, visuals_flag)
        '''for name, param in self._model._G.state_dict().items():
            if param.grad == None:
                continue
            print('Generator params: ', name)
            self._writer.add_histogram(name, param.grad.clone().cpu().data.numpy(), total_steps)
        for name, param in self._model._D.state_dict().items():
            if param.grad==None:
                continue
            print('Discriminator params: ', name)
            self._writer.add_histogram(name, param.grad.clone().cpu().data.numpy(), total_steps)'''

    def _display_visualizer_train(self, total_steps):
        visuals = self._model.get_current_visuals()

        tmp = np.transpose(visuals['1_input_img'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._input_imgs, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['2_fake_img'], (2, 0, 1)).astype(np.float32)
        torch.cat((self._fake_imgs, torch.from_numpy(tmp).unsqueeze(0)), dim=0)

        tmp = np.transpose(visuals['3_rec_real_img'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._rec_real_imgs, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['4_fake_img_unmasked'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._fake_imgs_unmasked, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = np.transpose(visuals['5_fake_img_mask'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat((self._fake_imgs_mask, torch.from_numpy(tmp).unsqueeze(0)),
                  dim=0)

        tmp = np.transpose(visuals['6_rec_real_img_mask'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._rec_real_imgs_mask, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = np.transpose(visuals['7_cyc_img_unmasked'],
                           (2, 0, 1)).astype(np.float32)
        torch.cat(
            (self._cyc_imgs_unmasked, torch.from_numpy(tmp).unsqueeze(0)),
            dim=0)

        tmp = visuals['8_real_cond']
        self._real_conds.append(tmp.tolist())

        tmp = visuals['9_desired_cond']
        self._desired_conds.append(tmp.tolist())

        #self._tb_visualizer.display_current_results(self._model.get_current_visuals(), total_steps, is_train=True)
        #self._tb_visualizer.plot_scalars(self._model.get_current_errors(), total_steps, is_train=True)
        #self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps):
        val_start_time = time.time()

        # set model to eval
        self._model.set_eval()

        # evaluate self._opt.num_iters_validate epochs
        val_errors = OrderedDict()
        for i_val_batch, val_batch in enumerate(self._dataset_test):
            if i_val_batch == self._opt.num_iters_validate:
                break

            # evaluate model
            self._model.set_input(val_batch)
            self._model.forward(keep_data_for_visuals=(i_val_batch == 0))
            errors = self._model.get_current_errors()

            # store current batch errors
            for k, v in errors.iteritems():
                if k in val_errors:
                    val_errors[k] += v
                else:
                    val_errors[k] = v

        # normalize errors
        for k in val_errors.iterkeys():
            val_errors[k] /= self._opt.num_iters_validate

        # visualize
        t = (time.time() - val_start_time)
        self._tb_visualizer.print_current_validate_errors(
            i_epoch, val_errors, t)
        self._tb_visualizer.plot_scalars(val_errors,
                                         total_steps,
                                         is_train=False)
        self._tb_visualizer.display_current_results(
            self._model.get_current_visuals(), total_steps, is_train=False)

        # set model back to train
        self._model.set_train()
Exemplo n.º 7
0
class Train:
    def __init__(self):
        self._opt = TrainOptions().parse()

        self.train_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
        ])

        self._dataset_train = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.train_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.train_list_w_g_list,
                                            transform=self.train_transform),
                                            batch_size=self._opt.batch_size, 
                                            num_workers = self._opt.n_threads_train,
                                            shuffle=True, drop_last=True)

        self._dataset_test = torch_data.DataLoader(dataset=OccFaceImageMixLoaderV2(
                                            self._opt.data_root, 
                                            self._opt.test_list_wo_g_list,
                                            self._opt.occlusions_root,
                                            self._opt.occlusions_list,
                                            self._opt.data_root, 
                                            self._opt.test_list_w_g_list,
                                            transform=self.test_transform),
                                            batch_size=self._opt.batch_size,
                                            drop_last=True)

        self._dataset_train_size = len(self._dataset_train)
        self._dataset_test_size = len(self._dataset_test)

        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()

    def _train(self):

        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size 
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        for i_epoch in range(self._opt.load_epoch + 1, 
                                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            epoch_start_time = time.time()

            self._train_epoch(i_epoch)

            print('saving the model at the end of epoch %d, iters %d' % (i_epoch, self._total_steps))
            self._model.save(i_epoch)

            time_epoch = time.time() - epoch_start_time
            print('End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)' %
                  (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay, time_epoch,
                   time_epoch / 60, time_epoch / 3600))

            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()

    def _train_epoch(self, i_epoch):
        
        epoch_iter = 0
        self._model.set_train()
        for i_train_batch, train_batch in enumerate(self._dataset_train):

            iter_start_time = time.time()

            do_visuals = \
                self._last_display_time is None or \
                    time.time() - self._last_display_time > self._opt.display_freq_s
            do_print_terminal = \
                do_visuals or time.time() - self._last_print_time > self._opt.print_freq_s

            train_generator = \
                do_visuals or ((i_train_batch+1) % self._opt.train_G_every_n_iterations == 0)

            train_batch_split = {}

            if (i_train_batch + 1) % self._opt.fake_real_rate == 0:
                train_batch_split['none_occ_img'] = train_batch['face_w_syn_occ_img_GT']
                train_batch_split['none_occ_attr'] = train_batch['face_w_syn_occ_attr']


                train_batch_split['occ_img'] = train_batch['face_w_occ_img']
                train_batch_split['occ_attr'] = train_batch['face_w_occ_attr']


                train_batch_split['none_occ_img_adv'] = train_batch['face_wo_occ_img_adv']  
                train_batch_split['none_occ_attr_adv'] = train_batch['face_wo_occ_attr_adv']

                has_GT_flag = False
                has_attr_flag = True        
            else:
                train_batch_split['none_occ_img'] = train_batch['face_w_syn_occ_img_GT']
                train_batch_split['none_occ_attr'] = train_batch['face_w_syn_occ_attr']

                train_batch_split['occ_img'] = train_batch['face_w_syn_occ_img']
                train_batch_split['occ_attr'] = train_batch['face_w_syn_occ_attr']

                train_batch_split['none_occ_img_adv'] = train_batch['face_wo_occ_img_adv']  
                train_batch_split['none_occ_attr_adv'] = train_batch['face_wo_occ_attr_adv']

                has_GT_flag = True
                has_attr_flag = True

            # print(has_GT_flag, has_attr_flag)
            self._model.set_input(train_batch_split)
            self._model.optimize_parameters(keep_data_for_visuals=do_visuals, 
                                                train_generator=train_generator, 
                                                has_GT=has_GT_flag, has_attr=has_attr_flag)   
        
            if do_print_terminal:

                self._display_terminal(iter_start_time, i_epoch, 
                                            i_train_batch, do_visuals, 
                                            has_GT=has_GT_flag, has_attr=has_attr_flag)

                self._last_print_time = time.time()

            if do_visuals:
                self._display_visualizer_train(i_epoch,
                                                self._total_steps, 
                                                has_GT=has_GT_flag, 
                                                has_attr=has_attr_flag)
                self._display_visualizer_val(i_epoch, 
                                                self._total_steps, 
                                                has_GT=False, 
                                                has_attr=True)
                self._last_display_time = time.time()


            self._total_steps += self._opt.batch_size
            epoch_iter += self._opt.batch_size

            if self._last_save_latest_time is None or \
                time.time() - self._last_save_latest_time > self._opt.save_latest_freq_s:
                
                print('saving the latest model (epoch %d, total_steps %d)' % (i_epoch, self._total_steps))
                self._model.save(i_epoch)
                self._last_save_latest_time = time.time()


    def _display_terminal(self, iter_start_time, i_epoch, i_train_batch, visuals_flag, has_GT, has_attr):
        
        errors = self._model.get_current_errors(has_GT, has_attr)
        t = (time.time() - iter_start_time) / self._opt.batch_size
        
        self._tb_visualizer.print_current_train_errors(i_epoch, i_train_batch, 
                                                        self._iters_per_epoch, 
                                                        errors, t, visuals_flag)

    def _display_visualizer_train(self, i_epoch, total_steps, has_GT, has_attr):
        
        if has_GT == True and has_attr == True:
            flag = '_w_GT_w_attr'
        if has_GT == False and has_attr == True:
            flag = '_wo_GT_w_attr'

        

        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), 
                                                        i_epoch, total_steps, is_train=True, 
                                                        save_visuals=True, flag=flag)
        self._tb_visualizer.plot_scalars(self._model.get_current_errors(has_GT, has_attr), 
                                            total_steps, is_train=True)
        self._tb_visualizer.plot_scalars(self._model.get_current_scalars(), 
                                            total_steps, is_train=True)

    def _display_visualizer_val(self, i_epoch, total_steps, has_GT, has_attr):
        
        if has_GT == True and has_attr == True:
            flag = '_w_GT_w_attr'
        if has_GT == False and has_attr == True:
            flag = '_wo_GT_w_attr'
        
        # print(has_GT, has_attr, flag)

        self._model.set_eval()

        for i_val_batch, val_batch in enumerate(self._dataset_test):

            if i_val_batch == self._opt.num_iters_validate:
                break
            
            val_batch_split = {}

            val_batch_split['none_occ_img'] = val_batch['face_w_syn_occ_img_GT']
            val_batch_split['none_occ_attr'] = val_batch['face_w_syn_occ_attr']

            val_batch_split['occ_img'] = val_batch['face_w_occ_img']
            val_batch_split['occ_attr'] = val_batch['face_w_occ_attr']

            val_batch_split['none_occ_img_adv'] = val_batch['face_wo_occ_img_adv']  
            val_batch_split['none_occ_attr_adv'] = val_batch['face_wo_occ_attr_adv']

            self._model.set_input(val_batch_split)
            self._model.forward(keep_data_for_visuals=True)

        self._tb_visualizer.display_current_results(self._model.get_current_visuals(), 
                                                    i_epoch, total_steps, is_train=False, 
                                                    save_visuals=True, flag=flag)

        self._model.set_train()