Exemplo n.º 1
0
def fit(parallel=False, **kwargs):
    with open('config.yaml') as cfg:
        config = yaml.load(cfg)
    update_config(config, kwargs)
    work_dir = config['name']
    os.makedirs(work_dir, exist_ok=True)
    with open(os.path.join(work_dir, 'config.yaml'), 'w') as out:
        yaml.dump(config, out)

    config['train']['salt'] = config['val']['salt'] = config['name']
    config['train']['n_fold'] = config['val']['n_fold'] = config['n_fold']

    train, val = make_dataloaders(config['train'], config['val'], config['batch_size'], multiprocessing=parallel)
    model = DataParallel(get_baseline(config['model']))
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

    trainer = Trainer(model=model,
                      train=train,
                      val=val,
                      work_dir=work_dir,
                      loss_fn=None,
                      optimizer=optimizer,
                      scheduler=ReduceLROnPlateau(factor=.2, patience=5, optimizer=optimizer),
                      device='cuda:0',
                      )

    stages = config['stages']
    epochs_completed = 0
    for i, stage in enumerate(stages):
        logger.info(f'Starting stage {i}')
        # ToDo: update train properties: mixup, crop type
        trainer.train.dataset.update_config(stage['train'])
        trainer.epochs = stage['epochs']
        weights = torch.from_numpy(np.array(stage['loss_weights'], dtype='float32')).to('cuda:0')
        trainer.loss_fn = partial(soft_cross_entropy,
                                  weights=weights)
        epochs_completed = trainer.fit(epochs_completed)

    convert_model(model_path=os.path.join(work_dir, 'model.pt'),
                  out_name=os.path.join(work_dir, f'{config["name"]}_{config["n_fold"]}.trcd'),
                  name=config['model']
                  )
Exemplo n.º 2
0
def train():
    working_dir = new_working_dir(FLAGS.working_dir_root, FLAGS.name)

    model = DataParallel(Model()).cuda()
    writer = SummaryWriter(log_dir=working_dir)
    saver = Saver(working_dir=working_dir)
    saver.save_meta(FLAGS.flag_values_dict())

    train_net = TrainNet(
        optimizer=torch.optim.Adam(model.parameters(),
                                   lr=FLAGS.lr,
                                   weight_decay=FLAGS.weight_decay),
        model=model,
        writer=writer,
        data_loader=TrainLoader(),
        lr_decay_epochs=FLAGS.lr_decay_epochs,
        lr_decay_rate=FLAGS.lr_decay_rate,
        summarize_steps=FLAGS.summarize_steps,
        save_steps=FLAGS.save_steps,
        saver=saver,
    )
    test_net = TestNet(
        model=model,
        writer=writer,
        data_loader=TestLoader(),
    )

    if FLAGS.ckpt_dir is not None:
        train_net.load(ckpt_dir=FLAGS.ckpt_dir)

    for num_epoch in range(FLAGS.num_epochs):
        train_net.step_epoch()
        test_net.step_epoch(num_step=train_net.num_step)

    train_net.save()
    writer.close()
Exemplo n.º 3
0
class FIRNModel(BaseModel):
    def __init__(self, opt):
        super(FIRNModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_back'])

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y, z):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out, y)

        z = z.reshape([out.shape[0], -1])
        l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(
            z**2) / z.shape[0]

        return l_forw_fit, l_forw_ce

    def loss_backward(self, x, y):
        x_samples = self.netG(x=y, rev=True)
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        return l_back_rec

    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()

        # forward downscaling
        self.input = self.real_H
        self.output = self.netG(x=self.input)

        zshape = self.output[:, 3:, :, :].shape
        LR_ref = self.ref_L.detach()

        l_forw_fit, l_forw_ce = self.loss_forward(self.output[:, :3, :, :],
                                                  LR_ref,
                                                  self.output[:, 3:, :, :])

        # backward upscaling
        LR = self.Quantization(self.output[:, :3, :, :])
        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        l_back_rec = self.loss_backward(self.real_H, y_)

        # total loss
        loss = l_forw_fit + l_back_rec + l_forw_ce
        loss.backward()

        # gradient clipping
        if self.train_opt['gradient_clipping']:
            nn.utils.clip_grad_norm_(self.netG.parameters(),
                                     self.train_opt['gradient_clipping'])

        self.optimizer_G.step()

        # set log
        self.log_dict['l_forw_fit'] = l_forw_fit.item()
        self.log_dict['l_forw_ce'] = l_forw_ce.item()
        self.log_dict['l_back_rec'] = l_back_rec.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Exemplo n.º 4
0
class MGANTrainer:
    def __init__(self, args, task, saver, logger, vocab):
        device = torch.device("cuda")
        self.pretrain = False
        self.saver = saver
        self.logger = logger
        self._model = MGANModel.build_model(args, task, pretrain=self.pretrain)
        self.model = DataParallel(self._model)
        self.model = self.model.to(device)
        self.opt = ClippedAdam(self.model.parameters(), lr=1e-3)
        self.opt.set_clip(clip_value=5.0)
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.opt,
                                                                   gamma=0.5)
        self.saver.load("mgan", self.model.module)
        self.step = 0
        self.vocab = vocab
        self.critic_lag_max = 50
        self.critic_lag = self.critic_lag_max

        self.args = args
        self.task = task

    def run(self, epoch, samples):
        self.model.train()
        num_rollouts = 1 if self.pretrain else self.args.num_rollouts
        self.lr_scheduler.step(epoch)
        self.rollout_discriminator(num_rollouts, samples)
        self.rollout_generator(num_rollouts, samples)
        self.rollout_critic(num_rollouts, samples)
        self.saver.checkpoint("mgan", self.model.module)
        self.step += 1

    def rollout_discriminator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        real, fake = AverageMeter(), AverageMeter()
        batch_size, seq_len = samples[0].size()

        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'discriminator-rollout')

        for rollout in pbar:
            real_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   unmasked,
                                   tag="d-step",
                                   real=True)

            real_loss = real_loss.sum() / batch_size

            with torch.no_grad():
                net_output = self.model(masked,
                                        lengths,
                                        mask,
                                        unmasked,
                                        tag="g-step")
                generated = net_output[1]

            fake_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   generated,
                                   tag="d-step",
                                   real=False)

            fake_loss = fake_loss.sum() / batch_size

            loss = (real_loss + fake_loss) / 2
            loss.backward()

            real.update(real_loss.item())
            fake.update(fake_loss.item())

        self.opt.step()
        self.logger.log("discriminator/real", self.step, real.avg)
        self.logger.log("discriminator/fake", self.step, fake.avg)
        self.logger.log("discriminator", self.step, real.avg + fake.avg)

    def rollout_critic(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'critic-rollout')
        for rollout in pbar:
            loss = self.model(masked, lengths, mask, unmasked, tag="c-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(loss.item())

        self.opt.step()
        self.logger.log("critic/loss", self.step, meter.avg)

    def rollout_generator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        ppl_meter = defaultdict(lambda: AverageMeter())
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'generator-rollout')

        for rollout in pbar:
            loss, generated, ppl = self.model(masked,
                                              lengths,
                                              mask,
                                              unmasked,
                                              tag="g-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(-1 * loss.item())
            # for key in ppl:
            #     ppl[key] = ppl[key].sum() / batch_size
            #     ppl_meter[key].update(ppl[key].item())
        self.opt.step()
        self.logger.log("generator/advantage", self.step, meter.avg)
        # for key in ppl_meter:
        #     self.logger.log("ppl/{}".format(key), ppl_meter[key].avg)

        self.debug('train', samples, generated)

    def debug(self, key, samples, generated):
        masked, unmasked, lengths, mask = samples
        tag = 'generated/{}'.format(key)
        logger = lambda s: self.logger.log(tag, s)
        pretty_print(logger,
                     self.vocab,
                     masked,
                     unmasked,
                     generated,
                     truncate=10)

    def validate_dataset(self, loader):
        self.model.eval()
        _meters = 'generator dfake dreal critic ppl_sampled ppl_truths'
        _n_meters = len(_meters.split())
        Meters = namedtuple('Meters', _meters)
        meters_list = [AverageMeter() for i in range(_n_meters)]
        meters = Meters(*meters_list)
        for sample_batch in loader:
            self._validate(meters, sample_batch)
            for key, value in meters._asdict().items():
                pass
                # print(key, value.avg)

    @property
    def umodel(self):
        if isinstance(self.model, DataParallel):
            return self.model.module
        return self.model

    def aggregate(self, batch_size):
        return lambda tensor: tensor.sum() / batch_size

    def _validate(self, meters, samples):
        with torch.no_grad():
            masked, unmasked, lengths, mask = samples
            batch_size, seq_len = samples[0].size()

            agg = self.aggregate(batch_size)

            real_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   unmasked,
                                   tag="d-step",
                                   real=True)

            real_loss = agg(real_loss)

            generator_loss, generated, ppl = self.model(masked,
                                                        lengths,
                                                        mask,
                                                        unmasked,
                                                        tag="g-step",
                                                        ppl=True)

            generator_loss = agg(generator_loss)

            fake_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   generated,
                                   tag="d-step",
                                   real=False)

            fake_loss = agg(fake_loss)

            loss = (real_loss + fake_loss) / 2

            critic_loss = self.model(masked,
                                     lengths,
                                     mask,
                                     unmasked,
                                     tag="c-step")
            critic_loss = agg(fake_loss)

            meters.dreal.update(real_loss.item())
            meters.dfake.update(fake_loss.item())
            meters.generator.update(generator_loss.item())
            meters.critic.update(critic_loss.item())

            self.debug('dev', samples, generated)

            for key in ppl:
                ppl[key] = agg(ppl[key])

            meters.ppl_sampled.update(ppl['sampled'].item())
            meters.ppl_truths.update(ppl['ground-truth'].item())
            self.debug('dev', samples, generated)
Exemplo n.º 5
0
class RRDBM(BaseModel):
    def __init__(self, opt):
        super(RRDBM, self).__init__(opt)

        # define networks and load pretrained models
        train_opt = opt['train']

        self.netG_R = define_SR(opt).to(self.device)

        if opt['dist']:
            self.netG_R = DistributedDataParallel(
                self.netG_R, device_ids=[torch.cuda.current_device()])

        else:
            self.netG_R = DataParallel(self.netG_R)
        # define losses, optimizer and scheduler
        if self.is_train:
            # losses
            # if train_opt['l_pixel_type']=="L1":
            #     self.criterionPixel= torch.nn.L1Loss().to(self.device)
            # elif train_opt['l_pixel_type']=="CR":
            #     self.criterionPixel=CharbonnierLoss().to(self.device)
            #
            # else:
            #     raise NotImplementedError("pixel_type does not implement still")
            self.criterionPixel = SRLoss(
                loss_type=train_opt['l_pixel_type']).to(self.device)
            # optimizers
            self.optimizer_G = torch.optim.Adam(self.netG_R.parameters(),
                                                lr=train_opt['lr'],
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #scheduler
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError("lr_scheme does not implement still")

            self.log_dict = OrderedDict()
            self.train_state()

        self.load()  # load R

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def feed_data(self, data):
        self.LQ = data['LQ'].to(self.device)
        self.HQ = data['HQ'].to(self.device)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""

        self.fake_HQ = self.netG_R(self.LQ)

    def backward_G(self, step):
        """Calculate the loss for generators G_A and G_B"""

        self.loss_G_pixel = self.criterionPixel(self.fake_HQ, self.HQ)
        if len(self.loss_G_pixel) == 2:
            if self.opt['train']['other_step'] < step:
                self.loss_G_total = self.loss_G_pixel[0] * self.opt['train']['l_l1_weight']+ \
                                    self.loss_G_pixel[1] * self.opt['train']['l_ssim_weight']
            else:
                self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][
                    'l_l1_weight']
        else:

            self.loss_G_total = self.loss_G_pixel[0] * self.opt['train'][
                'l_l1_weight']

        self.loss_G_total.backward()

    def optimize_parameters(self, step):
        # G
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G
        self.optimizer_G.zero_grad()  # set G gradients to zero
        self.backward_G(step)  # calculate gradients for G
        self.optimizer_G.step()  # update G's weights

        # set log
        for i in range(len(self.loss_G_pixel)):
            self.log_dict[str(i)] = self.loss_G_pixel[i].item()
        # self.log_dict['loss_l1'] = self.loss_G_pixel.item() if self.opt['train']['l_l1_weight']!=0 else 0

    def train_state(self):
        self.netG_R.train()

    def test_state(self):
        self.netG_R.eval()

    def val(self):
        self.test_state()
        with torch.no_grad():
            self.forward()
        self.train_state()

    def test(self, img):

        self.netG_R.eval()
        with torch.no_grad():

            SR = self.netG_R(img)
        return SR

    def get_network(self):
        return self.netG_R

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals_and_cal_metric(self, opt, current_step):

        visuals = [
            F.interpolate(self.LQ,
                          scale_factor=self.opt['datasets']['train']['scale'],
                          mode='bilinear',
                          align_corners=True), self.fake_HQ, self.HQ
        ]

        util.write_2images(visuals, opt['datasets']['val']['batch_size'],
                           opt['path']['val_images'],
                           'test_%08d' % (current_step))

        # HTML
        util.write_html(opt['path']['experiments_root'] + "/index.html",
                        (current_step), opt['train']['val_freq'],
                        opt['path']['val_images'])

        #src BRG range [0-255] HWC
        srimg = util.tensor2img(self.fake_HQ)
        hrimg = util.tensor2img(self.HQ)

        psnr = calculate_psnr(srimg, hrimg)
        ssim = calculate_ssim(srimg, hrimg)
        return {"psnr": psnr, "ssim": ssim}

    def print_network(self):

        if self.is_train:
            # Generator
            s, n = self.get_network_description(self.netG_R)
            net_struc_str = '{} - {}'.format(
                self.netG_R.__class__.__name__,
                self.netG_R.module.__class__.__name__)
            logger.info(
                'Network G_R structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G_R = self.opt['path']['pretrain_model_G_R']

        if load_path_G_R is not None:
            logger.info(
                'Loading models for G [{:s}] ...'.format(load_path_G_R))
            self.load_network(load_path_G_R, self.netG_R,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG_R, 'G_R', iter_step)
Exemplo n.º 6
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt["dist"]:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt["pixel_weight"] > 0:
                l_pix_type = train_opt["pixel_criterion"]
                if l_pix_type == "l1":
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == "l2":
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_pix_type))
                self.l_pix_w = train_opt["pixel_weight"]
            else:
                logger.info("Remove pixel loss.")
                self.cri_pix = None

            # G feature loss
            if train_opt["feature_weight"] > 0:
                l_fea_type = train_opt["feature_criterion"]
                if l_fea_type == "l1":
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == "l2":
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        "Loss type [{:s}] not recognized.".format(l_fea_type))
                self.l_fea_w = train_opt["feature_weight"]
            else:
                logger.info("Remove feature loss.")
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt["dist"]:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt["gan_type"], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt["gan_weight"]
            # D_update_ratio and D_init_iters
            self.D_update_ratio = (train_opt["D_update_ratio"]
                                   if train_opt["D_update_ratio"] else 1)
            self.D_init_iters = (train_opt["D_init_iters"]
                                 if train_opt["D_init_iters"] else 0)

            # optimizers
            # G
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            optim_params = []
            for (
                    k,
                    v,
            ) in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            "Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1_G"], train_opt["beta2_G"]),
            )
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt["weight_decay_D"] if train_opt[
                "weight_decay_D"] else 0
            self.optimizer_D = torch.optim.Adam(
                self.netD.parameters(),
                lr=train_opt["lr_D"],
                weight_decay=wd_D,
                betas=(train_opt["beta1_D"], train_opt["beta2_D"]),
            )
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError(
                    "MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data["LQ"].to(self.device)  # LQ
        if need_GT:
            self.var_H = data["GT"].to(self.device)  # GT
            input_ref = data["ref"] if "ref" in data else data["GT"]
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.opt["train"]["gan_type"] == "gan":
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt["train"]["gan_type"] == "ragan":
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = (
                    self.l_gan_w *
                    (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                     + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                    True)) / 2)
            l_g_total += l_g_gan

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt["train"]["gan_type"] == "gan":
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt["train"]["gan_type"] == "ragan":
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict["l_g_pix"] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict["l_g_fea"] = l_g_fea.item()
            self.log_dict["l_g_gan"] = l_g_gan.item()

        self.log_dict["l_d_real"] = l_d_real.item()
        self.log_dict["l_d_fake"] = l_d_fake.item()
        self.log_dict["D_real"] = torch.mean(pred_d_real.detach())
        self.log_dict["D_fake"] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["LQ"] = self.var_L.detach()[0].float().cpu()
        out_dict["SR"] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = "{} - {}".format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = "{}".format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                "Network G structure: {}, with parameters: {:,d}".format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = "{} - {}".format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = "{}".format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    "Network D structure: {}, with parameters: {:,d}".format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = "{} - {}".format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__,
                    )
                else:
                    net_struc_str = "{}".format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        "Network F structure: {}, with parameters: {:,d}".
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt["path"]["pretrain_model_G"]
        if load_path_G is not None:
            logger.info("Loading model for G [{:s}] ...".format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt["path"]["strict_load"])
        load_path_D = self.opt["path"]["pretrain_model_D"]
        if self.opt["is_train"] and load_path_D is not None:
            logger.info("Loading model for D [{:s}] ...".format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt["path"]["strict_load"])

    def save(self, iter_step):
        self.save_network(self.netG, "G", iter_step)
        self.save_network(self.netD, "D", iter_step)
