Ejemplo n.º 1
0
def LSGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()
    #FeatureMatchingLoss = FML1Loss(opt.fm_param)

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net.module, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net, 'deepfillNet_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # LSGAN vectors
            valid = Tensor(np.ones((img.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((img.shape[0], 1, 8, 8)))

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [-1, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [-1, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img = (img + 1) / 2  # in range [0, 1]
            img = utils.normalize_ImageNet_stats(img)  # in range of ImageNet
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg = (second_out_wholeimg +
                                   1) / 2  # in range [0, 1]
            second_out_wholeimg = utils.normalize_ImageNet_stats(
                second_out_wholeimg)
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)
Ejemplo n.º 2
0
                        type=int,
                        default=1,
                        help='test batch size, always 1')
    parser.add_argument('--load_name',
                        type=str,
                        default='deepfillNet_epoch4_batchsize4.pth',
                        help='test model name')
    opt = parser.parse_args()
    print(opt)

    # ----------------------------------------
    #       Initialize testing dataset
    # ----------------------------------------

    # Define the dataset
    testset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(testset))

    # Define the dataloader
    dataloader = DataLoader(testset,
                            batch_size=opt.batch_size,
                            pin_memory=True)

    # ----------------------------------------
    #                 Testing
    # ----------------------------------------

    model = torch.load(opt.load_name)

    for batch_idx, (img, mask) in enumerate(dataloader):
Ejemplo n.º 3
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark
    cv2.setNumThreads(0)
    cv2.ocl.setUseOpenCL(False)

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = nn.DataParallel(perceptualnet)
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()#reduce=False, size_average=False)
    RELU = nn.ReLU()

    # Optimizers
    optimizer_g1 = torch.optim.Adam(generator.coarse.parameters(), lr=opt.lr_g)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=opt.lr_g)
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr = opt.lr_d)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    # Save the model if pre_train == True
    def save_model(net, epoch, opt, batch=0, is_D=False):
        """Save the model at "checkpoint_interval" and its multiple"""
        if is_D==True:
            model_name = 'discriminator_WGAN_epoch%d_batch%d.pth' % (epoch + 1, batch)
        else:
            model_name = 'deepfillv2_WGAN_epoch%d_batch%d.pth' % (epoch+1, batch)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d batch %d' % (epoch, batch))
    
    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True)
    
    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        print("Start epoch ", epoch+1, "!")
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (1 - mask) + first_out * mask        # in range [0, 1]
            second_out_wholeimg = img * (1 - mask) + second_out * mask      # in range [0, 1]

            for wk in range(1):
                optimizer_d.zero_grad()
                fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
                true_scalar = discriminator(img, mask)
                #W_Loss = -torch.mean(true_scalar) + torch.mean(fake_scalar)#+ gradient_penalty(discriminator, img, second_out_wholeimg, mask)
                hinge_loss = torch.mean(RELU(1-true_scalar)) + torch.mean(RELU(fake_scalar+1))
                loss_D = hinge_loss
                loss_D.backward(retain_graph=True)
                optimizer_d.step()

            ### Train Generator
            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)
            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = - torch.mean(fake_scalar)

            optimizer_g1.zero_grad()
            first_MaskL1Loss.backward(retain_graph=True)
            optimizer_g1.step()

            optimizer_g.zero_grad()

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps, img_featuremaps)

            loss = 0.5*opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + GAN_Loss + second_PerceptualLoss * opt.lambda_perceptual
            loss.backward()

            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            # Print log
            print("\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]" %
                  ((epoch + 1), opt.epochs, (batch_idx+1), len(dataloader), first_MaskL1Loss.item(),
                   second_MaskL1Loss.item()))
            print("\r[D Loss: %.5f] [Perceptual Loss: %.5f] [G Loss: %.5f] time_left: %s" %
                  (loss_D.item(), second_PerceptualLoss.item(), GAN_Loss.item(), time_left))



            if (batch_idx + 1) % 100 ==0:

                # Generate Visualization image
                masked_img = img * (1 - mask) + mask
                img_save = torch.cat((img, masked_img, first_out, second_out, first_out_wholeimg, second_out_wholeimg),3)
                # Recover normalization: * 255 because last layer is sigmoid activated
                img_save = F.interpolate(img_save, scale_factor=0.5)
                img_save = img_save * 255
                # Process img_copy and do not destroy the data of img
                img_copy = img_save.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
                #img_copy = np.clip(img_copy, 0, 255)
                img_copy = img_copy.astype(np.uint8)
                save_img_name = 'sample_batch' + str(batch_idx+1) + '.png'
                save_img_path = os.path.join(sample_folder, save_img_name)
                img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
                cv2.imwrite(save_img_path, img_copy)
            if (batch_idx + 1) % 5000 == 0:
                save_model(generator, epoch, opt, batch_idx+1)
                save_model(discriminator, epoch, opt, batch_idx+1, is_D=True)


        #Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, epoch, opt)
        save_model(discriminator, epoch , opt, is_D=True)
