Esempio n. 1
0
class Validator:
    def __init__(self, config, val: DataLoader):
        self.config = config
        self.val_dataset = val
        self.metric_counter = MetricCounter()

    def validate(self):
        self._init_params()
        self.netG.load_state_dict(torch.load('best_G_fpn.h5')['model'])
        self.netG.train(True)

        self._validate()
        torch.cuda.empty_cache()

        print(self.metric_counter.loss_message())

    def _validate(self):
        self.metric_counter.clear()
        epoch_size = config.get('val_batches_per_epoch') or len(
            self.train_dataset)
        tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
        tq.set_description('Validation')
        i = 0
        total_psnr = 0
        total_ssim = 0
        total_samples = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = self.netG(inputs)
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(
                inputs, outputs, targets)
            total_ssim += curr_ssim * len(inputs)
            total_psnr += curr_psnr * len(inputs)
            total_samples += len(inputs)

            print("Metrcis:", curr_psnr, curr_ssim)
            print("Totals:", total_ssim, total_psnr, total_samples)
            print("nan:", np.isnan(img_for_vis).any())
            print("inf:", np.isinf(img_for_vis).any())
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='val')
            i += 1
            if i > epoch_size:
                break
            self.metric_counter.write_to_tensorboard(i, validation=True)
            del inputs, targets, outputs

        print("PSNR", total_psnr / total_samples)
        print("SSIM", total_ssim / total_samples)

        tq.close()

    def _init_params(self):
        self.netG, netD = get_nets(self.config['model'])
        self.netG.cuda()
        self.model = get_model(self.config['model'])
Esempio n. 2
0
class Trainer(object):
    def __init__(self, config):
        self.config = config
        self.train_dataset = self._get_dataset(config, 'train')
        self.val_dataset = self._get_dataset(config, 'test')
        self.experiment_name = config['experiment_desc'] + '_' + config[
            'model']['g_name']
        self.experiment_name += '_content' + str(config['model']['content_coef'])+ '_feature' \
              + str(config['model']['feature_coef']) + '_adv' \
              + str(config['model']['adv_coef'])
        self.metric_counter = MetricCounter(self.experiment_name)
        self.warmup_epochs = config['warmup_num']

    def train(self):
        self._init_params()
        for epoch in range(0, config['num_epochs']):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()

            if self.metric_counter.update_best_model():
                torch.save({'model': self.netG.state_dict()},
                           'best_{}.h5'.format(self.experiment_name))
            torch.save({'model': self.netG.state_dict()},
                       'last_{}.h5'.format(self.experiment_name))
            print(self.metric_counter.loss_message())
            logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" %
                          (self.experiment_name, epoch,
                           self.metric_counter.loss_message()))

    def _run_epoch(self, epoch):
        self.metric_counter.clear()
        for param_group in self.optimizer_G.param_groups:
            lr = param_group['lr']
        tq = tqdm.tqdm(self.train_dataset.dataloader)
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_D = self._update_d(outputs, targets)
            self.optimizer_G.zero_grad()
            loss_content, loss_vgg = self.criterionG[0](
                outputs, targets), self.criterionG[1](outputs, targets)
            loss_adv = self.adv_trainer.lossG(outputs, targets)
            loss_G = self.config['model']['content_coef'] * loss_content + \
               self.config['model']['adv_coef'] * loss_adv + \
               self.config['model']['feature_coef'] * loss_vgg
            loss_G.backward()
            self.optimizer_G.step()
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(),
                                           loss_D)
            curr_psnr, curr_ssim = self.model.get_acc(outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            tq.set_postfix(loss=self.metric_counter.loss_message())
            i += 1
            if i == 3:
                self.metric_counter.images_to_tensorboard(
                    [inputs[0], targets[0], outputs[0]], epoch)
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch)

    def _validate(self, epoch):
        self.metric_counter.clear()
        tq = tqdm.tqdm(self.val_dataset.dataloader)
        tq.set_description('Validation')
        i = 0
        for data in tq:
            i += 1
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_content, loss_vgg = self.criterionG[0](
                outputs, targets), self.criterionG[1](outputs, targets)
            loss_adv = self.adv_trainer.lossG(outputs, targets)
            loss_G = self.config['model']['content_coef'] * loss_content + \
               self.config['model']['adv_coef'] * loss_adv + \
               self.config['model']['feature_coef'] * loss_vgg
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(),
                                           loss_vgg.item())
            curr_psnr, curr_ssim = self.model.get_acc(outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)


