def get_mask(mask_file_name, real_img, opt):
    mask = functions.read_image_dir(mask_file_name, opt)
    if mask.shape[3] != real_img.shape[3]:
        mask = imresize_to_shape(mask, [real_img.shape[2], real_img.shape[3]],
                                 opt)
    mask = functions.dilate_mask(mask, opt)
    return mask
Example #2
0
                             reals_shapes,
                             noise_amp,
                             scale_w=1,
                             scale_h=2,
                             n=opt.num_samples)
            generate_samples(netG,
                             reals_shapes,
                             noise_amp,
                             scale_w=2,
                             scale_h=2,
                             n=opt.num_samples)

    elif opt.train_mode == "harmonization" or opt.train_mode == "editing":
        opt.noise_scaling = 0.1
        _name = "harmonized" if opt.train_mode == "harmonization" else "edited"
        real = functions.read_image_dir(opt.naive_img, opt)
        real = imresize_to_shape(real, reals_shapes[0][2:], opt)
        fixed_noise[0] = real
        if opt.train_mode == "editing":
            fixed_noise[0] = fixed_noise[0] + opt.noise_scaling * \
                                              functions.generate_noise([opt.nc_im, fixed_noise[0].shape[2],
                                                                        fixed_noise[0].shape[3]],
                                                                        device=opt.device)

        out = generate_samples(netG, reals_shapes, noise_amp, reconstruct=True)

        mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                            opt.naive_img[-4:])
        if os.path.exists(mask_file_name):
            mask = functions.read_image_dir(mask_file_name, opt)
            if mask.shape[3] != out.shape[3]:
def train(opt):
    print("Training model with the following parameters:")
    print("\t number of stages: {}".format(opt.train_stages))
    print("\t number of concurrently trained stages: {}".format(
        opt.train_depth))
    print("\t learning rate scaling: {}".format(opt.lr_scale))
    print("\t non-linearity: {}".format(opt.activation))

    real = functions.read_image(opt)
    real = functions.adjust_scales2image(real, opt)
    reals = functions.create_reals_pyramid(real, opt)
    print("Training on image pyramid: {}".format([r.shape for r in reals]))
    print("")

    if opt.naive_img != "":
        naive_img = functions.read_image_dir(opt.naive_img, opt)
        naive_img_large = imresize_to_shape(naive_img, reals[-1].shape[2:],
                                            opt)
        naive_img = imresize_to_shape(naive_img, reals[0].shape[2:], opt)
        naive_img = functions.convert_image_np(naive_img) * 255.0
    else:
        naive_img = None
        naive_img_large = None

    if opt.fine_tune:
        img_to_augment = naive_img
    else:
        img_to_augment = functions.convert_image_np(reals[0]) * 255.0

    if opt.train_mode == "editing":
        opt.noise_scaling = 0.1

    generator = init_G(opt)
    if opt.fine_tune:
        for _ in range(opt.train_stages - 1):
            generator.init_next_stage()
        generator.load_state_dict(
            torch.load(
                '{}/{}/netG.pth'.format(opt.model_dir, opt.train_stages - 1),
                map_location="cuda:{}".format(torch.cuda.current_device())))

    fixed_noise = []
    noise_amp = []

    for scale_num in range(opt.start_scale, opt.train_stages):
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print(OSError)
            pass
        functions.save_image('{}/real_scale.jpg'.format(opt.outf),
                             reals[scale_num])

        d_curr = init_D(opt)
        if opt.fine_tune:
            d_curr.load_state_dict(
                torch.load('{}/{}/netD.pth'.format(opt.model_dir,
                                                   opt.train_stages - 1),
                           map_location="cuda:{}".format(
                               torch.cuda.current_device())))
        elif scale_num > 0:
            d_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
            generator.init_next_stage()

        writer = SummaryWriter(log_dir=opt.outf)
        fixed_noise, noise_amp, generator, d_curr = train_single_scale(
            d_curr, generator, reals, img_to_augment, naive_img,
            naive_img_large, fixed_noise, noise_amp, opt, scale_num, writer)

        torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_))
        torch.save(generator, '%s/G.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_))
        del d_curr
    writer.close()
    return