예제 #1
0
 def __init__(self, opt):
     super(SPADEModel, self).__init__(opt)
     self.model_names = ['G']
     self.visual_names = ['labels', 'fake_B', 'real_B']
     self.modules = SPADEModelModules(opt).to(self.device)
     if len(opt.gpu_ids) > 0:
         self.modules = DataParallelWithCallback(self.modules,
                                                 device_ids=opt.gpu_ids)
         self.modules_on_one_gpu = self.modules.module
     else:
         self.modules_on_one_gpu = self.modules
     if opt.isTrain:
         self.model_names.append('D')
         self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake']
         self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers(
         )
         self.optimizers = [self.optimizer_G, self.optimizer_D]
         if not opt.no_fid:
             block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
             self.inception_model = InceptionV3([block_idx])
             self.inception_model.to(self.device)
             self.inception_model.eval()
         if 'cityscapes' in opt.dataroot and not opt.no_mIoU:
             self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
             util.load_network(self.drn_model, opt.drn_path, verbose=False)
             self.drn_model.to(self.device)
             self.drn_model.eval()
         self.eval_dataloader = create_eval_dataloader(self.opt)
         self.best_fid = 1e9
         self.best_mIoU = -1e9
         self.fids, self.mIoUs = [], []
         self.is_best = False
         self.npz = np.load(opt.real_stat_path)
     else:
         self.modules.eval()
예제 #2
0
    def __init__(self, opt):
        super(SPADEModel, self).__init__(opt)
        self.model_names = ['G_student', 'G_teacher', 'D']
        self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B']
        self.model_names.append('D')
        self.loss_names = [
            'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake'
        ]
        if hasattr(opt, 'distiller'):
            self.modules = SPADEDistillerModules(opt).to(self.device)
            if len(opt.gpu_ids) > 0:
                self.modules = DataParallelWithCallback(self.modules,
                                                        device_ids=opt.gpu_ids)
                self.modules_on_one_gpu = self.modules.module
            else:
                self.modules_on_one_gpu = self.modules
        for i in range(len(self.modules_on_one_gpu.mapping_layers)):
            self.loss_names.append('G_distill%d' % i)
        self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers(
        )
        self.optimizers = [self.optimizer_G, self.optimizer_D]
        if not opt.no_fid:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            self.inception_model = InceptionV3([block_idx])
            self.inception_model.to(self.device)
            self.inception_model.eval()
        if 'cityscapes' in opt.dataroot and not opt.no_mIoU:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            self.drn_model.to(self.device)
            self.drn_model.eval()
        self.eval_dataloader = create_eval_dataloader(self.opt)
        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.npz = np.load(opt.real_stat_path)

        model_profiling(self.modules_on_one_gpu.netG_teacher,
                        self.opt.data_height,
                        self.opt.data_width,
                        channel=self.opt.data_channel,
                        num_forwards=0,
                        verbose=False)
        model_profiling(self.modules_on_one_gpu.netG_student,
                        self.opt.data_height,
                        self.opt.data_width,
                        channel=self.opt.data_channel,
                        num_forwards=0,
                        verbose=False)
        print(
            f'Teacher FLOPs: {self.modules_on_one_gpu.netG_teacher.n_macs}, Student FLOPs: {self.modules_on_one_gpu.netG_student.n_macs}.'
        )
예제 #3
0
파일: test.py 프로젝트: 25thengineer/DMAD
def test_pix2pix_mIoU(model, opt):
    opt.phase = 'val'
    opt.num_threads = 0
    opt.batch_size = 1
    opt.serial_batches = True
    opt.no_flip = True
    opt.load_size = 256
    opt.display_id = -1
    dataset = create_dataset(opt)
    model.model_eval()

    result_dir = os.path.join(opt.checkpoints_dir, opt.name, 'test_results')
    util.mkdirs(result_dir)

    fake_B = {}
    names = []
    for i, data in enumerate(dataset):
        model.set_input(data)

        with torch.no_grad():
            model.forward()

        visuals = model.get_current_visuals()
        fake_B[data['A_paths'][0]] = visuals['fake_B']

        for path in range(len(model.image_paths)):
            short_path = ntpath.basename(model.image_paths[0][0])
            name = os.path.splitext(short_path)[0]
            if name not in names:
                names.append(name)
        util.save_images(visuals,
                         model.image_paths,
                         result_dir,
                         direction=opt.direction,
                         aspect_ratio=opt.aspect_ratio)

    drn_model = DRNSeg('drn_d_105', 19, pretrained=False).to(model.device)
    util.load_network(drn_model, opt.drn_path, verbose=False)
    drn_model.eval()

    mIoU = get_mIoU(list(fake_B.values()),
                    names,
                    drn_model,
                    model.device,
                    table_path=os.path.join(opt.dataroot, 'table.txt'),
                    data_dir=opt.dataroot,
                    batch_size=opt.batch_size,
                    num_workers=opt.num_threads)
    return mIoU
예제 #4
0
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.model_names = ['G', 'D']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.norm,
                                      opt.dropout_rate,
                                      opt.init_type,
                                      opt.init_gain,
                                      self.gpu_ids,
                                      opt=opt)

        self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                      opt.netD, opt.n_layers_D, opt.norm,
                                      opt.init_type, opt.init_gain,
                                      self.gpu_ids)

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)
        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.Tacts, self.Sacts = {}, {}
        self.npz = np.load(opt.real_stat_path)
