示例#1
0
def train(opt):
    print("Training model with the following parameters:")
    print("\t number of stages: {}".format(opt.train_stages))
    print("\t number of concurrently trained stages: {}".format(opt.train_depth))
    print("\t learning rate scaling: {}".format(opt.lr_scale))
    print("\t non-linearity: {}".format(opt.activation))

    real, real2 = functions.read_two_domains(opt)
    # real = functions.read_image(opt)
    # print(0, real.shape)
    real = functions.adjust_scales2image(real, opt)
    reals = functions.create_reals_pyramid(real, opt)

    real2 = functions.adjust_scales2image(real2, opt)
    reals2 = functions.create_reals_pyramid(real2, opt)

    generator, generator2 = init_G(opt)
    fixed_noise = []
    noise_amp = []
    fixed_noise2 = []
    noise_amp2 = []
    for scale_num in range(opt.stop_scale+1):
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_,scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
                print(OSError)
                pass
        functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num])

        d_curr, d_curr2 = init_D(opt)
        if scale_num > 0:
            d_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1)))
            generator.init_next_stage()
            d_curr2.load_state_dict(torch.load('%s/%d/netD2.pth' % (opt.out_,scale_num-1)))
            generator2.init_next_stage()

        writer = SummaryWriter(log_dir=opt.outf)
        fixed_noise, noise_amp, generator, d_curr, fixed_noise2, noise_amp2, generator2, d_curr2 = \
            train_single_scale(d_curr, generator, reals, fixed_noise, noise_amp, d_curr2, generator2, reals2,
                               fixed_noise2, noise_amp2, opt, scale_num, writer)

        torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_))
        torch.save(generator, '%s/G.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_))
        torch.save(fixed_noise2, '%s/fixed_noise2.pth' % (opt.out_))
        torch.save(generator2, '%s/G2.pth' % (opt.out_))
        torch.save(reals2, '%s/reals2.pth' % (opt.out_))
        torch.save(noise_amp2, '%s/noise_amp2.pth' % (opt.out_))
        del d_curr, d_curr2
    writer.close()
    return
示例#2
0
def generate_samples(netG, opt, depth, noise_amp, writer, reals, iter, n=25):
    opt.out_ = functions.generate_dir2save(opt)
    dir2save = '{}/gen_samples_stage_{}'.format(opt.out_, depth)
    reals_shapes = [r.shape for r in reals]
    all_images = []
    try:
        os.makedirs(dir2save)
    except OSError:
        pass
    with torch.no_grad():
        for idx in range(n):
            noise = functions.sample_random_noise(depth, reals_shapes, opt)
            sample = netG(noise, reals_shapes, noise_amp)
            all_images.append(sample)
            functions.save_image('{}/gen_sample_{}.jpg'.format(dir2save, idx), sample.detach())

        all_images = torch.cat(all_images, 0)
        all_images[0] = reals[depth].squeeze()
        grid = make_grid(all_images, nrow=min(5, n), normalize=True)
        writer.add_image('gen_images_{}'.format(depth), grid, iter)
示例#3
0
        if opt.fine_tune:
            if opt.model_dir == "":
                print("Model for fine tuning not specified.")
                print("Please use --model_dir to define model location.")
                exit()
            else:
                if not os.path.exists(opt.model_dir):
                    print("Model does not exist: {}".format(opt.model_dir))
                    print("Please specify a valid model.")
                    exit()
            if not os.path.exists(opt.naive_img):
                print("Image for harmonization/editing not found: {}".format(opt.naive_img))
                exit()
        from ConSinGAN.training_harmonization_editing import *

    dir2save = functions.generate_dir2save(opt)

    if osp.exists(dir2save):
        print('Trained model already exist: {}'.format(dir2save))
        exit()

    # create log dir
    try:
        os.makedirs(dir2save)
    except OSError:
        pass

    # save hyperparameters and code files
    with open(osp.join(dir2save, 'parameters.txt'), 'w') as f:
        for o in opt.__dict__:
            f.write("{}\t-\t{}\n".format(o, opt.__dict__[o]))
def generate_samples(netG,
                     img_to_augment,
                     naive_img,
                     naive_img_large,
                     aug,
                     opt,
                     depth,
                     noise_amp,
                     writer,
                     reals,
                     iter,
                     n=16):
    opt.out_ = functions.generate_dir2save(opt)
    dir2save = '{}/harmonized_samples_stage_{}'.format(opt.out_, depth)
    reals_shapes = [r.shape for r in reals]
    _name = "harmonized" if opt.train_mode == "harmonization" else "edited"
    images = []
    try:
        os.makedirs(dir2save)
    except OSError:
        pass

    if naive_img is not None:
        n = n - 1
    if opt.fine_tune:
        n = 1
    with torch.no_grad():
        for idx in range(n):
            noise = []
            for d in range(depth + 1):
                if d == 0:
                    if opt.fine_tune:
                        if opt.train_mode == "harmonization":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            augmented_image = functions.np2torch(
                                naive_img, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                    else:
                        if opt.train_mode == "harmonization":
                            data = {"image": img_to_augment}
                            augmented = aug.transform(**data)
                            augmented_image = functions.np2torch(
                                augmented["image"], opt)
                            noise.append(augmented_image)
                        elif opt.train_mode == "editing":
                            image = functions.shuffle_grid(img_to_augment)
                            augmented_image = functions.np2torch(image, opt)
                            noise.append(augmented_image + opt.noise_scaling *
                                         functions.generate_noise(
                                             [
                                                 opt.nc_im, reals_shapes[d][2],
                                                 reals_shapes[d][3]
                                             ],
                                             device=opt.device).detach())
                else:
                    noise.append(
                        functions.generate_noise(
                            [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]],
                            device=opt.device).detach())
            sample = netG(noise, reals_shapes, noise_amp)
            functions.save_image(
                '{}/{}_naive_sample.jpg'.format(dir2save, idx),
                augmented_image)
            functions.save_image(
                '{}/{}_{}_sample.jpg'.format(dir2save, idx, _name),
                sample.detach())
            augmented_image = imresize_to_shape(augmented_image,
                                                sample.shape[2:], opt)
            images.append(augmented_image)
            images.append(sample.detach())

        if opt.fine_tune:
            mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                opt.naive_img[-4:])
            augmented_image = imresize_to_shape(naive_img_large,
                                                sample.shape[2:], opt)
            if os.path.exists(mask_file_name):
                mask = get_mask(mask_file_name, augmented_image, opt)
                sample_w_mask = (
                    1 - mask) * augmented_image + mask * sample.detach()
                functions.save_image(
                    '{}/{}_sample_w_mask_{}.jpg'.format(dir2save, _name, iter),
                    sample_w_mask.detach())
                images = torch.cat(
                    [augmented_image,
                     sample.detach(), sample_w_mask], 0)
                grid = make_grid(images, nrow=3, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            else:
                print(
                    "Warning: no mask with name {} exists for image {}".format(
                        mask_file_name, opt.input_name))
                print("Only showing results without mask.")
                images = torch.cat([augmented_image, sample.detach()], 0)
                grid = make_grid(images, nrow=2, normalize=True)
                writer.add_image('{}_images_{}'.format(_name, depth), grid,
                                 iter)
            functions.save_image(
                '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                sample.detach())
        else:
            if naive_img is not None:
                noise = []
                for d in range(depth + 1):
                    if d == 0:
                        if opt.train_mode == "harmonization":
                            noise.append(functions.np2torch(naive_img, opt))
                        elif opt.train_mode == "editing":
                            noise.append(functions.np2torch(naive_img, opt) + opt.noise_scaling * \
                                              functions.generate_noise([opt.nc_im, reals_shapes[d][2],
                                                                        reals_shapes[d][3]],
                                                                        device=opt.device).detach())
                    else:
                        noise.append(
                            functions.generate_noise(
                                [
                                    opt.nfc, reals_shapes[d][2],
                                    reals_shapes[d][3]
                                ],
                                device=opt.device).detach())
                sample = netG(noise, reals_shapes, noise_amp)
                _naive_img = imresize_to_shape(naive_img_large,
                                               sample.shape[2:], opt)
                images.insert(0, sample.detach())
                images.insert(0, _naive_img)
                functions.save_image(
                    '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter),
                    sample.detach())

                mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4],
                                                    opt.naive_img[-4:])
                if os.path.exists(mask_file_name):
                    mask = get_mask(mask_file_name, _naive_img, opt)
                    sample_w_mask = (
                        1 - mask) * _naive_img + mask * sample.detach()
                    functions.save_image(
                        '{}/{}_sample_w_mask_{}.jpg'.format(
                            dir2save, _name, iter), sample_w_mask)

            images = torch.cat(images, 0)
            grid = make_grid(images, nrow=4, normalize=True)
            writer.add_image('{}_images_{}'.format(_name, depth), grid, iter)
