def train(opt): print("Training model with the following parameters:") print("\t number of stages: {}".format(opt.train_stages)) print("\t number of concurrently trained stages: {}".format(opt.train_depth)) print("\t learning rate scaling: {}".format(opt.lr_scale)) print("\t non-linearity: {}".format(opt.activation)) real, real2 = functions.read_two_domains(opt) # real = functions.read_image(opt) # print(0, real.shape) real = functions.adjust_scales2image(real, opt) reals = functions.create_reals_pyramid(real, opt) real2 = functions.adjust_scales2image(real2, opt) reals2 = functions.create_reals_pyramid(real2, opt) generator, generator2 = init_G(opt) fixed_noise = [] noise_amp = [] fixed_noise2 = [] noise_amp2 = [] for scale_num in range(opt.stop_scale+1): opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: print(OSError) pass functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num]) d_curr, d_curr2 = init_D(opt) if scale_num > 0: d_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) generator.init_next_stage() d_curr2.load_state_dict(torch.load('%s/%d/netD2.pth' % (opt.out_,scale_num-1))) generator2.init_next_stage() writer = SummaryWriter(log_dir=opt.outf) fixed_noise, noise_amp, generator, d_curr, fixed_noise2, noise_amp2, generator2, d_curr2 = \ train_single_scale(d_curr, generator, reals, fixed_noise, noise_amp, d_curr2, generator2, reals2, fixed_noise2, noise_amp2, opt, scale_num, writer) torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_)) torch.save(generator, '%s/G.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_)) torch.save(fixed_noise2, '%s/fixed_noise2.pth' % (opt.out_)) torch.save(generator2, '%s/G2.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(noise_amp2, '%s/noise_amp2.pth' % (opt.out_)) del d_curr, d_curr2 writer.close() return
def generate_samples(netG, opt, depth, noise_amp, writer, reals, iter, n=25): opt.out_ = functions.generate_dir2save(opt) dir2save = '{}/gen_samples_stage_{}'.format(opt.out_, depth) reals_shapes = [r.shape for r in reals] all_images = [] try: os.makedirs(dir2save) except OSError: pass with torch.no_grad(): for idx in range(n): noise = functions.sample_random_noise(depth, reals_shapes, opt) sample = netG(noise, reals_shapes, noise_amp) all_images.append(sample) functions.save_image('{}/gen_sample_{}.jpg'.format(dir2save, idx), sample.detach()) all_images = torch.cat(all_images, 0) all_images[0] = reals[depth].squeeze() grid = make_grid(all_images, nrow=min(5, n), normalize=True) writer.add_image('gen_images_{}'.format(depth), grid, iter)
if opt.fine_tune: if opt.model_dir == "": print("Model for fine tuning not specified.") print("Please use --model_dir to define model location.") exit() else: if not os.path.exists(opt.model_dir): print("Model does not exist: {}".format(opt.model_dir)) print("Please specify a valid model.") exit() if not os.path.exists(opt.naive_img): print("Image for harmonization/editing not found: {}".format(opt.naive_img)) exit() from ConSinGAN.training_harmonization_editing import * dir2save = functions.generate_dir2save(opt) if osp.exists(dir2save): print('Trained model already exist: {}'.format(dir2save)) exit() # create log dir try: os.makedirs(dir2save) except OSError: pass # save hyperparameters and code files with open(osp.join(dir2save, 'parameters.txt'), 'w') as f: for o in opt.__dict__: f.write("{}\t-\t{}\n".format(o, opt.__dict__[o]))
def generate_samples(netG, img_to_augment, naive_img, naive_img_large, aug, opt, depth, noise_amp, writer, reals, iter, n=16): opt.out_ = functions.generate_dir2save(opt) dir2save = '{}/harmonized_samples_stage_{}'.format(opt.out_, depth) reals_shapes = [r.shape for r in reals] _name = "harmonized" if opt.train_mode == "harmonization" else "edited" images = [] try: os.makedirs(dir2save) except OSError: pass if naive_img is not None: n = n - 1 if opt.fine_tune: n = 1 with torch.no_grad(): for idx in range(n): noise = [] for d in range(depth + 1): if d == 0: if opt.fine_tune: if opt.train_mode == "harmonization": augmented_image = functions.np2torch( naive_img, opt) noise.append(augmented_image) elif opt.train_mode == "editing": augmented_image = functions.np2torch( naive_img, opt) noise.append(augmented_image + opt.noise_scaling * functions.generate_noise( [ opt.nc_im, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) else: if opt.train_mode == "harmonization": data = {"image": img_to_augment} augmented = aug.transform(**data) augmented_image = functions.np2torch( augmented["image"], opt) noise.append(augmented_image) elif opt.train_mode == "editing": image = functions.shuffle_grid(img_to_augment) augmented_image = functions.np2torch(image, opt) noise.append(augmented_image + opt.noise_scaling * functions.generate_noise( [ opt.nc_im, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) else: noise.append( functions.generate_noise( [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]], device=opt.device).detach()) sample = netG(noise, reals_shapes, noise_amp) functions.save_image( '{}/{}_naive_sample.jpg'.format(dir2save, idx), augmented_image) functions.save_image( '{}/{}_{}_sample.jpg'.format(dir2save, idx, _name), sample.detach()) augmented_image = imresize_to_shape(augmented_image, sample.shape[2:], opt) images.append(augmented_image) images.append(sample.detach()) if opt.fine_tune: mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4], opt.naive_img[-4:]) augmented_image = imresize_to_shape(naive_img_large, sample.shape[2:], opt) if os.path.exists(mask_file_name): mask = get_mask(mask_file_name, augmented_image, opt) sample_w_mask = ( 1 - mask) * augmented_image + mask * sample.detach() functions.save_image( '{}/{}_sample_w_mask_{}.jpg'.format(dir2save, _name, iter), sample_w_mask.detach()) images = torch.cat( [augmented_image, sample.detach(), sample_w_mask], 0) grid = make_grid(images, nrow=3, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter) else: print( "Warning: no mask with name {} exists for image {}".format( mask_file_name, opt.input_name)) print("Only showing results without mask.") images = torch.cat([augmented_image, sample.detach()], 0) grid = make_grid(images, nrow=2, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter) functions.save_image( '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter), sample.detach()) else: if naive_img is not None: noise = [] for d in range(depth + 1): if d == 0: if opt.train_mode == "harmonization": noise.append(functions.np2torch(naive_img, opt)) elif opt.train_mode == "editing": noise.append(functions.np2torch(naive_img, opt) + opt.noise_scaling * \ functions.generate_noise([opt.nc_im, reals_shapes[d][2], reals_shapes[d][3]], device=opt.device).detach()) else: noise.append( functions.generate_noise( [ opt.nfc, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) sample = netG(noise, reals_shapes, noise_amp) _naive_img = imresize_to_shape(naive_img_large, sample.shape[2:], opt) images.insert(0, sample.detach()) images.insert(0, _naive_img) functions.save_image( '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter), sample.detach()) mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4], opt.naive_img[-4:]) if os.path.exists(mask_file_name): mask = get_mask(mask_file_name, _naive_img, opt) sample_w_mask = ( 1 - mask) * _naive_img + mask * sample.detach() functions.save_image( '{}/{}_sample_w_mask_{}.jpg'.format( dir2save, _name, iter), sample_w_mask) images = torch.cat(images, 0) grid = make_grid(images, nrow=4, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter)
def train(opt): print("Training model with the following parameters:") print("\t number of stages: {}".format(opt.train_stages)) print("\t number of concurrently trained stages: {}".format( opt.train_depth)) print("\t learning rate scaling: {}".format(opt.lr_scale)) print("\t non-linearity: {}".format(opt.activation)) real = functions.read_image(opt) real = functions.adjust_scales2image(real, opt) reals = functions.create_reals_pyramid(real, opt) print("Training on image pyramid: {}".format([r.shape for r in reals])) print("") if opt.naive_img != "": naive_img = functions.read_image_dir(opt.naive_img, opt) naive_img_large = imresize_to_shape(naive_img, reals[-1].shape[2:], opt) naive_img = imresize_to_shape(naive_img, reals[0].shape[2:], opt) naive_img = functions.convert_image_np(naive_img) * 255.0 else: naive_img = None naive_img_large = None if opt.fine_tune: img_to_augment = naive_img else: img_to_augment = functions.convert_image_np(reals[0]) * 255.0 if opt.train_mode == "editing": opt.noise_scaling = 0.1 generator = init_G(opt) if opt.fine_tune: for _ in range(opt.train_stages - 1): generator.init_next_stage() generator.load_state_dict( torch.load( '{}/{}/netG.pth'.format(opt.model_dir, opt.train_stages - 1), map_location="cuda:{}".format(torch.cuda.current_device()))) fixed_noise = [] noise_amp = [] for scale_num in range(opt.start_scale, opt.train_stages): opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: print(OSError) pass functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num]) d_curr = init_D(opt) if opt.fine_tune: d_curr.load_state_dict( torch.load('{}/{}/netD.pth'.format(opt.model_dir, opt.train_stages - 1), map_location="cuda:{}".format( torch.cuda.current_device()))) elif scale_num > 0: d_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) generator.init_next_stage() writer = SummaryWriter(log_dir=opt.outf) fixed_noise, noise_amp, generator, d_curr = train_single_scale( d_curr, generator, reals, img_to_augment, naive_img, naive_img_large, fixed_noise, noise_amp, opt, scale_num, writer) torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_)) torch.save(generator, '%s/G.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_)) del d_curr writer.close() return