예제 #1
0
    def init(self):
        opt = self.args
        if not os.path.exists(opt.saved_dir):
            os.makedirs(opt.saved_dir)
        self.fake_A_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(opt.pool_size)
        self.crit_cycle = torch.nn.L1Loss()
        self.crit_idt = torch.nn.L1Loss()
        self.crit_gan = GANLoss(opt.gan_mode).cuda()
        self.cam_loss = CAMLoss()
        self.optim_G = torch.optim.Adam(itertools.chain(
            self.model.G_A.parameters(), self.model.G_B.parameters()),
                                        lr=opt.lr,
                                        betas=(opt.beta1, 0.999))
        self.optim_D = torch.optim.Adam(
            itertools.chain(self.model.D_A.parameters(),
                            self.model.D_B.parameters()),
            lr=opt.lr,
            betas=(opt.beta1, 0.999))  # default: 0.5
        self.optimizers = [self.optim_G, self.optim_D]

        self.schedulers = [
            get_scheduler(optimizer, self.args)
            for optimizer in self.optimizers
        ]
예제 #2
0
    def define_loss(self):
        # ------------------------------------
        # G_loss
        # ------------------------------------
        if self.opt_train['G_lossfn_weight'] > 0:
            G_lossfn_type = self.opt_train['G_lossfn_type']
            if G_lossfn_type == 'l1':
                self.G_lossfn = nn.L1Loss().to(self.device)
            elif G_lossfn_type == 'l2':
                self.G_lossfn = nn.MSELoss().to(self.device)
            elif G_lossfn_type == 'l2sum':
                self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
            elif G_lossfn_type == 'ssim':
                self.G_lossfn = SSIMLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not found.'.format(G_lossfn_type))
            self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
        else:
            print('Do not use pixel loss.')
            self.G_lossfn = None

        # ------------------------------------
        # F_loss
        # ------------------------------------
        if self.opt_train['F_lossfn_weight'] > 0:
            F_lossfn_type = self.opt_train['F_lossfn_type']
            F_use_input_norm = self.opt_train['F_use_input_norm']
            F_feature_layer = self.opt_train['F_feature_layer']
            if self.opt['dist']:
                self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer,
                                               use_input_norm=F_use_input_norm,
                                               lossfn_type=F_lossfn_type).to(
                                                   self.device)
            else:
                self.F_lossfn = PerceptualLoss(feature_layer=F_feature_layer,
                                               use_input_norm=F_use_input_norm,
                                               lossfn_type=F_lossfn_type)
                self.F_lossfn.vgg = self.model_to_device(self.F_lossfn.vgg)
                self.F_lossfn.lossfn = self.F_lossfn.lossfn.to(self.device)
            self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
        else:
            print('Do not use feature loss.')
            self.F_lossfn = None

        # ------------------------------------
        # D_loss
        # ------------------------------------
        self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0,
                                0.0).to(self.device)
        self.D_lossfn_weight = self.opt_train['D_lossfn_weight']

        self.D_update_ratio = self.opt_train[
            'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
        self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[
            'D_init_iters'] else 0
예제 #3
0
    def define_loss(self):
        # ------------------------------------
        # G_loss
        # ------------------------------------
        if self.opt_train['G_lossfn_weight'] > 0:
            G_lossfn_type = self.opt_train['G_lossfn_type']
            if G_lossfn_type == 'l1':
                self.G_lossfn = nn.L1Loss().to(self.device)
            elif G_lossfn_type == 'l2':
                self.G_lossfn = nn.MSELoss().to(self.device)
            elif G_lossfn_type == 'l2sum':
                self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
            elif G_lossfn_type == 'ssim':
                self.G_lossfn = SSIMLoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] is not found.'.format(G_lossfn_type))
            self.G_lossfn_weight = self.opt_train['G_lossfn_weight']
        else:
            print('Do not use pixel loss.')
            self.G_lossfn = None

        # ------------------------------------
        # F_loss
        # ------------------------------------
        if self.opt_train['F_lossfn_weight'] > 0:
            F_lossfn_type = self.opt_train['F_lossfn_type']
            if F_lossfn_type == 'l1':
                self.F_lossfn = nn.L1Loss().to(self.device)
            elif F_lossfn_type == 'l2':
                self.F_lossfn = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(F_lossfn_type))
            self.F_lossfn_weight = self.opt_train['F_lossfn_weight']
            # self.netF = define_F(self.opt, use_bn=False).to(self.device)
        else:
            print('Do not use feature loss.')
            self.F_lossfn = None

        # ------------------------------------
        # D_loss
        # ------------------------------------
        self.D_lossfn = GANLoss(self.opt_train['gan_type'], 1.0,
                                0.0).to(self.device)
        self.D_lossfn_weight = self.opt_train['D_lossfn_weight']

        self.D_update_ratio = self.opt_train[
            'D_update_ratio'] if self.opt_train['D_update_ratio'] else 1
        self.D_init_iters = self.opt_train['D_init_iters'] if self.opt_train[
            'D_init_iters'] else 0
