コード例 #1
0
def train_single_scale(netD, netG, reals, img_to_augment, naive_img,
                       naive_img_large, fixed_noise, noise_amp, opt, depth,
                       writer):
    reals_shapes = [real.shape for real in reals]
    real = reals[depth]
    #aug = functions.Augment()

    alpha = opt.alpha

    ############################
    # define z_opt for training on reconstruction
    ###########################
    if opt.fine_tune:
        fixed_noise = torch.load('%s/fixed_noise.pth' % opt.model_dir,
                                 map_location="cuda:{}".format(
                                     torch.cuda.current_device()))
        z_opt = fixed_noise[depth]
    else:
        if depth == 0:
            if opt.train_mode == "harmonization":
                z_opt = reals[0]
            elif opt.train_mode == "editing":
                z_opt = reals[0] + opt.noise_scaling * functions.generate_noise(
                    [
                        opt.nc_im, reals_shapes[depth][2],
                        reals_shapes[depth][3]
                    ],
                    device=opt.device).detach()
        else:
            z_opt = functions.generate_noise(
                [opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]],
                device=opt.device)
        fixed_noise.append(z_opt.detach())

    ############################
    # define optimizers, learning rate schedulers, and learning rates for lower stages
    ###########################
    # setup optimizers for D
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))

    # setup optimizers for G
    # remove gradients from stages that are not trained
    for block in netG.body[:-opt.train_depth]:
        for param in block.parameters():
            param.requires_grad = False

    # set different learning rate for lower stages
    parameter_list = [{
        "params":
        block.parameters(),
        "lr":
        opt.lr_g *
        (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx))
    } for idx, block in enumerate(netG.body[-opt.train_depth:])]

    # add parameters of head and tail to training
    if depth - opt.train_depth < 0:
        parameter_list += [{
            "params": netG.head.parameters(),
            "lr": opt.lr_g * (opt.lr_scale**depth)
        }]
    parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}]
    optimizerG = optim.Adam(parameter_list,
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))

    # define learning rate schedules
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizerD, milestones=[0.8 * opt.niter], gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(
        optimizer=optimizerG, milestones=[0.8 * opt.niter], gamma=opt.gamma)

    ############################
    # calculate noise_amp
    ###########################
    if opt.fine_tune:
        noise_amp = torch.load('%s/noise_amp.pth' % opt.model_dir,
                               map_location="cuda:{}".format(
                                   torch.cuda.current_device()))
    else:
        if depth == 0:
            noise_amp.append(1)
        else:
            noise_amp.append(0)
            z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp)

            criterion = nn.MSELoss()
            rec_loss = criterion(z_reconstruction, real)

            RMSE = torch.sqrt(rec_loss).detach()
            _noise_amp = opt.noise_amp_init * RMSE
            noise_amp[-1] = _noise_amp

    # start training
    _iter = tqdm(range(opt.niter))
    for iter in _iter:
        _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale))

        ############################
        # (0) sample augmented training image
        ###########################
        noise = []
        for d in range(depth + 1):
            if d == 0:
                if opt.fine_tune:
                    if opt.train_mode == "harmonization":
                        noise.append(functions.np2torch(naive_img, opt))
                    elif opt.train_mode == "editing":
                        noise.append(
                            functions.np2torch(naive_img, opt) +
                            opt.noise_scaling * functions.generate_noise(
                                [
                                    opt.nc_im, reals_shapes[d][2],
                                    reals_shapes[d][3]
                                ],
                                device=opt.device).detach())
                else:
                    if opt.train_mode == "harmonization":
                        data = {"image": img_to_augment}
                        augmented = aug.transform(**data)
                        image = augmented["image"]
                        noise.append(functions.np2torch(image, opt))
                    elif opt.train_mode == "editing":
                        image = functions.shuffle_grid(img_to_augment)
                        image = functions.np2torch(image, opt) + \
                                opt.noise_scaling * functions.generate_noise([3, reals_shapes[d][2],
                                                                              reals_shapes[d][3]],
                                                                              device=opt.device).detach()
                        noise.append(image)
            else:
                noise.append(
                    functions.generate_noise(
                        [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]],
                        device=opt.device).detach())

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            netD.zero_grad()

            output = netD(real)
            errD_real = -output.mean()

            # train with fake
            if j == opt.Dsteps - 1:
                fake = netG(noise, reals_shapes, noise_amp)
            else:
                with torch.no_grad():
                    fake = netG(noise, reals_shapes, noise_amp)

            output = netD(fake.detach())
            errD_fake = output.mean()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            errD_total = errD_real + errD_fake + gradient_penalty
            errD_total.backward()
            optimizerD.step()

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################
        output = netD(fake)
        errG = -output.mean()

        if alpha != 0:
            loss = nn.MSELoss()
            rec = netG(fixed_noise, reals_shapes, noise_amp)
            rec_loss = alpha * loss(rec, real)
        else:
            rec_loss = 0

        netG.zero_grad()
        errG_total = errG + rec_loss
        errG_total.backward()

        for _ in range(opt.Gsteps):
            optimizerG.step()

        ############################
        # (3) Log Results
        ###########################
        if iter % 250 == 0 or iter + 1 == opt.niter:
            writer.add_scalar('Loss/train/D/real/{}'.format(j),
                              -errD_real.item(), iter + 1)
            writer.add_scalar('Loss/train/D/fake/{}'.format(j),
                              errD_fake.item(), iter + 1)
            writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j),
                              gradient_penalty.item(), iter + 1)
            writer.add_scalar('Loss/train/G/gen', errG.item(), iter + 1)
            writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(),
                              iter + 1)
            functions.save_image(
                '{}/fake_sample_{}.jpg'.format(opt.outf, iter + 1),
                fake.detach())
            functions.save_image(
                '{}/reconstruction_{}.jpg'.format(opt.outf, iter + 1),
                rec.detach())
            generate_samples(netG, img_to_augment, naive_img, naive_img_large,
                             aug, opt, depth, noise_amp, writer, reals,
                             iter + 1)
        elif opt.fine_tune and iter % 100 == 0:
            generate_samples(netG, img_to_augment, naive_img, naive_img_large,
                             aug, opt, depth, noise_amp, writer, reals,
                             iter + 1)

        schedulerD.step()
        schedulerG.step()
        # break

    functions.save_networks(netG, netD, z_opt, opt)
    return fixed_noise, noise_amp, netG, netD