Exemplo n.º 7
0
def create_and_test_triplet_network(batch_triplet_indices_loader,
                                    experiment_name,
                                    path_to_emb_net,
                                    unseen_triplets,
                                    dataset_name,
                                    model_name,
                                    logger,
                                    test_n,
                                    n,
                                    dim,
                                    layers,
                                    learning_rate=5e-2,
                                    epochs=20,
                                    hl_size=100):
    """
    Description: Constructs the OENN network, defines an optimizer and trains the network on the data w.r.t triplet loss.
    :param model_name:
    :param dataset_name:
    :param test_n:
    :param path_to_emb_net: Data loader object. Gives triplet indices in batches.
    :param n: # points
    :param dim: # features/ dimensions
    :param layers: # layers
    :param learning_rate: learning rate of optimizer.
    :param epochs: # epochs
    :param hl_size: # width of the hidden layer
    :param unseen_triplets: #TODO
    :param logger: # for logging
    :return:
    """

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    digits = int(math.ceil(math.log2(n)))

    #  Define train model
    emb_net_train = define_model(model_name=model_name,
                                 digits=digits,
                                 hl_size=hl_size,
                                 dim=dim,
                                 layers=layers)
    emb_net_train = emb_net_train.to(device)

    for param in emb_net_train.parameters():
        param.requires_grad = False

    if torch.cuda.device_count() > 1:
        emb_net_train = DataParallel(emb_net_train)
        print('multi-gpu')

    checkpoint = torch.load(path_to_emb_net)['model_state_dict']
    key_word = list(checkpoint.keys())[0].split('.')[0]
    if key_word == 'module':
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        emb_net_train.load_state_dict(new_state_dict)
    else:
        emb_net_train.load_state_dict(checkpoint)

    emb_net_train.eval()

    #  Define test model
    emb_net_test = define_model(model_name=model_name,
                                digits=digits,
                                hl_size=hl_size,
                                dim=dim,
                                layers=layers)
    emb_net_test = emb_net_test.to(device)

    if torch.cuda.device_count() > 1:
        emb_net_test = DataParallel(emb_net_test)
        print('multi-gpu')

    # Optimizer
    optimizer = torch.optim.Adam(emb_net_test.parameters(), lr=learning_rate)
    criterion = nn.TripletMarginLoss(margin=1, p=2)
    criterion = criterion.to(device)

    logger.info('#### Dataset Selection #### \n')
    logger.info('dataset:', dataset_name)
    logger.info('#### Network and learning parameters #### \n')
    logger.info('------------------------------------------ \n')
    logger.info('Model Name: ' + model_name + '\n')
    logger.info('Number of hidden layers: ' + str(layers) + '\n')
    logger.info('Hidden layer width: ' + str(hl_size) + '\n')
    logger.info('Embedding dimension: ' + str(dim) + '\n')
    logger.info('Learning rate: ' + str(learning_rate) + '\n')
    logger.info('Number of epochs: ' + str(epochs) + '\n')

    logger.info(' #### Training begins #### \n')
    logger.info('---------------------------\n')

    digits = int(math.ceil(math.log2(n)))
    bin_array = data_utils.get_binary_array(n, digits)

    trip_data = torch.tensor(bin_array[unseen_triplets])
    trip = trip_data.squeeze().to(device).float()

    # Training begins
    train_time = 0
    for ep in range(epochs):
        # Epoch is one pass over the dataset
        epoch_loss = 0

        for batch_ind, trips in enumerate(batch_triplet_indices_loader):
            sys.stdout.flush()
            trip = trips.squeeze().to(device).float()

            # Training time
            begin_train_time = time.time()
            # Forward pass
            embedded_a = emb_net_test(trip[:, :digits])
            embedded_p = emb_net_train(trip[:, digits:2 * digits])
            embedded_n = emb_net_train(trip[:, 2 * digits:])
            # Compute loss
            loss = criterion(embedded_a, embedded_p, embedded_n).to(device)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # End of training
            end_train_time = time.time()
            if batch_ind % 50 == 0:
                logger.info('Epoch: ' + str(ep) + ' Mini batch: ' +
                            str(batch_ind) + '/' +
                            str(len(batch_triplet_indices_loader)) +
                            ' Loss: ' + str(loss.item()))
                sys.stdout.flush()  # Prints faster to the out file
            epoch_loss += loss.item()
            train_time = train_time + end_train_time - begin_train_time

        # Log
        logger.info('Epoch ' + str(ep) + ' - Average Epoch Loss:  ' +
                    str(epoch_loss / len(batch_triplet_indices_loader)) +
                    ' Training time ' + str(train_time))
        sys.stdout.flush()  # Prints faster to the out file

        # Saving the results
        logger.info('Saving the models and the results')
        sys.stdout.flush()  # Prints faster to the out file

        os.makedirs('test_checkpoints', mode=0o777, exist_ok=True)
        model_path = 'test_checkpoints/' + \
                     experiment_name + \
                     '.pt'
        torch.save(
            {
                'epochs': ep,
                'model_state_dict': emb_net_test.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss:': epoch_loss,
            }, model_path)

    # Compute the embedding of the data points.
    bin_array_test = data_utils.get_binary_array(test_n, digits)
    test_embeddings = emb_net_test(
        torch.Tensor(bin_array_test).cuda().float()).cpu().detach().numpy()
    train_embeddings = emb_net_train(
        torch.Tensor(bin_array).cuda().float()).cpu().detach().numpy()
    unseen_triplet_error, _ = data_utils.triplet_error_unseen(
        test_embeddings, train_embeddings, unseen_triplets)

    logger.info('Unseen triplet error is ' + str(unseen_triplet_error))
    return unseen_triplet_error
Exemplo n.º 8
0
class ESRGAN_EESN_FRCNN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_FRCNN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.config = config
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG)

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD)

        #FRCNN_model
        self.netFRCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn(
            pretrained=True)
        num_classes = 2  # car and background
        in_features = self.netFRCNN.roi_heads.box_predictor.cls_score.in_features
        self.netFRCNN.roi_heads.box_predictor = FastRCNNPredictor(
            in_features, num_classes)
        self.netFRCNN.to(self.device)

        self.netG.train()
        self.netD.train()
        self.netFRCNN.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF)
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # FRCNN -- use weigt decay
        FRCNN_params = [
            p for p in self.netFRCNN.parameters() if p.requires_grad
        ]
        self.optimizer_FRCNN = torch.optim.SGD(FRCNN_params,
                                               lr=0.005,
                                               momentum=0.9,
                                               weight_decay=0.0005)
        self.optimizers.append(self.optimizer_FRCNN)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, image, targets):
        self.var_L = image['image_lq'].to(self.device)
        self.var_H = image['image'].to(self.device)
        input_ref = image['ref'] if 'ref' in image else image['image']
        self.var_ref = input_ref.to(self.device)
        '''
        for t in targets:
            for k, v in t.items():
                print(v)
        '''
        self.targets = [{k: v.to(self.device)
                         for k, v in t.items()} for t in targets]

    def optimize_parameters(self, step):
        #Generator
        for p in self.netG.parameters():
            p.requires_grad = True
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, self.x_learned_lap_fake, _ = self.netG(
            self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            self.lap_HR = kornia.laplacian(self.var_H, 3)
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * (
                    self.cri_charbonnier(self.final_SR, self.var_H) +
                    self.cri_charbonnier(self.x_learned_lap_fake, self.lap_HR)
                )  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward(retain_graph=True)
            # self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()
        '''
        Freeze EESRGAN
        '''
        #freeze Generator
        '''
        for p in self.netG.parameters():
            p.requires_grad = False
        '''
        for p in self.netD.parameters():
            p.requires_grad = False
        #Run FRCNN
        self.optimizer_FRCNN.zero_grad()
        self.intermediate_img = self.final_SR
        img_count = self.intermediate_img.size()[0]
        self.intermediate_img = [
            self.intermediate_img[i] for i in range(img_count)
        ]
        loss_dict = self.netFRCNN(self.intermediate_img, self.targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        losses.backward()
        self.optimizer_G.step()
        self.optimizer_FRCNN.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        self.log_dict['FRCNN_loss'] = loss_value

    def test(self, valid_data_loader, train=True, testResult=False):
        self.netG.eval()
        self.netFRCNN.eval()
        self.targets = valid_data_loader
        if testResult == False:
            with torch.no_grad():
                self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                    self.var_L)
                self.x_lap_HR = kornia.laplacian(self.var_H, 3)
        if train == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
        if testResult == True:
            evaluate(self.netG, self.netFRCNN, self.targets, self.device)
            evaluate_save(self.netG, self.netFRCNN, self.targets, self.device,
                          self.config)
        self.netG.train()
        self.netFRCNN.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

        #FRCNN_model
        # Discriminator
        s, n = self.get_network_description(self.netFRCNN)
        if isinstance(self.netFRCNN, nn.DataParallel) or isinstance(
                self.netFRCNN, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netFRCNN.__class__.__name__,
                self.netFRCNN.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netFRCNN.__class__.__name__)

        logger.info(
            'Network FRCNN structure: {}, with parameters: {:,d}'.format(
                net_struc_str, n))
        logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])
        load_path_FRCNN = self.config['path']['pretrain_model_FRCNN']
        if load_path_FRCNN:
            logger.info(
                'Loading model for D [{:s}] ...'.format(load_path_FRCNN))
            self.load_network(load_path_FRCNN, self.netFRCNN,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netFRCNN, 'FRCNN', iter_step)
