Exemple #1
0
    def __init__(self, hparams: Namespace) -> None:
        super(SR3Experiment, self).__init__()

        self.hparams = hparams
        self.val_folder = Path('output')
        self.val_folder.mkdir(exist_ok=True, parents=True)

        self.create_model()
        self.scale = self.hparams.scale

        # loss
        loss_type = self.hparams.train.pixel_criterion
        self.loss_type = loss_type
        if loss_type == 'l1':
            self.cri_pix = nn.L1Loss(reduction='none')
        elif loss_type == 'l2':
            self.cri_pix = nn.MSELoss(reduction='none')
        elif loss_type == 'cb':
            self.cri_pix = CharbonnierLoss()
        elif loss_type == 'l1l2':

            class L1L2Loss(nn.Module):
                def __init__(self, reduction='mean'):
                    super().__init__()
                    self.l1 = nn.L1Loss(reduction=reduction)
                    self.l2 = nn.MSELoss(reduction=reduction)

                def forward(self, x, y):
                    return self.l1(x, y) + self.l2(x, y)

            self.cri_pix = L1L2Loss(reduction='none')
        else:
            raise NotImplementedError(
                'Loss type [{:s}] is not recognized.'.format(loss_type))
        self.l_pix_w = self.hparams.train.pixel_weight
Exemple #2
0
 def get_criterion(self, mode, opt):
     if mode == 'pix':
         loss_type = opt['pixel_criterion']
         if loss_type == 'l1':
             criterion = nn.L1Loss(reduction=opt['reduction']).to(self.device)
         elif loss_type == 'l2':
             criterion = nn.MSELoss(reduction=opt['reduction']).to(self.device)
         elif loss_type == 'cb':
             criterion = CharbonnierLoss(reduction=opt['reduction']).to(self.device)
         else:
             raise NotImplementedError('Loss type [{:s}] is not recognized for pixel'.
                 format(loss_type))
         weight = opt['pixel_weight']
     elif mode == 'gan':
         criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
         weight = opt['gan_weight']
     else:
         raise TypeError('Unknown type: {} for criterion'.format(mode))
     return criterion, weight
Exemple #3
0
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))

            self.l_pix_w = train_opt['pixel_weight']
            self.grad_w = train_opt['grad_weight']
            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()
Exemple #4
0
    def __init__(self, opt):
        super(SRVmafModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.use_gpu = opt['network_G']['use_gpu']
        self.use_gpu = True
        self.real_IQA_only = train_opt['IQA_only']

        # define network and load pretrained models
        if self.use_gpu:
            self.netG = networks.define_G(opt).to(self.device)
            if opt['dist']:
                self.netG = DistributedDataParallel(
                    self.netG, device_ids=[torch.cuda.current_device()])
            else:
                self.netG = DataParallel(self.netG)
        else:
            self.netG = networks.define_G(opt)

        if self.is_train:
            if train_opt['IQA_weight']:
                if train_opt['IQA_criterion'] == 'vmaf':
                    self.cri_IQA = nn.MSELoss()
                self.l_IQA_w = train_opt['IQA_weight']

                self.netI = networks.define_I(opt)
                if opt['dist']:
                    pass
                else:
                    self.netI = DataParallel(self.netI)
            else:
                logger.info('Remove IQA loss.')
                self.cri_IQA = None

        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # pixel loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # CX loss
            if train_opt['CX_weight']:
                l_CX_type = train_opt['CX_criterion']
                if l_CX_type == 'contextual_loss':
                    self.cri_CX = ContextualLoss()
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_CX_type))
                self.l_CX_w = train_opt['CX_weight']
            else:
                logger.info('Remove CX loss.')
                self.cri_CX = None

            # ssim loss
            if train_opt['ssim_weight']:
                self.cri_ssim = train_opt['ssim_criterion']
                self.l_ssim_w = train_opt['ssim_weight']
                self.ssim_window = train_opt['ssim_window']
            else:
                logger.info('Remove ssim loss.')
                self.cri_ssim = None

            # load VGG perceptual loss if use CX loss
            # if train_opt['CX_weight']:
            #     self.netF = networks.define_F(opt, use_bn=False).to(self.device)
            #     if opt['dist']:
            #         pass  # do not need to use DistributedDataParallel for netF
            #     else:
            #         self.netF = DataParallel(self.netF)

            # optimizers of netG
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # optimizers of netI
            if train_opt['IQA_weight']:
                wd_I = train_opt['weight_decay_I'] if train_opt[
                    'weight_decay_I'] else 0
                optim_params = []
                for k, v in self.netI.named_parameters(
                ):  # can optimize for a part of the model
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                self.optimizer_I = torch.optim.Adam(optim_params,
                                                    lr=train_opt['lr_I'],
                                                    weight_decay=wd_I,
                                                    betas=(train_opt['beta1'],
                                                           train_opt['beta2']))
                self.optimizers.append(self.optimizer_I)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
            self.set_requires_grad(self.netG, False)
            self.set_requires_grad(self.netI, False)