Ejemplo n.º 4
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # configurations
    save_folder = opt.save_path
    sample_folder = opt.sample_path
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    if not os.path.exists(sample_folder):
        os.makedirs(sample_folder)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(lr_in, optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = lr_in * (opt.lr_decrease_factor**(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'deepfillv2_LSGAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_name = os.path.join(save_folder, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_name)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [0, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = -torch.mean(fake_scalar)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * first_MaskL1Loss + opt.lambda_l1 * second_MaskL1Loss + \
                opt.lambda_perceptual * second_PerceptualLoss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [first Mask L1 Loss: %.5f] [second Mask L1 Loss: %.5f]"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   first_MaskL1Loss.item(), second_MaskL1Loss.item()))
            print(
                "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                % (loss_D.item(), GAN_Loss.item(),
                   second_PerceptualLoss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)

        ### Sample data every epoch
        masked_img = img * (1 - mask) + mask
        mask = torch.cat((mask, mask, mask), 1)
        if (epoch + 1) % 1 == 0:
            img_list = [img, mask, masked_img, first_out, second_out]
            name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out']
            utils.save_sample_png(sample_folder=sample_folder,
                                  sample_name='epoch%d' % (epoch + 1),
                                  img_list=img_list,
                                  name_list=name_list,
                                  pixel_max_cnt=255)
Ejemplo n.º 5
0
def Trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num
    print("Batch size is changed to %d" % opt.batch_size)
    print("Number of workers is changed to %d" % opt.num_workers)

    # Build path folder
    utils.check_path(opt.save_path)
    utils.check_path(opt.sample_path)

    # Build networks
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
    else:
        generator = generator.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(optimizer, epoch, opt, init_lr):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = init_lr * (opt.lr_decrease_factor
                        **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'GrayInpainting_epoch%d_batchsize%d.pth' % (
            epoch + 20, opt.batch_size)
        model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (grayscale, mask) in enumerate(dataloader):

            # Load and put to cuda
            grayscale = grayscale.cuda()  # out: [B, 1, 256, 256]
            mask = mask.cuda()  # out: [B, 1, 256, 256]

            # forward propagation
            optimizer_g.zero_grad()
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(out_wholeimg, grayscale)

            # Compute losses
            loss = MaskL1Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] time_left: %s"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   MaskL1Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g)

        # Save the model
        save_model(generator, (epoch + 1), opt)
        utils.sample(grayscale, mask, out_wholeimg, opt.sample_path,
                     (epoch + 1))
Ejemplo n.º 6
0
def Trainer_GAN(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    cudnn.benchmark = opt.cudnn_benchmark

    # Handle multiple GPUs
    gpu_num = torch.cuda.device_count()
    print("There are %d GPUs used" % gpu_num)
    opt.batch_size *= gpu_num
    opt.num_workers *= gpu_num
    print("Batch size is changed to %d" % opt.batch_size)
    print("Number of workers is changed to %d" % opt.num_workers)

    # Build path folder
    utils.check_path(opt.save_path)
    utils.check_path(opt.sample_path)

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)
    perceptualnet = utils.create_perceptualnet()

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        perceptualnet = nn.DataParallel(perceptualnet)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        perceptualnet = perceptualnet.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()
    MSELoss = nn.MSELoss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(optimizer, epoch, opt, init_lr):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = init_lr * (opt.lr_decrease_factor
                        **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        model_name = 'GrayInpainting_GAN_epoch%d_batchsize%d.pth' % (
            epoch, opt.batch_size)
        model_path = os.path.join(opt.save_path, model_name)
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.module.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(net.state_dict(), model_path)
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (grayscale, mask) in enumerate(dataloader):

            # Load and put to cuda
            grayscale = grayscale.cuda()  # out: [B, 1, 256, 256]
            mask = mask.cuda()  # out: [B, 1, 256, 256]

            # LSGAN vectors
            valid = Tensor(np.ones((grayscale.shape[0], 1, 8, 8)))
            fake = Tensor(np.zeros((grayscale.shape[0], 1, 8, 8)))

            # ----------------------------------------
            #           Train Discriminator
            # ----------------------------------------
            optimizer_d.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Fake samples
            fake_scalar = discriminator(out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(grayscale, mask)
            # Overall Loss and optimize
            loss_fake = MSELoss(fake_scalar, fake)
            loss_true = MSELoss(true_scalar, valid)
            # Overall Loss and optimize
            loss_D = 0.5 * (loss_fake + loss_true)
            loss_D.backward()

            # ----------------------------------------
            #             Train Generator
            # ----------------------------------------
            optimizer_g.zero_grad()

            # forward propagation
            out = generator(grayscale, mask)  # out: [B, 1, 256, 256]
            out_wholeimg = grayscale * (1 -
                                        mask) + out * mask  # in range [0, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(out_wholeimg, grayscale)

            # GAN Loss
            fake_scalar = discriminator(out_wholeimg, mask)
            MaskGAN_Loss = MSELoss(fake_scalar, valid)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            out_3c = torch.cat((out_wholeimg, out_wholeimg, out_wholeimg), 1)
            grayscale_3c = torch.cat((grayscale, grayscale, grayscale), 1)
            out_featuremaps = perceptualnet(out_3c)
            gt_featuremaps = perceptualnet(grayscale_3c)
            PerceptualLoss = L1Loss(out_featuremaps, gt_featuremaps)

            # Compute losses
            loss = opt.lambda_l1 * MaskL1Loss + opt.lambda_perceptual * PerceptualLoss + opt.lambda_gan * MaskGAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [Perceptual Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s"
                % ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                   MaskL1Loss.item(), PerceptualLoss.item(), loss_D.item(),
                   MaskGAN_Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate(optimizer_g, (epoch + 1), opt, opt.lr_g)
        adjust_learning_rate(optimizer_d, (epoch + 1), opt, opt.lr_d)

        # Save the model
        save_model(generator, (epoch + 1), opt)
        utils.sample(grayscale, mask, out_wholeimg, opt.sample_path,
                     (epoch + 1))
Ejemplo n.º 7
0
def Trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # Build networks
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu == True:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Loss functions
    L1Loss = nn.L1Loss()

    # Optimizers
    optimizer_g = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_d = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate_g(optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = opt.lr_g * (opt.lr_decrease_factor
                         **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_learning_rate_d(optimizer, epoch, opt):
        """Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs"""
        lr = opt.lr_d * (opt.lr_decrease_factor
                         **(epoch // opt.lr_decrease_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net.module, 'ContextureEncoder_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.checkpoint_interval == 0:
                torch.save(
                    net, 'ContextureEncoder_epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    # Training loop
    for epoch in range(opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train Discriminator
            optimizer_d.zero_grad()

            # Generator output
            masked_img = img * (1 - mask)
            fake = generator(masked_img)

            # Fake samples
            fake_scalar = discriminator(fake.detach())
            # True samples
            true_scalar = discriminator(img)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()

            ### Train Generator
            optimizer_g.zero_grad()

            # forward propagation
            fusion_fake = img * (1 - mask) + fake * mask  # in range [-1, 1]

            # Mask L1 Loss
            MaskL1Loss = L1Loss(fusion_fake, img)

            # GAN Loss
            fake_scalar = discriminator(fusion_fake)
            GAN_Loss = -torch.mean(fake_scalar)

            # Compute losses
            loss = MaskL1Loss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + batch_idx
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Mask L1 Loss: %.5f] [D Loss: %.5f] [G Loss: %.5f] time_left: %s"
                %
                ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                 MaskL1Loss.item(), loss_D.item(), GAN_Loss.item(), time_left))

        # Learning rate decrease
        adjust_learning_rate_g(optimizer_g, (epoch + 1), opt)
        adjust_learning_rate_d(optimizer_d, (epoch + 1), opt)

        # Save the model
        save_model(generator, (epoch + 1), opt)
Ejemplo n.º 8
0
def WGAN_trainer(opt):
    # ----------------------------------------
    #      Initialize training parameters
    # ----------------------------------------
    logger = Logger(opt)
    checkpoint = restore(opt)

    # cudnn benchmark accelerates the network
    if opt.cudnn_benchmark == True:
        cudnn.benchmark = True
    else:
        cudnn.benchmark = False

    # --------------------------------------
    #         Initialize models
    # --------------------------------------
    generator, discriminator, perceptualnet = create_networks(opt, checkpoint)

    # Loss functions
    L1Loss = nn.L1Loss()
    #FeatureMatchingLoss = FML1Loss(opt.fm_param)

    # Optimizers
    optimizer_g, optimizer_d = create_optimizers(generator, discriminator, opt,
                                                 checkpoint)

    # Log metrics with wandb
    wandb.watch(generator)
    wandb.config.update(opt)
    auto_sync_checkpoints_to_wandb()

    # ----------------------------------------
    #       Initialize training dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.InpaintDataset(opt)
    print('The overall number of images equals to %d' % len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #            Training and Testing
    # ----------------------------------------

    # Initialize start time
    prev_time = time.time()

    initial_epoch = checkpoint['epoch'] if opt.restore else 0
    n_iter = checkpoint['n_iter'] if opt.restore else 0

    # training loop
    for epoch in range(initial_epoch, opt.epochs):
        for batch_idx, (img, mask) in enumerate(dataloader):
            n_iter += 1
            logger.begin(n_iter)

            # Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
            img = img.cuda()
            mask = mask.cuda()

            ### Train discriminator
            optimizer_d.zero_grad()

            # Generator output
            first_out, second_out = generator(img, mask)

            # forward propagation
            first_out_wholeimg = img * (
                1 - mask) + first_out * mask  # in range [-1, 1]
            second_out_wholeimg = img * (
                1 - mask) + second_out * mask  # in range [-1, 1]

            if n_iter % opt.log_every == 1:
                logger.add_image(img, 'image/training')
                logger.add_image(mask, 'mask/training')
                logger.add_image(first_out_wholeimg, 'image/first iteration')
                logger.add_image(second_out_wholeimg, 'image/second iteration')

            # Fake samples
            fake_scalar = discriminator(second_out_wholeimg.detach(), mask)
            # True samples
            true_scalar = discriminator(img, mask)

            # Overall Loss and optimize
            loss_D = -torch.mean(true_scalar) + torch.mean(fake_scalar)
            loss_D.backward()
            optimizer_d.step()

            ### Train Generator
            optimizer_g.zero_grad()

            # Mask L1 Loss
            first_MaskL1Loss = L1Loss(first_out_wholeimg, img)
            second_MaskL1Loss = L1Loss(second_out_wholeimg, img)

            # GAN Loss
            fake_scalar = discriminator(second_out_wholeimg, mask)
            GAN_Loss = -torch.mean(fake_scalar)

            # Get the deep semantic feature maps, and compute Perceptual Loss
            img = (img + 1) / 2  # in range [0, 1]
            img = utils.normalize_ImageNet_stats(img)  # in range of ImageNet
            img_featuremaps = perceptualnet(img)  # feature maps
            second_out_wholeimg = (second_out_wholeimg +
                                   1) / 2  # in range [0, 1]
            second_out_wholeimg = utils.normalize_ImageNet_stats(
                second_out_wholeimg)
            second_out_wholeimg_featuremaps = perceptualnet(
                second_out_wholeimg)
            second_PerceptualLoss = L1Loss(second_out_wholeimg_featuremaps,
                                           img_featuremaps)

            # Compute losses
            loss = first_MaskL1Loss + second_MaskL1Loss + opt.perceptual_param * second_PerceptualLoss + opt.gan_param * GAN_Loss
            loss.backward()
            optimizer_g.step()

            # Determine approximate time left
            batches_done = n_iter
            batches_left = opt.epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            logger.add_scalars({
                'Epoch':
                epoch + 1,
                'Iteration':
                n_iter,
                'loss/first Mask L1 Loss':
                first_MaskL1Loss.item(),
                'loss/second Mask L1 Loss':
                second_MaskL1Loss.item(),
                'gan/D Loss':
                loss_D.item(),
                'gan/G Loss':
                GAN_Loss.item(),
                'Perceptual Loss':
                second_PerceptualLoss.item()
            })

            # Print log
            if n_iter % opt.log_every == 1:
                print("\r[Epoch %d/%d] [Batch %d/%d] iteration %d" %
                      ((epoch + 1), opt.epochs, batch_idx, len(dataloader),
                       n_iter))
                print(
                    "\r[D Loss: %.5f] [G Loss: %.5f] [Perceptual Loss: %.5f] time_left: %s"
                    % (loss_D.item(), GAN_Loss.item(),
                       second_PerceptualLoss.item(), time_left))

            if n_iter % opt.checkpoint_every == 1:
                save_state(epoch=epoch,
                           batch=batch_idx,
                           n_iter=n_iter,
                           G=generator,
                           optimizer_g=optimizer_g,
                           D=discriminator,
                           optimizer_d=optimizer_d,
                           loss=loss,
                           opt=opt)

        # Learning rate decrease
        adjust_learning_rate(opt.lr_g, optimizer_g, (epoch + 1), opt)
        adjust_learning_rate(opt.lr_d, optimizer_d, (epoch + 1), opt)