Exemplo n.º 9
0
class Trainer(object):
    def __init__(self,
                 batch=8,
                 subdivisions=4,
                 epochs=100,
                 burn_in=1000,
                 steps=[400000, 450000]):

        _model = build_from_dict(model, DETECTORS)
        self.model = DataParallel(_model.cuda(), device_ids=[0])

        self.train_dataset = build_from_dict(data_cfg['train'], DATASET)
        self.val_dataset = build_from_dict(data_cfg['val'], DATASET)

        self.burn_in = burn_in
        self.steps = steps
        self.epochs = epochs

        self.batch = batch
        self.subdivisions = subdivisions

        self.train_size = len(self.train_dataset)
        self.val_size = len(self.val_dataset)

        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=batch // subdivisions,
                                       shuffle=True,
                                       num_workers=1,
                                       pin_memory=True,
                                       drop_last=True,
                                       collate_fn=self.collate)

        self.val_loader = DataLoader(self.val_dataset,
                                     batch_size=batch // subdivisions,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True,
                                     drop_last=True,
                                     collate_fn=self.collate)

        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=0.001 / batch,
            betas=(0.9, 0.999),
            eps=1e-08,
        )

        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
                                                     self.burnin_schedule)

    def train(self):
        self.model.train()
        global_step = 0
        checkpoints = r'/disk2/project/pytorch-YOLOv4/checkpoints/'
        save_prefix = 'Yolov4_epoch_'
        saved_models = collections.deque()
        for epoch in range(self.epochs):

            epoch_loss = 0
            epoch_step = 0

            for i, batch in enumerate(self.train_loader):
                losses = self.model(**batch)
                loss = self.parse_losses(losses)
                loss.backward()
                epoch_loss += loss.item()
                print('loss :{}'.format(loss))

                global_step += 1
                epoch_step += 1

                if global_step % self.subdivisions == 0:
                    self.optimizer.zero_grad()
                    self.optimizer.step()
                    self.scheduler.step()

            try:
                # os.mkdir(config.checkpoints)
                os.makedirs(checkpoints, exist_ok=True)
            except OSError:
                pass
            save_path = os.path.join(checkpoints,
                                     f'{save_prefix}{epoch + 1}.pth')
            torch.save(model.state_dict(), save_path)

            saved_models.append(save_path)
            if len(saved_models) > 5:
                model_to_remove = saved_models.popleft()
                try:
                    os.remove(model_to_remove)
                except:
                    pass

    def burnin_schedule(self, i):
        if i < self.burn_in:
            factor = pow(i / self.burn_in, 4)
        elif i < self.steps[0]:
            factor = 1.0
        elif i < self.steps[1]:
            factor = 0.1
        else:
            factor = 0.01
        return factor

    def collate(self, batch):
        if 'multi_scale' in data_cfg.keys() and len(
                data_cfg['multi_scale']) > 0:
            multi_scale = data_cfg['multi_scale']
            if isinstance(multi_scale, dict) and 'type' in multi_scale.keys():
                randomShape = build_from_dict(multi_scale, TRANSFORMS)
                batch = randomShape(batch)
        collate = default_collate(batch)
        return collate

    def parse_losses(self, losses):
        log_vars = collections.OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    '{} is not a tensor or list of tensors'.format(loss_name))

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        return loss
Exemplo n.º 10
0
class VRNModel(BaseModel):
    def __init__(self, opt):
        super(VRNModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.gop = opt['gop']
        train_opt = opt['train']
        test_opt = opt['test']
        self.opt = opt
        self.train_opt = train_opt
        self.test_opt = test_opt
        self.opt_net = opt['network_G']
        self.center = self.gop // 2

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])
            self.Reconstruction_center = ReconstructionLoss(losstype="center")

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def init_hidden_state(self, z):
        b, c, h, w = z.shape
        h_t = []
        c_t = []
        for _ in range(self.opt_net['block_num_rbm']):
            h_t.append(torch.zeros([b, c, h, w]).cuda())
            c_t.append(torch.zeros([b, c, h, w]).cuda())
        memory = torch.zeros([b, c, h, w]).cuda()

        return h_t, c_t, memory

    def loss_forward(self, out, y):
        if self.opt['model'] == 'LSTM-VRN':
            l_forw_fit = self.train_opt[
                'lambda_fit_forw'] * self.Reconstruction_forw(out, y)
            return l_forw_fit
        elif self.opt['model'] == 'MIMO-VRN':
            l_forw_fit = 0
            for i in range(out.shape[1]):
                l_forw_fit += self.train_opt[
                    'lambda_fit_forw'] * self.Reconstruction_forw(
                        out[:, i], y[:, i])
            return l_forw_fit

    def loss_back_rec(self, out, x):
        if self.opt['model'] == 'LSTM-VRN':
            l_back_rec = self.train_opt[
                'lambda_rec_back'] * self.Reconstruction_back(out, x)
            return l_back_rec
        elif self.opt['model'] == 'MIMO-VRN':
            l_back_rec = 0
            for i in range(x.shape[1]):
                l_back_rec += self.train_opt[
                    'lambda_rec_back'] * self.Reconstruction_back(
                        out[:, i], x[:, i])
            return l_back_rec

    def loss_center(self, out, x):
        # x.shape: (b, t, c, h, w)
        b, t = x.shape[:2]
        l_center = 0
        for i in range(b):
            mse_s = self.Reconstruction_center(out[i], x[i])
            mse_mean = torch.mean(mse_s)
            for j in range(t):
                l_center += torch.sqrt((mse_s[j] - mse_mean.detach())**2 +
                                       1e-18)
        l_center = self.train_opt['lambda_center'] * l_center / b

        return l_center

    def optimize_parameters(self):
        self.optimizer_G.zero_grad()

        if self.opt['model'] == 'LSTM-VRN':
            # forward downscaling
            b, t, c, h, w = self.real_H.shape
            self.output = [self.netG(x=self.real_H[:, i]) for i in range(t)]

            # hidden state initialization
            z_p = torch.zeros(self.output[0][:, 3:].shape).to(self.device)
            hs = self.init_hidden_state(z_p)
            z_p_back = torch.zeros(self.output[0][:, 3:].shape).to(self.device)
            hs_back = self.init_hidden_state(z_p_back)

            # LSTM forward
            for i in range(self.center + 1):
                y = self.Quantization(self.output[i][:, :3])
                z_p, hs = self.netG(x=[y, z_p], rev=True, hs=hs, direction='f')
            # LSTM backward
            for j in reversed(range(self.center, t)):
                y = self.Quantization(self.output[j][:, :3])
                z_p_back, hs_back = self.netG(x=[y, z_p_back],
                                              rev=True,
                                              hs=hs_back,
                                              direction='b')

            # backward upscaling
            y = self.Quantization(self.output[self.center][:, :3])
            out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True)

            l_back_rec = self.loss_back_rec(self.real_H[:, self.center], out_x)
            LR_ref = self.ref_L[:, self.center].detach()
            l_forw_fit = self.loss_forward(self.output[self.center][:, :3],
                                           LR_ref)

            # total loss
            loss = l_forw_fit + l_back_rec
            loss.backward()

        elif self.opt['model'] == 'MIMO-VRN':
            b, t, c, h, w = self.real_H.shape
            center = t // 2
            intval = self.gop // 2

            self.input = self.real_H[:, center - intval:center + intval + 1]
            self.output = self.netG(x=self.input.reshape(b, -1, h, w))

            LR_ref = self.ref_L[:,
                                center - intval:center + intval + 1].detach()
            out_lrs = self.output[:, :3 * self.gop, :, :].reshape(
                -1, self.gop, 3, h // 4, w // 4)
            l_forw_fit = self.loss_forward(out_lrs, LR_ref)

            y = self.Quantization(self.output[:, :3 * self.gop, :, :])
            out_x, out_z = self.netG(x=[y, None], rev=True)

            l_back_rec = self.loss_back_rec(
                out_x.reshape(-1, self.gop, 3, h, w), self.input)
            l_center_x = self.loss_center(out_x.reshape(-1, self.gop, 3, h, w),
                                          self.input)

            # total loss
            loss = l_forw_fit + l_back_rec + l_center_x
            loss.backward()

            if self.train_opt['lambda_center'] != 0:
                self.log_dict['l_center_x'] = l_center_x.item()
        else:
            raise Exception('Model should be either LSTM-VRN or MIMO-VRN.')

        # set log
        self.log_dict['l_back_rec'] = l_back_rec.item()
        self.log_dict['l_forw_fit'] = l_forw_fit.item()

        # gradient clipping
        if self.train_opt['gradient_clipping']:
            nn.utils.clip_grad_norm_(self.netG.parameters(),
                                     self.train_opt['gradient_clipping'])

        self.optimizer_G.step()

    def test(self):
        Lshape = self.ref_L.shape

        self.netG.eval()
        with torch.no_grad():

            if self.opt['model'] == 'LSTM-VRN':

                forw_L = []
                fake_H = []
                b, t, c, h, w = self.real_H.shape

                # forward downscaling
                self.output = [
                    self.netG(x=self.real_H[:, i]) for i in range(t)
                ]

                for i in range(t):
                    # hidden state initialization
                    z_p = torch.zeros(self.output[0][:,
                                                     3:].shape).to(self.device)
                    hs = self.init_hidden_state(z_p)
                    z_p_back = torch.zeros(self.output[0][:, 3:].shape).to(
                        self.device)
                    hs_back = self.init_hidden_state(z_p_back)

                    # find sequence index
                    if i - self.center < 0:
                        indices_past = [0 for _ in range(self.center - i)]
                        for index in range(i + 1):
                            indices_past.append(index)
                        indices_future = [
                            index for index in range(i, i + self.center + 1)
                        ]
                    elif i > t - self.center - 1:
                        indices_past = [
                            index for index in range(i - self.center, i + 1)
                        ]
                        indices_future = [index for index in range(i, t)]
                        for index in range(self.center - len(indices_future) +
                                           1):
                            indices_future.append(t - 1)
                    else:
                        indices_past = [
                            index for index in range(i - self.center, i + 1)
                        ]
                        indices_future = [
                            index for index in range(i, i + self.center + 1)
                        ]

                    # LSTM forward
                    for j in indices_past:
                        y = self.Quantization(self.output[j][:, :3])
                        z_p, hs = self.netG(x=[y, z_p],
                                            rev=True,
                                            hs=hs,
                                            direction='f')
                    # LSTM backward
                    for k in reversed(indices_future):
                        y = self.Quantization(self.output[k][:, :3])
                        z_p_back, hs_back = self.netG(x=[y, z_p_back],
                                                      rev=True,
                                                      hs=hs_back,
                                                      direction='b')

                    # backward upscaling
                    y = self.Quantization(self.output[i][:, :3])
                    out_x, out_z = self.netG(x=[y, [z_p, z_p_back]], rev=True)

                    forw_L.append(y)
                    fake_H.append(out_x)

            elif self.opt['model'] == 'MIMO-VRN':

                forw_L = []
                fake_H = []
                b, t, c, h, w = self.real_H.shape
                n_gop = t // self.gop

                for i in range(n_gop + 1):
                    if i == n_gop:
                        # calculate indices to pad last frame
                        indices = [
                            i * self.gop + j for j in range(t % self.gop)
                        ]
                        for _ in range(self.gop - t % self.gop):
                            indices.append(t - 1)
                        self.input = self.real_H[:, indices]
                    else:
                        self.input = self.real_H[:, i * self.gop:(i + 1) *
                                                 self.gop]

                    # forward downscaling
                    self.output = self.netG(x=self.input.reshape(b, -1, h, w))
                    out_lrs = self.output[:, :3 * self.gop, :, :].reshape(
                        -1, self.gop, 3, h // 4, w // 4)

                    # backward upscaling
                    y = self.Quantization(self.output[:, :3 * self.gop, :, :])
                    out_x, out_z = self.netG(x=[y, None], rev=True)
                    out_x = out_x.reshape(-1, self.gop, 3, h, w)

                    if i == n_gop:
                        for j in range(t % self.gop):
                            forw_L.append(out_lrs[:, j])
                            fake_H.append(out_x[:, j])
                    else:
                        for j in range(self.gop):
                            forw_L.append(out_lrs[:, j])
                            fake_H.append(out_x[:, j])

            else:
                raise Exception('Model should be either LSTM-VRN or MIMO-VRN.')

            self.fake_H = torch.stack(fake_H, dim=1)
            self.forw_L = torch.stack(forw_L, dim=1)

        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()

        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
Exemplo n.º 11
0
class RRSNetModel(BaseModel):
    def __init__(self, opt):
        super(RRSNetModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.l1_init = train_opt['l1_init'] if train_opt['l1_init'] else 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD_grad = networks.define_D_grad(opt).to(self.device) # D_grad
            if opt['dist']:
                self.netD = DistributedDataParallel(self.netD,
                                                    device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
                self.netD_grad = DistributedDataParallel(self.netD_grad,
                                                    device_ids=[torch.cuda.current_device()],find_unused_parameters=True)
            else:
                self.netD = DataParallel(self.netD)
                self.netD_grad = DataParallel(self.netD_grad)

            self.netG.train()
            self.netD.train()
            self.netD_grad.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt, use_bn=False).to(self.device)
                if opt['dist']:
                    pass  # do not need to use DistributedDataParallel for netF
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
            # Branch_init_iters
            self.Branch_pretrain = train_opt['Branch_pretrain'] if train_opt['Branch_pretrain'] else 0
            self.Branch_init_iters = train_opt['Branch_init_iters'] if train_opt['Branch_init_iters'] else 1

            # gradient_pixel_loss
            if train_opt['gradient_pixel_weight'] > 0:
                self.cri_pix_grad = nn.MSELoss().to(self.device)
                self.l_pix_grad_w = train_opt['gradient_pixel_weight']
            else:
                self.cri_pix_grad = None

            # gradient_gan_loss
            if train_opt['gradient_gan_weight'] > 0:
                self.cri_grad_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
                self.l_gan_grad_w = train_opt['gradient_gan_weight']
            else:
                self.cri_grad_gan = None


            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'], train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'], train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # D_grad
            wd_D_grad = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
            self.optimizer_D_grad = torch.optim.Adam(self.netD_grad.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))

            self.optimizers.append(self.optimizer_D_grad)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.get_grad = Get_gradient()
            self.get_grad_nopadding = Get_gradient_nopadding()

        self.print_network()  # print network

    def feed_data(self, data, need_GT=True):
        self.var_LQ = data['LQ'].to(self.device)  # LQ
        self.var_LQ_UX4 = data['LQ_UX4'].to(self.device)  
        self.var_Ref = data['Ref'].to(self.device)
        self.var_Ref_DUX4 = data['Ref_DUX4'].to(self.device)

        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            self.var_ref = data['GT'].clone().to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        for p in self.netD_grad.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4)

        self.fake_H_grad = self.get_grad(self.fake_H)
        self.var_H_grad = self.get_grad(self.var_H)
        self.var_ref_grad = self.get_grad(self.var_ref)
        self.var_H_grad_nopadding = self.get_grad_nopadding(self.var_H)
        self.grad_LR = self.get_grad_nopadding(self.var_LQ)

        l_g_total = 0

        if step < self.l1_init:
          l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
          l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
          l_g_total = l_pix + l_g_pix_grad 
          l_g_total.backward()
          self.optimizer_G.step()
          self.log_dict['l_g_pix'] = l_pix.item()
        else:
          if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            with torch.autograd.set_detect_anomaly(True):
              if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
              if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

              if self.cri_pix_grad: #gradient pixel loss
                l_g_pix_grad = self.l_pix_grad_w * self.cri_pix_grad(self.fake_H_grad, self.var_H_grad)
                l_g_total = l_g_total + l_g_pix_grad

              if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
              elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                    self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
              l_g_total += l_g_gan

              # grad G gan + cls loss
              if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
                l_g_gan_grad = self.l_gan_grad_w * self.cri_gan(pred_g_fake_grad, True)
              elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real_grad = self.netD_grad(self.var_ref_grad).detach()
                pred_g_fake_grad = self.netD_grad(self.fake_H_grad)
                l_g_gan_grad = self.l_gan_grad_w * (
                    self.cri_gan(pred_d_real_grad - torch.mean(pred_g_fake_grad), False) +
                    self.cri_gan(pred_g_fake_grad - torch.mean(pred_d_real_grad), True)) / 2
              l_g_total = l_g_total + l_g_gan_grad

              l_g_total.backward()
              self.optimizer_G.step()

          # D
          for p in self.netD.parameters():
            p.requires_grad = True

          for p in self.netD_grad.parameters():
            p.requires_grad = True

          with torch.autograd.set_detect_anomaly(True):
            self.optimizer_D.zero_grad()
          # need to forward and backward separately, since batch norm statistics differ
            l_d_total = 0
            if self.opt['train']['gan_type'] == 'gan':
              pred_d_real = self.netD(self.var_ref)
              l_d_real = self.cri_gan(pred_d_real, True)
              l_d_real.backward()
              pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
              l_d_fake = self.cri_gan(pred_d_fake, False)
              l_d_fake.backward()
            elif self.opt['train']['gan_type'] == 'ragan':
              pred_d_real = self.netD(self.var_ref)
              pred_d_fake = self.netD(self.fake_H.detach())
              l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
              l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
              l_d_total = (l_d_real + l_d_fake) / 2
              l_d_total.backward()
            self.optimizer_D.step()


            self.optimizer_D_grad.zero_grad()
            l_d_total_grad = 0


            if self.opt['train']['gan_type'] == 'gan':
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              l_d_real_grad = self.cri_grad_gan(pred_d_real_grad, True)
              l_d_real_grad.backward()
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_fake_grad = self.cri_gan(pred_d_fake_grad, False)
              l_d_fake_grad.backward()
            elif self.opt['train']['gan_type'] == 'ragan':
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_real_grad = self.cri_gan(pred_d_real_grad - torch.mean(pred_d_fake_grad), True)
              pred_d_real_grad = self.netD_grad(self.var_ref_grad)
              pred_d_fake_grad = self.netD_grad(self.fake_H_grad.detach())
              l_d_fake_grad = self.cri_gan(pred_d_fake_grad - torch.mean(pred_d_real_grad), False)
              l_d_total_grad = (l_d_real_grad + l_d_fake_grad) / 2
              l_d_total_grad.backward()

            self.optimizer_D_grad.step()

          # set log
          if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
          # D
          self.log_dict['l_d_real'] = l_d_real.item()
          self.log_dict['l_d_fake'] = l_d_fake.item()
          # D_grad 
          self.log_dict['l_d_real_grad'] = l_d_real_grad.item()
          self.log_dict['l_d_fake_grad'] = l_d_fake_grad.item()

          # D outputs
          self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
          self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

          # D_grad outputs
          self.log_dict['D_real_grad'] = torch.mean(pred_d_real_grad.detach())
          self.log_dict['D_fake_grad'] = torch.mean(pred_d_fake_grad.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_LQ, self.var_LQ_UX4, self.var_Ref, self.var_Ref_DUX4)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
        load_path_D_grad = self.opt['path']['pretrain_model_D_grad']
        if self.opt['is_train'] and load_path_D_grad is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D_grad))
            self.load_network(load_path_D_grad, self.netD_grad, self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netD_grad, 'D_grad', iter_step)
Exemplo n.º 12
0
class RRDBSRModel(BaseModel):
    def __init__(self, opt):
        super(RRDBSRModel, self).__init__(opt)

        # define networks and load pretrained models
        train_opt = opt['train']

        self.netG_SR = networks.define_SR(opt).to(self.device)

        if self.is_train:
            if not self.opt['full_sr']:

                self.netG_BA = networks.define_G(opt).to(self.device)

        if opt['dist']:
            self.netG_SR = DistributedDataParallel(
                self.netG_SR,
                device_ids=[torch.cuda.current_device()],
                find_unused_parameters=True)
            if self.is_train:
                if not self.opt['full_sr']:
                    self.netG_BA = DistributedDataParallel(
                        self.netG_BA,
                        device_ids=[torch.cuda.current_device()],
                        find_unused_parameters=True)

        else:

            self.netG_SR = DataParallel(self.netG_SR)
            if self.is_train:
                if not self.opt['full_sr']:
                    self.netG_BA = DataParallel(self.netG_BA)

        # define losses, optimizer and scheduler
        if self.is_train:

            # losses
            self.criterion = SRLoss(train_opt['loss_type']).to(
                self.device)  # define GAN loss.

            # optimizers
            self.optimizer_G = torch.optim.Adam(self.netG_SR.parameters(),
                                                lr=train_opt['lr'],
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))

            self.optimizers.append(self.optimizer_G)

            #scheduler

            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            else:
                raise NotImplementedError("lr_scheme does not implement still")

            self.log_dict = OrderedDict()

        self.load()  # load pre-trained mode

        if self.is_train:
            self.train_state()

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def feed_data(self, data):
        self.LR = data['LQ'].to(self.device)
        self.HR = data['HQ'].to(self.device)

    def B2A(self):

        with torch.no_grad():
            fake_LR = self.netG_BA(self.LR)

        return fake_LR

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        if self.opt['full_sr']:
            self.fake_LR = self.LR
        else:
            self.fake_LR = self.B2A()
        self.SR = self.netG_SR(self.fake_LR)

    def backward_G(self, step):
        """Calculate the loss for generators G_A and G_B"""

        self.loss_G = self.criterion(self.SR, self.HR)
        if len(self.loss_G) != 1:
            if self.opt['other_step'] > step:
                self.loss_total=self.loss_G[0]+\
                                self.loss_G[1]*self.opt['l_other_weight']
            else:
                self.loss_total = self.loss_G[0]
        else:
            self.loss_total = self.loss_G[0]
        self.loss_total.backward()

    def optimize_parameters(self, step):
        # G
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        self.optimizer_G.zero_grad()
        self.backward_G(step)  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights

        # set log
        for i in range(len(self.loss_G)):
            self.log_dict[str(i)] = self.loss_G[i].item()

    def train_state(self):
        self.netG_SR.train()
        if not self.opt['full_sr']:
            self.netG_BA.eval()

    def test_state(self):
        self.netG_SR.eval()
        if not self.opt['full_sr']:
            self.netG_BA.eval()

    def val(self):
        self.test_state()
        with torch.no_grad():
            self.forward()
        self.train_state()

    def test(self):
        self.netG_SR.eval()
        with torch.no_grad():
            SR = self.netG_SR(self.LR)
        return {'SR': SR}

    def get_current_log(self):
        return self.log_dict

    def print_network(self):

        if self.is_train:
            # Generator
            s, n = self.get_network_description(self.netG_SR)
            net_struc_str = '{} - {}'.format(
                self.netG_SR.__class__.__name__,
                self.netG_SR.module.__class__.__name__)
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G_SR = self.opt['path']['pretrain_model_G_SR']
        load_path_G_BA = self.opt['path']['pretrain_model_G_BA']

        if load_path_G_BA is not None:
            logger.info(
                'Loading models for G [{:s}] ...'.format(load_path_G_BA))
            self.load_network(load_path_G_BA, self.netG_BA,
                              self.opt['path']['strict_load'])
        else:
            logger.info('GAN model does not exist!')
            if self.is_train:
                if not self.opt['full_sr']:
                    exit(1)
        if load_path_G_SR is not None:

            logger.info(
                'Loading models for D [{:s}] ...'.format(load_path_G_SR))
            self.load_network(load_path_G_SR, self.netG_SR,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG_SR, 'G_SR', iter_step)
