Beispiel #1
0
    def initialize(self, opt, log):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor

        nb = opt.cycle_batchSize
        crop_height, crop_width = opt.crop_height, opt.crop_width
        self.input_A = self.Tensor(nb, 3, crop_height, crop_width)
        self.input_B = self.Tensor(nb, 3, crop_height, crop_width)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = define_G(gpu_ids=self.gpu_ids)
        self.netG_B = define_G(gpu_ids=self.gpu_ids)

        self.netD_A = define_D(gpu_ids=self.gpu_ids)
        self.netD_B = define_D(gpu_ids=self.gpu_ids)

        # for training
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)
        # define loss functions
        self.criterionGAN = GANLoss(use_lsgan=True, tensor=self.Tensor)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.cycle_lr,
                                            betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizers = []
        self.schedulers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D_A)
        self.optimizers.append(self.optimizer_D_B)
        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, opt))

        utils.print_log('------------ Networks initialized -------------', log)
        print_network(self.netG_A, 'netG_A', log)
        print_network(self.netG_B, 'netG_B', log)
        print_network(self.netD_A, 'netD_A', log)
        print_network(self.netD_B, 'netD_B', log)
        utils.print_log('-----------------------------------------------', log)
Beispiel #2
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Beispiel #3
0
    def __init__(self, opt):
        super(MaskMobileCycleGANModel, self).__init__()
        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(
            opt.gpu_ids) > 0 else 'cpu'
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'mask_weight'
        ]
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A']
        self.visual_names = visual_names_A + visual_names_B

        self.netG_A = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf)
        self.netG_B = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf)

        self.netD_A = NLayerDiscriminator(ndf=self.opt.ndf)
        self.netD_B = NLayerDiscriminator(ndf=self.opt.ndf)
        self.init_net()

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.group_mask_weight_names = []
        self.group_mask_weight_names.append('model.11')
        for i in range(13, 22, 1):
            self.group_mask_weight_names.append('model.%d.conv_block.9' % i)

        self.stop_AtoB_mask = False
        self.stop_BtoA_mask = False

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        # define optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [
            util.get_scheduler(optimizer, opt) for optimizer in self.optimizers
        ]
    def __init__(self, args):
        super().__init__(args)

        if args.mode == 'train':
            self.D = define_D(args)
            self.D = self.D.to(self.device)

            self.fake_right_pool = ImagePool(50)

            self.criterionMonoDepth = define_generator_loss(args)
            self.criterionMonoDepth = self.criterionMonoDepth.to(self.device)

            self.criterionGAN = define_discriminator_loss(args)
            self.criterionGAN = self.criterionGAN.to(self.device)

        # Load the correct networks, depending on which mode we are in.
        if args.mode == 'train':
            self.model_names = ['G', 'D']
            self.optimizer_names = ['G', 'D']
        else:
            self.model_names = ['G']

        self.loss_names = ['G', 'D']

        # We do Resume Training for this architecture.
        if args.resume == '':
            pass
        else:
            self.load_checkpoint(load_optim=False)

        if args.mode == 'train':
            # After resuming, set new optimizers.
            self.optimizer_G = optim.SGD(self.G.parameters(),
                                         lr=args.learning_rate)
            self.optimizer_D = optim.SGD(self.D.parameters(),
                                         lr=args.learning_rate)

            # Reset epoch.
            self.start_epoch = 0

        self.trainG = True
        self.count_trained_G = 0
        self.count_trained_D = 0
        self.regime = args.resume_regime

        if 'cuda' in self.device:
            torch.cuda.synchronize()
    def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
        super(train_style_translator_T, self).__init__(args)
        self._initialize_training()

        self.dataloaders_single = dataloaders_single
        self.dataloaders_xLabels_joint = dataloaders_xLabels_joint

        # define loss weights
        self.lambda_identity = 0.5  # coefficient of identity mapping score
        self.lambda_real = 10.0
        self.lambda_synthetic = 10.0
        self.lambda_GAN = 1.0

        # define pool size in adversarial loss
        self.pool_size = 50
        self.generated_syn_pool = ImagePool(self.pool_size)
        self.generated_real_pool = ImagePool(self.pool_size)

        self.netD_s = Discriminator80x80InstNorm(input_nc=3)
        self.netD_r = Discriminator80x80InstNorm(input_nc=3)
        self.netG_s2r = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.netG_r2s = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s']
        self.L1loss = nn.L1Loss()

        if self.isTrain:
            self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) +
                                             list(self.netD_r.parameters()),
                                             lr=self.D_lr,
                                             betas=(0.5, 0.999))
            self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) +
                                             list(self.netG_s2r.parameters()),
                                             lr=self.G_lr,
                                             betas=(0.5, 0.999))
            self.optim_name = ['netD_optimizer', 'netG_optimizer']
            self._get_scheduler()
            self.loss_BCE = nn.BCEWithLogitsLoss()
            self._initialize_networks()

            # apex can only be applied to CUDA models
            if self.use_apex:
                self._init_apex(Num_losses=3)

        self._check_parallel()
Beispiel #6
0
    def __init__(self, opt, G_A, G_B, D_A, D_B, optimizer_G, optimizer_D,
                 summary_writer):
        self.opt = opt
        self.device = th.device('cuda:{}'.format(
            self.opt.gpu_ids[0])) if self.opt.gpu_ids else th.device('cpu')
        self.G_A = G_A
        self.G_B = G_B
        self.D_A = D_A
        self.D_B = D_B
        # define optimizer G and D
        self.optimizer_G = optimizer_G
        self.optimizer_D = optimizer_D

        self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
        self.criterionCycle = th.nn.L1Loss()
        self.criterionIdt = th.nn.L1Loss()
        self.summary_writer = summary_writer
        self.fake_B_pool = ImagePool(self.opt.pool_size)
        self.fake_A_pool = ImagePool(self.opt.pool_size)
Beispiel #7
0
    def __init__(self, opt, cfg_AtoB=None, cfg_BtoA=None):
        super(MobileCycleGANModel, self).__init__()
        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu'
        self.cfg_AtoB = cfg_AtoB
        self.cfg_BtoA = cfg_BtoA
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A']
        self.visual_names = visual_names_A + visual_names_B

        self.netG_A = MobileResnetGenerator(opt=self.opt, cfg=cfg_AtoB)
        self.netG_B = MobileResnetGenerator(opt=self.opt, cfg=cfg_BtoA)

        self.netD_A = NLayerDiscriminator()
        self.netD_B = NLayerDiscriminator()
        self.init_net()

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.teacher_model = None
        if self.opt.lambda_attention_distill > 0:
            print('init attention distill')
            self.init_attention_distill()
        if self.opt.lambda_discriminator_distill > 0:
            print('init discriminator distill')
            self.init_discriminator_distill()

        # define loss functions
        self.criterionGAN= GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        # define optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
Beispiel #8
0
    def initialize(self, opt):
        super(CrossModelV, self).initialize(opt)
        self.netG = GModel()
        self.netD = DModel()
        self.netG.initialize(opt)
        self.netD.initialize(opt)

        self.criterionGAN = GANLoss(opt.use_lsgan)
        self.optimizer_G = torch.optim.Adam(
            self.netG.parameters(),
            lr=opt.learn_rate,
            #betas=(.5, 0.9)
            betas=(.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(.5, 0.999))
        self.pool = ImagePool(160)

        init_net(self)
        print(self)
    def __init__(self, params):
        super(network, self).__init__()
        self.Tensor = torch.cuda.FloatTensor
        self.configurate(params['net'])

        self.fake_pool_x = ImagePool(self.pool_size)
        self.fake_pool_y = ImagePool(self.pool_size)
        self.input_x = self.Tensor(self.batch_size, 3, 256, 256)
        self.input_y = self.Tensor(self.batch_size, 3, 256, 256)
        self.target_x = self.Tensor(self.batch_size, 1, 256, 256)
        self.target_y = self.Tensor(self.batch_size, 1, 256, 256)

        self.tf_summary = Logger('./logs', self.name)

        self.enc_x = Encoder(**params['enc_x']).cuda()
        self.enc_y = Encoder(**params['enc_y']).cuda()

        self.mul_gen_x = Multitask_Generator(**params['gen_x']).cuda()
        self.mul_gen_y = Multitask_Generator(**params['gen_y']).cuda()

        self.dis_x = NLayerDiscriminator(**params['dis']).cuda()
        self.dis_y = NLayerDiscriminator(**params['dis']).cuda()

        self.criterionGAN = GANLoss()
        self.criterionCyC = torch.nn.L1Loss()
        self.criterionSeg = Segmentation_Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.enc_x.parameters(), self.mul_gen_x.parameters(),
            self.enc_y.parameters(), self.mul_gen_y.parameters()),
                                            lr=self.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.dis_x.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.dis_y.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))
Beispiel #10
0
def get_discriminator_input_fn(conf, disc_conf, no_pool=False):
  if disc_conf.get_attr('use_image_pool', default=False) and not no_pool:
    pool_size = disc_conf.get_attr('image_pool_size',
                                   default=5 * conf.batch_size)
    sample_prob = disc_conf.get_attr('image_pool_sample_prob', default=0.5)
    image_pool = ImagePool(pool_size, sample_prob)
  else:
    image_pool = None

  pool_label_swapping = disc_conf.get_attr('image_pool_label_swapping',
                                           default=False)

  input_method = disc_conf.get_attr('input_method',
                                    default=DEFAULT_INPUT_METHOD)
  normalize_input = disc_conf.get_attr('normalize_input', default=False)
  scale_input = disc_conf.get_attr('scale_input_zero_one', default=False)

  strip_bg_class = disc_conf.get_attr('strip_bg_class', default=False)

  cond_input_src = disc_conf.get_attr('conditional_input_source',
                                      default='input')
  if cond_input_src == 'input':
    cond_input_src = CondInputSource.INPUT
  elif cond_input_src == 'generator':
    cond_input_src = CondInputSource.OUT_GEN
  else:
    raise ValueError(('Unknown conditional '
                     'input source {}').format(cond_input_src))

  cond_input_gen_key = disc_conf.get_attr('conditional_input_generator_key')

  disc_input_fn = _build_input_fn(input_method,
                                  normalize_input,
                                  image_pool,
                                  cond_input_src,
                                  cond_input_gen_key,
                                  strip_bg_class,
                                  scale_input,
                                  pool_label_swapping)

  return disc_input_fn