예제 #5
0
class Pix2PixModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        assert is_train
        parser = super(Pix2PixModel, Pix2PixModel).modify_commandline_options(
            parser, is_train)
        parser.add_argument('--restore_G_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator')
        parser.add_argument('--restore_D_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator')
        parser.add_argument('--recon_loss_type',
                            type=str,
                            default='l1',
                            choices=['l1', 'l2', 'smooth_l1'],
                            help='the type of the reconstruction loss')
        parser.add_argument('--lambda_recon',
                            type=float,
                            default=100,
                            help='weight for reconstruction loss')
        parser.add_argument('--lambda_gan',
                            type=float,
                            default=1,
                            help='weight for gan loss')
        parser.add_argument(
            '--real_stat_path',
            type=str,
            required=True,
            help=
            'the path to load the groud-truth images information to compute FID.'
        )
        return parser

    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_gan', 'G_recon', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.model_names = ['G', 'D']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.norm,
                                      opt.dropout_rate,
                                      opt.init_type,
                                      opt.init_gain,
                                      self.gpu_ids,
                                      opt=opt)

        self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                      opt.netD, opt.n_layers_D, opt.norm,
                                      opt.init_type, opt.init_gain,
                                      self.gpu_ids)

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)
        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.Tacts, self.Sacts = {}, {}
        self.npz = np.load(opt.real_stat_path)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        fake_AB = torch.cat((self.real_A, self.fake_B), 1).detach()
        real_AB = torch.cat((self.real_A, self.real_B), 1).detach()
        pred_fake = self.netD(fake_AB)
        self.loss_D_fake = self.criterionGAN(pred_fake,
                                             False,
                                             for_discriminator=True)

        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real,
                                             True,
                                             for_discriminator=True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_gan = self.criterionGAN(
            pred_fake, True, for_discriminator=False) * self.opt.lambda_gan
        # Second, G(A) = B
        self.loss_G_recon = self.criterionRecon(
            self.fake_B, self.real_B) * self.opt.lambda_recon
        # combine loss and calculate gradients

        self.loss_G = self.loss_G_gan + self.loss_G_recon
        self.loss_G.backward()

    def optimize_parameters(self, steps):
        self.forward()  # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()  # set D's gradients to zero
        self.backward_D()  # calculate gradients for D
        self.optimizer_D.step()  # update D's weights
        # update G
        self.set_requires_grad(
            self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()  # set G's gradients to zero
        self.backward_G()  # calculate graidents for G
        self.optimizer_G.step()  # udpate G's weights

    def evaluate_model(self, step):
        self.is_best = False

        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.netG.eval()

        fakes, names = [], []
        cnt = 0
        for i, data_i in enumerate(tqdm(self.eval_dataloader)):
            self.set_input(data_i)
            self.test()
            fakes.append(self.fake_B.cpu())
            for j in range(len(self.image_paths)):
                short_path = ntpath.basename(self.image_paths[j])
                name = os.path.splitext(short_path)[0]
                names.append(name)
                if cnt < 10:
                    input_im = util.tensor2im(self.real_A[j])
                    real_im = util.tensor2im(self.real_B[j])
                    fake_im = util.tensor2im(self.fake_B[j])
                    util.save_image(input_im,
                                    os.path.join(save_dir, 'input',
                                                 '%s.png' % name),
                                    create_dir=True)
                    util.save_image(real_im,
                                    os.path.join(save_dir, 'real',
                                                 '%s.png' % name),
                                    create_dir=True)
                    util.save_image(fake_im,
                                    os.path.join(save_dir, 'fake',
                                                 '%s.png' % name),
                                    create_dir=True)
                cnt += 1

        fid = get_fid(fakes,
                      self.inception_model,
                      self.npz,
                      device=self.device,
                      batch_size=self.opt.eval_batch_size)
        if fid < self.best_fid:
            self.is_best = True
            self.best_fid = fid
        self.fids.append(fid)
        if len(self.fids) > 3:
            self.fids.pop(0)

        ret = {
            'metric/fid': fid,
            'metric/fid-mean': sum(self.fids) / len(self.fids),
            'metric/fid-best': self.best_fid
        }
        if 'cityscapes' in self.opt.dataroot:
            mIoU = get_mIoU(fakes,
                            names,
                            self.drn_model,
                            self.device,
                            table_path=self.opt.table_path,
                            data_dir=self.opt.cityscapes_path,
                            batch_size=self.opt.eval_batch_size,
                            num_workers=self.opt.num_threads)
            if mIoU > self.best_mIoU:
                self.is_best = True
                self.best_mIoU = mIoU
            self.mIoUs.append(mIoU)
            if len(self.mIoUs) > 3:
                self.mIoUs = self.mIoUs[1:]
            ret['metric/mIoU'] = mIoU
            ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs)
            ret['metric/mIoU-best'] = self.best_mIoU

        self.netG.train()
        return ret