Exemplo n.º 13
0
class SSLOnlineEvaluator(Callback):  # pragma: no cover
    """Attaches a MLP for fine-tuning using the standard self-supervised protocol.

    Example::

        # your datamodule must have 2 attributes
        dm = DataModule()
        dm.num_classes = ... # the num of classes in the datamodule
        dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

        # your model must have 1 attribute
        model = Model()
        model.z_dim = ... # the representation dim

        online_eval = SSLOnlineEvaluator(
            z_dim=model.z_dim
        )
    """
    def __init__(
        self,
        z_dim: int,
        drop_p: float = 0.2,
        hidden_dim: Optional[int] = None,
        num_classes: Optional[int] = None,
        dataset: Optional[str] = None,
    ):
        """
        Args:
            z_dim: Representation dimension
            drop_p: Dropout probability
            hidden_dim: Hidden dimension for the fine-tune MLP
        """
        super().__init__()

        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.drop_p = drop_p

        self.optimizer: Optional[Optimizer] = None
        self.online_evaluator: Optional[SSLEvaluator] = None
        self.num_classes: Optional[int] = None
        self.dataset: Optional[str] = None
        self.num_classes: Optional[int] = num_classes
        self.dataset: Optional[str] = dataset

        self._recovered_callback_state: Optional[Dict[str, Any]] = None

    def setup(self,
              trainer: Trainer,
              pl_module: LightningModule,
              stage: Optional[str] = None) -> None:
        if self.num_classes is None:
            self.num_classes = trainer.datamodule.num_classes
        if self.dataset is None:
            self.dataset = trainer.datamodule.name

    def on_pretrain_routine_start(self, trainer: Trainer,
                                  pl_module: LightningModule) -> None:
        # must move to device after setup, as during setup, pl_module is still on cpu
        self.online_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        # switch fo PL compatibility reasons
        accel = (trainer.accelerator_connector if hasattr(
            trainer, "accelerator_connector") else
                 trainer._accelerator_connector)
        if accel.is_distributed:
            if accel.use_ddp:
                from torch.nn.parallel import DistributedDataParallel as DDP

                self.online_evaluator = DDP(self.online_evaluator,
                                            device_ids=[pl_module.device])
            elif accel.use_dp:
                from torch.nn.parallel import DataParallel as DP

                self.online_evaluator = DP(self.online_evaluator,
                                           device_ids=[pl_module.device])
            else:
                rank_zero_warn(
                    "Does not support this type of distributed accelerator. The online evaluator will not sync."
                )

        self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(),
                                          lr=1e-4)

        if self._recovered_callback_state is not None:
            self.online_evaluator.load_state_dict(
                self._recovered_callback_state["state_dict"])
            self.optimizer.load_state_dict(
                self._recovered_callback_state["optimizer_state"])

    def to_device(self, batch: Sequence,
                  device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]:
        # get the labeled batch
        if self.dataset == "stl10":
            labeled_batch = batch[1]
            batch = labeled_batch

        inputs, y = batch

        # last input is for online eval
        x = inputs[-1]
        x = x.to(device)
        y = y.to(device)

        return x, y

    def shared_step(
        self,
        pl_module: LightningModule,
        batch: Sequence,
    ):
        with torch.no_grad():
            with set_training(pl_module, False):
                x, y = self.to_device(batch, pl_module.device)
                representations = pl_module(x).flatten(start_dim=1)

        # forward pass
        mlp_logits = self.online_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_logits, y)

        acc = accuracy(mlp_logits.softmax(-1), y)

        return acc, mlp_loss

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        train_acc, mlp_loss = self.shared_step(pl_module, batch)

        # update finetune weights
        mlp_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        pl_module.log("online_train_acc",
                      train_acc,
                      on_step=True,
                      on_epoch=False)
        pl_module.log("online_train_loss",
                      mlp_loss,
                      on_step=True,
                      on_epoch=False)

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        val_acc, mlp_loss = self.shared_step(pl_module, batch)
        pl_module.log("online_val_acc",
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log("online_val_loss",
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)

    def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           checkpoint: Dict[str, Any]) -> dict:
        return {
            "state_dict": self.online_evaluator.state_dict(),
            "optimizer_state": self.optimizer.state_dict()
        }

    def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           callback_state: Dict[str, Any]) -> None:
        self._recovered_callback_state = callback_state
Exemplo n.º 14
0
class SRGANModel(BaseModel):
    def __init__(self, opt, is_train):
        super(SRGANModel, self).__init__(opt, is_train)
        train_opt = opt
        self.rank = 0

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt.pixel_weight > 0:
                l_pix_type = train_opt.pixel_criterion
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt.pixel_weight
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt.feature_weight > 0:
                l_fea_type = train_opt.feature_criterion
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt.feature_weight
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt.gan_type, 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt.gan_weight
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt.D_update_ratio if train_opt.D_update_ratio else 1
            self.D_init_iters = train_opt.D_init_iters if train_opt.D_init_iters else 0

            # optimizers
            # G
            wd_G = train_opt.weight_decay_G if train_opt.weight_decay_G else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt.lr_G,
                                                weight_decay=wd_G,
                                                betas=(train_opt.beta1_G,
                                                       train_opt.beta2_G))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt.weight_decay_D if train_opt.weight_decay_D else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt.lr_D,
                                                weight_decay=wd_D,
                                                betas=(train_opt.beta1_D,
                                                       train_opt.beta2_D))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt.lr_scheme == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt.lr_steps,
                            restarts=None,
                            weights=None,
                            gamma=train_opt.lr_gamma,
                            clear_state=False))
            elif train_opt.lr_scheme == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt.T_period,
                            eta_min=train_opt.eta_min,
                            restarts=train_opt.restarts,
                            weights=train_opt.restart_weights))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        # Discriminator의 layer들을 전부 backward가 되지 않도로 설정(Generator 학습이므로..?)
        for p in self.netD.parameters():
            p.requires_grad = False
        #Generator의 gradient를 0으로 초기화.
        self.optimizer_G.zero_grad()
        # ipdb.set_trace()
        self.fake_H = self.netG(self.var_L)
        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.opt.gan_type == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt.gan_type == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            # backward 진행
            l_g_total.backward()
            # iterator가 끝났음을 알려줌 -> 메모리 공간 해제 등이 발생
            self.optimizer_G.step()

        # D
        # Discrimanator Gradient가 다시 동작하도록 설정
        for p in self.netD.parameters():
            p.requires_grad = True
        # Discrimanator Gradient의 gradient를 0으로 초기화
        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt.gan_type == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt.gan_type == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach().float().cpu()
        out_dict['SR'] = self.fake_H.detach().float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach().float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            print('Network G structure: {}, with parameters: {:,d}'.format(
                net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                print('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    print('Network F structure: {}, with parameters: {:,d}'.
                          format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt.pretrain_model_G
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG, True)
        load_path_D = self.opt.pretrain_model_D
        if self.is_train and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD, True)

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
class CnnModel(BaseModel):
    def __init__(self, opt):
        super().__init__(opt)

        self.encoder_out = 512

        # Models
        self.cnn_encoder = BasicCnnEncoder(opt.nf).to(self.device)
        self.linear_decoder = LinearDecoder(self.encoder_out, opt.output_n).to(self.device)

        if self.gpu_ids:
            self.cnn_encoder = DataParallel(self.cnn_encoder, self.gpu_ids)
            self.linear_decoder = DataParallel(self.linear_decoder, self.gpu_ids)

        self.model_names = ['cnn_encoder', 'linear_decoder']

        if self.isTrain:
            # Loss
            self.criterion = nn.CrossEntropyLoss()

            # Optimizer
            self.optimizer = optim.Adam(
                itertools.chain(self.cnn_encoder.parameters(), self.linear_decoder.parameters()),
                lr=opt.lr)
            # Continue Training
            if self.opt.ct > 0:
                print(f"Continue training from {self.opt.ct}")
                self.load_networks(self.opt.ct, load_optim=True)

        print(self.cnn_encoder)
        print(self.linear_decoder)

    def feed_input(self, x: dict):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        :param x: include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.image = x['image'].to(self.device)
        self.label_original = x['label'].to(self.device)
        self.image_id = x['image_id']

    def optimize_parameters(self):
        """
        Optimizes parameters
        """
        # Forward
        self.forward()
        self.optimizer.zero_grad()
        # print(self.label_pred)
        # print(self.label_original)
        loss = self.criterion(self.label_pred, self.label_original)
        loss.backward()
        train_loss = loss.item()
        self.training_loss += train_loss
        # print("Train Loss", train_loss)
        self.optimizer.step()

    def forward(self):
        """Run forward pass
        Called by both functions <optimize_parameters> and <test>
        """
        feature_vec = self.cnn_encoder(self.image)
        self.label_pred = self.linear_decoder(feature_vec)
Exemplo n.º 16
0
    torch.load('./models/IDE_market.pth', map_location=torch.device('cpu')))

Da = DataParallel(models.Discriminator(), device_ids=[0, 1]).to(device)
Db = DataParallel(models.Discriminator(), device_ids=[0, 1]).to(device)
Ga = DataParallel(models.Generator(), device_ids=[0, 1]).to(device)
Gb = DataParallel(models.Generator(), device_ids=[0, 1]).to(device)
MSE = nn.MSELoss()
L1 = nn.L1Loss()


def classification_loss(logit, target):
    """Compute softmax cross entropy loss."""
    return F.cross_entropy(logit, target)


da_optimizer = torch.optim.Adam(Da.parameters(), lr=lr, betas=(0.5, 0.999))
db_optimizer = torch.optim.Adam(Db.parameters(), lr=lr, betas=(0.5, 0.999))
ga_optimizer = torch.optim.Adam(Ga.parameters(), lr=lr, betas=(0.5, 0.999))
gb_optimizer = torch.optim.Adam(Gb.parameters(), lr=lr, betas=(0.5, 0.999))
gb_optimizer = torch.optim.Adam(Gb.parameters(), lr=lr, betas=(0.5, 0.999))

