def save_images_png(run_name, data_loader, num_samples, num_classes, generator, discriminator, is_generate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if is_generate is True: batch_size = data_loader.batch_size n_batches = math.ceil(float(num_samples) / float(batch_size)) else: batch_size = data_loader.batch_size total_instance = len(data_loader.dataset) n_batches = math.ceil(float(num_samples) / float(batch_size)) data_iter = iter(data_loader) data_iter = iter(data_loader) type = "fake" if is_generate is True else "real" print("Save {num_samples} {type} images in png format....".format( num_samples=num_samples, type=type)) directory = join('./samples', run_name, type, "png") if exists(abspath(directory)): shutil.rmtree(abspath(directory)) os.makedirs(directory) for f in range(num_classes): os.makedirs(join(directory, str(f))) with torch.no_grad() if latent_op is False else dummy_context_mgr() as mpc: for i in tqdm(range(0, n_batches), disable=False): start = i * batch_size end = start + batch_size if is_generate: images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device) else: try: images, labels = next(data_iter) except StopIteration: break for idx, img in enumerate(images.detach()): if batch_size * i + idx < num_samples: save_image( (img + 1) / 2, join(directory, str(labels[idx].item()), '{idx}.png'.format(idx=batch_size * i + idx))) else: pass print('Save png to ./generated_images/%s' % run_name)
def save_images_npz(run_name, data_loader, num_samples, num_classes, generator, discriminator, is_generate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if is_generate is True: batch_size = data_loader.batch_size n_batches = math.ceil(float(num_samples) / float(batch_size)) else: batch_size = data_loader.batch_size total_instance = len(data_loader.dataset) n_batches = math.ceil(float(num_samples) / float(batch_size)) data_iter = iter(data_loader) data_iter = iter(data_loader) type = "fake" if is_generate is True else "real" print("Save {num_samples} {type} images in npz format....".format( num_samples=num_samples, type=type)) directory = join('./samples', run_name, type, "npz") if exists(abspath(directory)): shutil.rmtree(abspath(directory)) os.makedirs(directory) x = [] y = [] with torch.no_grad() if latent_op is False else dummy_context_mgr() as mpc: for i in tqdm(range(0, n_batches), disable=False): start = i * batch_size end = start + batch_size if is_generate: images, labels = generate_images(batch_size, generator, discriminator, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device) else: try: images, labels = next(data_iter) except StopIteration: break x += [np.uint8(255 * (images.detach().cpu().numpy() + 1) / 2.)] y += [labels.detach().cpu().numpy()] x = np.concatenate(x, 0)[:num_samples] y = np.concatenate(y, 0)[:num_samples] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = join(directory, "samples.npz") print('Saving npz to %s' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y})