コード例 #2
0
def generate_samples(netG,
                     img_to_augment,
                     naive_img,
                     naive_img_large,
                     aug,
                     opt,
                     depth,
                     noise_amp,
                     writer,
                     reals,
                     iter,
                     n=16):
    opt.out_ = functions.generate_dir2save(opt)
    dir2save = '{}/harmonized_samples_stage_{}'.format(opt.out_, depth)
    reals_shapes = [r.shape for r in reals]
    _name = "harmonized" if opt.train_mode == "harmonization" else "edited"
    images = []
    try:
        os.makedirs(dir2save)
    except OSError:
        pass

    if naive_img is not None:
        n = n - 1
    if opt.fine_tune:
        n = 1
    with torch.no_grad():
        for idx in range(n):
            noise = []
            for d in range(depth + 1):
                if d == 0:
                    if opt.fine_tune:
                        if opt.train_mode == "harmonization":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                    else:
                        if opt.train_mode == "harmonization":
                            data = {"image": img_to_augment}
                            augmented = aug.transform(**data)
                            augmented_image = functions.np2torch(
                                augmented["image"], opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            image = functions.shuffle_grid(img_to_augment)
                            augmented_image = functions.np2torch(image, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                else:
                    noise.append(
                        functions.generate_noise(
                            [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]],
                            device=opt.device).detach())
            sample = netG(noise, reals_shapes, noise_amp)
            functions.save_image(
                '{}/{}_naive_sample.jpg'.format(dir2save, idx),
                augmented_image)
            functions.save_image(
                '{}/{}_{}_sample.jpg'.format(dir2save, idx, _name),
                sample.detach())
            augmented_image = imresize_to_shape(augmented_image,
                                                sample.shape[2:], opt)
            images.append(augmented_image)
            images.append(sample.detach())

        if opt.fine_tune:
            mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                opt.naive_img[-4:])
            augmented_image = imresize_to_shape(naive_img_large,
                                                sample.shape[2:], opt)
            if os.path.exists(mask_file_name):
                mask = get_mask(mask_file_name, augmented_image, opt)
                sample_w_mask = (
                    1 - mask) * augmented_image + mask * sample.detach()
                functions.save_image(
                    '{}/{}_sample_w_mask_{}.jpg'.format(dir2save, _name, iter),
                    sample_w_mask.detach())
                images = torch.cat(
                    [augmented_image,
                     sample.detach(), sample_w_mask], 0)
                grid = make_grid(images, nrow=3, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            else:
                print(
                    "Warning: no mask with name {} exists for image {}".format(
                        mask_file_name, opt.input_name))
                print("Only showing results without mask.")
                images = torch.cat([augmented_image, sample.detach()], 0)
                grid = make_grid(images, nrow=2, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            functions.save_image(
                '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                sample.detach())
        else:
            if naive_img is not None:
                noise = []
                for d in range(depth + 1):
                    if d == 0:
                        if opt.train_mode == "harmonization":
                            noise.append(functions.np2torch(naive_img, opt))
                        elif opt.train_mode == "editing":
                            noise.append(functions.np2torch(naive_img, opt) + opt.noise_scaling * \
                                              functions.generate_noise([opt.nc_im, reals_shapes[d][2],
                                                                        reals_shapes[d][3]],
                                                                        device=opt.device).detach())
                    else:
                        noise.append(
                            functions.generate_noise(
                                [
                                    opt.nfc, reals_shapes[d][2],
                                    reals_shapes[d][3]
                                ],
                                device=opt.device).detach())
                sample = netG(noise, reals_shapes, noise_amp)
                _naive_img = imresize_to_shape(naive_img_large,
                                               sample.shape[2:], opt)
                images.insert(0, sample.detach())
                images.insert(0, _naive_img)
                functions.save_image(
                    '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                    sample.detach())

                mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                    opt.naive_img[-4:])
                if os.path.exists(mask_file_name):
                    mask = get_mask(mask_file_name, _naive_img, opt)
                    sample_w_mask = (
                        1 - mask) * _naive_img + mask * sample.detach()
                    functions.save_image(
                        '{}/{}_sample_w_mask_{}.jpg'.format(
                            dir2save, _name, iter), sample_w_mask)

            images = torch.cat(images, 0)
            grid = make_grid(images, nrow=4, normalize=True)
            writer.add_image('{}_images_{}'.format(_name, depth), grid, iter)