IDE_criterion = torch.nn.CrossEntropyLoss()
ignored_params = list(map(id, IDE.model.fc.parameters())) + list(
    map(id, IDE.classifier.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, IDE.parameters())
# Observe that all parameters are being optimized
IDE_optimizer = torch.optim.SGD([{
    'params': base_params,
    'lr': 0.001
}, {
    'params': IDE.model.fc.parameters(),
Exemplo n.º 17
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.img_path = data['GT_path']
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
        if self.train_opt.get("use_HR_ref"):
            self.var_HR_ref = data['img_reference'].to(self.device)
        if "LQ_next" in data.keys():
            self.var_L_next = data['LQ_next'].to(self.device)
            if "GT_next" in data.keys():
                self.var_H_next = data['GT_next'].to(self.device)
                self.var_video_H = torch.cat(
                    [data['GT'].unsqueeze(2), data['GT_next'].unsqueeze(2)],
                    dim=2).to(self.device)
        else:
            self.var_L_next = None

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        args = [self.var_L]
        if self.train_opt.get('use_HR_ref'):
            args += [self.var_HR_ref]
        if self.var_L_next is not None:
            args += [self.var_L_next]
        self.fake_H, self.binary_mask = self.netG(*args)

        #Video Gan
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            with torch.no_grad():
                args = [self.var_L, self.var_HR_ref, self.var_L_next]
                self.fake_H_next, self.binary_mask_next = self.netG(*args)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_pix_mask:
                l_g_pix_mask = self.l_pix_mask_w * self.cri_pix_mask(
                    self.fake_H, self.var_H, self.var_HR_ref)
                l_g_total += l_g_pix_mask
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            # Image Gan
            if self.opt['network_D'] == "discriminator_vgg_128_mask":
                import torch.nn.functional as F
                from models.modules import psina_seg
                if self.segmentor is None:
                    self.segmentor = psina_seg.base.SegmentationModule(
                        encode='stationary_probs').to(self.device)
                self.segmentor = self.segmentor.eval()
                lr = F.interpolate(self.var_H,
                                   scale_factor=0.25,
                                   mode='nearest')
                with torch.no_grad():
                    binary_mask = (
                        1 - self.segmentor.predict(lr[:, [2, 1, 0], ::]))
                binary_mask = F.interpolate(binary_mask,
                                            scale_factor=4,
                                            mode='nearest')
                pred_g_fake = self.netD(self.fake_H,
                                        self.fake_H * (1 - binary_mask),
                                        self.var_HR_ref,
                                        binary_mask * self.var_HR_ref)
            else:
                pred_g_fake = self.netD(self.fake_H)

            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                if self.opt['network_D'] == "discriminator_vgg_128_mask":
                    pred_g_fake = self.netD(self.var_H,
                                            self.var_H * (1 - binary_mask),
                                            self.var_HR_ref,
                                            binary_mask * self.var_HR_ref)
                else:
                    pred_d_real = self.netD(self.var_H)
                pred_d_real = pred_d_real.detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            #Video Gan
            if self.opt['train'].get("gan_video_weight", 0) > 0:
                self.fake_video_H = torch.cat(
                    [self.fake_H.unsqueeze(2),
                     self.fake_H_next.unsqueeze(2)],
                    dim=2)
                pred_g_video_fake = self.net_video_D(self.fake_video_H)
                if self.opt['train']['gan_video_type'] == 'gan':
                    l_g_video_gan = self.l_gan_video_w * self.cri_video_gan(
                        pred_g_video_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    pred_d_video_real = self.net_video_D(self.var_video_H)
                    pred_d_video_real = pred_d_video_real.detach()
                    l_g_video_gan = self.l_gan_video_w * (self.cri_video_gan(
                        pred_d_video_real - torch.mean(pred_g_video_fake),
                        False) + self.cri_video_gan(
                            pred_g_video_fake - torch.mean(pred_d_video_real),
                            True)) / 2
                l_g_total += l_g_video_gan

            # OFLOW regular
            if self.binary_mask is not None:
                l_g_total += 1 * self.binary_mask.mean()

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            for p in self.net_video_D.parameters():
                p.requires_grad = True

        # optimize Image D
        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_H)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2
        l_d_total.backward()
        self.optimizer_D.step()

        # optimize Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.optimizer_video_D.zero_grad()
            l_d_video_total = 0
            pred_d_video_real = self.net_video_D(self.var_video_H)
            pred_d_video_fake = self.net_video_D(
                self.fake_video_H.detach())  # detach to avoid BP to G
            if self.opt['train']['gan_video_type'] == 'gan':
                l_d_video_real = self.cri_video_gan(pred_d_video_real, True)
                l_d_video_fake = self.cri_video_gan(pred_d_video_fake, False)
                l_d_video_total = l_d_video_real + l_d_video_fake
            elif self.opt['train']['gan_video_type'] == 'ragan':
                l_d_video_real = self.cri_video_gan(
                    pred_d_video_real - torch.mean(pred_d_video_fake), True)
                l_d_video_fake = self.cri_video_gan(
                    pred_d_video_fake - torch.mean(pred_d_video_real), False)
                l_d_video_total = (l_d_video_real + l_d_video_fake) / 2
            l_d_video_total.backward()
            self.optimizer_video_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.log_dict['D_video_real'] = torch.mean(
                pred_d_video_real.detach())
            self.log_dict['D_video_fake'] = torch.mean(
                pred_d_video_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            args = [self.var_L]
            if self.train_opt.get('use_HR_ref'):
                args += [self.var_HR_ref]
            if self.var_L_next is not None:
                args += [self.var_L_next]
            self.fake_H, self.binary_mask = self.netG(*args)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.binary_mask is not None:
            out_dict['binary_mask'] = self.binary_mask.detach()[0].float().cpu(
            )
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        # G
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['pretrain_model_G_strict_load'])

        if self.opt['network_G'].get("pretrained_net") is not None:
            self.netG.module.load_pretrained_net_weights(
                self.opt['network_G']['pretrained_net'])

        # D
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['pretrain_model_D_strict_load'])

        # Video D
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            load_path_video_D = self.opt['path'].get("pretrain_model_video_D")
            if self.opt['is_train'] and load_path_video_D is not None:
                self.load_network(
                    load_path_video_D, self.net_video_D,
                    self.opt['path']['pretrain_model_video_D_strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        if self.opt['train'].get("gan_video_weight", 0) > 0:
            self.save_network(self.net_video_D, 'video_D', iter_step)

    @staticmethod
    def _freeze_net(network):
        for p in network.parameters():
            p.requires_grad = False
        return network

    @staticmethod
    def _unfreeze_net(network):
        for p in network.parameters():
            p.requires_grad = True
        return network

    def freeze(self, G, D):
        if G:
            self.netG.module.net = self._freeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._freeze_net(self.netD.module)

    def unfreeze(self, G, D):
        if G:
            self.netG.module.net = self._unfreeze_net(self.netG.module.net)
        if D:
            self.netD.module = self._unfreeze_net(self.netD.module)
Exemplo n.º 18
0
class Model:
    """
    This class handles basic methods for handling the model:
    1. Fit the model
    2. Make predictions
    3. Make inference predictions
    3. Save
    4. Load weights
    5. Restore the model
    6. Restore the model with averaged weights
    """
    def __init__(self, hparams, gpu=None, inference=False):

        self.hparams = hparams
        self.gpu = gpu
        self.inference = inference

        self.start_training = time()

        # ininialize model architecture
        self.__setup_model(inference=inference, gpu=gpu)
        self.postprocessing = Post_Processing()

        # define model parameters
        self.__setup_model_hparams()

        # declare preprocessing object
        self.__seed_everything(42)

    def fit(self, train, valid, pretrain):

        # setup train and val dataloaders
        train_loader = DataLoader(
            train,
            batch_size=self.hparams['batch_size'],
            shuffle=True,
            num_workers=self.hparams['num_workers'],
        )
        valid_loader = DataLoader(
            valid,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.hparams['num_workers'],
        )

        adv_loader = DataLoader(pretrain,
                                batch_size=self.hparams['batch_size'],
                                shuffle=True,
                                num_workers=0)

        # tensorboard
        writer = SummaryWriter(
            f"runs/{self.hparams['model_name']}_{self.start_training}")

        print('Start training the model')
        for epoch in range(self.hparams['n_epochs']):

            # training mode
            self.model.train()
            avg_loss = 0.0
            avg_adv_loss = 0.0

            for X_batch, y_batch, X_batch_adv, y_batch_adv in tqdm(
                    train_loader):

                sample = np.round(np.random.uniform(size=X_batch.shape[0]), 2)
                X_batch_adv_train_val, _, _, _ = next(iter(adv_loader))
                X_batch_adv_train_val = X_batch_adv_train_val[:X_batch.
                                                              shape[0]]
                X_batch_adv[sample >= 0.5] = X_batch_adv_train_val[
                    sample >= 0.5]
                y_batch_adv[sample >= 0.5] = 1
                y_batch_adv[sample < 0.5] = 0

                # push the data into the GPU
                X_batch = X_batch.float().to(self.device)
                y_batch = y_batch.float().to(self.device)
                X_batch_adv = X_batch_adv.float().to(self.device)
                y_batch_adv = y_batch_adv.float().to(self.device)

                # clean gradients from the previous step
                self.optimizer.zero_grad()

                # get model predictions
                pred, pred_adv = self.model(X_batch, X_batch_adv, train=True)

                # process main loss
                pred = pred.reshape(-1)
                y_batch = y_batch.reshape(-1)
                train_loss = self.loss(pred, y_batch)

                # process loss_2
                pred_adv = pred_adv.reshape(-1)
                y_batch_adv = y_batch_adv.reshape(-1)
                adv_loss = self.loss_adv(pred_adv, y_batch_adv)

                # calc loss
                avg_loss += train_loss.item() / len(train_loader)
                avg_adv_loss += adv_loss.item() / len(train_loader)

                train_loss = train_loss + self.hparams['model'][
                    'alpha'] * adv_loss

                # remove data from GPU
                y_batch = y_batch.float().cpu().detach().numpy()
                pred = pred.float().cpu().detach().numpy()
                X_batch = X_batch.float().cpu().detach().numpy()
                X_batch_adv = X_batch_adv.float().cpu().detach().numpy()
                y_batch_adv = y_batch_adv.cpu().detach().numpy()
                pred_adv = pred_adv.cpu().detach().numpy()

                # gradient clipping
                if self.apply_clipping:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                    torch.nn.utils.clip_grad_value_(self.model.parameters(),
                                                    0.5)

                # backprop
                train_loss.backward()

                # iptimizer step
                self.optimizer.step()

                y_batch = self.postprocessing.run(y_batch)
                pred = self.postprocessing.run(pred)

                # calculate a step for metrics
                self.metric.calc_running_score(labels=y_batch, outputs=pred)

            # calc train metrics
            metric_train = self.metric.compute()

            # evaluate the model
            print('Model evaluation')

            # val mode
            self.model.eval()
            self.optimizer.zero_grad()
            avg_val_loss = 0.0

            with torch.no_grad():

                for X_batch, y_batch, _, _ in tqdm(valid_loader):

                    # push the data into the GPU
                    X_batch = X_batch.float().to(self.device)
                    y_batch = y_batch.float().to(self.device)

                    # get predictions
                    pred = self.model(X_batch)

                    pred = pred.reshape(-1)
                    y_batch = y_batch.reshape(-1)
                    avg_val_loss += self.loss(
                        pred, y_batch).item() / len(valid_loader)

                    # remove data from GPU
                    X_batch = X_batch.float().cpu().detach().numpy()
                    pred = pred.float().cpu().detach().numpy()
                    y_batch = y_batch.float().cpu().detach().numpy()

                    y_batch = self.postprocessing.run(y_batch)
                    pred = self.postprocessing.run(pred)

                    # calculate a step for metrics
                    self.metric.calc_running_score(labels=y_batch,
                                                   outputs=pred)

            # calc val metrics
            metric_val = self.metric.compute()

            # early stopping for scheduler
            if self.hparams['scheduler_name'] == 'ReduceLROnPlateau':
                self.scheduler.step(metric_val)
            else:
                self.scheduler.step()

            es_result = self.early_stopping(score=metric_val,
                                            model=self.model,
                                            threshold=None)

            # print statistics
            if self.hparams['verbose_train']:
                print(
                    '| Epoch: ',
                    epoch + 1,
                    '| Train_loss: ',
                    avg_loss,
                    '| Val_loss: ',
                    avg_val_loss,
                    '| Adv_loss: ',
                    avg_adv_loss,
                    '| Metric_train: ',
                    metric_train,
                    '| Metric_val: ',
                    metric_val,
                    '| Current LR: ',
                    self.__get_lr(self.optimizer),
                )

            # add data to tensorboard
            writer.add_scalars(
                'Loss',
                {
                    'Train_loss': avg_loss,
                    'Val_loss': avg_val_loss
                },
                epoch,
            )
            writer.add_scalars('Metric', {
                'Metric_train': metric_train,
                'Metric_val': metric_val
            }, epoch)

            # early stopping procesudre
            if es_result == 2:
                print("Early Stopping")
                print(
                    f'global best val_loss model score {self.early_stopping.best_score}'
                )
                break
            elif es_result == 1:
                print(f'save global val_loss model score {metric_val}')

        writer.close()

        # load the best model trained so fat
        self.model = self.early_stopping.load_best_weights()

        return self.start_training

    def predict(self, X_test):
        """
        This function makes:
        1. batch-wise predictions
        2. calculation of the metric for each sample
        3. calculation of the metric for the entire dataset

        Parameters
        ----------
        X_test

        Returns
        -------

        """

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=0,
        )

        self.metric.reset()

        print('Getting predictions')
        with torch.no_grad():
            for i, (X_batch, y_batch, _, _) in enumerate(tqdm(test_loader)):
                X_batch = X_batch.float().to(self.device)
                y_batch = y_batch.float().to(self.device)

                pred = self.model(X_batch)

                pred = pred.reshape(-1)
                y_batch = y_batch.reshape(-1)

                pred = pred.cpu().detach().numpy()
                X_batch = X_batch.cpu().detach().numpy()
                y_batch = y_batch.cpu().detach().numpy()

                y_batch = self.postprocessing.run(y_batch)
                pred = self.postprocessing.run(pred)

                self.metric.calc_running_score(labels=y_batch, outputs=pred)

        fold_score = self.metric.compute()

        return fold_score

    def save(self, model_path):

        print('Saving the model')

        # states (weights + optimizers)
        if self.gpu != None:
            if len(self.gpu) > 1:
                torch.save(self.model.module.state_dict(), model_path + '.pt')
            else:
                torch.save(self.model.state_dict(), model_path + '.pt')
        else:
            torch.save(self.model.state_dict(), model_path)

        # hparams
        with open(f"{model_path}_hparams.yml", 'w') as file:
            yaml.dump(self.hparams, file)

        return True

    def load(self, model_name):
        self.model.load_state_dict(
            torch.load(model_name + '.pt', map_location=self.device))
        self.model.eval()
        return True

    @classmethod
    def restore(cls, model_name: str, gpu: list, inference: bool):

        if gpu is not None:
            assert all([isinstance(i, int)
                        for i in gpu]), "All gpu indexes should be integer"

        # load hparams
        hparams = yaml.load(open(model_name + "_hparams.yml"),
                            Loader=yaml.FullLoader)

        # construct class
        self = cls(hparams, gpu=gpu, inference=inference)

        # load weights + optimizer state
        self.load(model_name=model_name)

        return self

    ################## Utils #####################

    def __get_lr(self, optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def __setup_model(self, inference, gpu):

        # TODO: re-write to pure DDP
        if inference or gpu is None:
            self.device = torch.device('cpu')
            self.model = EfficientNet.from_pretrained(
                self.hparams['model']['pre_trained_model'],
                num_classes=self.hparams['model']['n_classes'])
            self.model.build_adv_model()
            self.model = self.model.to(self.device)
        else:
            if torch.cuda.device_count() > 1:
                if len(gpu) > 1:
                    print("Number of GPUs will be used: ", len(gpu))
                    self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda.
                                               is_available() else "cpu")
                    self.model = EfficientNet.from_pretrained(
                        self.hparams['model']['pre_trained_model'],
                        num_classes=self.hparams['model']['n_classes'],
                    )
                    self.model.build_adv_model()
                    self.model = self.model.to(self.device)
                    self.model = DP(self.model,
                                    device_ids=gpu,
                                    output_device=gpu[0])
                else:
                    print("Only one GPU will be used")
                    self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda.
                                               is_available() else "cpu")
                    self.model = EfficientNet.from_pretrained(
                        self.hparams['model']['pre_trained_model'],
                        num_classes=self.hparams['model']['n_classes'],
                    )
                    self.model.build_adv_model()
                    self.model = self.model.to(self.device)
            else:
                self.device = torch.device(
                    f"cuda:{gpu[0]}" if torch.cuda.is_available() else "cpu")
                self.model = EfficientNet.from_pretrained(
                    self.hparams['model']['pre_trained_model'],
                    num_classes=self.hparams['model']['n_classes'],
                )
                self.model.build_adv_model()
                self.model = self.model.to(self.device)
                print('Only one GPU is available')

        print('Cuda available: ', torch.cuda.is_available())

        return True

    def __setup_model_hparams(self):

        # 1. define losses
        self.loss = nn.MSELoss()
        self.loss_adv = nn.BCELoss()

        # 2. define model metric
        self.metric = Kappa()

        # 3. define optimizer
        self.optimizer = eval(f"torch.optim.{self.hparams['optimizer_name']}")(
            params=self.model.parameters(),
            **self.hparams['optimizer_hparams'])

        # 4. define scheduler
        self.scheduler = eval(
            f"torch.optim.lr_scheduler.{self.hparams['scheduler_name']}")(
                optimizer=self.optimizer, **self.hparams['scheduler_hparams'])

        # 5. define early stopping
        self.early_stopping = EarlyStopping(
            checkpoint_path=self.hparams['checkpoint_path'] +
            f'/checkpoint_{self.start_training}' + '.pt',
            patience=self.hparams['patience'],
            delta=self.hparams['min_delta'],
            is_maximize=True,
        )

        # 6. set gradient clipping
        self.apply_clipping = self.hparams['clipping']  # clipping of gradients

        # 7. Set scaler for optimizer
        self.scaler = torch.cuda.amp.GradScaler()

        return True

    def __seed_everything(self, seed):
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
Exemplo n.º 19
0
def run_experiment(config: ExperimentConfig, save_model=True):  # pylint: disable=too-many-statements, too-many-branches, too-many-locals
    # Check Pytorch Version Before Running
    logger.info('Torch Version: %s', torch.__version__)  # type: ignore
    logger.info('Cuda Version: %s', torch.version.cuda)  # type: ignore

    if config.random_seed is not None:
        setup_random_seed(config.random_seed)

    # Initialize Writer
    writer_dir = f"{config.tensorboard_log_root}/{config.cur_time}/"
    writer = SummaryWriter(log_dir=writer_dir)

    # Initialize Device
    if isinstance(config.gpu_device_id, list):
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            map(str, config.gpu_device_id))
    elif config.gpu_device_id is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_device_id)

    replica = torch.cuda.device_count()
    logger.info('Device Counts: %s', replica)
    wandb.log({"Model batch size": config.batch_size * replica}, 0)
    # device = torch.device(f"cuda:{config.gpu_device_id}")

    # Initialize Dataset and Split into train/valid/test DataSets
    dataset_dict = {
        "output_type": config.output_type,
        "frames_per_clip": config.frames_per_clip,
        "step_between_clips": config.step_between_clips,
        "frame_rate": config.frame_rate,
        "num_sample_per_clip": config.num_sample_per_clip,
    }

    if config.dataset_artifact:
        dataset = MouseClipDataset.from_wandb_artifact(
            config.dataset_artifact,
            split_by=config.split_by,
            mix_clip=config.mix_clip,
            no_valid=config.no_valid,
            extract_groom=config.extract_groom,
            exclude_5min=config.exclude_5min,
            exclude_2_mouse=config.exclude_2_mouse,
            exclude_fpvid=config.exclude_fpvid,
            exclude_2_mouse_valid=config.exclude_2_mouse_valid,
            **dataset_dict)
    elif config.dataset_root in ['./data/breakfast', './data/mpii']:
        dataset = MouseClipDataset.from_annotation_list(
            dataset_root=config.dataset_root, **dataset_dict)
    else:
        metadata_path = (config.metadata_path
                         if config.metadata_path is not None else os.path.join(
                             config.dataset_root, "metadata.pth"))
        dataset = MouseClipDataset.from_ds_folder(
            dataset_root=config.dataset_root,
            metadata_path=metadata_path,
            extract_groom=config.extract_groom,
            **dataset_dict)

    train_set = dataset.get_split("train", config.split_by,
                                  config.transform_size, {})
    valid_set = dataset.get_split("valid", config.split_by,
                                  config.transform_size, config.valid_set_args)
    test_set = dataset.get_split("test", config.split_by,
                                 config.transform_size, config.test_set_args)

    logger.info('Train Transform:\n%s', train_set.transform)
    logger.info('Valid Transform:\n%s', valid_set.transform)
    logger.info('Test Transform:\n%s', test_set.transform)

    dataloaders = {
        "train":
        DataLoader(train_set,
                   config.batch_size * replica,
                   sampler=train_set.get_sampler("train",
                                                 config.train_sampler_config,
                                                 config.samples_per_epoch),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=True),
        "valid":
        DataLoader(valid_set,
                   config.batch_size * replica,
                   sampler=valid_set.get_sampler("valid",
                                                 config.valid_sampler_config),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=False),
        "test":
        DataLoader(test_set,
                   config.batch_size * replica,
                   sampler=test_set.get_sampler("test",
                                                config.test_sampler_config),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=False),
    }

    # initialize model
    if config.model is not None:
        model = config.model(**config.model_args)
        if isinstance(model, ResNet):
            model.fc = torch.nn.Linear(model.fc.in_features,
                                       len(set(dataset.labels)))

        if config.xavier_init:
            model = init_xavier_weights(model)
        # Make wandb Track the model
        wandb.watch(model, "parameters")

        logger.info('Model: %s', model.__class__.__name__)
        # Log total parameters in the model
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        logger.info('Model params: %s', pytorch_total_params)
        pytorch_total_params_trainable = sum(p.numel()
                                             for p in model.parameters()
                                             if p.requires_grad)
        logger.info('Model params trainable: %s',
                    pytorch_total_params_trainable)

        model_structure_str = "Model Structue:\n"
        for name, param in model.named_parameters():
            model_structure_str += f"\t{name}: {param.requires_grad}, {param.numel()}\n"
        # logger.info(model_structure_str)

        model = model.cuda()
        if replica > 1:
            model = DataParallel(model)
    else:
        logger.critical("Model not chosen in config!")
        return None

    if isinstance(config.loss_function, torch.nn.Module):
        config.loss_function = config.loss_function.cuda()

    optimizer = config.optimizer(params=model.parameters(),
                                 **config.optimizer_args)
    logger.info("Optimizer: %s\n%s", config.optimizer.__name__,
                config.optimizer_args)

    if config.lr_scheduler is not None:
        lr_scheduler = config.lr_scheduler(optimizer,
                                           **config.lr_scheduler_args)
        logger.info("LR Scheduler: %s\n%s", config.lr_scheduler.__name__,
                    config.lr_scheduler_args)
    else:
        lr_scheduler = None
        logger.info("No LR Scheduler")

    logger.info("Training Started!")
    ckpter = Checkpointer(config.best_metric, save_path=config.save_path)

    training_history, total_steps = train_model(
        model=model,
        optimizer=optimizer,
        dataloaders=dataloaders,
        writer=writer,
        num_epochs=config.num_epochs,
        loss_function=config.loss_function,
        lr_scheduler=lr_scheduler,
        valid_every_epoch=config.valid_every_epoch,
        ckpter=ckpter,
    )
    logger.info("Training Complete!")

    if ckpter is not None:
        ckpter.load_best_model(model)

    logger.info("Testing Started!")
    test_report = evaluate_model(model, dataloaders['test'], "Testing",
                                 total_steps, writer, config.loss_function)
    logger.info("Testing Complete!")

    if save_model:
        train_artifact = wandb.Artifact(f'run_{wandb.run.id}_model', 'model')
        model_tmp_path = os.path.join(wandb.run.dir,
                                      f'best_valid_{ckpter.name}_model.pth')
        torch.save(model.module if isinstance(model, DataParallel) else model,
                   model_tmp_path)
        train_artifact.add_file(model_tmp_path)
        with train_artifact.new_file('split_data.json') as f:
            json.dump(dataset.split_data, f)
        wandb.run.log_artifact(train_artifact)

    return training_history, test_report
Exemplo n.º 20
0
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)