예제 #6
0
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        assert is_train
        parser = super(CycleGANModel,
                       CycleGANModel).modify_commandline_options(
                           parser, is_train)
        parser.add_argument('--restore_G_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_A')
        parser.add_argument('--restore_D_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_A')
        parser.add_argument('--restore_G_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_B')
        parser.add_argument('--restore_D_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_B')
        parser.add_argument('--lambda_A',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (A -> B -> A)')
        parser.add_argument('--lambda_B',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (B -> A -> B)')
        parser.add_argument(
            '--lambda_identity',
            type=float,
            default=0.5,
            help='use identity mapping. '
            'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. '
            'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
        )
        parser.add_argument(
            '--real_stat_A_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth A images information to compute FID.'
        )
        parser.add_argument(
            '--real_stat_B_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth B images information to compute FID.'
        )
        parser.set_defaults(norm='instance',
                            dataset_mode='unaligned',
                            batch_size=1,
                            ndf=64,
                            gan_mode='lsgan',
                            nepochs=100,
                            nepochs_decay=100,
                            save_epoch_freq=20)
        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images

        # define loss functions
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(
            self.device)  # define GAN loss.
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        # Since it is a cycle.
        self.real_A = input['A'].to(self.device)
        self.real_B = input['B'].to(self.device)

    def set_single_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_G_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_G_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_G_idt_A = 0
            self.loss_G_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_G_cycle_A = self.criterionCycle(self.rec_A,
                                                  self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_G_cycle_B = self.criterionCycle(self.rec_B,
                                                  self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
            [self.netD_A, self.netD_B],
            False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
        self.backward_D_B()  # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights

    def test_single_side(self, direction):
        generator = getattr(self, 'netG_%s' % direction[0])
        with torch.no_grad():
            self.fake_B = generator(self.real_A)

    def evaluate_model(self, step):
        ret = {}
        self.is_best = False
        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.netG_A.eval()
        self.netG_B.eval()
        for direction in ['AtoB', 'BtoA']:
            eval_dataloader = getattr(self, 'eval_dataloader_' + direction)
            fakes, names = [], []
            cnt = 0
            # print(len(eval_dataset))
            for i, data_i in enumerate(tqdm(eval_dataloader)):
                self.set_single_input(data_i)
                self.test_single_side(direction)
                # print(self.image_paths)
                fakes.append(self.fake_B.cpu())
                for j in range(len(self.image_paths)):
                    short_path = ntpath.basename(self.image_paths[j])
                    name = os.path.splitext(short_path)[0]
                    names.append(name)
                    if cnt < 10:
                        input_im = util.tensor2im(self.real_A[j])
                        fake_im = util.tensor2im(self.fake_B[j])
                        util.save_image(input_im,
                                        os.path.join(save_dir, direction,
                                                     'input', '%s.png' % name),
                                        create_dir=True)
                        util.save_image(fake_im,
                                        os.path.join(save_dir, direction,
                                                     'fake', '%s.png' % name),
                                        create_dir=True)
                    cnt += 1

            suffix = direction[-1]
            fid = get_fid(fakes,
                          self.inception_model,
                          getattr(self, 'npz_%s' % direction[-1]),
                          device=self.device,
                          batch_size=self.opt.eval_batch_size)
            if fid < getattr(self, 'best_fid_%s' % suffix):
                self.is_best = True
                setattr(self, 'best_fid_%s' % suffix, fid)
            fids = getattr(self, 'fids_%s' % suffix)
            fids.append(fid)
            if len(fids) > 3:
                fids.pop(0)
            ret['metric/fid_%s' % suffix] = fid
            ret['metric/fid_%s-mean' %
                suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len(
                    getattr(self, 'fids_%s' % suffix))
            ret['metric/fid_%s-best' % suffix] = getattr(
                self, 'best_fid_%s' % suffix)

            if 'cityscapes' in self.opt.dataroot and direction == 'BtoA':
                mIoU = get_mIoU(fakes,
                                names,
                                self.drn_model,
                                self.device,
                                table_path=self.opt.table_path,
                                data_dir=self.opt.cityscapes_path,
                                batch_size=self.opt.eval_batch_size,
                                num_workers=self.opt.num_threads)
                if mIoU > self.best_mIoU:
                    self.is_best = True
                    self.best_mIoU = mIoU
                self.mIoUs.append(mIoU)
                if len(self.mIoUs) > 3:
                    self.mIoUs = self.mIoUs[1:]
                ret['metric/mIoU'] = mIoU
                ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs)
                ret['metric/mIoU-best'] = self.best_mIoU

        self.netG_A.train()
        self.netG_B.train()
        return ret
예제 #7
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images

        # define loss functions
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(
            self.device)  # define GAN loss.
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)
예제 #8
0
        raise NotImplementedError
    configs = get_configs(config_name=opt.config_set)
    configs = list(configs.all_configs())

    dataloader = create_dataloader(opt)
    model = create_model(opt)
    model.setup(opt)
    device = model.device

    if not opt.no_fid:
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        inception_model = InceptionV3([block_idx])
        inception_model.to(device)
        inception_model.eval()
    if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA':
        drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
        util.load_network(drn_model, opt.drn_path, verbose=False)
        if len(opt.gpu_ids) > 0:
            drn_model = nn.DataParallel(drn_model, opt.gpu_ids)
        drn_model.eval()

    npz = np.load(opt.real_stat_path)
    results = []

    for data_i in dataloader:
        model.set_input(data_i)
        break

    for config in tqdm.tqdm(configs):
        qualified = True
        macs, _ = model.profile(config)
예제 #9
0
class SPADEModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        assert isinstance(parser, argparse.ArgumentParser)
        parser.set_defaults(netG='inception_spade')
        parser.add_argument(
            '--norm_G',
            type=str,
            default='spadesyncbatch3x3',
            help='instance normalization or batch normalization')
        parser.add_argument(
            '--num_upsampling_layers',
            choices=('normal', 'more', 'most'),
            default='more',
            help=
            "If 'more', adds upsampling layer between the two middle resnet blocks. "
            "If 'most', also add one more upsampling + resnet layer at the end of the generator"
        )
        if is_train:
            parser.add_argument('--restore_G_path',
                                type=str,
                                default=None,
                                help='the path to restore the generator')
            parser.add_argument('--restore_D_path',
                                type=str,
                                default=None,
                                help='the path to restore the discriminator')
            parser.add_argument(
                '--real_stat_path',
                type=str,
                required=True,
                help=
                'the path to load the groud-truth images information to compute FID.'
            )
            parser.add_argument('--lambda_gan',
                                type=float,
                                default=1,
                                help='weight for gan loss')
            parser.add_argument('--lambda_feat',
                                type=float,
                                default=10,
                                help='weight for gan feature loss')
            parser.add_argument('--lambda_vgg',
                                type=float,
                                default=10,
                                help='weight for vgg loss')
            parser.add_argument('--beta2',
                                type=float,
                                default=0.999,
                                help='momentum term of adam')
            parser.add_argument('--no_TTUR',
                                action='store_true',
                                help='Use TTUR training scheme')
            parser.add_argument('--no_fid',
                                action='store_true',
                                help='No FID evaluation during training')
            parser.add_argument('--no_mIoU',
                                action='store_true',
                                help='No mIoU evaluation during training '
                                '(sometimes because there are CUDA memory)')
            parser.set_defaults(netD='multi_scale',
                                ndf=64,
                                dataset_mode='cityscapes',
                                batch_size=16,
                                print_freq=50,
                                save_latest_freq=10000000000,
                                save_epoch_freq=10,
                                nepochs=100,
                                nepochs_decay=100,
                                init_type='xavier',
                                active_fn='nn.LeakyReLU')
        parser = networks.modify_commandline_options(parser, is_train)
        return parser

    def __init__(self, opt):
        super(SPADEModel, self).__init__(opt)
        self.model_names = ['G']
        self.visual_names = ['labels', 'fake_B', 'real_B']
        self.modules = SPADEModelModules(opt).to(self.device)
        if len(opt.gpu_ids) > 0:
            self.modules = DataParallelWithCallback(self.modules,
                                                    device_ids=opt.gpu_ids)
            self.modules_on_one_gpu = self.modules.module
        else:
            self.modules_on_one_gpu = self.modules
        if opt.isTrain:
            self.model_names.append('D')
            self.loss_names = ['G_gan', 'G_feat', 'G_vgg', 'D_real', 'D_fake']
            self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers(
            )
            self.optimizers = [self.optimizer_G, self.optimizer_D]
            if not opt.no_fid:
                block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
                self.inception_model = InceptionV3([block_idx])
                self.inception_model.to(self.device)
                self.inception_model.eval()
            if 'cityscapes' in opt.dataroot and not opt.no_mIoU:
                self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
                util.load_network(self.drn_model, opt.drn_path, verbose=False)
                self.drn_model.to(self.device)
                self.drn_model.eval()
            self.eval_dataloader = create_eval_dataloader(self.opt)
            self.best_fid = 1e9
            self.best_mIoU = -1e9
            self.fids, self.mIoUs = [], []
            self.is_best = False
            self.npz = np.load(opt.real_stat_path)
        else:
            self.modules.eval()

    def set_input(self, input):
        self.data = input
        self.image_paths = input['path']
        self.labels = input['label'].to(self.device)
        self.input_semantics, self.real_B = self.preprocess_input(input)

    def test(self):
        with torch.no_grad():
            self.forward(on_one_gpu=True)

    def preprocess_input(self, data):
        data['label'] = data['label'].long()
        data['label'] = data['label'].to(self.device)
        data['instance'] = data['instance'].to(self.device)
        data['image'] = data['image'].to(self.device)

        label_map = data['label']
        bs, _, h, w = label_map.size()
        nc = self.opt.input_nc + 1 if self.opt.contain_dontcare_label \
            else self.opt.input_nc
        input_label = torch.zeros([bs, nc, h, w], device=self.device)
        input_semantics = input_label.scatter_(1, label_map, 1.0)

        if not self.opt.no_instance:
            inst_map = data['instance']
            instance_edge_map = self.get_edges(inst_map)
            input_semantics = torch.cat((input_semantics, instance_edge_map),
                                        dim=1)

        return input_semantics, data['image']

    def forward(self, on_one_gpu=False):
        if on_one_gpu:
            self.fake_B = self.modules_on_one_gpu(self.input_semantics)
        else:
            self.fake_B = self.modules(self.input_semantics)

    def get_edges(self, t):
        edge = torch.zeros(t.size(), dtype=torch.uint8, device=self.device)
        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (
            (t[:, :, :, 1:] != t[:, :, :, :-1]).byte())
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (
            (t[:, :, :, 1:] != t[:, :, :, :-1]).byte())
        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (
            (t[:, :, 1:, :] != t[:, :, :-1, :]).byte())
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (
            (t[:, :, 1:, :] != t[:, :, :-1, :]).byte())
        return edge.float()

    def profile(self, verbose=True):
        macs, params = self.modules_on_one_gpu.profile(
            self.input_semantics[:1])
        if verbose:
            print('MACs: %.3fG\tParams: %.3fM' % (macs / 1e9, params / 1e6),
                  flush=True)
        return macs, params

    def backward_G(self):
        losses = self.modules(self.input_semantics, self.real_B, mode='G_loss')
        loss_G = losses['loss_G'].mean()
        for loss_name in self.loss_names:
            if loss_name.startswith('G'):
                setattr(self, 'loss_%s' % loss_name,
                        losses[loss_name].detach().mean())
        loss_G.backward()

    def backward_D(self):
        losses = self.modules(self.input_semantics, self.real_B, mode='D_loss')
        loss_D = losses['loss_D'].mean()
        for loss_name in self.loss_names:
            if loss_name.startswith('D'):
                setattr(self, 'loss_%s' % loss_name,
                        losses[loss_name].detach().mean())
        loss_D.backward()

    def optimize_parameters(self, steps):
        self.set_requires_grad(self.modules_on_one_gpu.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        self.set_requires_grad(self.modules_on_one_gpu.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

    def evaluate_model(self, step, save_image=False):
        self.is_best = False
        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.modules_on_one_gpu.netG.eval()
        torch.cuda.empty_cache()
        fakes, names = [], []
        ret = {}
        cnt = 0
        for i, data_i in enumerate(tqdm(self.eval_dataloader)):
            self.set_input(data_i)
            self.test()
            fakes.append(self.fake_B.cpu())
            for j in range(len(self.image_paths)):
                short_path = ntpath.basename(self.image_paths[j])
                name = os.path.splitext(short_path)[0]
                names.append(name)
                if cnt < 10 or save_image:
                    input_im = util.tensor2label(self.input_semantics[j],
                                                 self.opt.input_nc + 2)
                    real_im = util.tensor2im(self.real_B[j])
                    fake_im = util.tensor2im(self.fake_B[j])
                    util.save_image(input_im,
                                    os.path.join(save_dir, 'input',
                                                 '%s.png' % name),
                                    create_dir=True)
                    util.save_image(real_im,
                                    os.path.join(save_dir, 'real',
                                                 '%s.png' % name),
                                    create_dir=True)
                    util.save_image(fake_im,
                                    os.path.join(save_dir, 'fake',
                                                 '%s.png' % name),
                                    create_dir=True)
                cnt += 1
        if not self.opt.no_fid:
            fid = get_fid(fakes,
                          self.inception_model,
                          self.npz,
                          device=self.device,
                          batch_size=self.opt.eval_batch_size)
            if fid < self.best_fid:
                self.is_best = True
                self.best_fid = fid
            self.fids.append(fid)
            if len(self.fids) > 3:
                self.fids.pop(0)
            ret['metric/fid'] = fid
            ret['metric/fid-mean'] = sum(self.fids) / len(self.fids)
            ret['metric/fid-best'] = self.best_fid
        if 'cityscapes' in self.opt.dataroot and not self.opt.no_mIoU:
            mIoU = get_mIoU(fakes,
                            names,
                            self.drn_model,
                            self.device,
                            table_path=self.opt.table_path,
                            data_dir=self.opt.cityscapes_path,
                            batch_size=self.opt.eval_batch_size,
                            num_workers=self.opt.num_threads)
            if mIoU > self.best_mIoU:
                self.is_best = True
                self.best_mIoU = mIoU
            self.mIoUs.append(mIoU)
            if len(self.mIoUs) > 3:
                self.mIoUs = self.mIoUs[1:]
            ret['metric/mIoU'] = mIoU
            ret['metric/mIoU-mean'] = sum(self.mIoUs) / len(self.mIoUs)
            ret['metric/mIoU-best'] = self.best_mIoU

        self.modules_on_one_gpu.netG.train()
        torch.cuda.empty_cache()
        return ret

    def print_networks(self):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self.modules_on_one_gpu, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                print(net)
                print('[Network %s] Total number of parameters : %.3f M' %
                      (name, num_params / 1e6))
                if hasattr(self.opt, 'log_dir'):
                    with open(
                            os.path.join(self.opt.log_dir,
                                         'net' + name + '.txt'), 'w') as f:
                        f.write(str(net) + '\n')
                        f.write(
                            '[Network %s] Total number of parameters : %.3f M\n'
                            % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def load_networks(self,
                      verbose=True,
                      teacher_only=False,
                      restore_pretrain=True):
        self.modules_on_one_gpu.load_networks(verbose)
        if self.isTrain and self.opt.restore_O_path is not None:
            for i, optimizer in enumerate(self.optimizers):
                path = '%s-%d.pth' % (self.opt.restore_O_path, i)
                util.load_optimizer(optimizer, path, verbose)
            if self.opt.no_TTUR:
                G_lr, D_lr = self.opt.lr, self.opt.lr
            else:
                G_lr, D_lr = self.opt.lr / 2, self.opt.lr * 2
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = G_lr
            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = D_lr

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str) and hasattr(self, name):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def save_networks(self, epoch):
        self.modules_on_one_gpu.save_networks(epoch, self.save_dir)
        for i, optimizer in enumerate(self.optimizers):
            save_filename = '%s_optim-%d.pth' % (epoch, i)
            save_path = os.path.join(self.save_dir, save_filename)
            torch.save(optimizer.state_dict(), save_path)
예제 #10
0
    def __init__(self, opt):
        assert opt.isTrain
        super(BaseResnetDistiller, self).__init__(opt)
        self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        self.netG_teacher = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf,
                                              opt.teacher_netG, opt.norm, opt.teacher_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
        self.netG_student = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                              opt.student_netG, opt.norm, opt.student_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None:
            self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                                      opt.student_netG.replace('super_', ''), opt.norm,
                                                      opt.student_dropout_rate, opt.init_type, opt.init_gain,
                                                      self.gpu_ids, opt=opt)
        if hasattr(opt, 'distiller'):
            self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf,
                                                     opt.pretrained_netG, opt.norm, 0,
                                                     opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if opt.dataset_mode == 'aligned':
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        elif opt.dataset_mode == 'unaligned':
            self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode)

        self.netG_teacher.eval()
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss(self.device)
        else:
            raise NotImplementedError('Unknown reconstruction loss type [%s]!' % opt.loss_type)

        if isinstance(self.netG_teacher, nn.DataParallel):
            self.mapping_layers = ['module.model.%d' % i for i in range(9, 21, 3)]
        else:
            self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]

        self.netAs = []
        self.Tacts, self.Sacts = {}, {}

        G_params = [self.netG_student.parameters()]
        for i, n in enumerate(self.mapping_layers):
            ft, fs = self.opt.teacher_ngf, self.opt.student_ngf
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            else:
                netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            networks.init_net(netA)
            G_params.append(netA.parameters())
            self.netAs.append(netA)
            self.loss_names.append('G_distill%d' % i)

        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.npz = np.load(opt.real_stat_path)
        self.is_best = False
예제 #11
0
class BaseResnetDistiller(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        assert is_train
        parser = super(BaseResnetDistiller, BaseResnetDistiller).modify_commandline_options(parser, is_train)
        parser.add_argument('--teacher_netG', type=str, default='mobile_resnet_9blocks',
                            help='specify teacher generator architecture',
                            choices=['resnet_9blocks', 'mobile_resnet_9blocks',
                                     'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks'])
        parser.add_argument('--student_netG', type=str, default='mobile_resnet_9blocks',
                            help='specify student generator architecture',
                            choices=['resnet_9blocks', 'mobile_resnet_9blocks',
                                     'super_mobile_resnet_9blocks', 'sub_mobile_resnet_9blocks'])
        parser.add_argument('--teacher_ngf', type=int, default=64,
                            help='the base number of filters of the teacher generator')
        parser.add_argument('--student_ngf', type=int, default=48,
                            help='the base number of filters of the student generator')
        parser.add_argument('--restore_teacher_G_path', type=str, required=True,
                            help='the path to restore the teacher generator')
        parser.add_argument('--restore_student_G_path', type=str, default=None,
                            help='the path to restore the student generator')
        parser.add_argument('--restore_A_path', type=str, default=None,
                            help='the path to restore the adaptors for distillation')
        parser.add_argument('--restore_D_path', type=str, default=None,
                            help='the path to restore the discriminator')
        parser.add_argument('--restore_O_path', type=str, default=None,
                            help='the path to restore the optimizer')
        parser.add_argument('--recon_loss_type', type=str, default='l1',
                            choices=['l1', 'l2', 'smooth_l1', 'vgg'],
                            help='the type of the reconstruction loss')
        parser.add_argument('--lambda_distill', type=float, default=1,
                            help='weights for the intermediate activation distillation loss')
        parser.add_argument('--lambda_recon', type=float, default=100,
                            help='weights for the reconstruction loss.')
        parser.add_argument('--lambda_gan', type=float, default=1,
                            help='weight for gan loss')
        parser.add_argument('--teacher_dropout_rate', type=float, default=0)
        parser.add_argument('--student_dropout_rate', type=float, default=0)
        return parser

    def __init__(self, opt):
        assert opt.isTrain
        super(BaseResnetDistiller, self).__init__(opt)
        self.loss_names = ['G_gan', 'G_distill', 'G_recon', 'D_fake', 'D_real']
        self.optimizers = []
        self.image_paths = []
        self.visual_names = ['real_A', 'Sfake_B', 'Tfake_B', 'real_B']
        self.model_names = ['netG_student', 'netG_teacher', 'netD']
        self.netG_teacher = networks.define_G(opt.input_nc, opt.output_nc, opt.teacher_ngf,
                                              opt.teacher_netG, opt.norm, opt.teacher_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
        self.netG_student = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                              opt.student_netG, opt.norm, opt.student_dropout_rate,
                                              opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None:
            self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
                                                      opt.student_netG.replace('super_', ''), opt.norm,
                                                      opt.student_dropout_rate, opt.init_type, opt.init_gain,
                                                      self.gpu_ids, opt=opt)
        if hasattr(opt, 'distiller'):
            self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf,
                                                     opt.pretrained_netG, opt.norm, 0,
                                                     opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

        if opt.dataset_mode == 'aligned':
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        elif opt.dataset_mode == 'unaligned':
            self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            raise NotImplementedError('Unknown dataset mode [%s]!!!' % opt.dataset_mode)

        self.netG_teacher.eval()
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        elif opt.recon_loss_type == 'vgg':
            self.criterionRecon = models.modules.loss.VGGLoss(self.device)
        else:
            raise NotImplementedError('Unknown reconstruction loss type [%s]!' % opt.loss_type)

        if isinstance(self.netG_teacher, nn.DataParallel):
            self.mapping_layers = ['module.model.%d' % i for i in range(9, 21, 3)]
        else:
            self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]

        self.netAs = []
        self.Tacts, self.Sacts = {}, {}

        G_params = [self.netG_student.parameters()]
        for i, n in enumerate(self.mapping_layers):
            ft, fs = self.opt.teacher_ngf, self.opt.student_ngf
            if hasattr(opt, 'distiller'):
                netA = nn.Conv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            else:
                netA = SuperConv2d(in_channels=fs * 4, out_channels=ft * 4, kernel_size=1). \
                    to(self.device)
            networks.init_net(netA)
            G_params.append(netA.parameters())
            self.netAs.append(netA)
            self.loss_names.append('G_distill%d' % i)

        self.optimizer_G = torch.optim.Adam(itertools.chain(*G_params), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt, direction=opt.direction)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.npz = np.load(opt.real_stat_path)
        self.is_best = False

    def setup(self, opt, verbose=True):
        self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        self.load_networks(verbose)
        if verbose:
            self.print_networks()
        if self.opt.lambda_distill > 0:
            def get_activation(mem, name):
                def get_output_hook(module, input, output):
                    mem[name] = output

                return get_output_hook

            def add_hook(net, mem, mapping_layers):
                for n, m in net.named_modules():
                    if n in mapping_layers:
                        m.register_forward_hook(get_activation(mem, n))

            add_hook(self.netG_teacher, self.Tacts, self.mapping_layers)
            add_hook(self.netG_student, self.Sacts, self.mapping_layers)

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def set_single_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']

    def forward(self):
        raise NotImplementedError

    def backward_D(self):
        if self.opt.dataset_mode == 'aligned':
            fake = torch.cat((self.real_A, self.Sfake_B), 1).detach()
            real = torch.cat((self.real_A, self.real_B), 1).detach()
        else:
            fake = self.Sfake_B.detach()
            real = self.real_B.detach()

        pred_fake = self.netD(fake)
        self.loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True)

        pred_real = self.netD(real)
        self.loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def calc_distill_loss(self):
        raise NotImplementedError

    def backward_G(self):
        raise NotImplementedError

    def optimize_parameters(self, steps):
        raise NotImplementedError

    def print_networks(self):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if hasattr(self, name):
                net = getattr(self, name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
                with open(os.path.join(self.opt.log_dir, name + '.txt'), 'w') as f:
                    f.write(str(net) + '\n')
                    f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def load_networks(self, verbose=True):
        util.load_network(self.netG_teacher, self.opt.restore_teacher_G_path, verbose)
        if self.opt.restore_student_G_path is not None:
            util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose)
            if hasattr(self, 'netG_student_tmp'):
                util.load_network(self.netG_student_tmp, self.opt.restore_student_G_path, verbose)
        if self.opt.restore_D_path is not None:
            util.load_network(self.netD, self.opt.restore_D_path, verbose)
        if self.opt.restore_A_path is not None:
            for i, netA in enumerate(self.netAs):
                path = '%s-%d.pth' % (self.opt.restore_A_path, i)
                util.load_network(netA, path, verbose)
        if self.opt.restore_O_path is not None:
            for i, optimizer in enumerate(self.optimizers):
                path = '%s-%d.pth' % (self.opt.restore_O_path, i)
                util.load_optimizer(optimizer, path, verbose)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.opt.lr

    def save_networks(self, epoch):

        def save_net(net, save_path):
            if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                if isinstance(net, DataParallel):
                    torch.save(net.module.cpu().state_dict(), save_path)
                else:
                    torch.save(net.cpu().state_dict(), save_path)
                net.cuda(self.gpu_ids[0])
            else:
                torch.save(net.cpu().state_dict(), save_path)

        save_filename = '%s_net_%s.pth' % (epoch, 'G')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s_student' % 'G')
        save_net(net, save_path)

        save_filename = '%s_net_%s.pth' % (epoch, 'D')
        save_path = os.path.join(self.save_dir, save_filename)
        net = getattr(self, 'net%s' % 'D')
        save_net(net, save_path)

        for i, net in enumerate(self.netAs):
            save_filename = '%s_net_%s-%d.pth' % (epoch, 'A', i)
            save_path = os.path.join(self.save_dir, save_filename)
            save_net(net, save_path)

        for i, optimizer in enumerate(self.optimizers):
            save_filename = '%s_optim-%d.pth' % (epoch, i)
            save_path = os.path.join(self.save_dir, save_filename)
            torch.save(optimizer.state_dict(), save_path)

    def evaluate_model(self, step):
        raise NotImplementedError

    def test(self):
        with torch.no_grad():
            self.forward()
예제 #12
0
파일: pix2pix_model.py 프로젝트: deJQK/CAT
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        BaseModel.__init__(self, opt)
        self.loss_names = [
            'G_gan', 'G_recon', 'D_real', 'D_fake', 'G_comp_cost'
        ]
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        self.model_names = ['G', 'D']
        self.netG = networks.define_G(opt.input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.norm,
                                      opt.dropout_rate,
                                      opt.init_type,
                                      opt.init_gain,
                                      self.gpu_ids,
                                      opt=opt)

        self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                      opt.ndf,
                                      opt.netD,
                                      opt.n_layers_D,
                                      opt.norm,
                                      opt.init_type,
                                      opt.init_gain,
                                      self.gpu_ids,
                                      opt=opt)

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        if opt.recon_loss_type == 'l1':
            self.criterionRecon = torch.nn.L1Loss()
        elif opt.recon_loss_type == 'l2':
            self.criterionRecon = torch.nn.MSELoss()
        elif opt.recon_loss_type == 'smooth_l1':
            self.criterionRecon = torch.nn.SmoothL1Loss()
        else:
            raise NotImplementedError(
                'Unknown reconstruction loss type [%s]!' % opt.loss_type)
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader = create_eval_dataloader(self.opt)

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model.to(self.device)
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.Tacts, self.Sacts = {}, {}
        self.npz = np.load(opt.real_stat_path)