예제 #4
0
 def get_criterion(self, mode, opt):
     if mode == 'pix':
         loss_type = opt['pixel_criterion']
         if loss_type == 'l1':
             criterion = nn.L1Loss(reduction=opt['reduction']).to(
                 self.device)
         elif loss_type == 'l2':
             criterion = nn.MSELoss(reduction=opt['reduction']).to(
                 self.device)
         elif loss_type == 'cb':
             criterion = CharbonnierLoss(reduction=opt['reduction']).to(
                 self.device)
         else:
             raise NotImplementedError(
                 'Loss type [{:s}] is not recognized for pixel'.format(
                     loss_type))
         weight = opt['pixel_weight']
     elif mode == 'gan':
         criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
         weight = opt['gan_weight']
     else:
         raise TypeError('Unknown type: {} for criterion'.format(mode))
     return criterion, weight
예제 #5
0
    def __init__(self, args):
        super(PPONModel, self).__init__(args)

        # define networks and load pre-trained models
        self.netG = networks.define_G(args).cuda()
        if self.is_train:
            if args.which_model == 'perceptual':
                self.netD = networks.define_D().cuda()
                self.netD.train()
            self.netG.train()

        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if args.pixel_weight > 0:
                l_pix_type = args.pixel_criterion
                if l_pix_type == 'l1':  # loss pixel type
                    self.cri_pix = nn.L1Loss().cuda()
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().cuda()
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = args.pixel_weight
            else:
                print('Remove pixel loss.')
                self.cri_pix = None  # critic pixel

            # G structure loss
            if args.structure_weight > 0:
                self.cri_msssim = pytorch_msssim.MS_SSIM(data_range=args.rgb_range).cuda()
                self.cri_ml1 = MultiscaleL1Loss().cuda()
            else:
                print('Remove structure loss.')
                self.cri_msssim = None
                self.cri_ml1 = None

            # G feature loss
            if args.feature_weight > 0:
                l_fea_type = args.feature_criterion
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().cuda()
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().cuda()
                else:
                    raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = args.feature_weight
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.vgg = networks.define_F().cuda()

            if args.gan_weight > 0:
                # gan loss
                self.cri_gan = GANLoss(args.gan_type, 1.0, 0.0).cuda()
                self.l_gan_w = args.gan_weight
            else:
                self.cri_gan = None

            # optimizers
            # G
            if args.which_model == 'structure':
                for param in self.netG.CFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.CRM.parameters():
                    param.requires_grad = False

            if args.which_model == 'perceptual':
                for param in self.netG.CFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.CRM.parameters():
                    param.requires_grad = False
                for param in self.netG.SFEM.parameters():
                    param.requires_grad = False
                for param in self.netG.SRM.parameters():
                    param.requires_grad = False
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('Warning: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=args.lr_G)
            self.optimizers.append(self.optimizer_G)

            # D
            if args.which_model == 'perceptual':
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=args.lr_D)
                self.optimizers.append(self.optimizer_D)

            # schedulers
            if args.lr_scheme == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer,
                                                                    args.lr_steps, args.lr_gamma))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        print('------------- Model initialized -------------')
        self.print_network()
        print('---------------------------------------------')