model = DataParallel(Net())
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss().cuda()

model.train()
for batch_idx, (data, target) in enumerate(train_loader):
    input_var = Variable(data.cuda())
    target_var = Variable(target.cuda())

    print('Getting model output')
    output = model(input_var)
    print('Got model output')

    loss = criterion(output, target_var)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Exemplo n.º 21
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.opt['train']['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.l_R_w > 0:  # rank-content loss
                l_g_rank = self.netR(self.fake_H)
                l_g_rank = torch.sigmoid(l_g_rank - self.R_bias)
                l_g_rank = torch.sum(l_g_rank)
                l_g_rank = self.l_R_w * l_g_rank
                l_g_total += l_g_rank

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  # detach to avoid BP to G
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_g_rank'] = l_g_rank.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.l_R_w:  # R, Ranker Network
                s, n = self.get_network_description(self.netR)
                if isinstance(self.netR, nn.DataParallel) or isinstance(
                        self.netR, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netR.__class__.__name__,
                        self.netR.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netR.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network Ranker structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_R = self.opt['path']['pretrain_model_R']
        if load_path_R is not None:
            logger.info('Loading model for R [{:s}] ...'.format(load_path_R))
            self.load_network(load_path_R, self.netR,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
Exemplo n.º 22
0
class MWGANModel(BaseModel):
    def __init__(self, opt):
        super(MWGANModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        self.train_opt = opt['train']

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        # pretrained_dict = torch.load(opt['path']['pretrain_model_others'])
        # netG_dict = self.netG.state_dict()
        # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in netG_dict}
        # netG_dict.update(pretrained_dict)
        # self.netG.load_state_dict(netG_dict)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            if not self.train_opt['only_G']:
                self.netD = networks.define_D(opt).to(self.device)
                # init_weights(self.netD)
                if opt['dist']:
                    self.netD = DistributedDataParallel(
                        self.netD, device_ids=[torch.cuda.current_device()])
                else:
                    self.netD = DataParallel(self.netD)

                self.netG.train()
                self.netD.train()
            else:
                self.netG.train()
        else:
            self.netG.train()

        # define losses, optimizer and scheduler
        if self.is_train:

            # G pixel loss
            if self.train_opt['pixel_weight'] > 0:
                l_pix_type = self.train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'cb':
                    self.cri_pix = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = self.train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if self.train_opt['lpips_weight'] > 0:
                l_lpips_type = self.train_opt['lpips_criterion']
                if l_lpips_type == 'lpips':
                    self.cri_lpips = lpips.LPIPS(net='vgg').to(self.device)
                    if opt['dist']:
                        self.cri_lpips = DistributedDataParallel(
                            self.cri_lpips,
                            device_ids=[torch.cuda.current_device()])
                    else:
                        self.cri_lpips = DataParallel(self.cri_lpips)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(
                            l_lpips_type))
                self.l_lpips_w = self.train_opt['lpips_weight']
            else:
                logger.info('Remove lpips loss.')
                self.cri_lpips = None

            # G feature loss
            if self.train_opt['feature_weight'] > 0:
                self.fea_trans = GramMatrix().to(self.device)
                l_fea_type = self.train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                elif l_fea_type == 'cb':
                    self.cri_fea = CharbonnierLoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = self.train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(self.train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = self.train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = self.train_opt[
                'D_update_ratio'] if self.train_opt['D_update_ratio'] else 1
            self.D_init_iters = self.train_opt[
                'D_init_iters'] if self.train_opt['D_init_iters'] else 0

            # optimizers
            # G
            wd_G = self.train_opt['weight_decay_G'] if self.train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=self.train_opt['lr_G'],
                weight_decay=wd_G,
                betas=(self.train_opt['beta1_G'], self.train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)

            if not self.train_opt['only_G']:
                # D
                wd_D = self.train_opt['weight_decay_D'] if self.train_opt[
                    'weight_decay_D'] else 0
                self.optimizer_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=self.train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(self.train_opt['beta1_D'],
                           self.train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if self.train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            self.train_opt['lr_steps'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights'],
                            gamma=self.train_opt['lr_gamma'],
                            clear_state=self.train_opt['clear_state']))
            elif self.train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            self.train_opt['T_period'],
                            eta_min=self.train_opt['eta_min'],
                            restarts=self.train_opt['restarts'],
                            weights=self.train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        if self.is_train:
            if not self.train_opt['only_G']:
                self.print_network()  # print network
        else:
            self.print_network()  # print network

        try:
            self.load()  # load G and D if needed
            print('Pretrained model loaded')
        except Exception as e:
            print('No pretrained model found')

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            # print(self.var_H.size())
            self.var_H = self.var_H.squeeze(1)
            # self.var_H = self.DWT(self.var_H)

            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)
            # print(self.var_ref.size())
            self.var_ref = self.var_ref.squeeze(1)
            # print(s)
            # self.var_ref = self.DWT(self.var_ref)

    def process_list(self, input1, input2):
        result = []
        for index in range(len(input1)):
            result.append(input1[index] - torch.mean(input2[index]))
        return result

    def optimize_parameters(self, step):
        # G
        if not self.train_opt['only_G']:
            for p in self.netD.parameters():
                p.requires_grad = False

        self.optimizer_G.zero_grad()

        self.fake_H = self.netG(self.var_L)

        # self.var_H = self.var_H.squeeze(1)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix

            if self.cri_lpips:  # pixel loss
                l_g_lpips = torch.mean(
                    self.l_lpips_w *
                    self.cri_lpips.forward(self.fake_H, self.var_H))
                l_g_total += l_g_lpips

            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                real_fea_trans = self.fea_trans(real_fea)
                fake_fea_trans = self.fea_trans(fake_fea)
                l_g_fea_trans = self.l_fea_w * self.cri_fea(
                    fake_fea_trans, real_fea_trans) * 10
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea
                l_g_total += l_g_fea_trans

            if not self.train_opt['only_G']:
                pred_g_fake = self.netD(self.fake_H)

                if self.opt['train']['gan_type'] == 'gan':
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                elif self.opt['train']['gan_type'] == 'ragan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan_ra':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    pred_d_real = self.netD(self.var_ref)
                    pred_d_real = [ele.detach() for ele in pred_d_real]
                    # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                    l_g_gan = self.l_gan_w * (self.cri_gan(
                        self.process_list(pred_d_real, pred_g_fake), False
                    ) + self.cri_gan(
                        self.process_list(pred_g_fake, pred_d_real), True)) / 2
                elif self.opt['train']['gan_type'] == 'lsgan':
                    # self.var_ref = self.var_ref[:,1:,:,:]
                    l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
                l_g_total += l_g_gan

            l_g_total.backward()
            self.optimizer_G.step()
        else:
            self.var_ref = self.var_ref

        if not self.train_opt['only_G']:
            # D
            for p in self.netD.parameters():
                p.requires_grad = True

            self.optimizer_D.zero_grad()
            l_d_total = 0
            pred_d_real = self.netD(self.var_ref)
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G

            if self.opt['train']['gan_type'] == 'gan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += l_d_real + l_d_fake
            elif self.opt['train']['gan_type'] == 'ragan':
                l_d_real = self.cri_gan(
                    self.process_list(pred_d_real, pred_d_fake), True)
                l_d_fake = self.cri_gan(
                    self.process_list(pred_d_fake, pred_d_real), False)
                l_d_total += (l_d_real + l_d_fake) / 2
            elif self.opt['train']['gan_type'] == 'lsgan':
                l_d_real = self.cri_gan(pred_d_real, True)
                l_d_fake = self.cri_gan(pred_d_fake, False)
                l_d_total += (l_d_real + l_d_fake) / 2

            l_d_total.backward()
            self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item() / self.l_pix_w
            if self.cri_lpips:
                self.log_dict['l_g_lpips'] = l_g_lpips.item() / self.l_lpips_w
            if not self.train_opt['only_G']:
                self.log_dict['l_g_gan'] = l_g_gan.item() / self.l_gan_w
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item() / self.l_fea_w
                self.log_dict['l_g_fea_trans'] = l_g_fea_trans.item(
                ) / self.l_fea_w / 10

        if not self.train_opt['only_G']:
            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real[0].detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake[0].detach())

    def test(self, load_path=None, input_u=None, input_v=None):

        if load_path is not None:
            self.load_network(load_path, self.netG,
                              self.opt['path']['strict_load'])
            print(
                '***************************************************************'
            )
            print('Load model successfully')
            print(
                '***************************************************************'
            )

        self.netG.eval()
        # self.var_H = self.var_H.squeeze(1)
        # img_to_write = self.var_L.detach()[0].float().cpu()
        # print(img_to_write.size())
        # cv2.imwrite('./test.png',img_to_write.numpy().transpose(1,2,0)*255)
        with torch.no_grad():
            if self.var_L.size()[-1] > 1280:
                width = self.var_L.size()[-1]
                height = self.var_L.size()[-2]
                fake_list = []
                for height_start in [0, int(height / 2)]:
                    for width_start in [0, int(width / 2)]:
                        self.fake_slice = self.netG(
                            self.var_L[:, :, :, height_start:(height_start +
                                                              int(height / 2)),
                                       width_start:(width_start +
                                                    int(width / 2))])
                        fake_list.append(self.fake_slice)
                enhanced_frame_h1 = torch.cat([fake_list[0], fake_list[2]], 2)
                enhanced_frame_h2 = torch.cat([fake_list[1], fake_list[3]], 2)
                self.fake_H = torch.cat([enhanced_frame_h1, enhanced_frame_h2],
                                        3)
            else:
                self.fake_H = self.netG(self.var_L)
            if input_u is not None and input_v is not None:
                self.var_L_u = input_u.to(self.device)
                self.var_L_v = input_v.to(self.device)
                self.fake_H_u_s = self.netG(self.var_L_u.float())
                self.fake_H_v_s = self.netG(self.var_L_v.float())
                # self.fake_H_u = torch.cat((self.fake_H_u_s[0], self.fake_H_u_s[1]), 1)
                # self.fake_H_v = torch.cat((self.fake_H_v_s[0], self.fake_H_v_s[1]), 1)
                self.fake_H_u = self.fake_H_u_s
                self.fake_H_v = self.fake_H_v_s
                # self.fake_H_u = self.IWT(self.fake_H_u)
                # self.fake_H_v = self.IWT(self.fake_H_v)
            else:
                self.fake_H_u = None
                self.fake_H_v = None
            self.fake_H_all = self.fake_H
            if self.opt['network_G']['out_nc'] == 4:
                self.fake_H_all = self.IWT(self.fake_H_all)
                if input_u is not None and input_v is not None:
                    self.fake_H_u = self.IWT(self.fake_H_u)
                    self.fake_H_v = self.IWT(self.fake_H_v)
        # self.fake_H = self.var_L[:,2,:,:,:]
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0][2].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if self.fake_H_u is not None:
            out_dict['SR_U'] = self.fake_H_u.detach()[0].float().cpu()
            out_dict['SR_V'] = self.fake_H_v.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
            print('G loaded')
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
            print('D loaded')

    def save(self, iter_step):
        if not self.train_opt['only_G']:
            self.save_network(self.netG, 'G', iter_step)
            self.save_network(self.netD, 'D', iter_step)
        else:
            self.save_network(self.netG,
                              self.opt['network_G']['which_model_G'],
                              iter_step, self.opt['path']['pretrain_model_G'])
Exemplo n.º 23
0
class ESRGAN_EESN_Model(BaseModel):
    def __init__(self, config, device):
        super(ESRGAN_EESN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG, device_ids=[1, 0])

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD, device_ids=[1, 0])

        self.netG.train()
        self.netD.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF, device_ids=[1, 0])
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    '''
    The main repo did not use collate_fn and image read has different flags
    and also used np.ascontiguousarray()
    Might change my code if problem happens
    '''

    def feed_data(self, data):
        self.var_L = data['image_lq'].to(self.device)
        self.var_H = data['image'].to(self.device)
        input_ref = data['ref'] if 'ref' in data else data['image']
        self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        #Generator
        for p in self.netD.parameters():
            p.requires_grad = False
        self.optimizer_G.zero_grad()
        self.fake_H, self.final_SR, _, _ = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  #pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach(
                )  #don't want to backpropagate this, need proper explanation
                fake_fea = self.netF(
                    self.fake_H)  #In netF normalize=False, check it
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            pred_g_fake = self.netD(self.fake_H)
            if self.configT['gan_type'] == 'gan':
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.configT['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
            #EESN calculate loss
            if self.cri_charbonnier:  # charbonnier pixel loss HR and SR
                l_e_charbonnier = 5 * self.cri_charbonnier(
                    self.final_SR,
                    self.var_H)  #change the weight to empirically
            l_g_total += l_e_charbonnier

            l_g_total.backward()
            self.optimizer_G.step()

        #descriminator
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.var_ref)
        pred_d_fake = self.netD(
            self.fake_H.detach())  #to avoid BP to Generator
        if self.configT['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.configT['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real +
                         l_d_fake) / 2  # thinking of adding final sr d loss

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()
            self.log_dict['l_e_charbonnier'] = l_e_charbonnier.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, self.final_SR, self.x_learned_lap_fake, self.x_lap = self.netG(
                self.var_L)
            _, _, _, self.x_lap_HR = self.netG(self.var_H)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        #out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['lap_learned'] = self.x_learned_lap_fake.detach()[0].float(
        ).cpu()
        out_dict['lap'] = self.x_lap.detach()[0].float().cpu()
        out_dict['lap_HR'] = self.x_lap_HR.detach()[0].float().cpu()
        out_dict['final_SR'] = self.final_SR.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        # Discriminator
        s, n = self.get_network_description(self.netD)
        if isinstance(self.netD, nn.DataParallel) or isinstance(
                self.netD, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netD.__class__.__name__,
                self.netD.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netD.__class__.__name__)

        logger.info('Network D structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)

        if self.cri_fea:  # F, Perceptual Network
            s, n = self.get_network_description(self.netF)
            if isinstance(self.netF, nn.DataParallel) or isinstance(
                    self.netF, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netF.__class__.__name__,
                    self.netF.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netF.__class__.__name__)

            logger.info(
                'Network F structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.config['path']['pretrain_model_G']
        if load_path_G:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.config['path']['strict_load'])
        load_path_D = self.config['path']['pretrain_model_D']
        if load_path_D:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.config['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
Exemplo n.º 24
0
def main(args):
    print("===> Loading datasets")
    data_set = DatasetLoader(args.data_lr,
                             args.data_hr,
                             size_w=args.size_w,
                             size_h=args.size_h,
                             scale=args.scale,
                             n_frames=args.n_frames,
                             interval_list=args.interval_list,
                             border_mode=args.border_mode,
                             random_reverse=args.random_reverse)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=False,
                              drop_last=True)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    criterion = CharbonnierLoss()
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    # print(model)

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            # 如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

        # 如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch if args.start_epoch != 0 else start_epoch

    #如果use_current_lr大于0 测代替作为lr
    args.lr = args.use_current_lr if args.use_current_lr > 0 else args.lr
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 betas=(args.beta1, args.beta2),
                                 eps=1e-8)

    #### training
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)
        if args.use_tqdm == 1:
            losses, psnrs = one_epoch_train_tqdm(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])
        else:
            losses, psnrs = one_epoch_train_logger(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])

        # save model
        # if epoch %9 != 0:
        #     continue

        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edvr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