예제 #13
0
class BaseSPADEDistiller(SPADEModel):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        assert isinstance(parser, argparse.ArgumentParser)
        parser.add_argument(
            '--num_upsampling_layers',
            choices=('normal', 'more', 'most'),
            default='more',
            help=
            "If 'more', adds upsampling layer between the two middle resnet blocks. "
            "If 'most', also add one more upsampling + resnet layer at the end of the generator"
        )
        parser.add_argument('--teacher_netG',
                            type=str,
                            default='inception_spade',
                            help='specify teacher generator architecture',
                            choices=['inception_spade'])
        parser.add_argument('--student_netG',
                            type=str,
                            default='inception_spade',
                            help='specify student generator architecture',
                            choices=['inception_spade'])
        parser.add_argument(
            '--teacher_ngf',
            type=int,
            default=64,
            help='the base number of filters of the teacher generator')
        parser.add_argument(
            '--student_ngf',
            type=int,
            default=48,
            help='the base number of filters of the student generator')
        parser.add_argument(
            '--teacher_norm_G',
            type=str,
            default='spadesyncbatch3x3',
            help=
            'instance normalization or batch normalization of the teacher model'
        )
        parser.add_argument(
            '--student_norm_G',
            type=str,
            default='spadesyncbatch3x3',
            help=
            'instance normalization or batch normalization of the student model'
        )
        parser.add_argument('--restore_teacher_G_path',
                            type=str,
                            required=True,
                            help='the path to restore the teacher generator')
        parser.add_argument('--restore_student_G_path',
                            type=str,
                            default=None,
                            help='the path to restore the student generator')
        parser.add_argument(
            '--restore_A_path',
            type=str,
            default=None,
            help='the path to restore the adaptors for distillation')
        parser.add_argument('--restore_D_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator')
        parser.add_argument('--restore_O_path',
                            type=str,
                            default=None,
                            help='the path to restore the optimizer')
        parser.add_argument('--lambda_gan',
                            type=float,
                            default=1,
                            help='weight for gan loss')
        parser.add_argument('--lambda_feat',
                            type=float,
                            default=10,
                            help='weight for gan feature loss')
        parser.add_argument('--lambda_vgg',
                            type=float,
                            default=10,
                            help='weight for vgg loss')
        parser.add_argument('--lambda_distill',
                            type=float,
                            default=10,
                            help='weight for vgg loss')
        parser.add_argument('--distill_G_loss_type',
                            type=str,
                            default='mse',
                            choices=['mse', 'ka'],
                            help='the type of the G distillation loss')
        parser.add_argument('--beta2',
                            type=float,
                            default=0.999,
                            help='momentum term of adam')
        parser.add_argument('--no_TTUR',
                            action='store_true',
                            help='Use TTUR training scheme')
        parser.add_argument('--no_fid',
                            action='store_true',
                            help='No FID evaluation during training')
        parser.add_argument('--no_mIoU',
                            action='store_true',
                            help='No mIoU evaluation during training '
                            '(sometimes because there are CUDA memory)')
        parser.set_defaults(netD='multi_scale',
                            ndf=64,
                            dataset_mode='cityscapes',
                            batch_size=16,
                            print_freq=50,
                            save_latest_freq=10000000000,
                            save_epoch_freq=10,
                            nepochs=100,
                            nepochs_decay=100,
                            init_type='xavier')
        return parser

    def __init__(self, opt):
        super(SPADEModel, self).__init__(opt)
        self.model_names = ['G_student', 'G_teacher', 'D']
        self.visual_names = ['labels', 'Tfake_B', 'Sfake_B', 'real_B']
        self.model_names.append('D')
        self.loss_names = [
            'G_gan', 'G_feat', 'G_vgg', 'G_distill', 'D_real', 'D_fake'
        ]
        if hasattr(opt, 'distiller'):
            self.modules = SPADEDistillerModules(opt).to(self.device)
            if len(opt.gpu_ids) > 0:
                self.modules = DataParallelWithCallback(self.modules,
                                                        device_ids=opt.gpu_ids)
                self.modules_on_one_gpu = self.modules.module
            else:
                self.modules_on_one_gpu = self.modules
        for i in range(len(self.modules_on_one_gpu.mapping_layers)):
            self.loss_names.append('G_distill%d' % i)
        self.optimizer_G, self.optimizer_D = self.modules_on_one_gpu.create_optimizers(
        )
        self.optimizers = [self.optimizer_G, self.optimizer_D]
        if not opt.no_fid:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            self.inception_model = InceptionV3([block_idx])
            self.inception_model.to(self.device)
            self.inception_model.eval()
        if 'cityscapes' in opt.dataroot and not opt.no_mIoU:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            self.drn_model.to(self.device)
            self.drn_model.eval()
        self.eval_dataloader = create_eval_dataloader(self.opt)
        self.best_fid = 1e9
        self.best_mIoU = -1e9
        self.fids, self.mIoUs = [], []
        self.is_best = False
        self.npz = np.load(opt.real_stat_path)

        model_profiling(self.modules_on_one_gpu.netG_teacher,
                        self.opt.data_height,
                        self.opt.data_width,
                        channel=self.opt.data_channel,
                        num_forwards=0,
                        verbose=False)
        model_profiling(self.modules_on_one_gpu.netG_student,
                        self.opt.data_height,
                        self.opt.data_width,
                        channel=self.opt.data_channel,
                        num_forwards=0,
                        verbose=False)
        print(
            f'Teacher FLOPs: {self.modules_on_one_gpu.netG_teacher.n_macs}, Student FLOPs: {self.modules_on_one_gpu.netG_student.n_macs}.'
        )

    def forward(self, on_one_gpu=False):
        if on_one_gpu:
            self.Tfake_B, self.Sfake_B = self.modules_on_one_gpu(
                self.input_semantics)
        else:
            self.Tfake_B, self.Sfake_B = self.modules(self.input_semantics)

    def load_networks(self,
                      verbose=True,
                      teacher_only=False,
                      restore_pretrain=True):
        self.modules_on_one_gpu.load_networks(
            verbose,
            teacher_only=teacher_only,
            restore_pretrain=restore_pretrain)
        if self.opt.restore_O_path is not None:
            for i, optimizer in enumerate(self.optimizers):
                path = '%s-%d.pth' % (self.opt.restore_O_path, i)
                util.load_optimizer(optimizer, path, verbose)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.opt.lr

    def save_networks(self, epoch):
        self.modules_on_one_gpu.save_networks(epoch, self.save_dir)
        for i, optimizer in enumerate(self.optimizers):
            save_filename = '%s_optim-%d.pth' % (epoch, i)
            save_path = os.path.join(self.save_dir, save_filename)
            torch.save(optimizer.state_dict(), save_path)

    def evaluate_model(self, step):
        raise NotImplementedError

    def optimize_parameters(self, steps):
        self.set_requires_grad(self.modules_on_one_gpu.netD, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        self.set_requires_grad(self.modules_on_one_gpu.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()
예제 #14
0
def main(configs, opt, gpu_id, queue, verbose):
    opt.gpu_ids = [gpu_id]
    dataloader = create_dataloader(opt, verbose)
    model = create_model(opt, verbose)
    model.setup(opt, verbose)
    device = model.device
    if not opt.no_fid:
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        inception_model = InceptionV3([block_idx])
        inception_model.to(device)
        inception_model.eval()
    if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA':
        drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
        util.load_network(drn_model, opt.drn_path, verbose=False)
        if len(opt.gpu_ids) > 0:
            drn_model = nn.DataParallel(drn_model, opt.gpu_ids)
        drn_model.eval()

    npz = np.load(opt.real_stat_path)
    results = []

    for data_i in dataloader:
        model.set_input(data_i)
        break

    for config in tqdm.tqdm(configs):
        qualified = True
        macs, _ = model.profile(config)
        if macs > opt.budget:
            qualified = False
        else:
            qualified = True

        fakes, names = [], []

        if qualified:
            for i, data_i in enumerate(dataloader):
                model.set_input(data_i)

                model.test(config)
                fakes.append(model.fake_B.cpu())
                for path in model.get_image_paths():
                    short_path = ntpath.basename(path)
                    name = os.path.splitext(short_path)[0]
                    names.append(name)

        result = {'config_str': encode_config(config), 'macs': macs}
        if not opt.no_fid:
            if qualified:
                fid = get_fid(fakes,
                              inception_model,
                              npz,
                              device,
                              opt.batch_size,
                              use_tqdm=False)
                result['fid'] = fid
            else:
                result['fid'] = 1e9
        if 'cityscapes' in opt.dataroot and opt.direction == 'BtoA':
            if qualified:
                mIoU = get_mIoU(fakes,
                                names,
                                drn_model,
                                device,
                                data_dir=opt.cityscapes_path,
                                batch_size=opt.batch_size,
                                num_workers=opt.num_threads,
                                use_tqdm=False)
                result['mIoU'] = mIoU
            else:
                result['mIoU'] = mIoU
        print(result, flush=True)
        results.append(result)
    queue.put(results)