Exemple #5
0
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            # Gradient Loss
            if train_opt['gradient_weight'] > 0:
                self.cri_gradient = GradientLoss(train_opt['gradient_type'], train_opt['gradient_grid_horizontal'],
                                                 train_opt['gradient_grid_vertical'], train_opt['gradient_criterion'],
                                                 self.device)
                self.l_gradient_w = train_opt['gradient_weight']
            else:
                logger.info('Remove gradient loss')
                self.cri_gradient = None

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
Exemple #6
0
def main(args):
    print("===> Loading datasets")
    data_set = DatasetLoader(args.data_lr,
                             args.data_hr,
                             size_w=args.size_w,
                             size_h=args.size_h,
                             scale=args.scale,
                             n_frames=args.n_frames,
                             interval_list=args.interval_list,
                             border_mode=args.border_mode,
                             random_reverse=args.random_reverse)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=False,
                              drop_last=True)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    criterion = CharbonnierLoss()
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    # print(model)

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            # 如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

        # 如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch if args.start_epoch != 0 else start_epoch

    #如果use_current_lr大于0 测代替作为lr
    args.lr = args.use_current_lr if args.use_current_lr > 0 else args.lr
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 betas=(args.beta1, args.beta2),
                                 eps=1e-8)

    #### training
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)
        if args.use_tqdm == 1:
            losses, psnrs = one_epoch_train_tqdm(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])
        else:
            losses, psnrs = one_epoch_train_logger(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])

        # save model
        # if epoch %9 != 0:
        #     continue

        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edvr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        # self.print_network()
        self.load()
        '''
        # Freeze front when indicated
        if opt['train']['freeze_front']:
            print('Freeze the front part of the model')
            if opt['network_G']['which_model_G'] == 'EDVR':
                for name, child in self.netG.module.named_children():
                    if name not in ('tsa_fusion', 'recon_trunk', 'upconv1', 'upconv2', 'HRconv', 'conv_last'):
                        for params in child.parameters():
                            params.requires_grad = False
                
                elif opt['network_G']['which_model_G'] == 'DUF':
                    for name, child in self.netG.module.named_children():
                        if name in ('conv3d_1', 'dense_block_1', 'dense_block_2', 'dense_block_3'):
                            for params in child.parameters():
                                params.requires_grad = False
                
            else:
                raise NotImplementedError()
        '''
        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='mean').to(
                    self.device)  # Change from sum to mean
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='mean').to(
                    self.device)  # Change from sum to mean
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            elif loss_type == 'huber':
                self.cri_pix = HuberLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            if opt['train']['freeze_front']:
                normal_params = []
                freeze_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'module.conv3d_1' in k or 'module.dense_block_1' in k or 'module.dense_block_2' in k or 'module.dense_block_3' in k:
                            freeze_params.append(v)
                        else:
                            normal_params.append(v)
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': freeze_params,
                        'lr': 0
                    },
                ]
            elif train_opt['small_offset_lr']:
                normal_params = []
                conv_offset_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'pcd_align' in k or 'fea_L' in k or 'feature_extraction' in k or 'conv_first' in k:
                            conv_offset_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': conv_offset_params,
                        'lr': train_opt['lr_G'] * 0.1
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))

            if train_opt['optim'] == 'SGD':
                self.optimizer_G = torch.optim.SGD(optim_params,
                                                   lr=train_opt['lr_G'],
                                                   weight_decay=wd_G)
            else:
                self.optimizer_G = torch.optim.Adam(optim_params,
                                                    lr=train_opt['lr_G'],
                                                    weight_decay=wd_G,
                                                    betas=(train_opt['beta1'],
                                                           train_opt['beta2']))

            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()
Exemple #8
0
    def __init__(self, opt):
        super(ClassSR_Model, self).__init__(opt)

        self.patch_size = int(opt["patch_size"])
        self.step = int(opt["step"])
        self.scale = int(opt["scale"])
        self.name = opt['name']
        self.which_model = opt['network_G']['which_model_G']

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)

        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.l1w = float(opt["train"]["l1w"])
            self.class_loss_w = float(opt["train"]["class_loss_w"])
            self.average_loss_w = float(opt["train"]["average_loss_w"])
            self.pf = opt['logger']['print_freq']
            self.batch_size = int(opt['datasets']['train']['batch_size'])
            self.netG.train()

            # loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            elif loss_type == 'ClassSR_loss':
                self.cri_pix = nn.L1Loss().to(self.device)
                self.class_loss = class_loss_3class().to(self.device)
                self.average_loss = average_loss_3class().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not recognized.'.format(loss_type))

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            if opt['fix_SR_module']:
                for k, v in self.netG.named_parameters(
                ):  # can optimize for a part of the model
                    if v.requires_grad and "class" not in k:
                        v.requires_grad = False

            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()