class IRNpModel(BaseModel):
    def __init__(self, opt):
        super(IRNpModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                losstype=self.train_opt['pixel_criterion_back'])

            # feature loss
            if train_opt['feature_weight'] > 0:
                self.Reconstructionf = ReconstructionLoss(
                    losstype=self.train_opt['feature_criterion'])

                self.l_fea_w = train_opt['feature_weight']
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)
            else:
                self.l_fea_w = 0

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

    def feed_data(self, data):
        self.ref_L = data['LQ'].to(self.device)  # LQ
        self.real_H = data['GT'].to(self.device)  # GT

    def gaussian_batch(self, dims):
        return torch.randn(tuple(dims)).to(self.device)

    def loss_forward(self, out, y):
        l_forw_fit = self.train_opt[
            'lambda_fit_forw'] * self.Reconstruction_forw(out[:, :3, :, :], y)

        return l_forw_fit

    def loss_backward(self, x, x_samples):
        x_samples_image = x_samples[:, :3, :, :]
        l_back_rec = self.train_opt[
            'lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)

        # feature loss
        if self.l_fea_w > 0:
            l_back_fea = self.feature_loss(x, x_samples_image)
        else:
            l_back_fea = torch.tensor(0)

        # GAN loss
        pred_g_fake = self.netD(x_samples_image)
        if self.opt['train']['gan_type'] == 'gan':
            l_back_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_real = self.netD(x).detach()
            l_back_gan = self.l_gan_w * (
                self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
                self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2

        return l_back_rec, l_back_fea, l_back_gan

    def feature_loss(self, real, fake):
        real_fea = self.netF(real).detach()
        fake_fea = self.netF(fake)
        l_g_fea = self.l_fea_w * self.Reconstructionf(real_fea, fake_fea)

        return l_g_fea

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()

        print('input shape: ', self.input.shape)
        self.input = self.real_H
        self.output = self.netG(x=self.input)
        print('output shape: ', self.output.shape)

        loss = 0
        zshape = self.output[:, 3:, :, :].shape
        print('z shape: ', zshape)

        LR = self.Quantization(self.output[:, :3, :, :])

        gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt[
            'gaussian_scale'] != None else 1
        y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)
        print('y_ shape: ', y_.shape)

        self.fake_H = self.netG(x=y_, rev=True)
        print('fake_H shape: ', self.fake_H.shape)

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            l_forw_fit = self.loss_forward(self.output, self.ref_L)
            l_back_rec, l_back_fea, l_back_gan = self.loss_backward(
                self.real_H, self.fake_H)

            loss += l_forw_fit + l_back_rec + l_back_fea + l_back_gan

            loss.backward()

            # gradient clipping
            if self.train_opt['gradient_clipping']:
                nn.utils.clip_grad_norm_(self.netG.parameters(),
                                         self.train_opt['gradient_clipping'])

            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        l_d_total = 0
        pred_d_real = self.netD(self.real_H)
        pred_d_fake = self.netD(self.fake_H.detach())
        if self.opt['train']['gan_type'] == 'gan':
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_total = l_d_real + l_d_fake
        elif self.opt['train']['gan_type'] == 'ragan':
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True)
            l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real),
                                    False)
            l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_forw_fit'] = l_forw_fit.item()
            self.log_dict['l_back_rec'] = l_back_rec.item()
            self.log_dict['l_back_fea'] = l_back_fea.item()
            self.log_dict['l_back_gan'] = l_back_gan.item()
        self.log_dict['l_d'] = l_d_total.item()

    def test(self):
        Lshape = self.ref_L.shape

        input_dim = Lshape[1]
        self.input = self.real_H

        print('test mode==>input shape: ', self.input.shape)
        zshape = [
            Lshape[0], input_dim * (self.opt['scale']**2) - Lshape[1],
            Lshape[2], Lshape[3]
        ]
        print('test mode==>zshape: ', zshape)

        gaussian_scale = 1
        if self.test_opt and self.test_opt['gaussian_scale'] != None:
            gaussian_scale = self.test_opt['gaussian_scale']

        self.netG.eval()
        with torch.no_grad():
            self.forw_L = self.netG(x=self.input)[:, :3, :, :]
            self.forw_L = self.Quantization(self.forw_L)
            print('test mode==>forw_L shape: ', self.forw_L.shape)
            y_forw = torch.cat(
                (self.forw_L, gaussian_scale * self.gaussian_batch(zshape)),
                dim=1)
            print('test mode==>y_forw shape: ', y_forw.shape)
            self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]
            print('test mode==>fake_H shape: ', y_forw.shape)

        self.netG.train()

    def downscale(self, HR_img):
        self.netG.eval()
        with torch.no_grad():
            LR_img = self.netG(x=HR_img)[:, :3, :, :]
            LR_img = self.Quantization(self.forw_L)
        self.netG.train()

        return LR_img

    def upscale(self, LR_img, scale, gaussian_scale=1):
        Lshape = LR_img.shape
        zshape = [Lshape[0], Lshape[1] * (scale**2 - 1), Lshape[2], Lshape[3]]
        y_ = torch.cat((LR_img, gaussian_scale * self.gaussian_batch(zshape)),
                       dim=1)

        self.netG.eval()
        with torch.no_grad():
            HR_img = self.netG(x=y_, rev=True)[:, :3, :, :]
        self.netG.train()

        return HR_img

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['LR_ref'] = self.ref_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
        out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)
        self.save_network(self.netD, 'D', iter_label)