예제 #6
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # G Rank-content loss
            if train_opt['R_weight'] > 0:
                self.l_R_w = train_opt['R_weight']  # load rank-content loss
                self.R_bias = train_opt['R_bias']
                self.netR = networks.define_R(opt).to(self.device)
                if opt['dist']:
                    self.netR = DistributedDataParallel(
                        self.netR, device_ids=[torch.cuda.current_device()])
                else:
                    self.netR = DataParallel(self.netR)
            else:
                logger.info('Remove rank-content loss.')

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
예제 #7
0
    optimizers.append(optimizer_D)

    # Handle multi-gpu if desired
    if (device1.type == 'cuda') and (config['ngpu'] > 1):
        netG = nn.DataParallel(netG, list(range(config['ngpu'])))
        netD = nn.DataParallel(netD, list(range(config['ngpu'])))
        netF = nn.DataParallel(netF, list(range(config['ngpu'])))
    # summary(netG, input_size=(3, input_shape, input_shape), device="cuda")
    # summary(netD, input_size=(3, output_shape, output_shape), device="cuda")

    # G pixel loss
    cri_pix = nn.L1Loss().to(device1)
    # G feature loss
    cri_fea = nn.L1Loss().to(device1)
    # GD gan loss
    cri_gan = GANLoss("vanilla", 1.0, 0.0).to(device1)

    # schedulers
    schedulers = list()
    for optimizer in optimizers:
        schedulers.append(
            lr_scheduler.MultiStepLR(optimizer, [50, 75, 100, 200], 0.5))

    log_dict = OrderedDict()

    global_step = config['n_epoch_start'] * train_loader.__len__()
    for i in range(config['n_epoch_start']):
        for scheduler in schedulers:
            scheduler.step()

    for epoch in trange(config['n_epoch_start'], config['n_epoch_end']):
예제 #8
0
    def __init__(self, cfg, local_cfg):
        self.cfg = cfg
        self.local_cfg = local_cfg
        self.device = torch.device(self.local_cfg.gpu)

        # setup models
        self.cfg.model.gen.shape = self.cfg.dataset.shape
        self.cfg.model.dis.shape = self.cfg.dataset.shape
        self.G = define_G(self.cfg)
        self.D = define_D(self.cfg)
        self.G_ema = define_G(self.cfg)
        self.G_ema.eval()
        ema_inplace(self.G_ema, self.G, 0.0)
        self.A = DiffAugment(policy=self.cfg.solver.augment)
        self.lidar = LiDAR(
            num_ring=cfg.dataset.shape[0],
            num_points=cfg.dataset.shape[1],
            min_depth=cfg.dataset.min_depth,
            max_depth=cfg.dataset.max_depth,
            angle_file=osp.join(cfg.dataset.root, "angles.pt"),
        )
        self.lidar.eval()

        self.G.to(self.device)
        self.D.to(self.device)
        self.G_ema.to(self.device)
        self.A.to(self.device)
        self.lidar.to(self.device)

        self.G = DDP(self.G,
                     device_ids=[self.local_cfg.gpu],
                     broadcast_buffers=False)
        self.D = DDP(self.D,
                     device_ids=[self.local_cfg.gpu],
                     broadcast_buffers=False)

        if dist.get_rank() == 0:
            print("minibatch size per gpu:", self.local_cfg.batch_size)
            print("number of gradient accumulation:",
                  self.cfg.solver.num_accumulation)

        self.ema_decay = 0.5**(self.cfg.solver.batch_size /
                               (self.cfg.solver.smoothing_kimg * 1000))

        # training dataset
        self.dataset = define_dataset(self.cfg.dataset, phase="train")
        self.loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.local_cfg.batch_size,
            shuffle=False,
            num_workers=self.local_cfg.num_workers,
            pin_memory=self.cfg.pin_memory,
            sampler=torch.utils.data.distributed.DistributedSampler(
                self.dataset),
            drop_last=True,
        )
        self.loader = cycle(self.loader)

        # validation dataset
        self.val_dataset = define_dataset(self.cfg.dataset, phase="val")
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.local_cfg.batch_size,
            shuffle=True,
            num_workers=self.local_cfg.num_workers,
            pin_memory=self.cfg.pin_memory,
            drop_last=False,
        )

        # loss criterion
        self.loss_weight = dict(self.cfg.solver.loss)
        self.criterion = {}
        self.criterion["gan"] = GANLoss(self.cfg.solver.gan_mode).to(
            self.device)
        if "gp" in self.loss_weight and self.loss_weight["gp"] > 0.0:
            self.criterion["gp"] = True
        if "pl" in self.loss_weight and self.loss_weight["pl"] > 0.0:
            self.criterion["pl"] = True
            self.pl_ema = torch.tensor(0.0).to(self.device)
        if dist.get_rank() == 0:
            print("loss: {}".format(tuple(self.criterion.keys())))

        # optimizer
        self.optim_G = optim.Adam(
            params=self.G.parameters(),
            lr=self.cfg.solver.lr.alpha.gen,
            betas=(self.cfg.solver.lr.beta1, self.cfg.solver.lr.beta2),
        )
        self.optim_D = optim.Adam(
            params=self.D.parameters(),
            lr=self.cfg.solver.lr.alpha.dis,
            betas=(self.cfg.solver.lr.beta1, self.cfg.solver.lr.beta2),
        )

        # automatic mixed precision
        self.enable_amp = cfg.enable_amp
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)
        if dist.get_rank() == 0 and self.enable_amp:
            print("amp enabled")

        # resume from checkpoints
        self.start_iteration = 0
        if self.cfg.resume is not None:
            state_dict = torch.load(self.cfg.resume, map_location="cpu")
            self.start_iteration = state_dict[
                "step"] // self.cfg.solver.batch_size
            self.G.module.load_state_dict(state_dict["G"])
            self.D.module.load_state_dict(state_dict["D"])
            self.G_ema.load_state_dict(state_dict["G_ema"])
            self.optim_G.load_state_dict(state_dict["optim_G"])
            self.optim_D.load_state_dict(state_dict["optim_D"])
            if "pl" in self.criterion:
                self.criterion["pl"].pl_ema = state_dict["pl_ema"].to(
                    self.device)

        # for visual validation
        self.fixed_noise = torch.randn(self.local_cfg.batch_size,
                                       cfg.model.gen.in_ch,
                                       device=self.device)