Beispiel #11
0
class SSRGAN(BaseModel):
    def name(self):
        return 'SSRGAN'

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)

        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [
                l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real,
                                     d_fake), flags) if f
            ]

        return loss_filter

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.input_nc
        self.para = opt.trade_off

        # define networks
        # Generator network
        netG_input_nc = input_nc
        self.netG = networks.define_G(netG_input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.n_downsample_global,
                                      opt.n_blocks_global,
                                      opt.n_local_enhancers,
                                      opt.n_blocks_local,
                                      opt.norm,
                                      gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # netD_input_nc = input_nc + opt.output_nc
            netD_input_nc = opt.output_nc
            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        # Encoder network
        if self.gen_features:
            self.netE = networks.define_G(opt.output_nc,
                                          opt.feat_num,
                                          opt.nef,
                                          'encoder',
                                          opt.n_downsample_E,
                                          norm=opt.norm,
                                          gpu_ids=self.gpu_ids)
        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch,
                                  pretrained_path)
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch,
                                  pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss)
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()

            # AWAN
            self.criterionCSS = networks.CSS()

            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_CSS',
                                               'D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print(
                    '------------- Only training the local enhancer network (for %d epochs) ------------'
                    % opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))
            else:
                params = list(self.netG.parameters())
            if self.gen_features:
                params += list(self.netE.parameters())
            self.optimizer_G = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

    def encode_input(self, rgb, hyper, infer=False):
        # RGB for training
        if rgb is not None:
            rgb = Variable(rgb.data.cuda())
        # hyper for training
        if hyper is not None:
            hyper = Variable(hyper.data.cuda())

        return rgb, hyper

    def discriminate(self, rgb, hyper, use_pool=False):
        # input_concat = torch.cat((rgb, hyper.detach()), dim=1)
        input_concat = hyper.detach()
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            return self.netD.forward(input_concat)

    def forward(self, rgb, hyper, infer=False):
        # Encode Inputs
        rgb, real_hyper = self.encode_input(rgb, hyper)

        # Fake Generation
        input_concat = rgb
        fake_hyper = self.netG.forward(input_concat)

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(rgb, fake_hyper, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)

        # Real Detection and Loss
        pred_real = self.discriminate(rgb, real_hyper)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)
        # pred_fake = self.netD.forward(torch.cat((rgb, fake_hyper), dim=1))
        pred_fake = self.netD.forward(fake_hyper)
        loss_G_GAN = self.criterionGAN(pred_fake, True)

        lrm, lrm_rgb = self.criterionCSS(fake_hyper, real_hyper, rgb)
        loss_G_GAN += lrm + self.para * lrm_rgb  #  default 10

        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i]) - 1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

        # VGG feature matching loss
        # loss_G_VGG = 0

        loss_G_CSS = lrm + self.para * lrm_rgb

        # Only return the fake_B image if necessary to save BW
        # return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake), None if not infer else fake_hyper]
        return [
            self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_CSS,
                             loss_D_real, loss_D_fake),
            None if not infer else fake_hyper
        ]

    def inference(self, rgb, hyper, image=None):
        # Encode Inputs
        rgb, real_hyper = self.encode_input(Variable(rgb),
                                            Variable(hyper),
                                            infer=True)

        # Fake Generation
        input_concat = rgb

        with torch.no_grad():
            fake_hyper = self.netG.forward(input_concat)
        return fake_hyper

    def sample_features(self, inst):
        # read precomputed feature clusters
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name,
                                    self.opt.cluster_path)
        features_clustered = np.load(cluster_path, encoding='latin1').item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)
        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num,
                               inst.size()[2],
                               inst.size()[3])
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0])
                idx = (inst == int(i)).nonzero()
                for k in range(self.opt.feat_num):
                    feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2],
                             idx[:, 3]] = feat[cluster_idx, k]
        if self.opt.data_type == 16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        image = Variable(image.cuda(), volatile=True)
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        inst_np = inst.cpu().numpy().astype(int)
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num + 1))
        for i in np.unique(inst_np):
            label = i if i < 1000 else i // 1000
            idx = (inst == int(i)).nonzero()
            num = idx.size()[0]
            idx = idx[num // 2, :]
            val = np.zeros((1, feat_num + 1))
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2],
                                     idx[3]].data[0]
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:, :, :,
             1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] !=
                                                   t[:, :, :, :-1])
        edge[:, :,
             1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] !=
                                                   t[:, :, :-1, :])
        if self.opt.data_type == 16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=self.opt.lr,
                                            betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print(
                '------------ Now also finetuning global generator -----------'
            )

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
class train_style_translator_T(base_model):
    def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
        super(train_style_translator_T, self).__init__(args)
        self._initialize_training()

        self.dataloaders_single = dataloaders_single
        self.dataloaders_xLabels_joint = dataloaders_xLabels_joint

        # define loss weights
        self.lambda_identity = 0.5  # coefficient of identity mapping score
        self.lambda_real = 10.0
        self.lambda_synthetic = 10.0
        self.lambda_GAN = 1.0

        # define pool size in adversarial loss
        self.pool_size = 50
        self.generated_syn_pool = ImagePool(self.pool_size)
        self.generated_real_pool = ImagePool(self.pool_size)

        self.netD_s = Discriminator80x80InstNorm(input_nc=3)
        self.netD_r = Discriminator80x80InstNorm(input_nc=3)
        self.netG_s2r = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.netG_r2s = _ResGenerator_Upsample(input_nc=3, output_nc=3)
        self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s']
        self.L1loss = nn.L1Loss()

        if self.isTrain:
            self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) +
                                             list(self.netD_r.parameters()),
                                             lr=self.D_lr,
                                             betas=(0.5, 0.999))
            self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) +
                                             list(self.netG_s2r.parameters()),
                                             lr=self.G_lr,
                                             betas=(0.5, 0.999))
            self.optim_name = ['netD_optimizer', 'netG_optimizer']
            self._get_scheduler()
            self.loss_BCE = nn.BCEWithLogitsLoss()
            self._initialize_networks()

            # apex can only be applied to CUDA models
            if self.use_apex:
                self._init_apex(Num_losses=3)

        self._check_parallel()

    def _get_project_name(self):
        return 'train_style_translator_T'

    def _initialize_networks(self):
        for name in self.model_name:
            getattr(self, name).train().to(self.device)
            init_weights(getattr(self, name),
                         net_name=name,
                         init_type='normal',
                         gain=0.02)

    def compute_D_loss(self, real_sample, fake_sample, netD):
        loss = 0
        syn_acc = 0
        real_acc = 0

        output = netD(fake_sample)
        label = torch.full((output.size()), self.syn_label, device=self.device)

        predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
        total_num = torch.numel(output)
        syn_acc += (predSyn == label).type(
            torch.float32).sum().item() / total_num
        loss += self.loss_BCE(output, label)

        output = netD(real_sample)
        label = torch.full((output.size()),
                           self.real_label,
                           device=self.device)

        predReal = (output > 0.5).to(self.device, dtype=torch.float32)
        real_acc += (predReal == label).type(
            torch.float32).sum().item() / total_num
        loss += self.loss_BCE(output, label)

        return loss, syn_acc, real_acc

    def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb,
                       reconstruct_real, reconstruct_syn):
        '''
		real_sample: [batch_size, 4, 240, 320] real rgb
		synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb
		r2s_rgb: netG_r2s(real)
		s2r_rgb: netG_s2r(synthetic)
		'''
        loss = 0

        # identity loss if applicable
        if self.lambda_identity > 0:
            idt_real = self.netG_s2r(real_sample)[-1]
            idt_synthetic = self.netG_r2s(synthetic_sample)[-1]
            idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real +
                        self.L1loss(idt_synthetic, synthetic_sample) *
                        self.lambda_synthetic) * self.lambda_identity
        else:
            idt_loss = 0

        # GAN loss
        real_pred = self.netD_r(s2r_rgb)
        real_label = torch.full(real_pred.size(),
                                self.real_label,
                                device=self.device)
        GAN_loss_real = self.loss_BCE(real_pred, real_label)

        syn_pred = self.netD_s(r2s_rgb)
        syn_label = torch.full(syn_pred.size(),
                               self.real_label,
                               device=self.device)
        GAN_loss_syn = self.loss_BCE(syn_pred, syn_label)

        GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN

        # cycle consistency loss
        rec_real_loss = self.L1loss(reconstruct_real,
                                    real_sample) * self.lambda_real
        rec_syn_loss = self.L1loss(reconstruct_syn,
                                   synthetic_sample) * self.lambda_synthetic
        rec_loss = rec_real_loss + rec_syn_loss

        loss += (idt_loss + GAN_loss + rec_loss)

        return loss, idt_loss, GAN_loss, rec_loss

    def train(self):
        phase = 'train'
        since = time.time()
        best_loss = float('inf')

        tensorboardX_iter_count = 0
        for epoch in range(self.total_epoch_num):
            print('\nEpoch {}/{}'.format(epoch + 1, self.total_epoch_num))
            print('-' * 10)
            fn = open(self.train_log, 'a')
            fn.write('\nEpoch {}/{}\n'.format(epoch + 1, self.total_epoch_num))
            fn.write('--' * 5 + '\n')
            fn.close()

            iterCount = 0

            for sample_dict in self.dataloaders_xLabels_joint:
                imageListReal, depthListReal = sample_dict['real']
                imageListSyn, depthListSyn = sample_dict['syn']

                imageListSyn = imageListSyn.to(self.device)
                depthListSyn = depthListSyn.to(self.device)
                imageListReal = imageListReal.to(self.device)
                depthListReal = depthListReal.to(self.device)

                with torch.set_grad_enabled(phase == 'train'):
                    s2r_rgb = self.netG_s2r(imageListSyn)[-1]
                    reconstruct_syn = self.netG_r2s(s2r_rgb)[-1]

                    r2s_rgb = self.netG_r2s(imageListReal)[-1]
                    reconstruct_real = self.netG_s2r(r2s_rgb)[-1]

                    #############  update generator
                    set_requires_grad([self.netD_r, self.netD_s], False)

                    netG_loss = 0.
                    self.netG_optimizer.zero_grad()
                    netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(
                        imageListReal, imageListSyn, r2s_rgb, s2r_rgb,
                        reconstruct_real, reconstruct_syn)

                    if self.use_apex:
                        with amp.scale_loss(netG_loss,
                                            self.netG_optimizer,
                                            loss_id=0) as netG_loss_scaled:
                            netG_loss_scaled.backward()
                    else:
                        netG_loss.backward()

                    self.netG_optimizer.step()

                    #############  update discriminator
                    set_requires_grad([self.netD_r, self.netD_s], True)

                    self.netD_optimizer.zero_grad()
                    r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb)
                    netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(
                        imageListSyn, r2s_rgb.detach(), self.netD_s)
                    s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb)
                    netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(
                        imageListReal, s2r_rgb.detach(), self.netD_r)

                    netD_loss = netD_s_loss + netD_r_loss

                    if self.use_apex:
                        with amp.scale_loss(netD_loss,
                                            self.netD_optimizer,
                                            loss_id=1) as netD_loss_scaled:
                            netD_loss_scaled.backward()
                    else:
                        netD_loss.backward()
                    self.netD_optimizer.step()

                iterCount += 1

                if self.use_tensorboardX:
                    self.train_display_freq = len(
                        self.dataloaders_xLabels_joint
                    )  # feel free to adjust the display frequency
                    nrow = imageListReal.size()[0]
                    if tensorboardX_iter_count % self.train_display_freq == 0:
                        s2r_rgb_concat = torch.cat(
                            (imageListSyn, s2r_rgb, imageListReal,
                             reconstruct_syn),
                            dim=0)
                        self.write_2_tensorboardX(
                            self.train_SummaryWriter,
                            s2r_rgb_concat,
                            name='RGB: syn, s2r, real, reconstruct syn',
                            mode='image',
                            count=tensorboardX_iter_count,
                            nrow=nrow)

                        r2s_rgb_concat = torch.cat(
                            (imageListReal, r2s_rgb, imageListSyn,
                             reconstruct_real),
                            dim=0)
                        self.write_2_tensorboardX(
                            self.train_SummaryWriter,
                            r2s_rgb_concat,
                            name='RGB: real, r2s, synthetic, reconstruct real',
                            mode='image',
                            count=tensorboardX_iter_count,
                            nrow=nrow)

                    loss_val_list = [netD_loss, netG_loss]
                    loss_name_list = ['netD_loss', 'netG_loss']
                    self.write_2_tensorboardX(self.train_SummaryWriter,
                                              loss_val_list,
                                              name=loss_name_list,
                                              mode='scalar',
                                              count=tensorboardX_iter_count)

                    tensorboardX_iter_count += 1

                if iterCount % 20 == 0:
                    loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(
                        iterCount, len(self.dataloaders_xLabels_joint),
                        netD_loss, netG_loss)
                    G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}'.format(
                        netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss)

                    print(loss_summary)
                    print(G_loss_summary)

                    fn = open(self.train_log, 'a')
                    fn.write(loss_summary + '\n')
                    fn.write(G_loss_summary + '\n')
                    fn.close()

            if (epoch + 1) % self.save_steps == 0:
                self.save_models(['netG_r2s'],
                                 mode=epoch + 1,
                                 save_list=['styleTranslator'])

            # take step in optimizer
            for scheduler in self.scheduler_list:
                scheduler.step()
                for optim in self.optim_name:
                    lr = getattr(self, optim).param_groups[0]['lr']
                    lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(
                        epoch + 1, self.total_epoch_num, optim, lr)
                    print(lr_update)

                    fn = open(self.train_log, 'a')
                    fn.write(lr_update + '\n')
                    fn.close()

        time_elapsed = time.time() - since
        print('\nTraining complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        fn = open(self.train_log, 'a')
        fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(
            time_elapsed // 60, time_elapsed % 60))
        fn.close()

    def evaluate(self, mode):
        pass
Beispiel #13
0
class ITN():
    def __repr__(self):
        return ('{name})'.format(name=self.__class__.__name__,
                                 **self.__dict__))

    def initialize(self, opt, log):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor

        nb = opt.cycle_batchSize
        crop_height, crop_width = opt.crop_height, opt.crop_width
        self.input_A = self.Tensor(nb, 3, crop_height, crop_width)
        self.input_B = self.Tensor(nb, 3, crop_height, crop_width)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = define_G(gpu_ids=self.gpu_ids)
        self.netG_B = define_G(gpu_ids=self.gpu_ids)

        self.netD_A = define_D(gpu_ids=self.gpu_ids)
        self.netD_B = define_D(gpu_ids=self.gpu_ids)

        # for training
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)
        # define loss functions
        self.criterionGAN = GANLoss(use_lsgan=True, tensor=self.Tensor)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()
        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.cycle_lr,
                                            betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=opt.cycle_lr,
                                              betas=(opt.cycle_beta1, 0.999))
        self.optimizers = []
        self.schedulers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D_A)
        self.optimizers.append(self.optimizer_D_B)
        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer, opt))

        utils.print_log('------------ Networks initialized -------------', log)
        print_network(self.netG_A, 'netG_A', log)
        print_network(self.netG_B, 'netG_B', log)
        print_network(self.netD_A, 'netD_A', log)
        print_network(self.netD_B, 'netD_B', log)
        utils.print_log('-----------------------------------------------', log)

    def set_mode(self, mode):
        if mode.lower() == 'train':
            self.netG_A.train()
            self.netG_B.train()
            self.netD_A.train()
            self.netD_B.train()
            self.criterionGAN.train()
            self.criterionCycle.train()
            self.criterionIdt.train()
        elif mode.lower() == 'eval':
            self.netG_A.eval()
            self.netG_B.eval()
            self.netD_A.eval()
            self.netD_B.eval()
        else:
            raise NameError('The wrong mode : {}'.format(mode))

    def set_input(self, input):
        input_A = input['A']
        input_B = input['B']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)

    def prepaer_input(self):
        self.real_A = torch.autograd.Variable(self.input_A)
        self.real_B = torch.autograd.Variable(self.input_B)

    def num_parameters(self):
        params = count_parameters_in_MB(self.netG_A)
        params += count_parameters_in_MB(self.netG_B)
        params += count_parameters_in_MB(self.netD_B)
        params += count_parameters_in_MB(self.netD_B)
        return params

    def num_flops(self):
        self.prepaer_input()
        flops1, params1 = get_model_infos(self.netG_A.model, None, self.real_A)
        fake_B = self.netG_A(self.real_A)
        flops2, params2 = get_model_infos(self.netD_A.model, None, fake_B)
        return flops1 + flops2

    def test(self):
        self.real_A = torch.autograd.Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = torch.autograd.Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        self.loss_G_B = self.criterionGAN(pred_fake, True)
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.prepaer_input()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A.item()
        G_A = self.loss_G_A.item()
        Cyc_A = self.loss_cycle_A.item()
        D_B = self.loss_D_B.item()
        G_B = self.loss_G_B.item()
        Cyc_B = self.loss_cycle_B.item()
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.item()
            idt_B = self.loss_idt_B.item()
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('idt_A', idt_A), ('D_B', D_B), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
                                ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_visuals(self, isTrain):
        real_A = tensor2im(self.real_A.data)
        rec_A = tensor2im(self.rec_A.data)
        fake_A = tensor2im(self.fake_A.data)

        real_B = tensor2im(self.real_B.data)
        rec_B = tensor2im(self.rec_B.data)
        fake_B = tensor2im(self.fake_B.data)

        if isTrain and self.opt.identity > 0.0:
            idt_A = tensor2im(self.idt_A.data)
            idt_B = tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def save(self, save_dir, log):
        save_network(save_dir, 'G_A', self.netG_A, self.gpu_ids)
        save_network(save_dir, 'D_A', self.netD_A, self.gpu_ids)
        save_network(save_dir, 'G_B', self.netG_B, self.gpu_ids)
        save_network(save_dir, 'D_B', self.netD_B, self.gpu_ids)
        utils.print_log('save the model into {}'.format(save_dir), log)

    def load(self, save_dir, log):
        load_network(save_dir, 'G_A', self.netG_A)
        load_network(save_dir, 'D_A', self.netD_A)
        load_network(save_dir, 'G_B', self.netG_B)
        load_network(save_dir, 'D_B', self.netD_B)
        utils.print_log('load the model from {}'.format(save_dir), log)

    # update learning rate (called once every epoch)
    def update_learning_rate(self, log):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        utils.print_log('learning rate = {:.7f}'.format(lr), log)