Exemplo n.º 26
0
class CLSGAN_Model(BaseModel):
    def __init__(self, opt):
        super(CLSGAN_Model, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        G_opt = opt['network_G']

        # define networks and load pretrained models
        self.netG = RCAN(G_opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device)
            self.netD = DataParallel(self.netD)
            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = VGGFeatureExtractor(feature_layer=34,
                                                use_bn=False,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                self.netF = DataParallel(self.netF)

            # G feature loss
            if train_opt['cls_weight'] > 0:
                l_cls_type = train_opt['cls_criterion']
                if l_cls_type == 'CE':
                    self.cri_cls = nn.NLLLoss().to(self.device)
                elif l_cls_type == 'l1':
                    self.cri_cls = nn.L1Loss().to(self.device)
                elif l_cls_type == 'l2':
                    self.cri_cls = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_cls_type))
                self.l_cls_w = train_opt['cls_weight']
            else:
                logger.info('Remove classification loss.')
                self.cri_cls = None
            if self.cri_cls:  # load VGG perceptual loss
                self.netC = VGGFeatureExtractor(feature_layer=49,
                                                use_bn=True,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netC.load_model(load_path_C)
                self.netC = DataParallel(self.netC)

            if train_opt['brc_weight'] > 0:
                self.l_brc_w = train_opt['brc_weight']
                self.netR = VGG_Classifier().to(self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netR.load_model(load_path_C)
                self.netR = DataParallel(self.netR)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H, self.cls_L = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            if self.cri_cls:  # F-G classification loss
                #print(self.netC(self.var_H).detach().shape)
                #real_cls = self.netC(self.var_H).argmax(1).detach()
                #fake_cls = torch.log( nn.Softmax(dim=1) (self.netC(self.fake_H)) )
                real_cls = self.netC(self.var_H).detach()
                fake_cls = self.netC(self.fake_H)
                l_g_cls = self.l_cls_w * self.cri_cls(fake_cls, real_cls)
                l_g_total = l_g_cls
            if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            if self.opt['train']['br_optimizer'] == 'joint':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

            l_g_total.backward()
            self.optimizer_G.step()

            self.optimizer_G.zero_grad()

            # seperate branching update
            if self.opt['train']['br_optimizer'] == 'branch':
                ref = self.netR(self.var_H).argmax(dim=1)
                l_branch = self.l_brc_w * nn.CrossEntropyLoss()(self.cls_L,
                                                                ref)
                self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            # pred_d_real = self.netD(self.var_ref)
            # pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
            # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
            # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
            # l_d_total = (l_d_real + l_d_fake) / 2
            # l_d_total.backward()
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H, _ = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

            if self.cri_cls:  # C, F-G Classification Network
                s, n = self.get_network_description(self.netC)
                if isinstance(self.netC, nn.DataParallel) or isinstance(
                        self.netC, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netC.__class__.__name__,
                        self.netC.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netC.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network C structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)

    def clear_data(self):
        return None
def main():

    global args, best_prec1
    args = parser.parse_args()

    # Read list of training and validation data
    listfiles_train, labels_train = read_lists(TRAIN_OUT)
    listfiles_val, labels_val = read_lists(VAL_OUT)
    listfiles_test, labels_test = read_lists(TEST_OUT)
    dataset_train = Dataset(listfiles_train,
                            labels_train,
                            subtract_mean=False,
                            V=12)
    dataset_val = Dataset(listfiles_val, labels_val, subtract_mean=False, V=12)
    dataset_test = Dataset(listfiles_test,
                           labels_test,
                           subtract_mean=False,
                           V=12)

    # shuffle data
    dataset_train.shuffle()
    dataset_val.shuffle()
    dataset_test.shuffle()
    tra_data_size, val_data_size, test_data_size = dataset_train.size(
    ), dataset_val.size(), dataset_test.size()
    print 'training size:', tra_data_size
    print 'validation size:', val_data_size
    print 'testing size:', test_data_size

    batch_size = args.b
    print("batch_size is :" + str(batch_size))
    learning_rate = args.lr
    print("learning_rate is :" + str(learning_rate))
    num_cuda = cuda.device_count()
    print("number of GPUs have been detected:" + str(num_cuda))

    # creat model
    print("model building...")
    mvcnn = DataParallel(modelnet40_Alex(num_cuda, batch_size))
    #mvcnn = modelnet40(num_cuda, batch_size, multi_gpu = False)
    mvcnn.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint'{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            mvcnn.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #print(mvcnn)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adadelta(mvcnn.parameters(), weight_decay=1e-4)
    # evaluate performance only
    if args.evaluate:
        print 'testing mode ------------------'
        validate(dataset_test, mvcnn, criterion, optimizer, batch_size)
        return

    print 'training mode ------------------'
    for epoch in xrange(args.start_epoch, args.epochs):
        print('epoch:', epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(dataset_train, mvcnn, criterion, optimizer, epoch, batch_size)

        # evaluate on validation set
        prec1 = validate(dataset_val, mvcnn, criterion, optimizer, batch_size)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
        elif epoch % 5 is 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
Exemplo n.º 28
0
class SRGANModel(BaseModel):
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)

        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

        self.netD = networks.define_D(opt).to(self.device)
        self.netD = DataParallel(self.netD)
        if self.is_train:
            self.netG.train()
            self.netD.train()

        if not self.is_train and 'attack' in self.opt:
            # G pixel loss
            if opt['pixel_weight'] > 0:
                l_pix_type = opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if opt['feature_weight'] > 0:
                l_fea_type = opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = opt['gan_weight']

        self.delta = 0

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def attack_fgsm(self, is_collect_data=False):
        # collect_data='collect_data' in self.opt['attack'] and self.opt['attack']['collect_data']

        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = False
        self.var_L.requires_grad_()

        self.fake_H = self.netG(self.var_L)

        # l_g_total, l_g_pix, l_g_fea, l_g_gan=self.loss_for_G(self.fake_H,self.var_H,self.var_ref)
        l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)

        # zero_grad
        if self.var_L.grad is not None:
            self.var_L.grad.zero_()
        # self.netG.zero_grad()

        # l_g_total.backward()
        l_g_pix.backward()

        data_grad = self.var_L.grad.data

        sign_data_grad = data_grad.sign()
        perturbed_data = self.var_L + self.opt['attack']['eps'] * sign_data_grad
        perturbed_data = torch.clamp(perturbed_data, 0, 1)

        if is_collect_data:
            init_data = self.var_L.detach()
            self.var_L = perturbed_data.detach()
            perturbed_data = self.var_L.clone().detach()
            return init_data, perturbed_data
        else:
            self.var_L = perturbed_data.detach()
            return

    # TODO test
    def attack_pgd(self, is_collect_data=False):
        eps = self.opt['attack']['eps']

        for p in self.netG.parameters():
            p.requires_grad = False
        orig_input = self.var_L.clone().detach()

        randn = torch.FloatTensor(self.var_L.size()).uniform_(-eps, eps).cuda()
        self.var_L += randn
        self.var_L.clamp_(0, 1.0)

        # self.var_L.requires_grad_()
        # if self.var_L.grad is not None:
        #     self.var_L.grad.zero_()
        self.var_L.detach_()

        for _ in range(self.opt['attack']['step_num']):
            # if self.var_L.grad is not None:
            #     self.var_L.grad.zero_()
            var_L_step = torch.autograd.Variable(self.var_L,
                                                 requires_grad=True)
            self.fake_H = self.netG(var_L_step)
            l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
            l_pix.backward()
            data_grad = var_L_step.grad.data

            pert = self.opt['attack']['step'] * data_grad.sign()
            self.var_L = self.var_L + pert.data
            self.var_L = torch.max(orig_input - eps, self.var_L)
            self.var_L = torch.min(orig_input + eps, self.var_L)
            self.var_L.clamp_(0, 1.0)

        if is_collect_data:
            return orig_input, self.var_L.clone().detach()
        else:
            self.var_L.detach_()
            return

    def feed_data(self, data, need_GT=True, is_collect_data=False):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

        # TODO attack code start
        if 'attack' in self.opt and need_GT and not (
                'raw_data' in self.opt['attack']
                and self.opt['attack']['raw_data'] == True):
            if 'type' in self.opt['attack'] and self.opt['attack'][
                    'type'] == 'pgd':
                if not is_collect_data:
                    self.attack_pgd()
                else:
                    return self.attack_pgd(is_collect_data=True)
            else:
                if not is_collect_data:
                    self.attack_fgsm()
                else:
                    return self.attack_fgsm(is_collect_data=True)
        # attack code end

    def loss_for_G(self, fake_H, var_H, var_ref):
        l_g_total = 0
        if self.cri_pix:  # pixel loss
            l_g_pix = self.l_pix_w * self.cri_pix(fake_H, var_H)
            l_g_total += l_g_pix
        if self.cri_fea:  # feature loss
            real_fea = self.netF(var_H).detach()
            fake_fea = self.netF(fake_H)
            l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
            l_g_total += l_g_fea
        if self.l_gan_w > 0.0:
            if ('train' in self.opt and self.opt['train']['gan_type']
                    == 'gan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'gan'):
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif ('train' in self.opt and self.opt['train']['gan_type']
                  == 'ragan') or ('attack' in self.opt
                                  and self.opt['gan_type'] == 'ragan'):
                pred_d_real = self.netD(var_ref).detach()
                pred_g_fake = self.netD(fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan
        else:
            l_g_gan = torch.tensor(0.0)
        return l_g_total, l_g_pix, l_g_fea, l_g_gan

    def optimize_parameters(self, step):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = True
        if 'adv_train' in self.opt:
            self.var_L.requires_grad_()
            if self.var_L.grad is not None:
                self.var_L.grad.data.zero_()

        if 'adv_train' not in self.opt:
            self.fake_H = self.netG(self.var_L)
        else:
            self.fake_H = self.netG(torch.clamp(self.var_L + self.delta, 0, 1))

        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if 'adv_train' not in self.opt:
                l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                    self.fake_H, self.var_H, self.var_ref)

                self.optimizer_G.zero_grad()

                l_g_total.backward()
                self.optimizer_G.step()
            else:
                for _ in range(self.opt['adv_train']['m']):
                    l_g_total, l_g_pix, l_g_fea, l_g_gan = self.loss_for_G(
                        self.fake_H, self.var_H, self.var_ref)

                    self.optimizer_G.zero_grad()
                    if self.var_L.grad is not None:
                        self.var_L.grad.data.zero_()

                    l_g_total.backward()
                    self.optimizer_G.step()

                    self.delta = self.delta + \
                        self.opt['adv_train']['step'] * \
                        self.var_L.grad.data.sign()
                    self.delta.clamp_(-self.opt['attack']['eps'],
                                      self.opt['attack']['eps'])
                    self.fake_H = self.netG(
                        torch.clamp(self.var_L + self.delta, 0, 1))
        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            # detach to avoid BP to G
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            self.log_dict['l_g_total'] = l_g_total.item()
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

            self.log_dict['l_d_real'] = l_d_real.item()
            self.log_dict['l_d_fake'] = l_d_fake.item()
            self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
            self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        logger.info('Network G structure: {}, with parameters: {:,d}'.format(
            net_struc_str, n))
        logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            logger.info(
                'Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                logger.info(
                    'Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])
        load_path_F = self.opt['path']['pretrain_model_F']
        if load_path_F is not None:
            logger.info('Loading model for F [{:s}] ...'.format(load_path_F))
            network = self.netF.module.features
            if isinstance(network, nn.DataParallel):
                network = network.module
            load_net = torch.load(load_path_F)
            load_net_clean = OrderedDict()  # remove unnecessary 'module.'
            for k, v in load_net.items():
                if k.startswith('module.features.'):
                    load_net_clean[k[16:]] = v
            network.load_state_dict(load_net_clean,
                                    strict=self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
class GenerativeModel(BaseModel):
    def __init__(self, opt):
        super(GenerativeModel, self).__init__(opt)

        # DISTRIBUTED TRAINING OR NOT
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1

        # DEFINE NETWORKS
        self.netE = networks.define_encoder(opt).to(self.device)
        self.netD = networks.define_decoder(opt).to(self.device)
        self.netF, self.nz, self.stop_gradients = networks.define_flow(opt)
        self.netF.to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(self.netE, device_ids=[torch.cuda.current_device()])
            self.netD = DistributedDataParallel(self.netD, device_ids=[torch.cuda.current_device()])
            self.netF = DistributedDataParallel(self.netF, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
            self.netD = DataParallel(self.netD)
            self.netF = DataParallel(self.netF)

        if self.is_train:
            self.netE.train()
            self.netD.train()
            self.netF.train()

        # GET CONFIG PARAMS FOR LOSSES AND LR
        train_opt = opt['train']

        # DEFINE LOSSES, OPTIMIZER AND SCHEDULE
        if self.is_train:
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss(reduction='mean').to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss(reduction='mean').to(self.device)
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']

                if train_opt['add_background_mask']:
                    self.add_mask = True
                else:
                    self.add_mask = False

            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            if train_opt['nll_weight'] is None:
                raise ValueError('nll loss should be always in this version')
            self.cri_nll = NLLLoss(reduction='mean').to(self.device)
            self.l_nll_w = train_opt['nll_weight']

            if train_opt['feature_weight'] > 0:
                self.cri_fea = VGGLoss().to(self.device)
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None

            # optimizers
            if train_opt['lr_E'] > 0:
                self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                                    lr=train_opt['lr_E'],
                                                    weight_decay=train_opt['weight_decay_E'] if train_opt[
                                                        'weight_decay_E'] else 0,
                                                    betas=(train_opt['beta1_E'], train_opt['beta2_E']))
                self.optimizers.append(self.optimizer_E)
            else:
                for p in self.netE.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze encoder.')

            if train_opt['lr_D'] > 0:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=train_opt['lr_D'],
                                                    weight_decay=train_opt['weight_decay_D'] if train_opt[
                                                        'weight_decay_D'] else 0,
                                                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_D)
            else:
                for p in self.netD.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze decoder.')

            if train_opt['lr_F'] > 0:
                self.optimizer_F = torch.optim.Adam(self.netF.parameters(),
                                                    lr=train_opt['lr_F'],
                                                    weight_decay=train_opt['weight_decay_F'] if train_opt[
                                                        'weight_decay_F'] else 0,
                                                    betas=(train_opt['beta1_F'], train_opt['beta2_F']))
                self.optimizers.append(self.optimizer_F)
            else:
                for p in self.netF.parameters():
                    p.requires_grad_(False)
                logger.info('Freeze flow.')

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                logger.info('No learning rate scheme is applied.')

            self.log_dict = OrderedDict()

        self.print_network()  # print networks structure
        self.load()  # load G, D, F if needed
        self.test_flow()

    def feed_data(self, data, need_GT=True):
        self.image = data[0].to(self.device)
        if need_GT:
            self.image_gt = self.image

    def optimize_parameters(self, step):
        for optimizer in self.optimizers:
            optimizer.zero_grad()

        z = self.netE(self.image)
        reconstructed = self.netD(z)

        l_total = 0

        if self.cri_pix:  # pixel loss
            if self.add_mask:
                mask = (self.image_gt[:, 0, :, :] == 1).unsqueeze(1).float()
                inv_mask = 1 - mask
                l_pix = (0.2 * self.cri_pix(reconstructed * mask, self.image_gt * mask) +
                         0.8 * self.cri_pix(reconstructed * inv_mask, self.image_gt * inv_mask))
            else:
                l_pix = self.l_pix_w * self.cri_pix(reconstructed, self.image_gt)
            l_total += l_pix

        if self.cri_fea:  # feature loss
            l_fea = self.l_fea_w * self.cri_fea(reconstructed, self.image_gt)
            l_total += l_fea

        # negative likelihood loss
        if self.stop_gradients:
            noise_out, logdets = self.netF(z.detach())
        else:
            noise_out, logdets = self.netF(z)

        l_nll = self.l_nll_w * self.cri_nll(noise_out, logdets)
        l_total += l_nll

        l_total.backward()
        for optimizer in self.optimizers:
            optimizer.step()

        # set log
        if self.cri_pix:
            self.log_dict['l_pix'] = l_pix.item()
        if self.cri_fea:
            self.log_dict['l_fea'] = l_fea.item()
        if self.cri_nll:
            self.log_dict['l_nll'] = l_nll.item()

    def sample_images(self, n=25):
        self.netF.eval()
        self.netD.eval()
        with torch.no_grad():
            noise = torch.randn(n, self.nz).to(self.device)
            if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel):
                sample = self.netD(self.netF.module.reverse(noise)).detach().float().cpu()
            else:
                sample = self.netD(self.netF.reverse(noise)).detach().float().cpu()
        self.netF.train()
        self.netD.train()
        return sample

    def get_current_log(self):
        return self.log_dict

    def print_network(self):
        for name, net in [('E', self.netE), ('D', self.netD), ('F', self.netF)]:
            s, n = self.get_network_description(net)
            if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(net.__class__.__name__,
                                                 net.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(net.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
                logger.info(s)

        if self.is_train and self.cri_fea:
            vgg_net = self.cri_fea.vgg
            s, n = self.get_network_description(vgg_net)
            if isinstance(vgg_net, nn.DataParallel) or isinstance(vgg_net, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(vgg_net.__class__.__name__,
                                                 vgg_net.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(vgg_net.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network VGG structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

    def load(self):
        load_path_E = self.opt['path']['pretrained_encoder']
        if load_path_E is not None:
            logger.info('Loading model for E [{:s}] ...'.format(load_path_E))
            self.load_network(load_path_E, self.netE, self.opt['path']['strict_load'])

        load_path_D = self.opt['path']['pretrained_decoder']
        if load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])

        load_path_F = self.opt['path']['pretrained_flow']
        if load_path_F is not None:
            logger.info('Loading model for F [{:s}] ...'.format(load_path_F))
            self.load_network(load_path_F, self.netF, self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netE, 'E', iter_step)
        self.save_network(self.netD, 'D', iter_step)
        self.save_network(self.netF, 'F', iter_step)
        
        
    def test_flow(self):
        with torch.no_grad():
            test_input = torch.randn((2, self.nz)).to(self.device)
            test_output, _ = self.netF(test_input)
            if isinstance(self.netF, nn.DataParallel) or isinstance(self.netF, DistributedDataParallel):
                test_input2 = self.netF.module.reverse(test_output)
            else:
                test_input2 = self.netF.reverse(test_output)
            assert torch.allclose(test_input, test_input2), 'Flow model is incorrect'
Exemplo n.º 30
0
class Model:
    """
    This class handles basic methods for handling the model:
    1. Fit the model
    2. Make predictions
    3. Save
    4. Load
    """
    def __init__(self, input_size, n_channels, hparams):

        self.hparams = hparams

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # define the models
        self.model = WaveNet(n_channels=n_channels).to(self.device)
        summary(self.model, (input_size, n_channels))
        # self.model.half()

        if torch.cuda.device_count() > 1:
            print("Number of GPUs will be used: ",
                  torch.cuda.device_count() - 3)
            self.model = DP(self.model,
                            device_ids=list(
                                range(torch.cuda.device_count() - 3)))
        else:
            print('Only one GPU is available')

        self.metric = Metric()
        self.num_workers = 1
        ########################## compile the model ###############################

        # define optimizer
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=self.hparams['lr'],
                                          weight_decay=1e-5)

        # weights = torch.Tensor([0.025,0.033,0.039,0.046,0.069,0.107,0.189,0.134,0.145,0.262,1]).cuda()
        self.loss = nn.BCELoss()  # CompLoss(self.device)

        # define early stopping
        self.early_stopping = EarlyStopping(
            checkpoint_path=self.hparams['checkpoint_path'] + '/checkpoint.pt',
            patience=self.hparams['patience'],
            delta=self.hparams['min_delta'],
        )
        # lr cheduler
        self.scheduler = ReduceLROnPlateau(
            optimizer=self.optimizer,
            mode='max',
            factor=0.2,
            patience=3,
            verbose=True,
            threshold=self.hparams['min_delta'],
            threshold_mode='abs',
            cooldown=0,
            eps=0,
        )

        self.seed_everything(42)
        self.threshold = 0.75
        self.scaler = torch.cuda.amp.GradScaler()

    def seed_everything(self, seed):
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)

    def fit(self, train, valid):

        train_loader = DataLoader(
            train,
            batch_size=self.hparams['batch_size'],
            shuffle=True,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate
        valid_loader = DataLoader(
            valid,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        # tensorboard object
        writer = SummaryWriter()

        for epoch in range(self.hparams['n_epochs']):

            # trian the model
            self.model.train()
            avg_loss = 0.0

            train_preds, train_true = torch.Tensor([]), torch.Tensor([])

            for (X_batch, y_batch) in tqdm(train_loader):
                y_batch = y_batch.float().to(self.device)
                X_batch = X_batch.float().to(self.device)

                self.optimizer.zero_grad()
                # get model predictions
                pred = self.model(X_batch)
                X_batch = X_batch.cpu().detach()

                # process loss_1
                pred = pred.view(-1, pred.shape[-1])
                y_batch = y_batch.view(-1, y_batch.shape[-1])
                train_loss = self.loss(pred, y_batch)
                y_batch = y_batch.float().cpu().detach()
                pred = pred.float().cpu().detach()

                train_loss.backward(
                )  #self.scaler.scale(train_loss).backward()  #
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                # torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.5)
                self.optimizer.step()  # self.scaler.step(self.optimizer)  #
                self.scaler.update()

                # calc metric
                avg_loss += train_loss.item() / len(train_loader)

                train_true = torch.cat([train_true, y_batch], 0)
                train_preds = torch.cat([train_preds, pred], 0)

            # calc triaing metric
            train_preds = train_preds.numpy()
            train_preds[np.where(train_preds >= self.threshold)] = 1
            train_preds[np.where(train_preds < self.threshold)] = 0
            metric_train = self.metric.compute(labels=train_true.numpy(),
                                               outputs=train_preds)

            # evaluate the model
            print('Model evaluation...')
            self.model.zero_grad()
            self.model.eval()
            val_preds, val_true = torch.Tensor([]), torch.Tensor([])
            avg_val_loss = 0.0
            with torch.no_grad():
                for X_batch, y_batch in valid_loader:
                    y_batch = y_batch.float().to(self.device)
                    X_batch = X_batch.float().to(self.device)

                    pred = self.model(X_batch)
                    X_batch = X_batch.float().cpu().detach()

                    pred = pred.reshape(-1, pred.shape[-1])
                    y_batch = y_batch.view(-1, y_batch.shape[-1])

                    avg_val_loss += self.loss(
                        pred, y_batch).item() / len(valid_loader)
                    y_batch = y_batch.float().cpu().detach()
                    pred = pred.float().cpu().detach()

                    val_true = torch.cat([val_true, y_batch], 0)
                    val_preds = torch.cat([val_preds, pred], 0)

            # evalueate metric
            val_preds = val_preds.numpy()
            val_preds[np.where(val_preds >= self.threshold)] = 1
            val_preds[np.where(val_preds < self.threshold)] = 0
            metric_val = self.metric.compute(val_true.numpy(), val_preds)

            self.scheduler.step(avg_val_loss)
            res = self.early_stopping(score=avg_val_loss, model=self.model)

            # print statistics
            if self.hparams['verbose_train']:
                print(
                    '| Epoch: ',
                    epoch + 1,
                    '| Train_loss: ',
                    avg_loss,
                    '| Val_loss: ',
                    avg_val_loss,
                    '| Metric_train: ',
                    metric_train,
                    '| Metric_val: ',
                    metric_val,
                    '| Current LR: ',
                    self.__get_lr(self.optimizer),
                )

            # # add history to tensorboard
            writer.add_scalars(
                'Loss',
                {
                    'Train_loss': avg_loss,
                    'Val_loss': avg_val_loss
                },
                epoch,
            )

            writer.add_scalars('Metric', {
                'Metric_train': metric_train,
                'Metric_val': metric_val
            }, epoch)

            if res == 2:
                print("Early Stopping")
                print(
                    f'global best min val_loss model score {self.early_stopping.best_score}'
                )
                break
            elif res == 1:
                print(f'save global val_loss model score {avg_val_loss}')

        writer.close()

        self.model.zero_grad()

        return True

    def predict(self, X_test):

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.hparams['batch_size'],
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        test_preds = torch.Tensor([])
        print('Start generation of predictions')
        with torch.no_grad():
            for i, (X_batch, y_batch) in enumerate(tqdm(test_loader)):
                X_batch = X_batch.float().to(self.device)

                pred = self.model(X_batch)

                X_batch = X_batch.float().cpu().detach()

                test_preds = torch.cat([test_preds, pred.cpu().detach()], 0)

        return test_preds.numpy()

    def get_heatmap(self, X_test):

        # evaluate the model
        self.model.eval()

        test_loader = torch.utils.data.DataLoader(
            X_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers)  # ,collate_fn=train.my_collate

        test_preds = torch.Tensor([])
        with torch.no_grad():
            for i, (X_batch) in enumerate(test_loader):
                X_batch = X_batch.float().to(self.device)

                pred = self.model.activatations(X_batch)
                pred = torch.sigmoid(pred)

                X_batch = X_batch.float().cpu().detach()

                test_preds = torch.cat([test_preds, pred.cpu().detach()], 0)

        return test_preds.numpy()

    def model_save(self, model_path):
        torch.save(self.model, model_path)
        return True

    def model_load(self, model_path):
        self.model = torch.load(model_path)
        return True

    ################## Utils #####################

    def __get_lr(self, optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']