Exemplo n.º 1
0
def Continue_train_WGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        # Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'WGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.NormalRGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target = generator(true_input)

                # Fake samples
                fake_scalar_d = discriminator(true_input, fake_target.detach())
                true_scalar_d = discriminator(true_input, true_target)

                # Overall Loss and optimize
                loss_D = -torch.mean(true_scalar_d) + torch.mean(fake_scalar_d)
                loss_D.backward()
                optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            # GAN Loss
            fake_scalar = discriminator(true_input, fake_target)
            GAN_Loss = -torch.mean(fake_scalar)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)
Exemplo n.º 2
0
def Continue_train_LSGAN(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    criterion_MSE = torch.nn.MSELoss().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)
    discriminator = utils.create_discriminator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()
        discriminator = nn.DataParallel(discriminator)
        discriminator = discriminator.cuda()
    else:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(opt, epoch, iteration, len_dataset, generator):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'LSGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator.module, 'LSGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))
        else:
            if opt.save_mode == 'epoch':
                if (epoch % opt.save_by_epoch
                        == 0) and (iteration % len_dataset == 0):
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'LSGAN_%s_epoch%d_bs%d.pth' %
                            (opt.task, epoch, opt.batch_size))
                        print(
                            'The trained model is successfully saved at epoch %d'
                            % (epoch))
            if opt.save_mode == 'iter':
                if iteration % opt.save_by_iter == 0:
                    if opt.save_name_mode:
                        torch.save(
                            generator, 'LSGAN_%s_iter%d_bs%d.pth' %
                            (opt.task, iteration, opt.batch_size))
                        print(
                            'The trained model is successfully saved at iteration %d'
                            % (iteration))

    # Tensor type
    Tensor = torch.cuda.FloatTensor

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.NormalRGBDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)

    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Sample noise and get data
            noise1 = utils.get_noise(true_input.shape[0], opt.z_dim,
                                     opt.random_type)
            noise1 = noise1.cuda()  # out: batch * z_dim
            noise2 = utils.get_noise(true_input.shape[0], opt.z_dim,
                                     opt.random_type)
            noise2 = noise2.cuda()  # out: batch * z_dim
            concat_noise = torch.cat((noise1, noise2),
                                     0)  # out: 2batch * z_dim
            concat_input = torch.cat((true_input, true_input),
                                     0)  # out: 2batch * 1 * 256 * 256
            concat_target = torch.cat((true_target, true_target),
                                      0)  # out: 2batch * 3 * 256 * 256

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(
                concat_input, concat_noise)  # out: 2batch * 3 * 256 * 256

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, concat_target)

            # MSGAN Loss
            fake_target1, fake_target2 = fake_target.split(
                true_input.shape[0], 0)
            ms_value = torch.mean(
                torch.abs(fake_target2 - fake_target1)) / torch.mean(
                    torch.abs(noise2 - noise1))
            eps = 1e-5
            ModeSeeking_Loss = 1 / (ms_value + eps)

            # GAN Loss
            fake_scalar = discriminator(concat_input, fake_target)
            # Adversarial ground truth
            valid = Tensor(np.ones((fake_scalar.shape[0], 1, 30, 30)))
            GAN_Loss = criterion_MSE(fake_scalar, valid)

            # Overall Loss and optimize
            loss = opt.lambda_l1 * Pixellevel_L1_Loss + opt.lambda_ms * ModeSeeking_Loss + opt.lambda_gan * GAN_Loss
            loss.backward()
            optimizer_G.step()

            # Train Discriminator
            for j in range(opt.additional_training_d):
                optimizer_D.zero_grad()

                # Generator output
                fake_target = generator(concat_input, concat_noise)
                fake_target1, fake_target2 = fake_target.split(
                    concat_noise.shape[0], 0)

                # Fake samples
                fake_scalar_d1 = discriminator(true_input,
                                               fake_target1.detach())
                fake_scalar_d2 = discriminator(true_input,
                                               fake_target2.detach())
                # Adversarial ground truth
                fake = Tensor(np.zeros((fake_scalar_d1.shape[0], 1, 30, 30)))
                loss_fake = criterion_MSE(fake_scalar_d1,
                                          fake) + criterion_MSE(
                                              fake_scalar_d2, fake)

                # True samples
                true_scalar_d = discriminator(true_input, true_target)
                # Adversarial ground truth
                valid = Tensor(np.ones((true_scalar_d.shape[0], 1, 30, 30)))
                loss_true = criterion_MSE(true_scalar_d,
                                          valid) + criterion_MSE(
                                              true_scalar_d, valid)

                # Overall Loss and optimize
                loss_D = 0.5 * (loss_fake + loss_true)
                loss_D.backward()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [GAN Loss: %.4f] [D Loss: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), GAN_Loss.item(), loss_D.item(),
                   time_left))

            # Save model at certain epochs or iterations
            save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader),
                       generator)

            # Learning rate decrease at certain epochs
            adjust_learning_rate(opt, (epoch + 1), (iters_done + 1),
                                 optimizer_G)