예제 #1
0
    def __init__(self, config):
        self.config = config

        self.device = config.device
        self.max_itr = config.max_itr
        self.batch_size = config.batch_size
        self.img_size = config.img_size
        self.dim_z = config.dim_z
        self.dim_c = config.dim_c
        self.scale = config.scale
        self.n_gen = config.n_gen

        self.start_itr = 1

        dataloader = DataLoader(
            config.data_root, config.dataset_name, config.img_size, config.batch_size, config.with_label
            )
        train_loader, test_loader = dataloader.get_loader(only_train=True)
        self.dataloader = train_loader
        self.dataloader = endless_dataloader(self.dataloader)

        self.generator = Generator(config).to(config.device)
        self.discriminator = Discriminator(config).to(config.device)

        self.optim_g = torch.optim.Adam(self.generator.parameters(), lr=config.lr_g, betas=(config.beta1, config.beta2))
        self.optim_d = torch.optim.Adam(self.discriminator.parameters(), lr=config.lr_d, betas=(config.beta1, config.beta2))
        self.criterion = GANLoss()

        if not self.config.checkpoint_path == '':
            self._load_models(self.config.checkpoint_path)

        self.x, self.y, self.r = get_coordinates(self.img_size, self.img_size, self.scale, self.batch_size)
        self.x, self.y, self.r = self.x.to(self.device), self.y.to(self.device), self.r.to(self.device)

        self.writer = SummaryWriter(log_dir=config.log_dir)
예제 #2
0
    def __init__(self, args):
        self.args = args
        self.device = args.device
        self.start_iter = 1
        self.train_iters = args.train_iters
        # coeffs
        self.lambda_A = args.lambda_A
        self.lambda_B = args.lambda_B
        self.lambda_idt = args.lambda_idt

        self.dataloader_A, self.dataloader_B = get_dataloader(args)

        self.D_B, self.G_AB = get_model(args)
        self.D_A, self.G_BA = get_model(args)

        self.criterion_GAN = GANLoss(use_lsgan=args.use_lsgan).to(args.device)
        self.criterion_cycle = nn.L1Loss()
        self.criterion_idt = nn.L1Loss()

        self.optimizer_D = torch.optim.Adam(
            itertools.chain(self.D_B.parameters(), self.D_A.parameters()),
            lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
        self.optimizer_G = torch.optim.Adam(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
            lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)

        self.logger = self.get_logger(args)
        self.writer = SummaryWriter(args.log_dir)

        save_args(args.log_dir, args)
예제 #3
0
    def __init__(self, **kwargs):
        self.netG_in_channels = kwargs['netG_in_channels']
        self.netG_out_channels = kwargs['netG_out_channels']
        self.phase = kwargs['phase']
        self.device = kwargs['device']
        self.gpus = [int(x) for x in list(kwargs['gpu'])]

        if self.phase == 'train':
            use_sigmoid = not kwargs['use_lsgan']
            self.netG = resnet152_fpn(self.netG_in_channels,
                                      self.netG_out_channels,
                                      pretrained=False)
            self.netD = NLayerDiscriminator(self.netG_in_channels +
                                            self.netG_out_channels,
                                            64,
                                            use_sigmoid=use_sigmoid,
                                            init_type='normal')
            if len(kwargs['gpu']) > 1:
                self.netG = nn.DataParallel(self.netG, device_ids=self.gpus)
                self.netD = nn.DataParallel(self.netD, device_ids=self.gpus)
            self.netG.to(self.device)
            self.netD.to(self.device)
        else:
            self.netG = resnet152_fpn(self.netG_in_channels,
                                      self.netG_out_channels,
                                      pretrained=False)
            print('Loading model from {}.'.format(kwargs['model_file']))
            self.netG.load_state_dict(torch.load(kwargs['model_file']))
            self.netG.to(self.device)
            self.netG.eval()

        if self.phase == 'train':
            # self.fake_AB_pool = ImagePool(kwargs['poolsize'])

            self.GANloss = GANLoss(self.device, use_lsgan=kwargs['use_lsgan'])
            self.L1loss = nn.L1Loss()
            self.lambda_L1 = kwargs['lambda_L1']
            self.CEloss = nn.CrossEntropyLoss()

            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=kwargs['lr'],
                                                betas=(0.5, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=kwargs['lr'],
                                                betas=(0.5, 0.999))

            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch - kwargs['niter']) / float(
                    kwargs['niter_decay'] + 1)
                return lr_l

            self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_G, lr_lambda=lambda_rule)
            self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_D, lr_lambda=lambda_rule)