예제 #9
0
    def __init__(self, opt):
        super(InpaintingModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained model
        self.netG = networks.define_G(opt).to(self.device)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            self.netG.train()
            self.netD.train()

        self.load()  # load G and D

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                elif l_pix_type == 'ml1':
                    self.cri_pix = MultiscaleL1Loss().to(self.device)
                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)

                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
                self.guided_cri_fea = MaskedL1Loss().to(self.device)
            else:
                self.cri_fea = None
            if self.cri_fea:  # load VGG model
                # self.vgg = Vgg19()
                # self.vgg.load_state_dict(torch.load(vgg_model))
                # for param in self.vgg.parameters():
                #     param.requires_grad = False
                self.vgg = networks.define_F(opt)
                self.vgg.to(self.device)
                self.vgg_layers = ['r11', 'r21', 'r31', 'r41', 'r51']
                self.vgg_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]]
                self.vgg_fns = [self.cri_fea] * len(self.vgg_layers)

            ## discriminator features
            if train_opt['dis_feature_weight'] > 0:
                l_dis_fea_type = train_opt['dis_feature_criterion']
                if l_dis_fea_type == 'l1':
                    self.cri_dis_fea = nn.L1Loss().to(self.device)
                elif l_dis_fea_type == 'l2':
                    self.cri_dis_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError('Unsupported loss type: {}'.format(l_dis_fea_type))
                self.l_dis_fea_w = train_opt['dis_feature_weight']
            else:
                self.cri_dis_fea = None
            if self.cri_dis_fea:
                self.dis_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]]
                self.dis_fns = [self.cri_dis_fea] * len(self.dis_weights)

            ## center loss weight
            if train_opt['center_weight'] > 0:
                self.l_center_w = train_opt['center_weight']
            else:
                self.l_center_w = 0

            # G & D gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']

            # optimizers
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], betas=(0.5, 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], betas=(0.5, 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_policy'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer,
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('Unsupported learning scheme: {}'.format(train_opt['lr_policy']))

            self.log_dict = OrderedDict()
            # print network
            self.print_network()
예제 #10
0
def main(args):
    opt = BaseOptions.parse(args)
    make_dir(opt.checkpoints_dir)
    BaseOptions.print_options(opt)

    torch.backends.cudnn.benchmark = True
    device = torch.device(
        "cuda:{}".format(opt.gpu_ids[0]) if opt.gpu_ids else "cpu")

    net_D = Discriminator(opt.input_nc, opt.conv_dim_d, opt.n_layers_d,
                          opt.use_sigmoid).to(device)
    net_G = Generator(opt.input_nc, opt.conv_dim_g, opt.n_blocks_g,
                      opt.use_bias).to(device)

    if opt.resume_iters:
        load_net(net_D, opt.resume_iters, "D", device)
        load_net(net_G, opt.resume_iters, "G", device)
    else:
        init_weights(net_D, opt.init_type)
        init_weights(net_G, opt.init_type)

    if len(opt.gpu_ids) > 1:
        net_D = torch.nn.DataParallel(net_D, device_ids=opt.gpu_ids)
        net_G = torch.nn.DataParallel(net_G, device_ids=opt.gpu_ids)

    print_network(net_D, "net_D")
    print_network(net_G, "net_G")

    optimizer_D = torch.optim.Adam(net_D.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_G = torch.optim.Adam(net_G.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizers = [optimizer_D, optimizer_G]

    schedulers = [get_scheduler(optimizer, opt) for optimizer in optimizers]

    criterionGAN = GANLoss(no_lsgan=True).to(device)
    criterionL1 = torch.nn.L1Loss()

    dataset = Dataset(opt.json_file, opt.aligned)
    print("#training images = {}".format(len(dataset)))
    data_loader = get_loader(dataset, opt.batch_size, True, opt.workers)

    data_time = 0.0
    total_time = 0.0
    data_iter = iter(data_loader)
    logger = Logger(opt.checkpoints_dir)
    for curr_iters in range(opt.start_iters,
                            opt.start_iters + opt.train_iters):
        start_time = time.time()

        try:
            real_A, real_B = next(data_iter)
        except Exception:
            data_iter = iter(data_loader)
            real_A, real_B = next(data_iter)

        real_A = real_A.to(device)
        real_B = real_B.to(device)

        data_time += time.time() - start_time

        fake_B = net_G(real_A)

        # update D
        set_requires_grad(net_D, True)
        optimizer_D.zero_grad()

        pred_fake = net_D(fake_B.detach())
        loss_D_fake = criterionGAN(pred_fake, False)

        pred_real = net_D(real_B)
        loss_D_real = criterionGAN(pred_real, True)

        loss_D = loss_D_fake + loss_D_real
        loss_D.backward()
        optimizer_D.step()

        # update G
        set_requires_grad(net_D, False)
        optimizer_G.zero_grad()

        pred_fake = net_D(fake_B)
        loss_G_GAN = criterionGAN(pred_fake, True)

        loss_G_L1 = criterionL1(fake_B, real_B)

        loss_G = loss_G_GAN + loss_G_L1 * 100
        loss_G.backward()
        optimizer_G.step()

        total_time += time.time() - start_time

        logger.add(loss_D_fake=loss_D_fake.mean().item(),
                   loss_D_real=loss_D_real.mean().item(),
                   loss_G_GAN=loss_G_GAN.mean().item(),
                   loss_G_L1=loss_G_L1.mean().item())

        for scheduler in schedulers:
            scheduler.step()

        if curr_iters % opt.model_save == 0:
            print("saving the model: iters = {}, lr = {}".format(
                curr_iters, optimizers[0].param_groups[0]["lr"]))
            save_net(net_D, curr_iters, "D", opt)
            save_net(net_G, curr_iters, "G", opt)

        if curr_iters % opt.display_freq == 0:
            print("#iters[{}]: data time {}, total time {}".format(
                curr_iters, data_time, total_time))
            data_time = 0.0
            total_time = 0.0
            logger.save(curr_iters)
예제 #11
0
    def __init__(self, opt, dataset=None):
        super(SRGANModel, self).__init__(opt)

        if dataset:
            self.cri_text = True

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    pass  # do not need to use DistributedDataParallel for netF
                else:
                    self.netF = DataParallel(self.netF)
            if self.cri_text:
                from lib.models.model_builder import ModelBuilder
                self.netT = ModelBuilder(
                    arch="ResNet_ASTER",
                    rec_num_classes=dataset.rec_num_classes,
                    sDim=512,
                    attDim=512,
                    max_len_labels=100,
                    eos=dataset.char2id[dataset.EOS],
                    STN_ON=True).to(self.device)

                self.netT = DataParallel(self.netT)
                self.netT.eval()
                from lib.util.serialization import load_checkpoint
                checkpoint = load_checkpoint(train_opt['text_model'])
                self.netT.load_state_dict(checkpoint['state_dict'])

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed