def gen_animation(G, deformator, direction_index, out_file, z=None, size=None, r=8): import imageio if z is None: z = torch.randn([1, G.dim_z], device='cuda') interpolation_deformed = interpolate(G, z, shifts_r=r, shifts_count=5, dim=direction_index, deformator=deformator, with_central_border=False) resize = Resize(size) if size is not None else lambda x: x img = [ resize(to_image(torch.clamp(im, -1, 1))) for im in interpolation_deformed ] imageio.mimsave(out_file, img + img[::-1])
def log_training(writer, imgs_dir, params, iteration, batch_it, epoch, d_loss, g_loss, g_diversity, generator, fake_imgs): if batch_it % params.steps_per_log == 0: print( '{}% | Step {} | Epoch {}: [D loss: {}] [G loss: {}] [Diversiy: {}]' .format(int(100.0 * iteration / params.steps), iteration, epoch, d_loss.item(), g_loss.item(), g_diversity.item())) writer.add_scalar('discriminator loss', d_loss.item(), iteration) writer.add_scalar('generator loss', g_loss.item(), iteration) writer.add_scalar('epoch', epoch, iteration) if params.steps_per_activations_log is not None and \ batch_it % params.steps_per_activations_log == 0: log_buckets_activations(writer, generator, iteration) if batch_it % params.steps_per_img_save == 0: torchvision.utils.save_image( fake_imgs.data[:25], os.path.join(imgs_dir, 'e{}_{}.png'.format(epoch, iteration)), nrow=5, normalize=True) writer.add_image( 'generated', torchvision.transforms.ToTensor()(to_image( torchvision.utils.make_grid(fake_imgs[:25], 5))), iteration)
def make_interpolation_chart(G, deformator=None, z=None, shifts_r=10, shifts_count=5, dims=None, dims_count=10, texts=None, direction_size=10, **kwargs): with_deformation = deformator is not None if with_deformation: deformator_is_training = deformator.training deformator.eval() z = z if z is not None else make_noise(1, G.dim_z).cuda() if with_deformation: original_img = G(z).cpu() else: original_img = G(z).cpu() imgs = [] if dims is None: dim_range = min(dims_count, direction_size) dims = range(dim_range) for i in dims: imgs.append(interpolate(G, z, shifts_r, shifts_count, i, direction_size, deformator)) if z.shape[0] == 1: rows_count = len(imgs) + 1 else: rows_count = len(imgs) fig, axs = plt.subplots(rows_count, **kwargs) if z.shape[0] == 1: axs[0].axis('off') axs[0].imshow(to_image(original_img, True)) axs = axs[1:] if texts is None: texts = dims for ax, shifts_imgs, text in zip(axs, imgs, texts): ax.axis('off') plt.subplots_adjust(left=0.001) ax.imshow(to_image(make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=0), True)) ax.text(-20, 21, str(text), fontsize=10) if deformator is not None and deformator_is_training: deformator.train() plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0)#调整子图间距 return fig
def make_interpolation_chart(G, deformator=None, z=None, shifts_r=10.0, shifts_count=5, dims=None, dims_count=10, texts=None, **kwargs): with_deformation = deformator is not None if with_deformation: deformator_is_training = deformator.training deformator.eval() z = z if z is not None else make_noise(1, G.dim_z).cuda() if with_deformation: original_img = G(z).cpu() else: original_img = G(z).cpu() imgs = [] if dims is None: dims = range(dims_count) for i in dims: imgs.append(interpolate(G, z, shifts_r, shifts_count, i, deformator)) rows_count = len(imgs) + 1 fig, axs = plt.subplots(rows_count, **kwargs) axs[0].axis('off') axs[0].imshow(to_image(original_img, True)) if texts is None: texts = dims for ax, shifts_imgs, text in zip(axs[1:], imgs, texts): ax.axis('off') plt.subplots_adjust(left=0.5) ax.imshow( to_image( make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=1), True)) ax.text(-20, 21, str(text), fontsize=10) if deformator is not None and deformator_is_training: deformator.train() return fig
discovered_annotation += '{}: {}\n'.format(d[0], d[1]) print('human-annotated directions:\n' + discovered_annotation) rows = 8 plt.figure(figsize=(5, rows), dpi=250) # set desired class for conditional GAN # if is_conditional(G): # G.set_classes(12) annotated = list(deformator.annotation.values()) inspection_dim = annotated[0] # zs = torch.randn([rows, G.dim_z] if type(G.dim_z) == int else [rows] + G.dim_z, device='cuda') zs = torch.randn([rows, G.dim_z] if type(G.dim_z) == int else [rows] + G.dim_z) for z, i in zip(zs, range(rows)): interpolation_deformed = interpolate(G, z.unsqueeze(0), shifts_r=16, shifts_count=3, dim=inspection_dim, deformator=deformator, with_central_border=True) plt.subplot(rows, 1, i + 1) plt.axis('off') grid = make_grid(interpolation_deformed, nrow=11, padding=1, pad_value=0.0) grid = torch.clamp(grid, -1, 1) plt.imshow(to_image(grid))
def inspect_path_generator_freeze(generator, samples_to_take=7, out_file=None): # inspect layers freeze buckets = list(generator.buckets()) if generator.noise_is_discrete: noises_count = generator.const_noise.shape[0] noise_to_take = random.randint(0, noises_count - 1) generator.freeze_noise(noise_to_take) else: noises_count = 1 generator.freeze_noise(True) blocks_to_take = \ [random.randint(0, len(bucket.blocks) - 1) for bucket in buckets] def reset_model(): for i, bucket in enumerate(generator.buckets()): bucket.freeze(blocks_to_take[i]) def samples_to_grid(samples): return torchvision.utils.make_grid(torch.cat(samples), samples_to_take, pad_value=1) reset_model() original = generator(1).detach() grids_with_varying = [] for i_layer, varying_bucket in enumerate(buckets): varying_images = [] all_indices = list(range(len(varying_bucket.blocks))) if len(all_indices) > 1: all_indices.remove(blocks_to_take[i_layer]) for i in all_indices[:samples_to_take]: varying_bucket.freeze(i) varying_images.append(generator(1).detach()) grids_with_varying.append(samples_to_grid(varying_images)) reset_model() # add noise variation images if noises_count > 1 or not generator.noise_is_discrete: varying_images = [] if generator.noise_is_discrete: noise_indices = list(range(noises_count)) noise_indices.remove(noise_to_take) for i in noise_indices[:samples_to_take]: generator.freeze_noise(i) varying_images.append(generator(1).detach()) else: generator.freeze_noise(False) for i in range(samples_to_take): varying_images.append(generator(1).detach()) grids_with_varying.insert(0, samples_to_grid(varying_images)) generator.unfreeze_all() plt.subplot(len(grids_with_varying) + 1, 1, 1) plt.axis('off') plt.imshow(to_image(original, False)) for i, grid_with_varying in enumerate(grids_with_varying): plt.subplot(len(grids_with_varying) + 1, 1, i + 2) plt.axis('off') plt.imshow(to_image(grid_with_varying, False)) if out_file is not None: plt.savefig(out_file, dpi=200)