# 			if i == 3:
# 				self.metric_counter.images_to_tensorboard([inputs[0], targets[0], outputs[0]], epoch)
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch, validation=True)

    def _get_dataset(self, config, filename):
        return CustomDataLoader(config, filename)

    def _update_d(self, outputs, targets):
        if self.config['model']['d_name'] == 'no_gan':
            return 0
        self.optimizer_D.zero_grad()
        loss_D = self.config['model']['adv_coef'] * self.adv_trainer.lossD(
            outputs, targets)
        loss_D.backward(retain_graph=True)
        self.optimizer_D.step()
        return loss_D.item()

    def _get_optim(self, params):
        if self.config['optimizer']['name'] == 'adam':
            optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'sgd':
            optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'adadelta':
            optimizer = optim.Adadelta(params,
                                       lr=self.config['optimizer']['lr'])
        else:
            raise ValueError("Optimizer [%s] not recognized." %
                             self.config['optimizer']['name'])
        return optimizer

    def _get_scheduler(self, optimizer):
        if self.config['scheduler']['name'] == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                patience=self.config['scheduler']['patience'],
                factor=self.config['scheduler']['factor'],
                min_lr=self.config['scheduler']['min_lr'])
        elif self.config['optimizer']['name'] == 'sgdr':
            scheduler = WarmRestart(optimizer)
        elif self.config['scheduler']['name'] == 'linear':
            scheduler = LinearDecay(
                optimizer,
                min_lr=self.config['scheduler']['min_lr'],
                num_epochs=self.config['num_epochs'],
                start_epoch=self.config['scheduler']['start_epoch'])
        else:
            raise ValueError("Scheduler [%s] not recognized." %
                             self.config['scheduler']['name'])
        return scheduler

    def _get_adversarial_trainer(self, D_name, netD, criterionD):
        if D_name == 'no_gan':
            return AdversarialTrainerFactory.createModel(
                'NoAdversarialTrainer')
        elif D_name == 'patch_gan' or D_name == 'multi_scale':
            return AdversarialTrainerFactory.createModel(
                'SingleAdversarialTrainer', netD, criterionD)
        elif D_name == 'double_gan':
            return AdversarialTrainerFactory.createModel(
                'DoubleAdversarialTrainer', netD, criterionD)
        else:
            raise ValueError("Discriminator Network [%s] not recognized." %
                             D_name)

    def _init_params(self):
        self.criterionG, criterionD = get_loss(self.config['model'])
        self.netG, netD = get_nets(self.config['model'])
        self.netG.cuda()
        self.adv_trainer = self._get_adversarial_trainer(
            self.config['model']['d_name'], netD, criterionD)
        self.model = get_model(self.config['model'])
        self.optimizer_G = self._get_optim(
            filter(lambda p: p.requires_grad, self.netG.parameters()))
        self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
        self.scheduler_G = self._get_scheduler(self.optimizer_G)
        self.scheduler_D = self._get_scheduler(self.optimizer_D)