Beispiel #14
0
class pix2pixGAN(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    @staticmethod
    def modify_commandline_options():
        parser = two_domain_parser_options()
        return add_lambda_L1(parser)

    def __init__(self, args, logger):
        super().__init__(args, logger)
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['loss_G', 'loss_D']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        self.model_names = ['G', 'D']

        self.sample_names = ['fake_B', 'real_A', 'real_B']
        # load/define networks
        self.G = networks.define_G(args.input_nc, args.output_nc, args.ngf,
                                      args.which_model_netG, args.norm, not args.no_dropout, args.init_type, args.init_gain, self.gpu_ids)

        if not 'continue_train' in args:
            use_sigmoid = args.no_lsgan
            self.D = networks.define_D(args.input_nc + args.output_nc, args.ndf,
                                          args.which_model_netD,
                                          args.n_layers_D, args.norm, use_sigmoid, args.init_type, args.init_gain, self.gpu_ids)

            self.fake_AB_pool = ImagePool(args.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not args.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=args.g_lr, betas=(args.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=args.d_lr, betas=(args.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input, args):
        AtoB = self.args.which_direction == 'AtoB'
        self.real_A = input[args.A_label if AtoB else args.B_label].to(self.device)
        self.real_B = input[args.B_label if AtoB else args.A_label].to(self.device)

    def forward(self):
        self.fake_B = self.G(self.real_A)

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        pred_fake = self.D(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.D(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.D(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.args.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self, num_steps, overwite_gen):
        self.forward()
        # update D
        self.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.D, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
Beispiel #15
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan and not opt.no_sigmoid
            self.netD_1A = networks.define_D(opt.output_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_1B = networks.define_D(opt.input_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
            self.netD_B = networks.define_D(opt.input_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_1A', which_epoch)
                self.load_network(self.netD_B, 'D_1B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_1A = torch.optim.Adam(self.netD_1A.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_1B = torch.optim.Adam(self.netD_1B.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            self.optimizers.append(self.optimizer_D_1A)
            self.optimizers.append(self.optimizer_D_1B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            networks.print_network(self.netD_1A, opt)
        print('-----------------------------------------------')
Beispiel #16
0
class CycleGANModel(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG inception_9blocks' InceptionNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        assert is_train
        parser = super(CycleGANModel,
                       CycleGANModel).modify_commandline_options(
                           parser, is_train)
        parser.add_argument('--restore_G_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_A')
        parser.add_argument('--restore_D_A_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_A')
        parser.add_argument('--restore_G_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the generator G_B')
        parser.add_argument('--restore_D_B_path',
                            type=str,
                            default=None,
                            help='the path to restore the discriminator D_B')
        parser.add_argument('--lambda_A',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (A -> B -> A)')
        parser.add_argument('--lambda_B',
                            type=float,
                            default=10.0,
                            help='weight for cycle loss (B -> A -> B)')
        parser.add_argument(
            '--lambda_identity',
            type=float,
            default=0.5,
            help='use identity mapping. '
            'Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. '
            'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
        )
        parser.add_argument(
            '--real_stat_A_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth A images information to compute FID.'
        )
        parser.add_argument(
            '--real_stat_B_path',
            type=str,
            required=True,
            help=
            'the path to load the ground-truth B images information to compute FID.'
        )
        parser.set_defaults(norm='instance',
                            dataset_mode='unaligned',
                            batch_size=1,
                            ndf=64,
                            gan_mode='lsgan',
                            nepochs=100,
                            nepochs_decay=100,
                            save_epoch_freq=20)
        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        self.netD_A = networks.define_D(opt.output_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netD_B = networks.define_D(opt.input_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        if opt.lambda_identity > 0.0:
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best_A = False
        self.is_best_B = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.real_A = input['A'].to(self.device)
        self.real_B = input['B'].to(self.device)

    def set_single_input(self, input):
        self.real_A = input['A'].to(self.device)
        self.image_paths = input['A_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)
        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        if lambda_idt > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.loss_G_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            self.idt_B = self.netG_B(self.real_A)
            self.loss_G_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_G_idt_A = 0
            self.loss_G_idt_B = 0

        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        self.loss_G_cycle_A = self.criterionCycle(self.rec_A,
                                                  self.real_A) * lambda_A
        self.loss_G_cycle_B = self.criterionCycle(self.rec_B,
                                                  self.real_B) * lambda_B
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_G_cycle_A + self.loss_G_cycle_B + self.loss_G_idt_A + self.loss_G_idt_B
        self.loss_G.backward()

    def optimize_parameters(self, steps):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        self.forward()
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

    def test_single_side(self, direction):
        generator = getattr(self, 'netG_%s' % direction[0])
        with torch.no_grad():
            self.fake_B = generator(self.real_A)

    def evaluate_model(self, step, save_image=False):
        ret = {}
        self.is_best_A = False
        self.is_best_B = False
        save_dir = os.path.join(self.opt.log_dir, 'eval', str(step))
        os.makedirs(save_dir, exist_ok=True)
        self.netG_A.eval()
        self.netG_B.eval()
        for direction in ['AtoB', 'BtoA']:
            eval_dataloader = getattr(self, 'eval_dataloader_' + direction)
            fakes, names = [], []
            cnt = 0
            for i, data_i in enumerate(tqdm(eval_dataloader)):
                self.set_single_input(data_i)
                self.test_single_side(direction)
                fakes.append(self.fake_B.cpu())
                for j in range(len(self.image_paths)):
                    short_path = ntpath.basename(self.image_paths[j])
                    name = os.path.splitext(short_path)[0]
                    names.append(name)
                    if cnt < 10 or save_image:
                        input_im = util.tensor2im(self.real_A[j])
                        fake_im = util.tensor2im(self.fake_B[j])
                        util.save_image(input_im,
                                        os.path.join(save_dir, direction,
                                                     'input', '%s.png' % name),
                                        create_dir=True)
                        util.save_image(fake_im,
                                        os.path.join(save_dir, direction,
                                                     'fake', '%s.png' % name),
                                        create_dir=True)
                    cnt += 1

            suffix = direction[-1]
            fid = get_fid(fakes,
                          self.inception_model,
                          getattr(self, 'npz_%s' % direction[-1]),
                          device=self.device,
                          batch_size=self.opt.eval_batch_size)
            if fid < getattr(self, 'best_fid_%s' % suffix):
                setattr(self, 'is_best_%s' % direction[0], True)
                setattr(self, 'best_fid_%s' % suffix, fid)
            fids = getattr(self, 'fids_%s' % suffix)
            fids.append(fid)
            if len(fids) > 3:
                fids.pop(0)
            ret['metric/fid_%s' % suffix] = fid
            ret['metric/fid_%s-mean' %
                suffix] = sum(getattr(self, 'fids_%s' % suffix)) / len(
                    getattr(self, 'fids_%s' % suffix))
            ret['metric/fid_%s-best' % suffix] = getattr(
                self, 'best_fid_%s' % suffix)

        self.netG_A.train()
        self.netG_B.train()
        return ret
Beispiel #17
0
    def __init__(self, opt):
        # Initialize the Models

        # Global Variables
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain

        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        self.device = torch.device(
            f'cuda:{self.gpu_ids[0]}') if self.gpu_ids else torch.device('cpu')
        self.metric = 0  # used for learning rate policy 'plateau'

        self.G_AtoB = build_G(input_nc=opt.input_nc,
                              output_nc=opt.output_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.G_BtoA = build_G(input_nc=opt.output_nc,
                              output_nc=opt.input_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.net_names = ['G_AtoB', 'G_BtoA']

        if self.isTrain:
            self.D_A = build_D(input_nc=opt.output_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)
            self.D_B = build_D(input_nc=opt.input_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)

            self.net_names.append('D_A')
            self.net_names.append('D_B')

            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.G_AtoB.parameters(), self.G_BtoA.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.D_A.parameters(), self.D_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            # lr Scheduler
            self.schedulers = [
                get_scheduler(optimizer,
                              lr_policy=opt.lr_policy,
                              n_epochs=opt.n_epochs,
                              lr_decay_iters=opt.lr_decay_iters,
                              epoch_count=opt.epoch_count,
                              n_epochs_decay=opt.n_epochs_decay)
                for optimizer in self.optimizers
            ]

        # Internal Variables
        self.real_A = None
        self.real_B = None
        self.image_paths = None
        self.fake_A = None
        self.fake_B = None
        self.rec_A = None
        self.rec_B = None
        self.idt_A = None
        self.idt_B = None
        self.loss_idt_A = None
        self.loss_idt_B = None
        self.loss_G_AtoB = None
        self.loss_G_BtoA = None
        self.cycle_loss_A = None
        self.cycle_loss_B = None
        self.loss_G = None
        self.loss_D_A = None
        self.loss_D_B = None

        # Printing the Networks
        for net_name in self.net_names:
            print(net_name, "\n", getattr(self, net_name))

        # Continue training, if isTrain
        if self.isTrain:
            if self.opt.ct > 0:
                print(f"Continue training from {self.opt.ct}")
                self.load_train_model(str(self.opt.ct))
class VanillaGanSingleArchitecture(BaseArchitecture):
    def __init__(self, args):
        super().__init__(args)

        if args.mode == 'train':
            self.D = define_D(args)
            self.D = self.D.to(self.device)

            self.fake_right_pool = ImagePool(50)

            self.criterion = define_generator_loss(args)
            self.criterion = self.criterion.to(self.device)
            self.criterionGAN = define_discriminator_loss(args)
            self.criterionGAN = self.criterionGAN.to(self.device)

            self.optimizer_G = optim.Adam(self.G.parameters(),
                                          lr=args.learning_rate)
            self.optimizer_D = optim.SGD(self.D.parameters(),
                                         lr=args.learning_rate)

        # Load the correct networks, depending on which mode we are in.
        if args.mode == 'train':
            self.model_names = ['G', 'D']
            self.optimizer_names = ['G', 'D']
        else:
            self.model_names = ['G']

        self.loss_names = ['G', 'G_MonoDepth', 'G_GAN', 'D']
        self.losses = {}

        if self.args.resume:
            self.load_checkpoint()

        if 'cuda' in self.device:
            torch.cuda.synchronize()

    def set_input(self, data):
        self.data = to_device(data, self.device)
        self.left = self.data['left_image']
        self.right = self.data['right_image']

    def forward(self):
        self.disps = self.G(self.left)

        # Prepare disparities
        disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in self.disps]
        self.disp_right_est = disp_right_est[0]

        self.fake_right = self.criterion.generate_image_right(
            self.left, self.disp_right_est)

    def backward_D(self):
        # Fake
        fake_pool = self.fake_right_pool.query(self.fake_right)
        pred_fake = self.D(fake_pool.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        # Real
        pred_real = self.D(self.right)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        # G should fake D
        pred_fake = self.D(self.fake_right)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        self.loss_G_MonoDepth = self.criterion(self.disps,
                                               [self.left, self.right])

        self.loss_G = self.loss_G_GAN * self.args.discriminator_w + self.loss_G_MonoDepth
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        # Update D.
        self.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        # Update G.
        self.set_requires_grad(self.D, False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def update_learning_rate(self, epoch, learning_rate):
        """ Sets the learning rate to the initial LR
            decayed by 2 every 10 epochs after 30 epochs.
        """
        if self.args.adjust_lr:
            if 30 <= epoch < 40:
                lr = learning_rate / 2
            elif epoch >= 40:
                lr = learning_rate / 4
            else:
                lr = learning_rate
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_D.param_groups:
                param_group['lr'] = lr

    def get_untrained_loss(self):
        # -- Generator
        loss_G_MonoDepth = self.criterion(self.disps, [self.left, self.right])
        fake_G_right = self.D(self.fake_right)
        loss_G_GAN = self.criterionGAN(fake_G_right, True)
        loss_G = loss_G_GAN * self.args.discriminator_w + loss_G_MonoDepth

        # -- Discriminator
        loss_D_fake = self.criterionGAN(self.D(self.fake_right), False)
        loss_D_real = self.criterionGAN(self.D(self.right), True)
        loss_D = (loss_D_fake + loss_D_real) * 0.5

        return {
            'G': loss_G.item(),
            'G_MonoDepth': loss_G_MonoDepth.item(),
            'G_GAN': loss_G_GAN.item(),
            'D': loss_D.item()
        }

    @property
    def architecture(self):
        return 'Single GAN Architecture'
Beispiel #19
0
class CrossModel(BaseModel):
    def __init__(self):
        super(CrossModel, self).__init__()
        self.model_names = 'cross_model'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        if is_train:
            parser.add_argument('--style_dropout',
                                type=float,
                                default=.5,
                                help='dropout ratio of style feature vector')
            parser.add_argument('--style_channels',
                                type=int,
                                default=32,
                                help='size of style channels')
            parser.add_argument(
                '--pool_size',
                type=int,
                default=150,
                help=
                'size of image pool, which is used to prevent model collapse')
            parser.add_argument('--lambda_E',
                                type=float,
                                default=0.0,
                                help='lambda of extra loss')
            parser.add_argument('--fast_forward',
                                type=bool,
                                default=False,
                                help='do not train the selector')
            parser.add_argument('--opt_betas1', type=float, default=.5)
            parser.add_argument('--opt_betas2', type=float, default=.999)
            parser.add_argument('--g_model_transnet',
                                type=str,
                                default='resnet')
            parser.add_argument('--g_model_transnet_n_blocks',
                                type=int,
                                default=8)
            parser.add_argument('--d_model_n_blocks', type=int, default=1)
            parser.add_argument('--d_model_use_dropout',
                                type=bool,
                                default=False)
            parser.add_argument('--selector_criterion_method',
                                type=str,
                                default='l1')
        return parser

    def init_vistool(self, opt):
        self.vistool = vistool.VisTool(env=opt.name + '_model')
        self.vistool.register_data('fake_imgs', 'images')
        self.vistool.register_data('styles', 'images')
        self.vistool.register_data('texts', 'images')
        self.vistool.register_data('diff_with_average', 'images')
        self.vistool.register_data('gmodel_sorted', 'images')
        self.vistool.register_data('dmodel_sorted', 'images')
        self.vistool.register_data('scores', 'array')
        self.vistool.register_data('dis_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('sel_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('rad_preds_L1_loss', 'scalar_ma')
        self.vistool.register_data('mod_preds_L1_loss', 'scalar_ma')
        self.vistool.register_window('dmodel_sorted',
                                     'images',
                                     source='dmodel_sorted')
        self.vistool.register_window('gmodel_sorted',
                                     'images',
                                     source='gmodel_sorted')
        if not opt.fast_forward:
            self.vistool.register_window('scores', 'bar', source='scores')
        self.vistool.register_window('preds_L1_loss',
                                     'lines',
                                     sources=[
                                         'dis_preds_L1_loss',
                                         'sel_preds_L1_loss',
                                         'rad_preds_L1_loss',
                                         'mod_preds_L1_loss'
                                     ])

    def initialize(self, opt):
        super(CrossModel, self).initialize(opt)
        self.fastForward = opt.fast_forward
        self.netG = GModel()
        self.netD = DModel()
        self.netG.initialize(opt)
        self.netD.initialize(opt)
        self.criterionGAN = GANLoss(False)
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(opt.opt_betas1,
                                                   opt.opt_betas2))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.learn_rate,
                                            betas=(opt.opt_betas1,
                                                   opt.opt_betas2))
        self.pool = ImagePool(opt.pool_size)
        self.lambda_E = opt.lambda_E
        self.criterionSelector = find_criterion_using_name(
            opt.selector_criterion_method)()

        init_net(self)
        path = opt.checkpoints_dir + '/' + self.model_names + '.txt'
        with open(path, 'w') as f:
            f.write(str(self))
        logger.info("Model Structure has been written into %s" % path)

        self.init_vistool(opt)

    def set_input(self, texts, styles, target):
        self.texts = texts
        self.styles = styles
        self.real_img = target.unsqueeze(1)

    def forward(self):
        self.netG(self.texts, self.styles)
        self.fake_imgs = self.netG.basic_preds

    def backward_D(self):
        fake_all = self.fake_imgs
        real_all = self.real_img
        texts = self.texts
        styles = self.styles

        #A trick to prevent mode collapse
        img = torch.cat((fake_all, real_all, texts, styles), 1).detach()
        img = self.pool.query(img)
        tot = (img.size(1) - 1) // 3
        fake_all, real_all, texts, styles = torch.split(
            img, [tot, 1, tot, tot], 1)
        fake_all = fake_all.contiguous()
        real_all = real_all.contiguous()

        pred_fake = self.netD(fake_all.detach(), texts, styles)
        pred_real = self.netD(real_all.detach(), texts, styles)

        self.loss_fake = self.criterionGAN(pred_fake, False)
        self.loss_real = self.criterionGAN(pred_real, True)
        self.loss_D = (self.loss_fake + self.loss_real) * .5
        self.loss_D.backward()

    def backward_G(self):
        fake_all = self.fake_imgs
        pred_fake = self.netD(fake_all, self.texts, self.styles)
        self.loss_G = self.criterionGAN(pred_fake, True)  #Gan loss
        self.loss_GSE = self.loss_G
        if not self.fastForward:
            pred_result = pred_fake.detach()
            self.loss_S = (pred_result -
                           self.netG.basic_score).abs().mean()  #Selector loss
            self.loss_GSE += self.loss_S
            self.vistool.update(
                'scores',
                torch.stack((pred_result[0], self.netG.basic_score[0]), 1))
        self.loss_E = self.netG.extra_loss  # Extra loss
        self.loss_GSE += self.loss_E * self.lambda_E
        self.loss_GSE.backward()

    def optimize_parameters(self):
        self.forward()
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        if self.optm_d:
            self.optimizer_D.step()

        self.set_requires_grad(self.netD, False)
        self.forward()
        self.optimizer_G.zero_grad()
        self.backward_G()
        if self.optm_g:
            self.optimizer_G.step()

        bs, tot, W, H = self.texts.shape
        score = self.netG.basic_score + self.netD.basic_score * .5
        rank = torch.sort(score, 1, descending=True)[1]
        model_preds = torch.gather(
            self.netG.basic_preds, 1,
            rank.view(bs, tot, 1, 1).expand(bs, tot, W, H))

        self.vistool.update('gmodel_sorted', self.netG.best_preds[0] * .5 + .5)
        self.vistool.update('dmodel_sorted', self.netD.dis_preds[0] * .5 + .5)
        self.vistool.update('diff_with_average', self.netG.diff_with_average)
        self.vistool.update(
            'mod_preds_L1_loss',
            self.criterionSelector(model_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update(
            'dis_preds_L1_loss',
            self.criterionSelector(self.netD.dis_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update(
            'sel_preds_L1_loss',
            self.criterionSelector(self.netG.best_preds[:, 0, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        idx = random.randint(0, self.netG.best_preds.size(1) - 1)
        self.vistool.update(
            'rad_preds_L1_loss',
            self.criterionSelector(self.netG.best_preds[:, idx, :, :],
                                   self.real_img[:, 0, :, :]).mean())
        self.vistool.update('fake_imgs', self.fake_imgs[0] * .5 + .5)
        self.vistool.update('styles', self.styles[0] * .5 + .5)
        self.vistool.update('texts', self.texts[0] * .5 + .5)
        self.vistool.sync()
Beispiel #20
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.input_nc
        self.para = opt.trade_off

        # define networks
        # Generator network
        netG_input_nc = input_nc
        self.netG = networks.define_G(netG_input_nc,
                                      opt.output_nc,
                                      opt.ngf,
                                      opt.netG,
                                      opt.n_downsample_global,
                                      opt.n_blocks_global,
                                      opt.n_local_enhancers,
                                      opt.n_blocks_local,
                                      opt.norm,
                                      gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            # netD_input_nc = input_nc + opt.output_nc
            netD_input_nc = opt.output_nc
            self.netD = networks.define_D(netD_input_nc,
                                          opt.ndf,
                                          opt.n_layers_D,
                                          opt.norm,
                                          use_sigmoid,
                                          opt.num_D,
                                          not opt.no_ganFeat_loss,
                                          gpu_ids=self.gpu_ids)

        # Encoder network
        if self.gen_features:
            self.netE = networks.define_G(opt.output_nc,
                                          opt.feat_num,
                                          opt.nef,
                                          'encoder',
                                          opt.n_downsample_E,
                                          norm=opt.norm,
                                          gpu_ids=self.gpu_ids)
        if self.opt.verbose:
            print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch,
                                  pretrained_path)
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch,
                                  pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError(
                    "Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss,
                                                     not opt.no_vgg_loss)
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()

            # AWAN
            self.criterionCSS = networks.CSS()

            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_CSS',
                                               'D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3, 0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        params += [value]
                        finetune_list.add(key.split('.')[0])
                print(
                    '------------- Only training the local enhancer network (for %d epochs) ------------'
                    % opt.niter_fix_global)
                print('The layers that are finetuned are ',
                      sorted(finetune_list))
            else:
                params = list(self.netG.parameters())
            if self.gen_features:
                params += list(self.netE.parameters())
            self.optimizer_G = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params,
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, opt.dropout_rate,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                        opt.n_layers_D, opt.norm,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(
            opt.pool_size
        )  # create image buffer to store previously generated images

        # define loss functions
        self.criterionGAN = models.modules.loss.GANLoss(opt.gan_mode).to(
            self.device)  # define GAN loss.
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        if 'cityscapes' in opt.dataroot:
            self.drn_model = DRNSeg('drn_d_105', 19, pretrained=False)
            util.load_network(self.drn_model, opt.drn_path, verbose=False)
            if len(opt.gpu_ids) > 0:
                self.drn_model = nn.DataParallel(self.drn_model, opt.gpu_ids)
            self.drn_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)
Beispiel #22
0
class Pix2PixModel(BaseModel):
    def name(self):
        return 'Pix2PixModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # define tensors
        self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize,
                                   opt.fineSize)
        self.input_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize,
                                   opt.fineSize)

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm,
                                      not opt.no_dropout, opt.init_type,
                                      self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.which_model_netD,
                                          opt.n_layers_D, opt.norm,
                                          use_sigmoid, opt.init_type,
                                          self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # no backprop gradients
    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B, volatile=True)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(
            torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B,
                                          self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

    def get_current_errors(self):
        return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
                            ('G_L1', self.loss_G_L1.data[0]),
                            ('D_real', self.loss_D_real.data[0]),
                            ('D_fake', self.loss_D_fake.data[0])])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        real_B = util.tensor2im(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                            ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)
Beispiel #23
0
class CycleGan:
    def __init__(self, opt):
        # Initialize the Models

        # Global Variables
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain

        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        self.device = torch.device(
            f'cuda:{self.gpu_ids[0]}') if self.gpu_ids else torch.device('cpu')
        self.metric = 0  # used for learning rate policy 'plateau'

        self.G_AtoB = build_G(input_nc=opt.input_nc,
                              output_nc=opt.output_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.G_BtoA = build_G(input_nc=opt.output_nc,
                              output_nc=opt.input_nc,
                              ngf=opt.ngf,
                              norm=opt.norm,
                              padding_type=opt.padding_type,
                              use_dropout=not opt.no_dropout,
                              n_blocks=opt.n_blocks_G,
                              init_type=opt.init_type,
                              init_gain=opt.init_gain,
                              gpu_ids=opt.gpu_ids)

        self.net_names = ['G_AtoB', 'G_BtoA']

        if self.isTrain:
            self.D_A = build_D(input_nc=opt.output_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)
            self.D_B = build_D(input_nc=opt.input_nc,
                               ndf=opt.ndf,
                               n_layers=opt.n_layers_D,
                               norm=opt.norm,
                               init_type=opt.init_type,
                               init_gain=opt.init_gain,
                               gpu_ids=opt.gpu_ids)

            self.net_names.append('D_A')
            self.net_names.append('D_B')

            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.G_AtoB.parameters(), self.G_BtoA.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.D_A.parameters(), self.D_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            # lr Scheduler
            self.schedulers = [
                get_scheduler(optimizer,
                              lr_policy=opt.lr_policy,
                              n_epochs=opt.n_epochs,
                              lr_decay_iters=opt.lr_decay_iters,
                              epoch_count=opt.epoch_count,
                              n_epochs_decay=opt.n_epochs_decay)
                for optimizer in self.optimizers
            ]

        # Internal Variables
        self.real_A = None
        self.real_B = None
        self.image_paths = None
        self.fake_A = None
        self.fake_B = None
        self.rec_A = None
        self.rec_B = None
        self.idt_A = None
        self.idt_B = None
        self.loss_idt_A = None
        self.loss_idt_B = None
        self.loss_G_AtoB = None
        self.loss_G_BtoA = None
        self.cycle_loss_A = None
        self.cycle_loss_B = None
        self.loss_G = None
        self.loss_D_A = None
        self.loss_D_B = None

        # Printing the Networks
        for net_name in self.net_names:
            print(net_name, "\n", getattr(self, net_name))

        # Continue training, if isTrain
        if self.isTrain:
            if self.opt.ct > 0:
                print(f"Continue training from {self.opt.ct}")
                self.load_train_model(str(self.opt.ct))

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        old_lr = self.optimizers[0].param_groups[0]['lr']
        for scheduler in self.schedulers:
            if self.opt.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate %.7f -> %.7f' % (old_lr, lr))

    def feed_input(self, x):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        :type x: dict
        :param x: include the data itself and its metadata information.
        x should have the structure {'A': Tensor Images, 'B': Tensor Images,
        'A_paths': paths of the A Images, 'B_paths': paths of the B Images}

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = x['A' if AtoB else 'B'].to(self.device)
        self.real_B = x['B' if AtoB else 'A'].to(self.device)
        self.image_paths = x['A_paths' if AtoB else 'B_paths']

    def optimize_parameters(self):
        # Forward
        self.forward()

        # Train Generators
        self._set_requires_grad(
            [self.D_A, self.D_B],
            False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights

        # Train Discriminators
        self._set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

    def forward(self):
        """Run forward pass
        Called by both functions <optimize_parameters> and <test>
        """
        self.fake_B = self.G_AtoB(self.real_A)  # G_A(A)
        self.rec_A = self.G_BtoA(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.G_BtoA(self.real_B)  # G_B(B)
        self.rec_B = self.G_AtoB(self.fake_A)  # G_A(G_B(B))

    def backward_G(self):
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # GAN loss D_A(G_AtoB(A))
        self.loss_G_AtoB = self.criterionGAN(self.D_A(self.fake_B), True)

        # GAN loss D_B(G_BtoA(B))
        self.loss_G_BtoA = self.criterionGAN(self.D_B(self.fake_A), True)

        # Forward cycle loss || G_B(G_A(A)) - A||
        self.cycle_loss_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A

        # Backward cycle loss || G_A(G_B(B)) - B||
        self.cycle_loss_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B

        # combined loss and calculate gradients
        self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.cycle_loss_A + self.cycle_loss_B
        self.loss_G += self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        :param netD: the discriminator D
        :param real: real images
        :param fake: images generated by a generator
        :return: Loss
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.D_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.D_B, self.real_A, fake_A)

    def _set_requires_grad(self,
                           nets: List[nn.Module],
                           requires_grad: bool = False) -> None:
        """
        Set requires_grad=False for all the networks to avoid unnecessary computations
        :param nets: a list of networks
        :param requires_grad: whether the networks require gradients or not
        """
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def train(self):
        """Make models train mode during test time"""
        self.G_AtoB.train()
        self.G_BtoA.train()

        if self.isTrain:
            self.D_A.train()
            self.D_B.train()

    def eval(self):
        """Make models eval mode during test time"""
        self.G_AtoB.eval()
        self.G_BtoA.eval()

        if self.isTrain:
            self.D_A.eval()
            self.D_B.eval()

    def compute_visuals(self, bidirectional=False):
        """ Computes the Visual output data from the model
        :type bidirectional: bool
        :param bidirectional: if true, Calculate both AtoB and BtoA, else calculate AtoB
        """
        self.eval()
        with torch.no_grad():
            self.fake_B = self.G_AtoB(self.real_A)
            if bidirectional:
                self.fake_A = self.G_BtoA(self.real_B)

    def _load_objects(self, file_names: List[str], object_names: List[str]):
        """Load objects from file

        :param file_names: Name of the Files to load
        :param object_names: Name of the object, where the files is going to be stored.

        file_names and object_names should be in same order
        """
        for file_name, object_name in zip(file_names, object_names):
            model_name = os.path.join(self.save_dir, file_name)
            print(f"Loading {object_name} from {model_name}")
            state_dict = torch.load(model_name, map_location=self.device)

            net = getattr(self, object_name)
            if isinstance(net, torch.nn.DataParallel):
                net = net.module
            net.load_state_dict(state_dict)

    def load_networks(self, initials, load_D=False):
        """ Loading Models
        Loads from /checkpoint_dir/name/{initials}_net_G_AtoB.pt
        :type initials: str
        :param initials: The initials of the model
        :type load_D: bool
        :param load_D: Is loading D or not
        """
        file_names = [f"{initials}_net_G_AtoB.pt", f"{initials}_net_G_BtoA.pt"]
        if load_D:
            file_names.append(f"{initials}_net_D_A.pt")
            file_names.append(f"{initials}_net_D_B.pt")

        object_names = ['G_AtoB', 'G_BtoA'] if not load_D else [
            'G_AtoB', 'G_BtoA', 'D_A', 'D_B'
        ]

        self._load_objects(file_names, object_names)

    def load_lr_schedulers(self, initials):
        s_file_name_0 = os.path.join(self.save_dir,
                                     f"{initials}_scheduler_0.pt")
        s_file_name_1 = os.path.join(self.save_dir,
                                     f"{initials}_scheduler_1.pt")

        print(f"Loading scheduler-0 from {s_file_name_0}")
        self.schedulers[0].load_state_dict(
            torch.load(s_file_name_0, map_location=self.device))
        print(f"Loading scheduler-1 from {s_file_name_1}")
        self.schedulers[1].load_state_dict(
            torch.load(s_file_name_1, map_location=self.device))

    def load_train_model(self, initials):
        """ Loading Models for training purpose

        :type initials: str
        :param initials: Initials of the object names
        """
        self.load_networks(initials, load_D=True)

        optim_file_names = [f"{initials}_optim_G.pt", f"{initials}_optim_D.pt"]
        optim_object_names = ['optimizer_G', 'optimizer_D']

        self._load_objects(optim_file_names, optim_object_names)

        self.load_lr_schedulers(initials)

    def save_networks(self, epoch):
        """Save models

        :type epoch: str
        :param epoch: Current Epoch (prefix for the name)
        """
        for net_name in self.net_names:
            net = getattr(self, net_name)
            self.save_network(net, net_name, epoch)

    def save_optimizers_and_scheduler(self, epoch):
        """Save optimizers

        :type epoch: str
        :param epoch: Current Epoch (prefix for the name)
        """
        # Saving Optimizers
        self.save_optimizer_scheduler(self.optimizer_G, f"{epoch}_optim_G.pt")
        self.save_optimizer_scheduler(self.optimizer_D, f"{epoch}_optim_D.pt")

        # Saving Schedulers
        self.save_optimizer_scheduler(self.schedulers[0],
                                      f"{epoch}_scheduler_0.pt")
        self.save_optimizer_scheduler(self.schedulers[1],
                                      f"{epoch}_scheduler_1.pt")

    def save_optimizer_scheduler(self, optim_or_scheduler, name):
        """Save a single optimizer

        :param optim_or_scheduler: The optimizer object
        :type name: str
        :param name: Name of the optimizer
        """
        save_path = os.path.join(self.save_dir, name)

        torch.save(optim_or_scheduler.state_dict(), save_path)

    def save_network(self, net, net_name, epoch):
        save_filename = '%s_net_%s.pt' % (epoch, net_name)
        if self.opt.isCloud:
            save_path = save_filename
        else:
            save_path = os.path.join(self.save_dir, save_filename)

        if len(self.gpu_ids) > 0 and torch.cuda.is_available():
            torch.save(net.module.cpu().state_dict(), save_path)
            net.cuda(self.gpu_ids[0])
        else:
            torch.save(net.cpu().state_dict(), save_path)

    def get_current_losses(self) -> dict:
        """Get the Current Losses

        :return: Losses
        """
        if isinstance(self.loss_idt_A, (int, float)):
            idt_loss_A = self.loss_idt_A
        else:
            idt_loss_A = self.loss_idt_A.item()

        if isinstance(self.loss_idt_B, (int, float)):
            idt_loss_B = self.loss_idt_B
        else:
            idt_loss_B = self.loss_idt_B.item()
        return collections.OrderedDict({
            'loss_idt_A': idt_loss_A,
            'loss_idt_B': idt_loss_B,
            'loss_D_A': self.loss_D_A.item(),
            'loss_D_B': self.loss_D_B.item(),
            'loss_G_AtoB': self.loss_G_AtoB.item(),
            'loss_G_BtoA': self.loss_G_BtoA.item(),
            'cycle_loss_A': self.cycle_loss_A.item(),
            'cycle_loss_B': self.cycle_loss_B.item()
        })

    def get_current_image_path(self):
        """
        :return: The current image path
        """
        return self.image_paths

    def get_current_visuals(self):
        """Get the Current Produced Images

        :return: Images {real_A, real_B, fake_A, fake_B, rec_A, rec_B}
        :rtype: dict
        """
        r = collections.OrderedDict({
            'real_A': self.real_A,
            'real_B': self.real_B
        })

        if self.fake_A is not None:
            r['fake_A'] = self.fake_A
        if self.fake_B is not None:
            r['fake_B'] = self.fake_B
        if self.rec_A is not None:
            r['rec_A'] = self.rec_A
        if self.rec_B is not None:
            r['rec_B'] = self.rec_B
        return r
Beispiel #24
0
def create_image_pools(data_pool_size):
    fake_A_pool = ImagePool(pool_size=data_pool_size)
    fake_B_pool = ImagePool(pool_size=data_pool_size)

    return fake_A_pool, fake_B_pool
Beispiel #25
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert opt.isTrain
        assert opt.direction == 'AtoB'
        assert opt.dataset_mode == 'unaligned'
        BaseModel.__init__(self, opt)
        self.loss_names = [
            'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
            'G_idt_B'
        ]
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.netG,
                                        opt.norm,
                                        opt.dropout_rate,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        self.netD_A = networks.define_D(opt.output_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)
        self.netD_B = networks.define_D(opt.input_nc,
                                        opt.ndf,
                                        opt.netD,
                                        opt.n_layers_D,
                                        opt.norm,
                                        opt.init_type,
                                        opt.init_gain,
                                        self.gpu_ids,
                                        opt=opt)

        if opt.lambda_identity > 0.0:
            assert (opt.input_nc == opt.output_nc)
        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)

        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        self.eval_dataloader_AtoB = create_eval_dataloader(self.opt,
                                                           direction='AtoB')
        self.eval_dataloader_BtoA = create_eval_dataloader(self.opt,
                                                           direction='BtoA')

        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model = InceptionV3([block_idx])
        self.inception_model.to(self.device)
        self.inception_model.eval()

        self.best_fid_A, self.best_fid_B = 1e9, 1e9
        self.best_mIoU = -1e9
        self.fids_A, self.fids_B = [], []
        self.mIoUs = []
        self.is_best_A = False
        self.is_best_B = False
        self.npz_A = np.load(opt.real_stat_A_path)
        self.npz_B = np.load(opt.real_stat_B_path)
Beispiel #26
0
    def train(self):
        """
        Train the MaskShadowGAN model by starting from a saved checkpoint or from
        the beginning.
        """
        if self.opt.load_model is not None:
            checkpoint = 'checkpoints/' + self.opt.load_model
        else:
            checkpoint_name = datetime.now().strftime("%d%m%Y-%H%M")
            checkpoint = 'checkpoints/{}'.format(checkpoint_name)

            try:
                os.makedirs(checkpoint)
            except os.error:
                print("Failed to make new checkpoint directory.")
                sys.exit(1)

        # build the Mask-ShadowGAN graph
        graph = tf.Graph()
        with graph.as_default():
            maskshadowgan = MaskShadowGANModel(self.opt, training=True)
            dataA_iter, dataB_iter, realA, realB = maskshadowgan.generate_dataset(
            )
            fakeA, fakeB, optimizers, Gen_loss, D_A_loss, D_B_loss = maskshadowgan.build(
            )
            saver = tf.train.Saver(max_to_keep=2)
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(checkpoint, graph)

        # create image pools for holding previously generated images
        fakeA_pool = ImagePool(self.opt.pool_size)
        fakeB_pool = ImagePool(self.opt.pool_size)

        # create queue to hold generated shadow masks
        mask_queue = MaskQueue(self.opt.queue_size)

        with tf.Session(graph=graph) as sess:
            if self.opt.load_model is not None:  # restore graph and variables
                saver.restore(sess, tf.train.latest_checkpoint(checkpoint))
                ckpt = tf.train.get_checkpoint_state(checkpoint)
                step = int(
                    os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
            else:
                sess.run(tf.global_variables_initializer())
                step = 0

            max_steps = self.opt.niter + self.opt.niter_decay

            # initialize data iterators
            sess.run([dataA_iter.initializer, dataB_iter.initializer])

            try:
                while step < max_steps:
                    try:
                        realA_img, realB_img = sess.run([realA, realB
                                                         ])  # fetch inputs

                        # generate shadow free image from shadow image
                        fakeB_img = sess.run(
                            fakeB, feed_dict={maskshadowgan.realA: realA_img})

                        # generate shadow mask and add to mask queue
                        mask_queue.insert(mask_generator(realA_img, fakeB_img))
                        rand_mask = mask_queue.rand_item()

                        # generate shadow image from shadow free image and shadow mask
                        fakeA_img = sess.run(fakeA,
                                             feed_dict={
                                                 maskshadowgan.realB:
                                                 realB_img,
                                                 maskshadowgan.rand_mask:
                                                 rand_mask
                                             })

                        # calculate losses for the generators and discriminators and minimize them
                        _, Gen_loss_val, D_B_loss_val, \
                        D_A_loss_val, sum = sess.run([optimizers, Gen_loss,
                                                      D_B_loss, D_A_loss, summary],
                                                      feed_dict={maskshadowgan.realA: realA_img,
                                                                 maskshadowgan.realB: realB_img,
                                                                 maskshadowgan.rand_mask: rand_mask,
                                                                 maskshadowgan.last_mask: mask_queue.last_item(),
                                                                 maskshadowgan.fakeA: fakeA_pool.query(fakeA_img),
                                                                 maskshadowgan.fakeB: fakeB_pool.query(fakeB_img)})

                        writer.add_summary(sum, step)
                        writer.flush()

                        # display the losses of the Generators and Discriminators
                        if step % self.opt.display_frequency == 0:
                            print('Step {}:'.format(step))
                            print('Gen_loss: {}'.format(Gen_loss_val))
                            print('D_B_loss: {}'.format(D_B_loss_val))
                            print('D_A_loss: {}'.format(D_A_loss_val))

                        # save a checkpoint of the model to the `checkpoints` directory
                        if step % self.opt.checkpoint_frequency == 0:
                            save_path = saver.save(sess,
                                                   checkpoint + '/model.ckpt',
                                                   global_step=step)
                            print("Model saved as {}".format(save_path))

                        step += 1
                    except tf.errors.OutOfRangeError:  # reinitializer iterators every full pass through dataset
                        sess.run(
                            [dataA_iter.initializer, dataB_iter.initializer])
            except KeyboardInterrupt:  # save training before exiting
                print(
                    "Saving models training progress to the `checkpoints` directory..."
                )
                save_path = saver.save(sess,
                                       checkpoint + '/model.ckpt',
                                       global_step=step)
                print("Model saved as {}".format(save_path))
                sys.exit(0)
Beispiel #27
0
class CycleMultiDModel(CycleGANModel):
    def name(self):
        return 'CycleMultiDModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan and not opt.no_sigmoid
            self.netD_1A = networks.define_D(opt.output_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_1B = networks.define_D(opt.input_nc,
                                             opt.ndf,
                                             opt.which_model_netD,
                                             opt.n_layers_D,
                                             opt.norm,
                                             use_sigmoid,
                                             opt.init_type,
                                             self.gpu_ids,
                                             one_out=True,
                                             opt=opt)
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
            self.netD_B = networks.define_D(opt.input_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            one_out=False,
                                            opt=opt)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_1A', which_epoch)
                self.load_network(self.netD_B, 'D_1B', which_epoch)
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_1A = torch.optim.Adam(self.netD_1A.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizer_D_1B = torch.optim.Adam(self.netD_1B.parameters(),
                                                   lr=opt.lr,
                                                   betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            self.optimizers.append(self.optimizer_D_1A)
            self.optimizers.append(self.optimizer_D_1B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A, opt)
        if self.isTrain:
            networks.print_network(self.netD_A, opt)
            networks.print_network(self.netD_1A, opt)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        self.input_A.resize_(input_A.size()).copy_(input_A)
        self.input_B.resize_(input_B.size()).copy_(input_B)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG_A.forward(self.real_A)
        self.rec_A = self.netG_B.forward(self.fake_B)

        self.real_B = Variable(self.input_B, volatile=True)
        self.fake_A = self.netG_B.forward(self.real_B)
        self.rec_B = self.netG_A.forward(self.fake_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD.forward(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD.forward(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D_real, loss_D_fake

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A_real, self.loss_D_A_fake = self.backward_D_basic(
            self.netD_A, self.real_B, fake_B)
        self.loss_D_1A_real, self.loss_D_1A_fake = self.backward_D_basic(
            self.netD_1A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B_real, self.loss_D_B_fake = self.backward_D_basic(
            self.netD_B, self.real_A, fake_A)
        self.loss_D_1B_real, self.loss_D_1B_fake = self.backward_D_basic(
            self.netD_1B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.identity
        lambda_rec = self.opt.lambda_rec
        lambda_adv = self.opt.lambda_adv

        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A.forward(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_rec * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B.forward(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_rec * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss
        # D_A(G_A(A))
        self.fake_B = self.netG_A.forward(self.real_A)
        pred_fake = self.netD_A.forward(self.fake_B)
        pred_1fake = self.netD_1A.forward(self.fake_B)
        self.loss_G_A = (self.criterionGAN(pred_fake, True) +
                         self.criterionGAN(pred_1fake, True)) * lambda_adv
        # D_B(G_B(B))
        self.fake_A = self.netG_B.forward(self.real_B)
        pred_fake = self.netD_B.forward(self.fake_A)
        pred_1fake = self.netD_1B.forward(self.fake_A)
        self.loss_G_B = (self.criterionGAN(pred_fake, True) +
                         self.criterionGAN(pred_1fake, True)) * lambda_adv
        # Forward cycle loss
        self.rec_A = self.netG_B.forward(self.fake_B)
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_rec
        # Backward cycle loss
        self.rec_B = self.netG_A.forward(self.fake_A)
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_rec
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        D_A = self.loss_D_A_real.data[0] + self.loss_D_A_fake.data[0]
        D_1A = self.loss_D_1A_real.data[0] + self.loss_D_1A_fake.data[0]
        G_A = self.loss_G_A.data[0]
        G_A = self.loss_G_A.data[0]
        Cyc_A = self.loss_cycle_A.data[0]
        D_B = self.loss_D_B_real.data[0] + self.loss_D_B_fake.data[0]
        D_1B = self.loss_D_1B_real.data[0] + self.loss_D_1B_fake.data[0]
        G_B = self.loss_G_B.data[0]
        Cyc_B = self.loss_cycle_B.data[0]
        if self.opt.identity > 0.0:
            idt_A = self.loss_idt_A.data[0]
            idt_B = self.loss_idt_B.data[0]
            return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A),
                                ('Cyc_A', Cyc_A), ('idt_A', idt_A),
                                ('D_B', D_B), ('D_1B', D_1B), ('G_B', G_B),
                                ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
        else:
            return OrderedDict([('D_A', D_A), ('D_1A', D_1A), ('G_A', G_A),
                                ('Cyc_A', Cyc_A), ('D_B', D_B), ('D_1B', D_1B),
                                ('G_B', G_B), ('Cyc_B', Cyc_B)])

    def get_current_lr(self):
        lr_A = self.optimizer_D_A.param_groups[0]['lr']
        lr_B = self.optimizer_D_B.param_groups[0]['lr']
        lr_G = self.optimizer_G.param_groups[0]['lr']
        return OrderedDict([('D_A', lr_A), ('D_B', lr_B), ('G', lr_G)])

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        rec_A = util.tensor2im(self.rec_A.data)
        real_B = util.tensor2im(self.real_B.data)
        fake_A = util.tensor2im(self.fake_A.data)
        rec_B = util.tensor2im(self.rec_B.data)
        if self.opt.isTrain and self.opt.identity > 0.0:
            idt_A = util.tensor2im(self.idt_A.data)
            idt_B = util.tensor2im(self.idt_B.data)
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('idt_B', idt_B),
                                ('real_B', real_B), ('fake_A', fake_A),
                                ('rec_B', rec_B), ('idt_A', idt_A)])
        else:
            return OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                ('rec_A', rec_A), ('real_B', real_B),
                                ('fake_A', fake_A), ('rec_B', rec_B)])

    def get_network_params(self):
        return [('G_A', util.get_params(self.netG_A)),
                ('G_B', util.get_params(self.netG_B)),
                ('D_A', util.get_params(self.netD_A)),
                ('D_B', util.get_params(self.netD_B)),
                ('D_1A', util.get_params(self.netD_1A)),
                ('D_1B', util.get_params(self.netD_1B))]

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
        self.save_network(self.netD_1A, 'D_1A', label, self.gpu_ids)
        self.save_network(self.netD_1B, 'D_1B', label, self.gpu_ids)
def do_train(Cfg, model_G, model_Dip, model_Dii, model_D_reid, train_loader,
             val_loader, optimizerG, optimizerDip, optimizerDii, GAN_loss,
             L1_loss, ReID_loss, schedulerG, schedulerDip, schedulerDii):
    log_period = Cfg.SOLVER.LOG_PERIOD
    checkpoint_period = Cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = Cfg.SOLVER.EVAL_PERIOD
    output_dir = Cfg.DATALOADER.LOG_DIR
    # need modified the following in cfg
    epsilon = 0.00001
    margin = 0.4
    ####################################
    device = "cuda"
    epochs = Cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger('pose-transfer-gan.train')
    logger.info('Start training')

    if device:
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for training'.format(
                torch.cuda.device_count()))
            model_G = nn.DataParallel(model_G)
            model_Dii = nn.DataParallel(model_Dii)
            model_Dip = nn.DataParallel(model_Dip)
        model_G.to(device)
        model_Dip.to(device)
        model_Dii.to(device)
        model_D_reid.to(device)
    lossG_meter = AverageMeter()
    lossDip_meter = AverageMeter()
    lossDii_meter = AverageMeter()
    distDreid_meter = AverageMeter()
    fake_ii_pool = ImagePool(50)
    fake_ip_pool = ImagePool(50)

    #evaluator = R1_mAP(num_query, max_rank=50, feat_norm=Cfg.TEST.FEAT_NORM)
    #train
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        lossG_meter.reset()
        lossDip_meter.reset()
        lossDii_meter.reset()
        distDreid_meter.reset()
        schedulerG.step()
        schedulerDip.step()
        schedulerDii.step()

        model_G.train()
        model_Dip.train()
        model_Dii.train()
        model_D_reid.eval()
        for iter, batch in enumerate(train_loader):
            img1 = batch['img1'].to(device)
            pose1 = batch['pose1'].to(device)
            img2 = batch['img2'].to(device)
            pose2 = batch['pose2'].to(device)
            input_G = (img1, pose2)

            #forward
            fake_img2 = model_G(input_G)
            optimizerG.zero_grad()

            #train G
            input_Dip = torch.cat((fake_img2, pose2), 1)
            pred_fake_ip = model_Dip(input_Dip)
            loss_G_ip = GAN_loss(pred_fake_ip, True)
            input_Dii = torch.cat((fake_img2, img1), 1)
            pred_fake_ii = model_Dii(input_Dii)
            loss_G_ii = GAN_loss(pred_fake_ii, True)

            loss_L1, _, _ = L1_loss(fake_img2, img2)

            feats_real = model_D_reid(img2)
            feats_fake = model_D_reid(fake_img2)

            dist_cos = torch.acos(
                torch.clamp(torch.sum(feats_real * feats_fake, 1),
                            -1 + epsilon, 1 - epsilon))

            same_id_tensor = torch.FloatTensor(
                dist_cos.size()).fill_(1).to('cuda')
            dist_cos_margin = torch.max(dist_cos - margin,
                                        torch.zeros_like(dist_cos))
            loss_reid = ReID_loss(dist_cos_margin, same_id_tensor)
            factor = loss_reid_factor(epoch)
            loss_G = 0.5 * loss_G_ii * Cfg.LOSS.GAN_WEIGHT + 0.5 * loss_G_ip * Cfg.LOSS.GAN_WEIGHT + loss_L1 + loss_reid * Cfg.LOSS.REID_WEIGHT * factor
            loss_G.backward()
            optimizerG.step()

            #train Dip
            for i in range(Cfg.SOLVER.DG_RATIO):
                optimizerDip.zero_grad()
                real_input_ip = torch.cat((img2, pose2), 1)
                fake_input_ip = fake_ip_pool.query(
                    torch.cat((fake_img2, pose2), 1).data)
                pred_real_ip = model_Dip(real_input_ip)
                loss_Dip_real = GAN_loss(pred_real_ip, True)
                pred_fake_ip = model_Dip(fake_input_ip)
                loss_Dip_fake = GAN_loss(pred_fake_ip, False)
                loss_Dip = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dip_real +
                                                        loss_Dip_fake)
                loss_Dip.backward()
                optimizerDip.step()
            #train Dii
            for i in range(Cfg.SOLVER.DG_RATIO):
                optimizerDii.zero_grad()
                real_input_ii = torch.cat((img2, img1), 1)
                fake_input_ii = fake_ii_pool.query(
                    torch.cat((fake_img2, img1), 1).data)
                pred_real_ii = model_Dii(real_input_ii)
                loss_Dii_real = GAN_loss(pred_real_ii, True)
                pred_fake_ii = model_Dii(fake_input_ii)
                loss_Dii_fake = GAN_loss(pred_fake_ii, False)
                loss_Dii = 0.5 * Cfg.LOSS.GAN_WEIGHT * (loss_Dii_real +
                                                        loss_Dii_fake)
                loss_Dii.backward()
                optimizerDii.step()

            lossG_meter.update(loss_G.item(), 1)
            lossDip_meter.update(loss_Dip.item(), 1)
            lossDii_meter.update(loss_Dii.item(), 1)
            distDreid_meter.update(dist_cos.mean().item(), 1)
            if (iter + 1) % log_period == 0:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] G Loss: {:.3f}, Dip Loss: {:.3f}, Dii Loss: {:.3f}, Base G_Lr: {:.2e}, Base Dip_Lr: {:.2e}, Base Dii_Lr: {:.2e}"
                    .format(epoch, (iter + 1), len(train_loader),
                            lossG_meter.avg, lossDip_meter.avg,
                            lossDii_meter.avg,
                            schedulerG.get_lr()[0],
                            schedulerDip.get_lr()[0],
                            schedulerDii.get_lr()[0]))  #scheduler.get_lr()[0]
                logger.info("ReID Cos Distance: {:.3f}".format(
                    distDreid_meter.avg))
        end_time = time.time()
        time_per_batch = (end_time - start_time) / (iter + 1)
        logger.info(
            "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
            .format(epoch, time_per_batch,
                    train_loader.batch_size / time_per_batch))

        if epoch % checkpoint_period == 0:
            torch.save(model_G.state_dict(),
                       output_dir + 'model_G_{}.pth'.format(epoch))
            torch.save(model_Dip.state_dict(),
                       output_dir + 'model_Dip_{}.pth'.format(epoch))
            torch.save(model_Dii.state_dict(),
                       output_dir + 'model_Dii_{}.pth'.format(epoch))
        #
        if epoch % eval_period == 0:
            np.save(output_dir + 'train_Bx6x128x64_epoch{}.npy'.format(epoch),
                    fake_ii_pool.images[0].cpu().numpy())
            logger.info('Entering Evaluation...')
            tmp_results = []
            model_G.eval()
            for iter, batch in enumerate(val_loader):
                with torch.no_grad():
                    img1 = batch['img1'].to(device)
                    pose1 = batch['pose1'].to(device)
                    img2 = batch['img2'].to(device)
                    pose2 = batch['pose2'].to(device)
                    input_G = (img1, pose2)
                    fake_img2 = model_G(input_G)
                    tmp_result = torch.cat((img1, img2, fake_img2),
                                           1).cpu().numpy()
                    tmp_results.append(tmp_result)

            np.save(output_dir + 'test_Bx6x128x64_epoch{}.npy'.format(epoch),
                    tmp_results[0])
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        if isinstance(opt.weight_adv, list):
            self.weight_adv = map(float, opt.weight_adv)
        else:
            self.weight_adv = None
        if isinstance(opt.weight_rec, list):
            self.weight_rec = map(float, opt.weight_rec)
        else:
            self.weight_rec = None
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        if not opt.idt:
            self.down_2 = torch.nn.AvgPool2d(2)
            self.up_2 = torch.nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.down_2 = torch.nn.AvgPool2d(1)
            self.up_2 = torch.nn.AvgPool2d(1)

        self.netG_A = networks.define_G(opt.input_nc,
                                        opt.output_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        n_upsampling=opt.n_upsample,
                                        n_downsampling=opt.n_downsample,
                                        side='A',
                                        opt=opt)
        self.netG_B = networks.define_G(opt.output_nc,
                                        opt.input_nc,
                                        opt.ngf,
                                        opt.which_model_netG,
                                        opt.norm,
                                        not opt.no_dropout,
                                        opt.init_type,
                                        self.gpu_ids,
                                        n_upsampling=opt.n_upsample,
                                        n_downsampling=opt.n_downsample,
                                        side='B',
                                        opt=opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
            self.netD_B = networks.define_D(opt.input_nc,
                                            opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D,
                                            opt.norm,
                                            use_sigmoid,
                                            opt.init_type,
                                            self.gpu_ids,
                                            opt=opt)
        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A,
                               opt,
                               input_shape=(opt.input_nc, opt.fineSize,
                                            opt.fineSize))
        if self.isTrain:
            networks.print_network(self.netD_A,
                                   opt,
                                   input_shape=(3, opt.fineSize, opt.fineSize))
        print('-----------------------------------------------')

        if not self.isTrain or opt.continue_train:
            print 'Continue from ', opt.which_epoch
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain and not opt.test:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions

            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor,
                                                 target_weight=self.weight_adv)
            self.criterionCycle = networks.RECLoss(
                target_weight=self.weight_rec)
            # initialize optimizers
            self.optimizer_G_A = torch.optim.Adam(self.netG_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_G_B = torch.optim.Adam(self.netG_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))

            if opt.d_lr2:
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=(opt.lr / 2.0),
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=(opt.lr / 2.0),
                                                      betas=(opt.beta1, 0.999))
            else:
                self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))
                self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                      lr=opt.lr,
                                                      betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G_A)
            self.optimizers.append(self.optimizer_G_B)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))
Beispiel #30
0
class Trainer(object):

    def __init__(self, cuda, model, optimizer,loss_fun,
                train_loader,test_loader,lmk_num,view,crossentropy_weight,
                out, max_epoch, network_num,batch_size,GAN,
                do_classification=True,do_landmarkdetect=True,
                size_average=False, interval_validate=None,
                compete = False,onlyEval=False):
        self.cuda = cuda

        self.model = model
        self.optim = optimizer

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.interval_validate = interval_validate
        self.network_num = network_num

        self.do_classification = do_classification
        self.do_landmarkdetect = do_landmarkdetect
        self.crossentropy_weight = crossentropy_weight


        self.timestamp_start = \
            datetime.datetime.now(pytz.timezone('Asia/Tokyo'))
        self.size_average = size_average

        self.out = out
        if not osp.exists(self.out):
            os.makedirs(self.out)

        self.lmk_num = lmk_num
        self.GAN = GAN
        self.onlyEval = onlyEval
        if self.GAN:
            GAN_lr = 0.0002
            input_nc = 1
            output_nc = self.lmk_num
            ndf = 64
            norm_layer = torchsrc.models.get_norm_layer(norm_type='batch')
            gpu_ids = [0]
            self.netD = torchsrc.models.NLayerDiscriminator(input_nc+output_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=True, gpu_ids=gpu_ids)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),lr=GAN_lr, betas=(0.5, 0.999))
            self.netD.cuda()
            self.netD.apply(torchsrc.models.weights_init)
            pool_size = 10
            self.fake_AB_pool = ImagePool(pool_size)
            no_lsgan = True
            self.Tensor = torch.cuda.FloatTensor if gpu_ids else torch.Tensor
            self.criterionGAN = torchsrc.models.GANLoss(use_lsgan=not no_lsgan, tensor=self.Tensor)


        self.max_epoch = max_epoch
        self.epoch = 0
        self.iteration = 0
        self.best_mean_iu = 0

        self.compete = compete
        self.batch_size = batch_size
        self.view = view
        self.loss_fun = loss_fun


    def forward_step(self, data, category_name):
        if category_name == 'KidneyLong':
            pred_lmk = self.model(data, 'KidneyLong')
        elif category_name == 'KidneyTrans':
            pred_lmk = self.model(data, 'KidneyTrans')
        elif category_name == 'LiverLong':
            pred_lmk = self.model(data, 'LiverLong')
        elif category_name == 'SpleenLong':
            pred_lmk = self.model(data, 'SpleenLong')
        elif category_name == 'SpleenTrans':
            pred_lmk = self.model(data, 'SpleenTrans')
        return pred_lmk

    def backward_D(self,real_A,real_B,fake_B):
        # Fake
        # stop backprop to the generator by detaching fake_B
        fake_AB = self.fake_AB_pool.query(torch.cat((real_A, fake_B), 1))
        pred_fake = self.netD.forward(fake_AB.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((real_A, real_B), 1)
        pred_real = self.netD.forward(real_AB)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Combined loss
        self.loss_D = (loss_D_fake + loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self,real_A,fake_B):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((real_A, fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        loss_G_GAN = self.criterionGAN(pred_fake, True)
        return loss_G_GAN




    def validate(self):
        self.model.train()
        out = osp.join(self.out, 'seg_output')
        out_vis = osp.join(self.out, 'visualization')
        results_epoch_dir = osp.join(out,'epoch_%04d' % self.epoch)
        mkdir(results_epoch_dir)
        results_vis_epoch_dir = osp.join(out_vis, 'epoch_%04d' % self.epoch)
        mkdir(results_vis_epoch_dir)

        prev_sub_name = 'start'
        prev_view_name = 'start'

        for batch_idx, (data,target,target2ch,sub_name,view,img_name) in tqdm.tqdm(
                # enumerate(self.test_loader), total=len(self.test_loader),
                enumerate(self.test_loader), total=len(self.test_loader),
                desc='Valid epoch=%d' % self.epoch, ncols=80,
                leave=False):
            # if batch_idx>1000:
            #     return
            #

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data,volatile=True), Variable(target,volatile=True)

            # need_to_run = False
            # for sk in range(len(sub_name)):
            #     batch_finish_flag = os.path.join(results_epoch_dir, sub_name[sk], ('%s_%s.nii.gz' % (sub_name[sk], view[sk])))
            #     if not (os.path.exists(batch_finish_flag)):
            #         need_to_run = True
            # if not need_to_run:
            #     continue
            #
            pred = self.model(data)

            # imgs = data.data.cpu()
            lbl_pred = pred.data.max(1)[1].cpu().numpy()[:, :, :]

            batch_num = lbl_pred.shape[0]
            for si in range(batch_num):
                curr_sub_name = sub_name[si]
                curr_view_name = view[si]
                curr_img_name = img_name[si]

                # out_img_dir = os.path.join(results_epoch_dir, curr_sub_name)
                # finish_flag = os.path.join(out_img_dir,('%s_%s.nii.gz'%(curr_sub_name,curr_view_name)))
                # if os.path.exists(finish_flag):
                #     prev_sub_name = 'start'
                #     prev_view_name = 'start'
                #     continue

                if prev_sub_name == 'start':
                    if self.view == 'viewall':
                        seg = np.zeros([512,512,512], np.uint8)
                    else:
                        seg = np.zeros([512,512,1000],np.uint8)
                    slice_num = 0
                elif not(prev_sub_name==curr_sub_name and prev_view_name==curr_view_name):
                    out_img_dir = os.path.join(results_epoch_dir, prev_sub_name)
                    mkdir(out_img_dir)
                    out_nii_file = os.path.join(out_img_dir,('%s_%s.nii.gz'%(prev_sub_name,prev_view_name)))
                    seg_img = nib.Nifti1Image(seg, affine=np.eye(4))
                    nib.save(seg_img, out_nii_file)
                    if self.view == 'viewall':
                        seg = np.zeros([512,512,512], np.uint8)
                    else:
                        seg = np.zeros([512,512,1000],np.uint8)
                    slice_num = 0

                test_slice_name = ('slice_%04d.png'%(slice_num+1))
                assert test_slice_name == curr_img_name
                seg_slice = lbl_pred[si, :, :].astype(np.uint8)
                seg_slice = scipy.misc.imresize(seg_slice, (512, 512), interp='nearest')
                if curr_view_name == 'view1':
                    seg[slice_num,:,:] = seg_slice
                elif curr_view_name == 'view2':
                    seg[:,slice_num,:] = seg_slice
                elif curr_view_name == 'view3':
                    seg[:, :, slice_num] = seg_slice

                slice_num+=1
                prev_sub_name = curr_sub_name
                prev_view_name = curr_view_name


        out_img_dir = os.path.join(results_epoch_dir, curr_sub_name)
        mkdir(out_img_dir)
        out_nii_file = os.path.join(out_img_dir, ('%s_%s.nii.gz' % (curr_sub_name, curr_view_name)))
        seg_img = nib.Nifti1Image(seg, affine=np.eye(4))
        nib.save(seg_img, out_nii_file)

            #     out_img_dir = os.path.join(results_epoch_dir, sub_name[si], view[si])
            #     mkdir(out_img_dir)
            #     out_mat_file = os.path.join(out_img_dir,img_name[si].replace('.png','.mat'))
            #     if not os.path.exists(out_mat_file):
            #         out_dict = {}
            #         out_dict["sub_name"] = sub_name[si]
            #         out_dict["view"] = view[si]
            #         out_dict['img_name'] = img_name[si].replace('.png','.mat')
            #         out_dict["seg"] = seg
            #         sio.savemat(out_mat_file, out_dict)

            # if not(sub_name[0] == '010-006-001'):
            #     continue
            #
            # lbl_true = target.data.cpu()
            # for img, lt, lp, name, view, fname in zip(imgs, lbl_true, lbl_pred,sub_name,view,img_name):
            #     img, lt = self.test_loader.dataset.untransform(img, lt)
            #     if lt.sum()>5000:
            #         viz = fcn.utils.visualize_segmentation(
            #             lbl_pred = lp, lbl_true = lt, img = img, n_class=2)
            #         out_img_dir = os.path.join(results_vis_epoch_dir,name,view)
            #         mkdir(out_img_dir)
            #         out_img_file = os.path.join(out_img_dir,fname)
            #         if not (os.path.exists(out_img_file)):
            #             skimage.io.imsave(out_img_file, viz)




    def train(self):
        self.model.train()
        out = osp.join(self.out, 'visualization')
        mkdir(out)
        log_file = osp.join(out, 'training_loss.txt')
        fv = open(log_file, 'a')

        for batch_idx, (data, target, target2ch, sub_name, view, img_name) in tqdm.tqdm(
            enumerate(self.train_loader), total=len(self.train_loader),
                desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):
            #iteration = batch_idx + self.epoch * len(self.lmk_train_loader)

            # if not(sub_name[0] == '006-002-003' and view[0] =='view3' and img_name[0] == 'slice_0288.png'):
            #     continue

            if self.cuda:
                data, target, target2ch = data.cuda(), target.cuda(), target2ch.cuda()
            data, target, target2ch = Variable(data), Variable(target), Variable(target2ch)

            pred = self.model(data)
            self.optim.zero_grad()
            if self.GAN:
                self.optimizer_D.zero_grad()
                self.backward_D(data,target2ch,pred)
                self.optimizer_D.step()
                loss_G_GAN = self.backward_G(data,pred)
                if self.loss_fun == 'cross_entropy':
                    arr = np.array(self.crossentropy_weight)
                    weight = torch.from_numpy(arr).cuda().float()
                    loss_G_L2 = cross_entropy2d(pred, target.long(),weight=weight)
                elif self.loss_fun == 'Dice':
                    loss_G_L2 = dice_loss(pred,target2ch)
                elif self.loss_fun == 'Dice_norm':
                    loss_G_L2 = dice_loss_norm(pred, target2ch)
                loss = loss_G_GAN + loss_G_L2*100

                fv.write('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss = %.4f \n' % (
                    self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] ))

                if batch_idx%10 == 0:
                    print('--- epoch=%d, batch_idx=%d, D_loss=%.4f, G_loss=%.4f, L2_loss_loss = %.4f  \n' % (
                    self.epoch, batch_idx, self.loss_D.data[0], loss_G_GAN.data[0],loss_G_L2.data[0] ))
            else:
                if self.loss_fun == 'cross_entropy':
                    arr = np.array(self.crossentropy_weight)
                    weight = torch.from_numpy(arr).cuda().float()
                    loss = cross_entropy2d(pred, target.long(),weight=weight)
                elif self.loss_fun == 'Dice':
                    loss = dice_loss(pred,target2ch)
                elif self.loss_fun == 'Dice_norm':
                    loss = dice_loss_norm(pred, target2ch)
            loss.backward()
            self.optim.step()
            if batch_idx % 10 == 0:
                print('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0]))
                fv.write('epoch=%d, batch_idx=%d, loss=%.4f \n'%(self.epoch,batch_idx,loss.data[0]))


        fv.close()

    def train_epoch(self):
        for epoch in tqdm.trange(self.epoch, self.max_epoch,
                                 desc='Train', ncols=80):
            self.epoch = epoch
            out = osp.join(self.out, 'models', self.view)
            mkdir(out)

            model_pth = '%s/model_epoch_%04d.pth' % (out, epoch)
            gan_model_pth = '%s/GAN_D_epoch_%04d.pth' % (out, epoch)





            if os.path.exists(model_pth):
                self.model.load_state_dict(torch.load(model_pth))
                # if epoch == 9:
                # self.validate()
                # if self.onlyEval:
                # self.validate()
                if self.GAN and os.path.exists(gan_model_pth):
                    self.netD.load_state_dict(torch.load(gan_model_pth))
            else:
                if not self.onlyEval:
                    self.train()
                    self.validate()
                    torch.save(self.model.state_dict(), model_pth)
                    if self.GAN:
                        torch.save(self.netD.state_dict(), gan_model_pth)