def train(opt):
    print("Training model with the following parameters:")
    print("\t number of stages: {}".format(opt.train_stages))
    print("\t number of concurrently trained stages: {}".format(
        opt.train_depth))
    print("\t learning rate scaling: {}".format(opt.lr_scale))
    print("\t non-linearity: {}".format(opt.activation))

    real = functions.read_image(opt)
    real = functions.adjust_scales2image(real, opt)
    reals = functions.create_reals_pyramid(real, opt)
    print("Training on image pyramid: {}".format([r.shape for r in reals]))
    print("")

    if opt.naive_img != "":
        naive_img = functions.read_image_dir(opt.naive_img, opt)
        naive_img_large = imresize_to_shape(naive_img, reals[-1].shape[2:],
                                            opt)
        naive_img = imresize_to_shape(naive_img, reals[0].shape[2:], opt)
        naive_img = functions.convert_image_np(naive_img) * 255.0
    else:
        naive_img = None
        naive_img_large = None

    if opt.fine_tune:
        img_to_augment = naive_img
    else:
        img_to_augment = functions.convert_image_np(reals[0]) * 255.0

    if opt.train_mode == "editing":
        opt.noise_scaling = 0.1

    generator = init_G(opt)
    if opt.fine_tune:
        for _ in range(opt.train_stages - 1):
            generator.init_next_stage()
        generator.load_state_dict(
            torch.load(
                '{}/{}/netG.pth'.format(opt.model_dir, opt.train_stages - 1),
                map_location="cuda:{}".format(torch.cuda.current_device())))

    fixed_noise = []
    noise_amp = []

    for scale_num in range(opt.start_scale, opt.train_stages):
        opt.out_ = functions.generate_dir2save(opt)
        opt.outf = '%s/%d' % (opt.out_, scale_num)
        try:
            os.makedirs(opt.outf)
        except OSError:
            print(OSError)
            pass
        functions.save_image('{}/real_scale.jpg'.format(opt.outf),
                             reals[scale_num])

        d_curr = init_D(opt)
        if opt.fine_tune:
            d_curr.load_state_dict(
                torch.load('{}/{}/netD.pth'.format(opt.model_dir,
                                                   opt.train_stages - 1),
                           map_location="cuda:{}".format(
                               torch.cuda.current_device())))
        elif scale_num > 0:
            d_curr.load_state_dict(
                torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))
            generator.init_next_stage()

        writer = SummaryWriter(log_dir=opt.outf)
        fixed_noise, noise_amp, generator, d_curr = train_single_scale(
            d_curr, generator, reals, img_to_augment, naive_img,
            naive_img_large, fixed_noise, noise_amp, opt, scale_num, writer)

        torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_))
        torch.save(generator, '%s/G.pth' % (opt.out_))
        torch.save(reals, '%s/reals.pth' % (opt.out_))
        torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_))
        del d_curr
    writer.close()
    return