Esempio n. 3
0
class Trainer:
    def __init__(self, config, train: DataLoader, val: DataLoader):
        self.config = config
        self.train_dataset = train
        self.val_dataset = val
        self.adv_lambda = config['model']['adv_lambda']
        self.metric_counter = MetricCounter(config['experiment_desc'])
        self.warmup_epochs = config['warmup_num']
        gpu_id = self.config['gpu_id']
        self.device = torch.device('cuda:{}'.format(gpu_id) if (torch.cuda.is_available() and gpu_id > 0) else "cpu")

    def train(self):
        self._init_params()
        for epoch in range(0, config['num_epochs']):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()

            if self.metric_counter.update_best_model():
                torch.save({
                    'model': self.netG.state_dict()
                }, 'best_{}.h5'.format(self.config['experiment_desc']))
            torch.save({
                'model': self.netG.state_dict()
            }, 'last_{}.h5'.format(self.config['experiment_desc']))
            print(self.metric_counter.loss_message())
            logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
                self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))

    def _run_epoch(self, epoch):
        self.metric_counter.clear()
        for param_group in self.optimizer_G.param_groups:
            lr = param_group['lr']

        epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
        tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_D = self._update_d(outputs, targets)
            self.optimizer_G.zero_grad()
            loss_content = self.criterionG(outputs, targets)
            loss_adv = self.adv_trainer.loss_g(outputs, targets)
            loss_G = loss_content + self.adv_lambda * loss_adv
            loss_G.backward()
            self.optimizer_G.step()
            self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            tq.set_postfix(loss=self.metric_counter.loss_message())
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='train')
            i += 1
            if i > epoch_size:
                break
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch)

    def _validate(self, epoch):
        self.metric_counter.clear()
        epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
        tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
        tq.set_description('Validation')
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_content = self.criterionG(outputs, targets)
            loss_adv = self.adv_trainer.loss_g(outputs, targets)
            loss_G = loss_content + self.adv_lambda * loss_adv
            self.metric_counter.add_losses(loss_G.item(), loss_content.item())
            curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
            self.metric_counter.add_metrics(curr_psnr, curr_ssim)
            if not i:
                self.metric_counter.add_image(img_for_vis, tag='val')
            i += 1
            if i > epoch_size:
                break
        tq.close()
        self.metric_counter.write_to_tensorboard(epoch, validation=True)

    def _update_d(self, outputs, targets):
        if self.config['model']['d_name'] == 'no_gan':
            return 0
        self.optimizer_D.zero_grad()
        loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
        loss_D.backward(retain_graph=True)
        self.optimizer_D.step()
        return loss_D.item()

    def _get_optim(self, params):
        if self.config['optimizer']['name'] == 'adam':
            optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'sgd':
            optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
        elif self.config['optimizer']['name'] == 'adadelta':
            optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
        else:
            raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
        return optimizer

    def _get_scheduler(self, optimizer):
        if self.config['scheduler']['name'] == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                             mode='min',
                                                             patience=self.config['scheduler']['patience'],
                                                             factor=self.config['scheduler']['factor'],
                                                             min_lr=self.config['scheduler']['min_lr'])
        elif self.config['optimizer']['name'] == 'sgdr':
            scheduler = WarmRestart(optimizer)
        elif self.config['scheduler']['name'] == 'linear':
            scheduler = LinearDecay(optimizer,
                                    min_lr=self.config['scheduler']['min_lr'],
                                    num_epochs=self.config['num_epochs'],
                                    start_epoch=self.config['scheduler']['start_epoch'])
        else:
            raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
        return scheduler

    @staticmethod
    def _get_adversarial_trainer(d_name, net_d, criterion_d):
        if d_name == 'no_gan':
            return GANFactory.create_model('NoGAN')
        elif d_name == 'patch_gan' or d_name == 'multi_scale':
            return GANFactory.create_model('SingleGAN', net_d, criterion_d)
        elif d_name == 'double_gan':
            return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
        else:
            raise ValueError("Discriminator Network [%s] not recognized." % d_name)

    def _init_params(self):
        self.criterionG, criterionD = get_loss(self.config['model'])
        self.netG, netD = get_nets(self.config['model'])
        self.netG.to(self.device)
        self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
        self.model = get_model(self.config['model'])
        self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
        self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
        self.scheduler_G = self._get_scheduler(self.optimizer_G)
        self.scheduler_D = self._get_scheduler(self.optimizer_D)