예제 #4
0
    def initialize(self, n_input_channels, n_output_channels, n_blocks,
                   initial_filters, dropout_value, lr, batch_size, image_width,
                   image_height, gpu_ids, gan, pool_size, n_blocks_discr):

        self.input_img = self.tensor(batch_size, n_input_channels,
                                     image_height, image_width)
        self.input_gt = self.tensor(batch_size, n_output_channels,
                                    image_height, image_width)

        self.generator = UNetV2(n_input_channels,
                                n_output_channels,
                                n_blocks,
                                initial_filters,
                                gpu_ids=gpu_ids)

        if gan:
            self.discriminator = ImageDiscriminatorConv(
                n_output_channels,
                initial_filters,
                dropout_value,
                gpu_ids=gpu_ids,
                n_blocks=n_blocks_discr)
            self.criterion_gan = GANLoss(tensor=self.tensor)
            self.optimizer_dis = torch.optim.Adam(
                self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
            self.fake_mask_pool = ImagePool(pool_size)

        if self.load_network:
            self._load_network(self.generator, 'Model', self.load_epoch)
            if gan:
                self._load_network(self.discriminator, 'Discriminator',
                                   self.load_epoch)

        self.criterion_seg = BCELoss2d()
        self.optimizer_seg = torch.optim.Adam(self.generator.parameters(),
                                              lr=lr,
                                              betas=(0.5, 0.999))

        print('---------- Network initialized -------------')
        self.print_network(self.generator)
        if gan:
            self.print_network(self.discriminator)
        print('-----------------------------------------------')
예제 #5
0
파일: experiments.py 프로젝트: kzky/works
    def __init__(self,
                 decoder,
                 device=None,
                 model=None,
                 dim_rand=30,
                 n_cls=10,
                 learning_rate=1e-3,
                 act=F.relu):

        # Settings
        self.device = device
        self.dim_rand = dim_rand
        self.n_cls = n_cls
        self.act = act
        self.learning_rate = 1e-5

        # Model
        generator, discriminator = create_gan_experiment(model=model,
                                                         act=act,
                                                         dim_rand=dim_rand)
        self.generator = generator
        self.generator.to_gpu(device) if self.device else None
        self.discriminator = discriminator
        self.discriminator.to_gpu(device) if self.device else None
        self.decoder = decoder

        # Optimizer
        self.optimizer_gen = optimizers.Adam(learning_rate)
        self.optimizer_gen.setup(self.generator)
        self.optimizer_gen.use_cleargrads()

        self.optimizer_dis = optimizers.Adam(learning_rate)
        self.optimizer_dis.setup(self.discriminator)
        self.optimizer_dis.use_cleargrads()

        # Losses
        self.gan_loss = GANLoss()
예제 #6
0
    def __init__(self,
                 checkpoints_dir,
                 lr=0.0002,
                 niter_decay=45,
                 batch_size=4,
                 gpu_ids=[0, 1],
                 isTrain=True):

        #Hyperparams
        self.lr = lr
        self.beta1 = 0.5
        self.niter_decay = niter_decay

        self.input_nc = 11  #number of input channels
        self.output_nc = 3  #number of output channels
        self.label_nc = 11  #number of mask channels
        self.isTrain = isTrain  #Whether to train
        self.dis_net_input_nc = self.input_nc + self.output_nc
        self.dis_n_layers = 3
        self.num_D = 2
        self.lambda_feat = 10.0
        self.z_dim = 512
        self.batch_size = batch_size

        self.gpu_ids = gpu_ids
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor

        #Loss Function parameters - used in init_loss_funtion
        self.use_gan_feat_loss = True
        self.no_vgg_loss = True
        self.no_l2_loss = True
        self.checkpoints_dir = checkpoints_dir

        #Optimization Parameters
        self.use_lsgan = False

        self.no_ganFeat_loss = True

        self.gen_net = GeneratorNetwork(self.input_nc, self.output_nc)
        if len(gpu_ids) > 0:
            self.gen_net.cuda(gpu_ids[0])
        self.gen_net.apply(weights_init)

        if self.isTrain:
            use_sigmoid = True

        self.dis_net = DiscriminatorNetwork(self.dis_net_input_nc,
                                            self.dis_n_layers, self.num_D,
                                            use_sigmoid)
        if len(gpu_ids) > 0:
            self.dis_net.cuda(gpu_ids[0])
        self.dis_net.apply(weights_init)

        #Dont know why we need this???
        self.dis_net2 = DiscriminatorNetwork(self.dis_net_input_nc,
                                             self.dis_n_layers, self.num_D,
                                             use_sigmoid)
        if len(gpu_ids) > 0:
            self.dis_net2.cuda(gpu_ids[0])
        self.dis_net2.apply(weights_init)

        #         self.p_net = PNetwork(self.label_nc, self.output_nc)
        #         self.p_net.apply(weights_init)

        self.p_net = PNetwork(self.batch_size, self.checkpoints_dir)

        #TODO
        longSize = 256
        n_downsample_global = 2
        embed_feature_size = longSize // 2**n_downsample_global

        self.encoder_skin_net = EncoderGenerator_mask_skin(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.encoder_skin_net.cuda(gpu_ids[0])
        self.encoder_skin_net.apply(weights_init)

        self.encoder_hair_net = EncoderGenerator_mask_skin(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.encoder_hair_net.cuda(gpu_ids[0])
        self.encoder_hair_net.apply(weights_init)

        self.encoder_left_eye_net = EncoderGenerator_mask_eye(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.encoder_left_eye_net.cuda(gpu_ids[0])
        self.encoder_left_eye_net.apply(weights_init)

        self.encoder_right_eye_net = EncoderGenerator_mask_eye(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.encoder_right_eye_net.cuda(gpu_ids[0])
        self.encoder_right_eye_net.apply(weights_init)

        self.encoder_mouth_net = EncoderGenerator_mask_mouth(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.encoder_mouth_net.cuda(gpu_ids[0])
        self.encoder_mouth_net.apply(weights_init)

        self.decoder_skin_net = DecoderGenerator_mask_skin(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_skin_net.cuda(gpu_ids[0])
        self.decoder_skin_net.apply(weights_init)

        self.decoder_hair_net = DecoderGenerator_mask_skin(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_hair_net.cuda(gpu_ids[0])
        self.decoder_hair_net.apply(weights_init)

        self.decoder_left_eye_net = DecoderGenerator_mask_eye(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_left_eye_net.cuda(gpu_ids[0])
        self.decoder_left_eye_net.apply(weights_init)

        self.decoder_right_eye_net = DecoderGenerator_mask_eye(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_right_eye_net.cuda(gpu_ids[0])
        self.decoder_right_eye_net.apply(weights_init)

        self.decoder_mouth_net = DecoderGenerator_mask_mouth(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_mouth_net.cuda(gpu_ids[0])
        self.decoder_mouth_net.apply(weights_init)

        self.decoder_skin_image_net = DecoderGenerator_mask_skin_image(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_skin_image_net.cuda(gpu_ids[0])
        self.decoder_skin_image_net.apply(weights_init)

        self.decoder_hair_image_net = DecoderGenerator_mask_skin_image(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_hair_image_net.cuda(gpu_ids[0])
        self.decoder_hair_image_net.apply(weights_init)

        self.decoder_left_eye_image_net = DecoderGenerator_mask_eye_image(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_left_eye_image_net.cuda(gpu_ids[0])
        self.decoder_left_eye_image_net.apply(weights_init)

        self.decoder_right_eye_image_net = DecoderGenerator_mask_eye_image(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_right_eye_image_net.cuda(gpu_ids[0])
        self.decoder_right_eye_image_net.apply(weights_init)

        self.decoder_mouth_image_net = DecoderGenerator_mask_mouth_image(
            functools.partial(nn.BatchNorm2d, affine=True))
        if len(gpu_ids) > 0:
            self.decoder_mouth_image_net.cuda(gpu_ids[0])
        self.decoder_mouth_image_net.apply(weights_init)

        if self.isTrain:
            self.loss_filter = self.init_loss_filter(self.no_ganFeat_loss,
                                                     self.no_vgg_loss,
                                                     self.no_l2_loss)
            self.old_lr = self.lr

            self.criterionGAN = GANLoss(use_lsgan=self.use_lsgan,
                                        tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            self.criterionL2 = torch.nn.MSELoss()
            self.criterionL1 = torch.nn.L1Loss()
            #             self.criterionMFM = MFMLoss()

            weight_list = [0.2, 1, 5, 5, 5, 5, 3, 8, 8, 8, 1]
            self.criterionCrossEntropy = torch.nn.CrossEntropyLoss(
                weight=torch.FloatTensor(weight_list))

            #             if self.no_vgg_loss:
            #                 self.criterionVGG = VGGLoss(weights=None)

            #             self.criterionGM = GramMatrixLoss()
            print(self.loss_filter)
            self.loss_names = self.loss_filter('KL_embed', 'L2_mask_image',
                                               'G_GAN', 'G_GAN_Feat', 'D_real',
                                               'D_fake', 'L2_image', 'G2_GAN',
                                               'D2_real', 'D2_fake')

            params_decoder = list(self.decoder_skin_net.parameters()) + list(
                self.decoder_hair_net.parameters()) + list(
                    self.decoder_left_eye_net.parameters()) + list(
                        self.decoder_right_eye_net.parameters()) + list(
                            self.decoder_mouth_net.parameters())
            params_image_decoder = list(self.decoder_skin_image_net.parameters(
            )) + list(self.decoder_hair_image_net.parameters()) + list(
                self.decoder_left_eye_image_net.parameters()) + list(
                    self.decoder_right_eye_image_net.parameters()) + list(
                        self.decoder_mouth_image_net.parameters())
            params_encoder = list(self.encoder_skin_net.parameters()) + list(
                self.encoder_hair_net.parameters()) + list(
                    self.encoder_left_eye_net.parameters()) + list(
                        self.encoder_right_eye_net.parameters()) + list(
                            self.encoder_mouth_net.parameters())

            params_together = list(self.gen_net.parameters(
            )) + params_decoder + params_encoder + params_image_decoder
            self.optimizer_G_together = torch.optim.Adam(params_together,
                                                         lr=self.lr,
                                                         betas=(self.beta1,
                                                                0.999))

            params = list(self.dis_net.parameters())
            #            self.optimizer_D = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999))
            self.optimizer_D = torch.optim.RMSprop(params, lr=self.lr)
            # optimizer D2
            params = list(self.dis_net2.parameters())
            #           self.optimizer_D2 = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999))
            self.optimizer_D2 = torch.optim.RMSprop(params, lr=self.lr)
예제 #7
0
def main():

    args = args_initialize()

    save_freq = args.save_freq
    epochs = args.num_epoch
    cuda = args.cuda

    train_dataset = UnalignedDataset(is_train=True)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0
    )

    net_G_A = ResNetGenerator(input_nc=3, output_nc=3)
    net_G_B = ResNetGenerator(input_nc=3, output_nc=3)
    net_D_A = Discriminator()
    net_D_B = Discriminator()

    if args.cuda:
        net_G_A = net_G_A.cuda()
        net_G_B = net_G_B.cuda()
        net_D_A = net_D_A.cuda()
        net_D_B = net_D_B.cuda()

    fake_A_pool = ImagePool(50)
    fake_B_pool = ImagePool(50)

    criterionGAN = GANLoss(cuda=cuda)
    criterionCycle = torch.nn.L1Loss()
    criterionIdt = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(
        itertools.chain(net_G_A.parameters(), net_G_B.parameters()),
        lr=args.lr,
        betas=(args.beta1, 0.999)
    )
    optimizer_D_A = torch.optim.Adam(net_D_A.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(net_D_B.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    log_dir = './logs'
    checkpoints_dir = './checkpoints'
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    writer = SummaryWriter(log_dir)

    for epoch in range(epochs):

        running_loss = np.zeros((8))
        for batch_idx, data in enumerate(train_loader):

            input_A = data['A']
            input_B = data['B']

            if cuda:
                input_A = input_A.cuda()
                input_B = input_B.cuda()

            real_A = Variable(input_A)
            real_B = Variable(input_B)


            """
            Backward net_G
            """
            optimizer_G.zero_grad()
            lambda_idt = 0.5
            lambda_A = 10.0
            lambda_B = 10.0

            # 各 Generatorに変換後の画像を入力
            # 何もしないのが理想の出力
            idt_B = net_G_A(real_B)
            loss_idt_A = criterionIdt(idt_B, real_B) * lambda_B * lambda_idt

            idt_A = net_G_B(real_A)
            loss_idt_B = criterionIdt(idt_A, real_A) * lambda_A * lambda_idt

            # GAN loss = D_A(G_A(A))
            # G_Aとしては生成した偽物画像が本物(True)と判断して欲しい
            fake_B = net_G_A(real_A)
            pred_fake = net_D_A(fake_B)
            loss_G_A = criterionGAN(pred_fake, True)

            fake_A = net_G_B(real_B)
            pred_fake = net_D_B(fake_A)
            loss_G_B = criterionGAN(pred_fake, True)

            rec_A = net_G_B(fake_B)
            loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A

            rec_B = net_G_A(fake_A)
            loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B

            loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
            loss_G.backward()

            optimizer_G.step()

            """
            update D_A
            """
            optimizer_D_A.zero_grad()
            fake_B = fake_B_pool.query(fake_B.data)

            pred_real = net_D_A(real_B)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_A(fake_B.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            """
            update D_B
            """
            optimizer_D_B.zero_grad()
            fake_A = fake_A_pool.query(fake_A.data)

            pred_real = net_D_B(real_A)
            loss_D_real = criterionGAN(pred_real, True)

            pred_fake = net_D_B(fake_A.detach())
            loss_D_fake = criterionGAN(pred_fake, False)

            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()


            optimizer_D_B.step()

            ret_loss = np.array([
                loss_G_A.data.detach().cpu().numpy(), loss_D_A.data.detach().cpu().numpy(),
                loss_G_B.data.detach().cpu().numpy(), loss_D_B.data.detach().cpu().numpy(),
                loss_cycle_A.data.detach().cpu().numpy(), loss_cycle_B.data.detach().cpu().numpy(),
                loss_idt_A.data.detach().cpu().numpy(), loss_idt_B.data.detach().cpu().numpy()
            ])
            running_loss += ret_loss

            """
            Save checkpoints
            """
            if (epoch + 1) % save_freq == 0:
                save_network(net_G_A, 'G_A', str(epoch + 1))
                save_network(net_D_A, 'D_A', str(epoch + 1))
                save_network(net_G_B, 'G_B', str(epoch + 1))
                save_network(net_D_B, 'D_B', str(epoch + 1))

        running_loss /= len(train_loader)
        losses = running_loss
        print('epoch %d, losses: %s' % (epoch + 1, running_loss))

        writer.add_scalar('loss_G_A', losses[0], epoch)
        writer.add_scalar('loss_D_A', losses[1], epoch)
        writer.add_scalar('loss_G_B', losses[2], epoch)
        writer.add_scalar('loss_D_B', losses[3], epoch)
        writer.add_scalar('loss_cycle_A', losses[4], epoch)
        writer.add_scalar('loss_cycle_B', losses[5], epoch)
        writer.add_scalar('loss_idt_A', losses[6], epoch)
        writer.add_scalar('loss_idt_B', losses[7], epoch)
예제 #8
0
파일: models.py 프로젝트: yoooo233/sem-pcyc
    def __init__(self, params_model):
        super(SEM_PCYC, self).__init__()

        print('Initializing model variables...', end='')
        # Dimension of embedding
        self.dim_out = params_model['dim_out']
        # Dimension of semantic embedding
        self.sem_dim = params_model['sem_dim']
        # Number of classes
        self.num_clss = params_model['num_clss']
        # Sketch model: pre-trained on ImageNet
        self.sketch_model = VGGNetFeats(pretrained=False, finetune=False)
        self.load_weight(self.sketch_model, params_model['path_sketch_model'],
                         'sketch')
        # Image model: pre-trained on ImageNet
        self.image_model = VGGNetFeats(pretrained=False, finetune=False)
        self.load_weight(self.image_model, params_model['path_image_model'],
                         'image')
        # Semantic model embedding
        self.sem = []
        for f in params_model['files_semantic_labels']:
            self.sem.append(np.load(f, allow_pickle=True).item())
        self.dict_clss = params_model['dict_clss']
        print('Done')

        print('Initializing trainable models...', end='')
        # Generators
        # Sketch to semantic generator
        self.gen_sk2se = Generator(in_dim=512,
                                   out_dim=self.dim_out,
                                   noise=False,
                                   use_dropout=True)
        # Image to semantic generator
        self.gen_im2se = Generator(in_dim=512,
                                   out_dim=self.dim_out,
                                   noise=False,
                                   use_dropout=True)
        # Semantic to sketch generator
        self.gen_se2sk = Generator(in_dim=self.dim_out,
                                   out_dim=512,
                                   noise=False,
                                   use_dropout=True)
        # Semantic to image generator
        self.gen_se2im = Generator(in_dim=self.dim_out,
                                   out_dim=512,
                                   noise=False,
                                   use_dropout=True)
        # Discriminators
        # Common semantic discriminator
        self.disc_se = Discriminator(in_dim=self.dim_out,
                                     noise=True,
                                     use_batchnorm=True)
        # Sketch discriminator
        self.disc_sk = Discriminator(in_dim=512,
                                     noise=True,
                                     use_batchnorm=True)
        # Image discriminator
        self.disc_im = Discriminator(in_dim=512,
                                     noise=True,
                                     use_batchnorm=True)
        # Semantic autoencoder
        self.aut_enc = AutoEncoder(dim=self.sem_dim,
                                   hid_dim=self.dim_out,
                                   nlayer=1)
        # Classifiers
        self.classifier_sk = nn.Linear(512, self.num_clss, bias=False)
        self.classifier_im = nn.Linear(512, self.num_clss, bias=False)
        self.classifier_se = nn.Linear(self.dim_out, self.num_clss, bias=False)
        for param in self.classifier_sk.parameters():
            param.requires_grad = False
        for param in self.classifier_im.parameters():
            param.requires_grad = False
        for param in self.classifier_se.parameters():
            param.requires_grad = False
        print('Done')

        # Optimizers
        print('Defining optimizers...', end='')
        self.lr = params_model['lr']
        self.gamma = params_model['gamma']
        self.momentum = params_model['momentum']
        self.milestones = params_model['milestones']
        self.optimizer_gen = optim.Adam(list(self.gen_sk2se.parameters()) +
                                        list(self.gen_im2se.parameters()) +
                                        list(self.gen_se2sk.parameters()) +
                                        list(self.gen_se2im.parameters()),
                                        lr=self.lr)
        self.optimizer_disc = optim.SGD(list(self.disc_se.parameters()) +
                                        list(self.disc_sk.parameters()) +
                                        list(self.disc_im.parameters()),
                                        lr=self.lr,
                                        momentum=self.momentum)
        self.optimizer_ae = optim.SGD(self.aut_enc.parameters(),
                                      lr=100 * self.lr,
                                      momentum=self.momentum)
        self.scheduler_gen = optim.lr_scheduler.MultiStepLR(
            self.optimizer_gen, milestones=self.milestones, gamma=self.gamma)
        self.scheduler_disc = optim.lr_scheduler.MultiStepLR(
            self.optimizer_disc, milestones=self.milestones, gamma=self.gamma)
        self.scheduler_ae = optim.lr_scheduler.MultiStepLR(
            self.optimizer_ae, milestones=self.milestones, gamma=self.gamma)
        print('Done')

        # Loss function
        print('Defining losses...', end='')
        self.lambda_se = params_model['lambda_se']
        self.lambda_im = params_model['lambda_im']
        self.lambda_sk = params_model['lambda_sk']
        self.lambda_gen_cyc = params_model['lambda_gen_cyc']
        self.lambda_gen_adv = params_model['lambda_gen_adv']
        self.lambda_gen_cls = params_model['lambda_gen_cls']
        self.lambda_gen_reg = params_model['lambda_gen_reg']
        self.lambda_disc_se = params_model['lambda_disc_se']
        self.lambda_disc_sk = params_model['lambda_disc_sk']
        self.lambda_disc_im = params_model['lambda_disc_im']
        self.lambda_regular = params_model['lambda_regular']
        self.criterion_gan = GANLoss(use_lsgan=True)
        self.criterion_cyc = nn.L1Loss()
        self.criterion_cls = nn.CrossEntropyLoss()
        self.criterion_reg = nn.MSELoss()
        print('Done')

        # Intermediate variables
        print('Initializing variables...', end='')
        self.sk_fe = torch.zeros(1)
        self.sk_em = torch.zeros(1)
        self.im_fe = torch.zeros(1)
        self.im_em = torch.zeros(1)
        self.se_em_enc = torch.zeros(1)
        self.se_em_rec = torch.zeros(1)
        self.im2se_em = torch.zeros(1)
        self.sk2se_em = torch.zeros(1)
        self.se2im_em = torch.zeros(1)
        self.se2sk_em = torch.zeros(1)
        self.im_em_hat = torch.zeros(1)
        self.sk_em_hat = torch.zeros(1)
        self.se_em_hat1 = torch.zeros(1)
        self.se_em_hat2 = torch.zeros(1)
        print('Done')
예제 #9
0
    def __init__(self, opt):
        super(TwoStreamAE_mask, self).__init__(opt)
        if opt.resize_or_crop != 'none':
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.which_stream = opt.which_stream
        self.use_gan = opt.use_gan
        self.which_gan = opt.which_gan
        self.gan_weight = opt.gan_weight
        self.rec_weight = opt.rec_weight
        self.cond_in = opt.cond_in
        self.use_output_gate = opt.use_output_gate
        self.opt = opt
        
        if opt.no_comb:
            from MaskTwoStreamConvSwitch_NET import MaskTwoStreamConvSwitch_NET as model_factory
        else:
            from MaskTwoStreamConv_NET import MaskTwoStreamConv_NET as model_factory

        model = self.get_model(model_factory)
        self.netG = model(opt)
        self.netG.initialize()
        # move networsk to gpu
        if len(opt.gpu_ids) > 0:
            assert(torch.cuda.is_available())
            self.netG.cudafy(opt.gpu_ids[0])
 
        print('---------- Networks initialized -------------')
       
        # set loss functions and optimizers
        if self.isTrain:
            self.old_lr = opt.lr
            
            # defaine loss functions
            self.criterionRecon = MaskReconLoss()
            if opt.objReconLoss == 'l1':
                self.criterionObjRecon = nn.L1Loss()
            elif opt.objReconLoss == 'bce':
                self.criterionObjRecon = nn.BCELoss()
            else:
                self.criterionObjRecon = None

            # Names so we can breakout loss
            self.loss_names = ['G_Recon_comb', 'G_Recon_obj', \
                    'KL_loss', 'loss_G_GAN', 'loss_D_GAN', 'loss_G_GAN_Feat']

            params = self.netG.trainable_parameters
            self.optimizer = torch.optim.Adam(params, lr=opt.lr, \
                    betas=(opt.beta1, opt.beta2))
            
            ########## define discriminator
            if self.use_gan:
                label_nc = opt.label_nc if not (opt.cond_in=='ctx_obj') \
                        else opt.label_nc * 2
                if self.which_gan=='patch':
                    use_lsgan=False
                    self.netD = NLayerDiscriminator( \
                            input_nc=1+label_nc, 
                            ndf=opt.ndf,
                            n_layers=opt.num_layers_D,
                            norm_layer=opt.norm_layer,
                            use_sigmoid=not use_lsgan, getIntermFeat=False)
                elif self.which_gan=='patch_res':
                    use_lsgan=False
                    self.netD = NLayerResDiscriminator( \
                            input_nc=1+label_nc, 
                            ndf=opt.ndf,
                            n_layers=opt.num_layers_D,
                            norm_layer=opt.norm_layer,
                            use_sigmoid=not use_lsgan, getIntermFeat=False)
                elif self.which_gan=='patch_multiscale':
                    use_lsgan=True
                    self.netD = MultiscaleDiscriminator(
                            1+label_nc, 
                            opt.ndf, 
                            opt.num_layers_D, 
                            opt.norm_layer, 
                            not use_lsgan, 
                            2, 
                            True)
                self.ganloss = GANLoss(use_lsgan=use_lsgan, 
                        tensor=self.Tensor)
                if opt.use_ganFeat_loss:
                    self.criterionFeat = torch.nn.L1Loss()

                if len(opt.gpu_ids) > 0:
                    self.netD.cuda(opt.gpu_ids[0])
                params_D = [param for param in self.netD.parameters() \
                        if param.requires_grad]
                self.optimizer_D = torch.optim.Adam(
                        params_D, lr=opt.lr, betas=(opt.beta1, 0.999))

            # load networks
            if opt.continue_train or opt.load_pretrain:
                pretrained_path = '' if not self.isTrain else opt.load_pretrain
                self.load_network_dict(
                    self.netG.params_dict, self.optimizer, 'G',
                    opt.which_epoch, opt.load_pretrain)
                if opt.use_gan:
                    # TODO(sh): add loading for discriminator optimizer
                    self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
        else:
            self.load_network_dict(
                self.netG.params_dict, None, 'G', opt.which_epoch, '')
예제 #10
0
    def __init__(self, args):
        self.args = args

        args.logger.info('Initializing trainer')
        # if not os.path.isdir('../predict'):       only used in validation
        #     os.makedirs('../predict')
        self.model = get_model(args)
        if self.args.lock_coarse:
            for p in self.model.coarse_model.parameters():
                p.requires_grad = False
        torch.cuda.set_device(args.rank)
        self.model.cuda(args.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.rank])
        train_dataset, val_dataset = get_dataset(args)

        if not args.val:
            # train loss
            self.coarse_RGBLoss = RGBLoss(args, sharp=False)
            self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True)
            self.SegLoss = nn.CrossEntropyLoss()
            self.GANLoss = GANLoss(tensor=torch.FloatTensor)

            self.coarse_RGBLoss.cuda(args.rank)
            self.refine_RGBLoss.cuda(args.rank)
            self.SegLoss.cuda(args.rank)
            self.GANLoss.cuda(args.rank)

            if args.optimizer == "adamax":
                self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate)
            elif args.optimizer == "adam":
                self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate)
            elif args.optimizer == "sgd":
                self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9)

            # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate)
            self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9)


            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            self.train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        else:
            # val criteria
            self.L1Loss  = nn.L1Loss().cuda(args.rank)
            self.PSNRLoss = PSNR().cuda(args.rank)
            self.SSIMLoss = SSIM().cuda(args.rank)
            self.IoULoss = IoU().cuda(args.rank)
            self.VGGCosLoss = VGGCosineLoss().cuda(args.rank)

            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
            self.val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=val_sampler)

        torch.backends.cudnn.benchmark = True
        self.global_step = 0
        self.epoch=1
        if args.resume or (args.val and not args.checkepoch_range):
            self.load_checkpoint()

        if args.rank == 0:
            if args.val:
                self.writer =  SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\
                                SummaryWriter(args.path+'/val_int_1_logs')
            else:
                self.writer = SummaryWriter(args.path+'/logs')
        self.heatmap = self.create_stand_heatmap()
예제 #11
0
class RefinerGAN:
    def __init__(self, args):
        self.args = args

        args.logger.info('Initializing trainer')
        # if not os.path.isdir('../predict'):       only used in validation
        #     os.makedirs('../predict')
        self.model = get_model(args)
        if self.args.lock_coarse:
            for p in self.model.coarse_model.parameters():
                p.requires_grad = False
        torch.cuda.set_device(args.rank)
        self.model.cuda(args.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.rank])
        train_dataset, val_dataset = get_dataset(args)

        if not args.val:
            # train loss
            self.coarse_RGBLoss = RGBLoss(args, sharp=False)
            self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True)
            self.SegLoss = nn.CrossEntropyLoss()
            self.GANLoss = GANLoss(tensor=torch.FloatTensor)

            self.coarse_RGBLoss.cuda(args.rank)
            self.refine_RGBLoss.cuda(args.rank)
            self.SegLoss.cuda(args.rank)
            self.GANLoss.cuda(args.rank)

            if args.optimizer == "adamax":
                self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate)
            elif args.optimizer == "adam":
                self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate)
            elif args.optimizer == "sgd":
                self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9)

            # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate)
            self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9)


            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            self.train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        else:
            # val criteria
            self.L1Loss  = nn.L1Loss().cuda(args.rank)
            self.PSNRLoss = PSNR().cuda(args.rank)
            self.SSIMLoss = SSIM().cuda(args.rank)
            self.IoULoss = IoU().cuda(args.rank)
            self.VGGCosLoss = VGGCosineLoss().cuda(args.rank)

            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
            self.val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=val_sampler)

        torch.backends.cudnn.benchmark = True
        self.global_step = 0
        self.epoch=1
        if args.resume or (args.val and not args.checkepoch_range):
            self.load_checkpoint()

        if args.rank == 0:
            if args.val:
                self.writer =  SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\
                                SummaryWriter(args.path+'/val_int_1_logs')
            else:
                self.writer = SummaryWriter(args.path+'/logs')
        self.heatmap = self.create_stand_heatmap()

    def prepare_heat_map(self, prob_map):
        bs, c, h, w = prob_map.size()
        if h!=128:
            prob_map_ = F.interpolate(prob_map, size=(128, 256), mode='nearest', align_corners=True)
        return prob_map

    def create_heatmap(self, prob_map):
        c, h, w = prob_map.size()
        assert c==1, c
        assert h==128, h
        rgb_prob_map = torch.zeros(3, h, w)
        minimum, maximum = 0.0, 1.0
        ratio = 2 * (prob_map-minimum) / (maximum - minimum)

        rgb_prob_map[0] = 1-ratio
        rgb_prob_map[1] = ratio-1
        rgb_prob_map[:2].clamp_(0,1)
        rgb_prob_map[2] = 1-rgb_prob_map[0]-rgb_prob_map[1]
        return rgb_prob_map

    def create_stand_heatmap(self):
        heatmap = torch.zeros(3, 128, 256)
        for i in range(256):
            heatmap[0, :, i] = max(0, 1 - 2.*i/256)
            heatmap[1, :, i] = max(0, 2.*i/256 - 1)
            heatmap[2, :, i] = 1-heatmap[0, :, i]-heatmap[1, :, i]
        return heatmap


    def set_epoch(self, epoch):
        self.args.logger.info("Start of epoch %d" % (epoch+1))
        self.epoch = epoch + 1
        self.train_loader.sampler.set_epoch(epoch)
        # self.val_loader.sampler.set_epoch(epoch)

    def get_input(self, data):
        if self.args.mode == 'xs2xs':
            if self.args.syn_type == 'extra':
                x = torch.cat([data['frame1'], data['frame2'], data['seg1'], data['seg2']], dim=1)
                mask = torch.cat([data['fg_mask1'],data['fg_mask2']], dim=1)
                gt = torch.cat([data['frame3'], data['seg3']], dim=1)
            else:
                x = torch.cat([data['frame1'], data['frame3'], data['seg1'], data['seg3']], dim=1)
                mask = torch.cat([data['fg_mask1'],data['fg_mask3']], dim=1)
                gt = torch.cat([data['frame2'], data['seg2']], dim=1)        
        elif self.args.mode == 'xss2x':
            if self.args.syn_type == 'extra':
                x = torch.cat([data['frame1'], data['frame2'], data['seg1'], data['seg2'], data['seg3']], dim=1)
                gt = data['frame3']   
            else:
                x = torch.cat([data['frame1'], data['frame3'], data['seg1'], data['seg2'], data['seg3']], dim=1)
                gt = data['frame2']   
        return x, mask, gt   

    def normalize(self, img):
        return (img+1)/2

    def prepare_image_set(self, data, coarse_img, refined_imgs, seg, pred_fake, pred_real, extra=False):
        view_rgbs = [   self.normalize(data['frame1'][0]), 
                        self.normalize(data['frame2'][0]), 
                        self.normalize(data['frame3'][0])   ]
        view_segs = [vis_seg_mask(data['seg'+str(i)][0].unsqueeze(0), 20).squeeze(0) for i in range(1, 4)]


        # gan
        view_probs = []
        view_probs.append(self.heatmap)

        for i in range(self.args.num_D):
            toDraw = F.interpolate(pred_real[i][-1][0].unsqueeze(0).cpu(), (128, 256), mode='bilinear', align_corners=True).squeeze(0)
            view_probs.append(self.create_heatmap(toDraw))
            toDraw = F.interpolate(pred_fake[i][-1][0].unsqueeze(0).cpu(), (128, 256), mode='bilinear', align_corners=True).squeeze(0)
            view_probs.append(self.create_heatmap(toDraw))        

        if not extra:
            # coarse
            pred_rgb = self.normalize(coarse_img[0])
            pred_seg = vis_seg_mask(seg[0].unsqueeze(0), 20).squeeze(0) if self.args.mode == 'xs2xs' else torch.zeros_like(view_segs[0])
            insert_index = 2 if self.args.syn_type == 'inter' else 3
            
            view_rgbs.insert(insert_index, pred_rgb)
            view_segs.insert(insert_index, pred_seg)
            view_segs.append(torch.zeros_like(view_segs[-1]))
            # refine
            refined_bs_imgs = [ refined_img[0].unsqueeze(0) for refined_img in refined_imgs ] 
            for i in range(self.args.n_scales):
                insert_img = F.interpolate(refined_bs_imgs[i], size=(128,256))[0].clamp_(-1, 1) 

                pred_rgb = self.normalize(insert_img)
                insert_ind = insert_index + i+1
                view_rgbs.insert(insert_ind, pred_rgb)

            write_in_img = make_grid(view_rgbs + view_segs + view_probs, nrow=4+self.args.n_scales)
        # else:
        #     view_rgbs.insert(3, torch.zeros_like(view_rgbs[-1]))
        #     view_segs.insert(3, torch.zeros_like(view_segs[-1]))

        #     view_pred_rgbs = []
        #     view_pred_segs = []
        #     for i in range(self.args.extra_length):
        #         pred_rgb = self.normalize(img[i][0].cpu())
        #         pred_seg = vis_seg_mask(seg[i].cpu(), 20).squeeze(0) if self.args.mode == 'xs2xs' else torch.zeros_like(view_segs[0])
        #         view_pred_rgbs.append(pred_rgb)
        #         view_pred_segs.append(pred_seg)
        #     write_in_img = make_grid(view_rgbs + view_segs + view_pred_rgbs + view_pred_segs, nrow=4)

        
        return write_in_img

    def train(self):
        self.args.logger.info('Training started')
        self.model.train()
        end = time()
        load_time = 0
        comp_time = 0
        for step, data in enumerate(self.train_loader):
            self.step = step
            load_time += time() - end
            end = time()
            # for tensorboard
            self.global_step += 1
            # forward pass
            x, fg_mask, gt = self.get_input(data)
            x = x.cuda(self.args.rank, non_blocking=True)
            fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True)
            gt = gt.cuda(self.args.rank, non_blocking=True)

            coarse_img, refined_imgs, seg, pred_fake_D, pred_real_D, pred_fake_G = self.model(x, fg_mask, gt)
            if not self.args.lock_coarse:
                loss_dict = self.coarse_RGBLoss(coarse_img, gt[:, :3], False)
                if self.args.mode == 'xs2xs':
                   loss_dict['ce_loss'] = self.args.ce_weight*self.SegLoss(seg, torch.argmax(gt[:,3:], dim=1))   
            else:
                loss_dict = OrderedDict()
            for i in range(self.args.n_scales):
                # print(i, refined_imgs[-i].size())
                loss_dict.update(self.refine_RGBLoss(refined_imgs[-i-1], F.interpolate(gt[:,:3], scale_factor=(1/2)**i, mode='bilinear', align_corners=True),\
                                                                     refine_scale=1/(2**i), step=self.global_step, normed=False))
            # loss and accuracy
            loss = 0
            for i in loss_dict.values():
                loss += torch.mean(i)
            loss_dict['loss_all'] = loss

            if self.global_step > 1000:
                loss_dict['adv_loss'] = self.args.refine_adv_weight*self.GANLoss(pred_fake_G, True)
                
                g_loss = loss_dict['loss_all'] + loss_dict['adv_loss']

                loss_dict['d_real_loss'] = self.args.refine_d_weight*self.GANLoss(pred_real_D, True)
                loss_dict['d_fake_loss'] = self.args.refine_d_weight*self.GANLoss(pred_fake_D, False)
                loss_dict['d_loss'] = loss_dict['d_real_loss'] + loss_dict['d_fake_loss']

            else:
                g_loss = loss_dict['loss_all'] 

                loss_dict['d_real_loss'] = 0*self.GANLoss(pred_real_D, True)
                loss_dict['d_fake_loss'] = 0*self.GANLoss(pred_fake_D, False)
                loss_dict['d_loss'] = loss_dict['d_real_loss'] + loss_dict['d_fake_loss']               

            self.sync(loss_dict)

            self.optG.zero_grad()
            g_loss.backward()
            self.optG.step()

            # discriminator backward pass
            self.optD.zero_grad()
            loss_dict['d_loss'].backward()
            self.optD.step()
            comp_time += time() - end
            end = time()

            if self.args.rank == 0:
                # add info to tensorboard
                info = {key:value.item() for key,value in loss_dict.items()}
                # add discriminator value
                pred_value = 0
                real_value = 0
                for i in range(self.args.num_D):
                    pred_value += torch.mean(pred_fake_D[i][-1])
                    real_value += torch.mean(pred_real_D[i][-1])
                pred_value/=self.args.num_D
                real_value/=self.args.num_D
                info["fake"] = pred_value.item()
                info["real"] = real_value.item()
                self.writer.add_scalars("losses", info, self.global_step)
                # print
                if self.step % self.args.disp_interval == 0:
                    self.args.logger.info(
                        'Epoch [{epoch:d}/{tot_epoch:d}][{cur_batch:d}/{tot_batch:d}] '
                        'load [{load_time:.3f}s] comp [{comp_time:.3f}s] '
                        'loss [{loss:.4f}]'.format(
                            epoch=self.epoch, tot_epoch=self.args.epochs,
                            cur_batch=self.step+1, tot_batch=len(self.train_loader),
                            load_time=load_time, comp_time=comp_time,
                            loss=loss.item()
                        )
                    )
                    comp_time = 0
                    load_time = 0
                if self.step % 50 == 0: 
                    image_set = self.prepare_image_set(data, coarse_img.cpu(), [ refined_img.cpu() for refined_img in refined_imgs], seg.cpu(), \
                                            pred_fake_D, pred_real_D)
                    self.writer.add_image('image_{}'.format(self.global_step), image_set, self.global_step)


    def validate(self):
        self.args.logger.info('Validation epoch {} started'.format(self.epoch))
        self.model.eval()

        val_criteria = {
            'l1': AverageMeter(),
            'psnr':AverageMeter(),
            'ssim':AverageMeter(),
            'iou':AverageMeter(),
            'vgg':AverageMeter()
        }
        step_losses = OrderedDict()

        with torch.no_grad():
            end = time()
            load_time = 0
            comp_time = 0
            for i, data in enumerate(self.val_loader):
                load_time += time()-end
                end = time()
                self.step=i

                # forward pass
                x, fg_mask, gt = self.get_input(data)
                size = x.size(0)
                x = x.cuda(self.args.rank, non_blocking=True)
                fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True)
                gt = gt.cuda(self.args.rank, non_blocking=True)
                
                coarse_img, refined_imgs, seg, pred_fake_D, pred_real_D= self.model(x, fg_mask, gt)
                # rgb criteria
                step_losses['l1'] =   self.L1Loss(refined_imgs[-1], gt[:,:3])
                step_losses['psnr'] = self.PSNRLoss((refined_imgs[-1]+1)/2, (gt[:,:3]+1)/2)
                step_losses['ssim'] = 1-self.SSIMLoss(refined_imgs[-1], gt[:,:3])
                step_losses['iou'] =  self.IoULoss(torch.argmax(seg, dim=1), torch.argmax(gt[:,3:], dim=1))
                step_losses['vgg'] =  self.VGGCosLoss(refined_imgs[-1], gt[:, :3], False)
                self.sync(step_losses) # sum
                for key in list(val_criteria.keys()):
                    val_criteria[key].update(step_losses[key].cpu().item(), size*self.args.gpus)

                if self.args.syn_type == 'extra': # not implemented
                    imgs = []
                    segs = []
                    img = img[0].unsqueeze(0)
                    seg = seg[0].unsqueeze(0)
                    x = x[0].unsqueeze(0)
                    for i in range(self.args.extra_length):
                        if i!=0:
                            x = torch.cat([x[:,3:6], img, x[:, 26:46], seg_fil], dim=1).cuda(self.args.rank, non_blocking=True)
                            img, seg = self.model(x)
                        seg_fil = torch.argmax(seg, dim=1)
                        seg_fil = transform_seg_one_hot(seg_fil, 20, cuda=True)*2-1
                        imgs.append(img)
                        segs.append(seg_fil)
                        
                comp_time += time() - end
                end = time()

                # print
                if self.args.rank == 0:
                    if self.step % self.args.disp_interval == 0:
                        self.args.logger.info(
                            'Epoch [{epoch:d}][{cur_batch:d}/{tot_batch:d}] '
                            'load [{load_time:.3f}s] comp [{comp_time:.3f}s]'.format(
                                epoch=self.epoch, cur_batch=self.step+1, tot_batch=len(self.val_loader),
                                load_time=load_time, comp_time=comp_time
                            )
                        )
                        comp_time = 0
                        load_time = 0
                    if self.step % 3 == 0:
                        if self.args.syn_type == 'inter':
                            image_set = self.prepare_image_set(data, coarse_img.cpu(), [ refined_img.cpu() for refined_img in refined_imgs], seg.cpu(), \
                                            pred_fake_D, pred_real_D)
                        else:
                            image_set = self.prepare_image_set(data, imgs, segs, True)
                        image_name = 'e{}_img_{}'.format(self.epoch, self.step)
                        self.writer.add_image(image_name, image_set, self.step)

        if self.args.rank == 0:
            self.args.logger.info(
'Epoch [{epoch:d}]      \n \
L1\t: {l1:.4f}     \n\
PSNR\t: {psnr:.4f}   \n\
SSIM\t: {ssim:.4f}   \n\
IoU\t: {iou:.4f}    \n\
vgg\t: {vgg:.4f}\n'.format(
                    epoch=self.epoch,
                    l1=val_criteria['l1'].avg,
                    psnr=val_criteria['psnr'].avg,
                    ssim=val_criteria['ssim'].avg,
                    iou=val_criteria['iou'].avg,
                    vgg = val_criteria['vgg'].avg
                )
            )
            tfb_info = {key:value.avg for key,value in val_criteria.items()}
            self.writer.add_scalars('val/score', tfb_info, self.epoch)

    def test(self):
        self.args.logger.info('testing started')
        self.model.eval()

        with torch.no_grad():
            end = time()
            load_time = 0
            comp_time = 0

            img_count = 0
            for i, data in enumerate(self.val_loader):
                load_time += time()-end
                end = time()
                self.step=i

                # forward pass
                x, fg_mask, gt = self.get_input(data)
                size = x.size(0)
                x = x.cuda(self.args.rank, non_blocking=True)
                fg_mask = fg_mask.cuda(self.args.rank, non_blocking=True)
                gt = gt.cuda(self.args.rank, non_blocking=True)
                
                img, seg = self.model(x, fg_mask)                        

                bs = img.size(0)
                for i in range(bs):
                    pred_img = self.normalize(img[i])
                    gt_img = self.normalize(gt[i, :3])

                    save_img(pred_img, '{}/{}_pred.png'.format(self.args.save_dir, img_count))
                    save_img(gt_img, '{}/{}_gt.png'.format(self.args.save_dir, img_count))
                    img_count+=1

                comp_time += time() - end
                end = time()

                # print
                if self.args.rank == 0:
                    if self.step % self.args.disp_interval == 0:
                        self.args.logger.info(
                            'img [{}] load [{load_time:.3f}s] comp [{comp_time:.3f}s]'.format(img_count,
                                load_time=load_time, comp_time=comp_time
                            )
                        )
                        comp_time = 0
                        load_time = 0


    def sync(self, loss_dict, mean=True):
        '''Synchronize all tensors given using mean or sum.'''
        for tensor in loss_dict.values():
            dist.all_reduce(tensor)
            if mean:
                tensor.div_(self.args.gpus)

    def save_checkpoint(self):
        save_md_dir = '{}_{}_{}_{}'.format(self.args.model, self.args.mode, self.args.syn_type, self.args.session)
        save_name = os.path.join(self.args.path, 
                                'checkpoint',
                                save_md_dir + '_{}_{}.pth'.format(self.epoch, self.step))
        self.args.logger.info('Saving checkpoint..')
        torch.save({
            'session': self.args.session,
            'epoch': self.epoch + 1,
            'model': self.model.module.state_dict(),
            'optG': self.optG.state_dict(),
            'optD': self.optD.state_dict()
        }, save_name)
        self.args.logger.info('save model: {}'.format(save_name))

    def load_checkpoint(self):
        load_md_dir = '{}_{}_{}_{}'.format("RefineNet", self.args.mode, self.args.syn_type, self.args.checksession)
        if self.args.load_dir is not None:
            load_name = os.path.join(self.args.load_dir,
                                    'checkpoint',
                                    load_md_dir+'_{}_{}.pth'.format(self.args.checkepoch, self.args.checkpoint))
        else:
            load_name = os.path.join(load_md_dir+'_{}_{}.pth'.format(self.args.checkepoch, self.args.checkpoint))
        self.args.logger.info('Loading checkpoint %s' % load_name)
        ckpt = torch.load(load_name)
        if self.args.lock_coarse:
            model_dict = self.model.module.state_dict()
            new_ckpt = OrderedDict()
            for key,item in ckpt['model'].items():
                if 'coarse' in key:
                    new_ckpt[key] = item
            model_dict.update(new_ckpt)
            self.model.module.load_state_dict(model_dict)
        else:
            self.model.module.load_state_dict(ckpt['model'])
        # transfer opt params to current device
        if not self.args.lock_coarse:
            if not self.args.val :
                self.optimizer.load_state_dict(ckpt['optimizer'])
                self.epoch = ckpt['epoch']
                self.global_step = (self.epoch-1)*len(self.train_loader)
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda(self.args.rank)
            else :
                assert ckpt['epoch']-1 == self.args.checkepoch, [ckpt['epoch'], self.args.checkepoch]
                self.epoch = ckpt['epoch'] - 1
        self.args.logger.info('checkpoint loaded')
예제 #12
0
def main():
    opt = get_model_config()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(opt)

    # Model setting
    logger.info('Build Model')

    generator = define_G(3, 3, opt.ngf).to(device)
    total_param = sum([p.numel() for p in generator.parameters()])
    logger.info(f'Generator size: {total_param} tensors')

    discriminator = define_D(3 + 3, opt.ndf, opt.disc).to(device)
    total_param = sum([p.numel() for p in discriminator.parameters()])
    logger.info(f'Discriminator size: {total_param} tensors')

    if torch.cuda.device_count() > 1:
        logger.info(f"Let's use {torch.cuda.device_count()} GPUs!")
        generator = DataParallel(generator)
        discriminator = DataParallel(discriminator)

    if opt.mode == 'train':
        dirname = datetime.now().strftime("%m%d%H%M") + f'_{opt.name}'
        log_dir = os.path.join('./experiments', dirname)
        os.makedirs(log_dir, exist_ok=True)
        logger.info(f'LOG DIR: {log_dir}')

        # Dataset setting
        logger.info('Set the dataset')
        image_size: Tuple[int] = (opt.image_h, opt.image_w)
        train_transform, val_transform = get_transforms(
            image_size,
            augment_type=opt.augment_type,
            image_norm=opt.image_norm)

        trainset = TrainDataset(image_dir=os.path.join(opt.data_dir, 'train'),
                                transform=train_transform)
        valset = TrainDataset(image_dir=os.path.join(opt.data_dir, 'val'),
                              transform=val_transform)

        train_loader = DataLoader(dataset=trainset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(dataset=valset,
                                batch_size=opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

        # Loss setting
        criterion = {}
        criterion['gan'] = GANLoss(use_lsgan=True).to(device)
        criterion['l1'] = torch.nn.L1Loss().to(device)

        # Optimizer setting
        g_optimizer = get_optimizer(generator.parameters(), opt.optimizer,
                                    opt.lr, opt.weight_decay)
        d_optimizer = get_optimizer(discriminator.parameters(), opt.optimizer,
                                    opt.lr, opt.weight_decay)
        logger.info(
            f'Initial Learning rate(G): {g_optimizer.param_groups[0]["lr"]:.6f}'
        )
        logger.info(
            f'Initial Learning rate(D): {d_optimizer.param_groups[0]["lr"]:.6f}'
        )

        # Scheduler setting
        g_scheduler = get_scheduler(g_optimizer, opt.scheduler, opt)
        d_scheduler = get_scheduler(d_optimizer, opt.scheduler, opt)

        # Tensorboard setting
        writer = SummaryWriter(log_dir=log_dir)

        logger.info('Start to train!')
        train_process(opt,
                      generator,
                      discriminator,
                      criterion,
                      g_optimizer,
                      d_optimizer,
                      g_scheduler,
                      d_scheduler,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      log_dir=log_dir,
                      writer=writer,
                      device=device)

    # TODO: write inference code
    elif opt.mode == 'test':
        logger.info(f'Model loaded from {opt.checkpoint}')

        model.eval()
        logger.info('Start to test!')
        test_status = inference(model=model,
                                test_loader=test_loader,
                                device=device,
                                criterion=criterion)
예제 #13
0
def train(opt):
    print("train...")

    #Create results directories
    os.makedirs(f'{opt.result_imgs_path}/{opt.dataset_name}-{opt.version}',
                exist_ok=True)
    os.makedirs(f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}',
                exist_ok=True)

    # Losses
    mixed_loss = MixedLoss()
    l1_loss = nn.L1Loss()
    l2_loss = nn.MSELoss()
    loss_gan = GANLoss("lsgan")
    val_ssim = SSIM()

    #GPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    torch.manual_seed(777)
    if device == 'cuda':
        torch.cuda.manual_seed_all(777)
        Tensor = torch.cuda.FloatTensor
    else:
        Tensor = torch.FloatTensor

    #Modle Initialize
    G_AB = Generator().to(device)
    Dis = Discriminator().to(device)

    #Loss Initialize
    l1_loss = l1_loss.to(device)
    l2_loss = l2_loss.to(device)
    mixed_loss = mixed_loss.to(device)
    loss_gan = loss_gan.to(device)

    #Load Pre-trained Models
    if opt.epoch != 0:
        G_AB.load_state_dict(
            torch.load(
                f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/G_AB_{opt.epoch:0>4}.pth'
            ))
        Dis.load_state_dict(
            torch.load(
                f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/Dis_{opt.epoch:0>4}.pth'
            ))
    # Initialize weights
    else:
        G_AB.apply(weights_init_normal)
        Dis.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(G_AB.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(Dis.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize((opt.img_height, opt.img_height), Image.BICUBIC),
        # transforms.RandomCrop((opt.img_height, opt.img_width)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    # validation data loading
    val_dataloader = DataLoader(valDataset(f'data/{opt.dataset_name}',
                                           transforms_=transforms_,
                                           mode='val'),
                                batch_size=6,
                                shuffle=True,
                                num_workers=1)

    # real rainy data loading
    real_dataset = DataLoader(RealDataset(f'data/{opt.dataset_name}',
                                          transforms_=transforms_,
                                          mode='test'),
                              batch_size=6,
                              shuffle=True,
                              num_workers=1)

    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):

        dataloader = DataLoader(
            train_dataset(
                f'data/{opt.dataset_name}/training',
                transforms_=transforms_,
                # rand=1,
                mode='train'),
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=opt.n_cpu)

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(batch['A'].type(Tensor))
            real_B = Variable(batch['B'].type(Tensor))

            # ------------------
            #  Train Generators
            # ------------------
            optimizer_G.zero_grad()

            gen = G_AB(real_A)

            (a, b, c, d, d_fake_b) = Dis(gen)
            (A, B, C, D, d_real_b) = Dis(real_B)

            loss_p = []
            g_adv = []
            for j, (q, p) in enumerate(
                    zip((a, b, c, d, d_fake_b), (A, B, C, D, d_real_b))):
                p_ = Variable(p, requires_grad=False)
                # perceptual loss
                bat, ch, h, w = q.size()
                loss_p.append(l1_loss(q, p_) * 10)

            g_loss = loss_gan(torch.cat((d_fake_b, d_real_b), 1), True)

            mixed = mixed_loss(gen, real_B)
            loss_perc = torch.mean(torch.stack(loss_p))

            loss_G = 10 * mixed + 10 * g_loss + loss_perc  # + 20 * st_loss

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator
            # -----------------------
            optimizer_D.zero_grad()
            gen_ = fake_B_buffer.push_and_pop(gen)

            (q, w, e, r, D_fake) = Dis(gen_)
            (z, x, y, u, D_real) = Dis(real_B)

            fake_list = []
            real_list = []
            loss_pp = []
            for j, (g, n) in enumerate(
                    zip((q, w, e, r, D_fake), (z, x, y, u, D_real))):
                n_ = Variable(n, requires_grad=False)
                loss_pp.append(l1_loss(g, n_))

            loss_fake = loss_gan(torch.cat((D_fake, D_real), 1), False)
            loss_real = loss_gan(torch.cat((D_real, D_real), 1), True)

            loss_pa = torch.mean(torch.stack(loss_pp))

            if loss_pa > opt.margin:
                loss_pa = 0
            else:
                loss_pa = opt.margin - loss_pa

            # Total loss
            loss_D = (loss_real + loss_fake) * 0.5 + loss_pa

            loss_D.backward()
            optimizer_D.step()

            # --------------
            #  Log Progress
            # --------------
            # Determine approximate time left

            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] "
                             "[G_loss: %f, "
                             # "g_loss: %f, "
                             "Adv_loss: %f, "
                             # "mixed: %f, "
                             "GPU Memory Usage: %d, "
                             "lr: %s, "
                             "ETA: %s]" % (
                                 epoch,
                                 opt.n_epochs,
                                 i,
                                 len(dataloader),
                                 loss_G.item(),
                                 # g_loss.item(),
                                 loss_D.item(),
                                 # mixed.item(),
                                 (torch.cuda.memory_allocated() / 1024) / 1024,
                                 lr_scheduler_G.get_lr(),
                                 time_left))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                imgs = next(iter(val_dataloader))
                A = imgs['A'].type(Tensor)

                generated = G_AB(A)
                B = imgs['B'].type(Tensor)

                ssim = val_ssim(B, generated)
                # PSNR
                mse = nn.MSELoss()
                mm = mse(generated, B)
                pp = 10 * log10(1 / mm.item())

                real_test = next(iter(real_dataset))
                real_rain = real_test['A'].type(Tensor)
                generated_real_snow = G_AB(real_rain)

                img_sample = torch.cat(
                    (real_rain.data, generated_real_snow.data), 0)
                save_image(
                    img_sample,
                    f'{opt.result_imgs_path}/{opt.dataset_name}-{opt.version}/{epoch:0>4}_{batches_done:0>4}-ssim_{ssim.item():0.3f} psnr_{pp:0.3f}.png',
                    nrow=6,
                    normalize=True)

            if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
                # Save model checkpoints
                torch.save(
                    G_AB.state_dict(),
                    f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/G_AB_{epoch:0>4}.pth'
                )
                torch.save(
                    Dis.state_dict(),
                    f'{opt.result_models_path}/{opt.dataset_name}-{opt.version}/Dis_{epoch:0>4}.pth'
                )
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()
예제 #14
0
    def train(self):
        """ Train UEGAN ."""
        self.fetcher = InputFetcher(self.loaders.ref)
        self.fetcher_val = InputFetcher(self.loaders.val)

        self.train_steps_per_epoch = len(self.loaders.ref)
        self.model_save_step = int(self.args.model_save_epoch *
                                   self.train_steps_per_epoch)

        # set nima, psnr, ssim global parameters
        if self.args.is_test_nima:
            self.best_nima_epoch, self.best_nima = 0, 0.0
        if self.args.is_test_psnr_ssim:
            self.best_psnr_epoch, self.best_psnr = 0, 0.0
            self.best_ssim_epoch, self.best_ssim = 0, 0.0

        # set loss functions
        self.criterionPercep = PerceptualLoss()
        self.criterionIdt = MultiscaleRecLoss(
            scale=3, rec_loss_type=self.args.idt_loss_type, multiscale=True)
        self.criterionGAN = GANLoss(self.args.adv_loss_type,
                                    tensor=torch.cuda.FloatTensor)

        # start from scratch or trained models
        if self.args.pretrained_model:
            start_step = int(self.args.pretrained_model *
                             self.train_steps_per_epoch)
            self.load_pretrained_model(self.args.pretrained_model)
        else:
            start_step = 0

        # start training
        print(
            "======================================= start training ======================================="
        )
        self.start_time = time.time()
        total_steps = int(self.args.total_epochs * self.train_steps_per_epoch)
        self.val_start_steps = int(self.args.num_epochs_start_val *
                                   self.train_steps_per_epoch)
        self.val_each_steps = int(self.args.val_each_epochs *
                                  self.train_steps_per_epoch)

        print(
            "=========== start to iteratively train generator and discriminator ==========="
        )
        pbar = tqdm(total=total_steps,
                    desc='Train epoches',
                    initial=start_step)
        for step in range(start_step, total_steps):
            ########## model train
            self.G.train()
            self.D.train()

            ########## data iter
            input = next(self.fetcher)
            self.real_raw, self.real_exp, self.real_raw_name = input.img_raw, input.img_exp, input.img_name

            ########## forward
            self.fake_exp = self.G(self.real_raw)
            self.fake_exp_store = self.fake_exp_pool.query(self.fake_exp)

            ########## update D
            self.d_optimizer.zero_grad()
            real_exp_preds = self.D(self.real_exp)
            fake_exp_preds = self.D(self.fake_exp_store.detach())
            d_loss = self.criterionGAN(real_exp_preds,
                                       fake_exp_preds,
                                       None,
                                       None,
                                       for_discriminator=True)
            if self.args.adv_input:
                input_preds = self.D(self.real_raw)
                d_loss += self.criterionGAN(real_exp_preds,
                                            input_preds,
                                            None,
                                            None,
                                            for_discriminator=True)
            d_loss.backward()
            self.d_optimizer.step()
            self.d_loss = d_loss.item()

            ########## update G
            self.g_optimizer.zero_grad()
            real_exp_preds = self.D(self.real_exp)
            fake_exp_preds = self.D(self.fake_exp)
            g_adv_loss = self.args.lambda_adv * self.criterionGAN(
                real_exp_preds,
                fake_exp_preds,
                None,
                None,
                for_discriminator=False)
            self.g_adv_loss = g_adv_loss.item()
            g_loss = g_adv_loss

            g_percep_loss = self.args.lambda_percep * self.criterionPercep(
                (self.fake_exp + 1.) / 2., (self.real_raw + 1.) / 2.)
            self.g_percep_loss = g_percep_loss.item()
            g_loss += g_percep_loss

            self.real_exp_idt = self.G(self.real_exp)
            g_idt_loss = self.args.lambda_idt * self.criterionIdt(
                self.real_exp_idt, self.real_exp)
            self.g_idt_loss = g_idt_loss.item()
            g_loss += g_idt_loss

            g_loss.backward()
            self.g_optimizer.step()
            self.g_loss = g_loss.item()

            ### print info and save models
            self.print_info(step, total_steps, pbar)

            ### logging using tensorboard
            self.logging(step)

            ### validation
            self.model_validation(step)

            ### learning rate update
            if step % self.train_steps_per_epoch == 0:
                current_epoch = step // self.train_steps_per_epoch
                self.lr_scheduler_g.step(epoch=current_epoch)
                self.lr_scheduler_d.step(epoch=current_epoch)
                for param_group in self.g_optimizer.param_groups:
                    pbar.write(
                        "====== Epoch: {:>3d}/{}, Learning rate(lr) of Encoder(E) and Generator(G): [{}], "
                        .format(((step + 1) // self.train_steps_per_epoch),
                                self.args.total_epochs, param_group['lr']),
                        end='')
                for param_group in self.d_optimizer.param_groups:
                    pbar.write(
                        "Learning rate (lr) of Discriminator(D): [{}] ======".
                        format(param_group['lr']))

            pbar.update(1)
            pbar.set_description(f"Train epoch %.2f" %
                                 ((step + 1.0) / self.train_steps_per_epoch))

        self.val_best_results()

        pbar.write("=========== Complete training ===========")
        pbar.close()
예제 #15
0
    # load dataset and data loader
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.MNIST('.', transform=transform, download=True)
    dataloader = data.DataLoader(dataset, batch_size=4)

    # model
    g = Generator()
    d = Discriminator()

    # losses
    gan_loss = GANLoss()

    # use
    is_cuda = torch.cuda.is_available()
    if is_cuda:
        g = g.cuda()
        d = d.cuda()

    # optimizer
    optim_G = optim.Adam(g.parameters())
    optim_D = optim.Adam(d.parameters())

    # train
    for epoch in range(num_epoch):
        total_batch = len(dataloader)
예제 #16
0
def train(args):

    # check if results path exists, if not create the folder
    check_folder(args.results_path)

    # generator model
    generator = HourglassNet(high_res=args.high_resolution)
    generator.to(device)

    # discriminator model
    discriminator = Discriminator(input_nc=1)
    discriminator.to(device)

    # optimizer
    optimizer_g = torch.optim.Adam(generator.parameters())
    optimizer_d = torch.optim.Adam(discriminator.parameters())

    # training parameters
    feature_weight = 0.5
    skip_count = 0
    use_gan = args.use_gan
    print_frequency = 5

    # dataloader
    illum_dataset = IlluminationDataset()
    illum_dataloader = DataLoader(illum_dataset, batch_size=args.batch_size)

    # gan loss based on lsgan that uses squared error
    gan_loss = GANLoss(gan_mode='lsgan')

    # training
    for epoch in range(1, args.epochs + 1):

        for data_idx, data in enumerate(illum_dataloader):
            source_img, source_light, target_img, target_light = data

            source_img.to(device)
            source_light.to(device)
            target_img.to(device)
            target_light.to(device)

            optimizer_g.zero_grad()

            # if skip connections are required for training, else skip the
            # connections based on the the training scheme for low-res/high-res
            # images
            if args.use_skip:
                skip_count = 0
            else:
                skip_count = 5 if args.high_resolution else 4

            output = generator(source_img, target_light, skip_count,
                               target_img)

            source_face_feats, source_light_pred, target_face_feats, source_relit_pred = output

            img_loss = image_and_light_loss(source_relit_pred, target_img,
                                            source_light_pred, target_light)
            feat_loss = feature_loss(source_face_feats, target_face_feats)

            # if gan loss is used
            if use_gan:
                g_loss = gan_loss(discriminator(source_relit_pred),
                                  target_is_real=True)
            else:
                g_loss = torch.Tensor([0])

            total_g_loss = img_loss + g_loss + (feature_weight * feat_loss)
            total_g_loss.backward()
            optimizer_g.step()

            # training the discriminator
            if use_gan:
                optimizer_d.zero_grad()
                pred_real = discriminator(target_img)
                pred_fake = discriminator(source_relit_pred.detach())

                loss_real = gan_loss(pred_real, target_is_real=True)
                loss_fake = gan_loss(pred_fake, target_is_real=False)

                d_loss = (loss_real + loss_fake) * 0.5
                d_loss.backward()
                optimizer_d.step()
            else:
                loss_real = torch.Tensor([0])
                loss_fake = torch.Tensor([0])

            if data_idx % print_frequency == 0:
                print(
                    "Epoch: [{}]/[{}], Iteration: [{}]/[{}], image loss: {}, feature loss: {}, gen fake loss: {}, dis real loss: {}, dis fake loss: {}"
                    .format(epoch, args.epochs + 1, data_idx + 1,
                            len(illum_dataloader), img_loss.item(),
                            feat_loss.item(), g_loss.item(), loss_real.item(),
                            loss_fake.item()))

        # saving model
        checkpoint_path = os.path.join(args.results_path,
                                       'checkpoint_epoch_{}.pth'.format(epoch))
        checkpoint = {
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
            'optimizer_d': optimizer_d.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)