示例#1
0
def gen_animation(G,
                  deformator,
                  direction_index,
                  out_file,
                  z=None,
                  size=None,
                  r=8):
    import imageio

    if z is None:
        z = torch.randn([1, G.dim_z], device='cuda')
    interpolation_deformed = interpolate(G,
                                         z,
                                         shifts_r=r,
                                         shifts_count=5,
                                         dim=direction_index,
                                         deformator=deformator,
                                         with_central_border=False)

    resize = Resize(size) if size is not None else lambda x: x
    img = [
        resize(to_image(torch.clamp(im, -1, 1)))
        for im in interpolation_deformed
    ]
    imageio.mimsave(out_file, img + img[::-1])
示例#2
0
def log_training(writer, imgs_dir, params, iteration, batch_it, epoch, d_loss,
                 g_loss, g_diversity, generator, fake_imgs):
    if batch_it % params.steps_per_log == 0:
        print(
            '{}% | Step {} | Epoch {}: [D loss: {}] [G loss: {}] [Diversiy: {}]'
            .format(int(100.0 * iteration / params.steps), iteration, epoch,
                    d_loss.item(), g_loss.item(), g_diversity.item()))

        writer.add_scalar('discriminator loss', d_loss.item(), iteration)
        writer.add_scalar('generator loss', g_loss.item(), iteration)
        writer.add_scalar('epoch', epoch, iteration)

    if params.steps_per_activations_log is not None and \
            batch_it % params.steps_per_activations_log == 0:
        log_buckets_activations(writer, generator, iteration)

    if batch_it % params.steps_per_img_save == 0:
        torchvision.utils.save_image(
            fake_imgs.data[:25],
            os.path.join(imgs_dir, 'e{}_{}.png'.format(epoch, iteration)),
            nrow=5,
            normalize=True)
        writer.add_image(
            'generated',
            torchvision.transforms.ToTensor()(to_image(
                torchvision.utils.make_grid(fake_imgs[:25], 5))), iteration)
def make_interpolation_chart(G, deformator=None, z=None,
                             shifts_r=10, shifts_count=5,
                             dims=None, dims_count=10, texts=None, direction_size=10, **kwargs):
    with_deformation = deformator is not None
    if with_deformation:
        deformator_is_training = deformator.training
        deformator.eval()
    z = z if z is not None else make_noise(1, G.dim_z).cuda()

    if with_deformation:
        original_img = G(z).cpu()
    else:
        original_img = G(z).cpu()
    imgs = []
    if dims is None:
        dim_range = min(dims_count, direction_size)
        dims = range(dim_range)
    for i in dims:
        imgs.append(interpolate(G, z, shifts_r, shifts_count, i, direction_size, deformator))

    if z.shape[0] == 1:
        rows_count = len(imgs) + 1
    else:
        rows_count = len(imgs)
    fig, axs = plt.subplots(rows_count, **kwargs)


    if z.shape[0] == 1:
        axs[0].axis('off')
        axs[0].imshow(to_image(original_img, True))
        axs = axs[1:]

    if texts is None:
        texts = dims
    for ax, shifts_imgs, text in zip(axs, imgs, texts):
        ax.axis('off')
        plt.subplots_adjust(left=0.001)
        ax.imshow(to_image(make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=0), True))
        ax.text(-20, 21, str(text), fontsize=10)

    if deformator is not None and deformator_is_training:
        deformator.train()

    plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
                wspace=None, hspace=0)#调整子图间距

    return fig
示例#4
0
def make_interpolation_chart(G,
                             deformator=None,
                             z=None,
                             shifts_r=10.0,
                             shifts_count=5,
                             dims=None,
                             dims_count=10,
                             texts=None,
                             **kwargs):
    with_deformation = deformator is not None
    if with_deformation:
        deformator_is_training = deformator.training
        deformator.eval()
    z = z if z is not None else make_noise(1, G.dim_z).cuda()

    if with_deformation:
        original_img = G(z).cpu()
    else:
        original_img = G(z).cpu()
    imgs = []
    if dims is None:
        dims = range(dims_count)
    for i in dims:
        imgs.append(interpolate(G, z, shifts_r, shifts_count, i, deformator))

    rows_count = len(imgs) + 1
    fig, axs = plt.subplots(rows_count, **kwargs)

    axs[0].axis('off')
    axs[0].imshow(to_image(original_img, True))

    if texts is None:
        texts = dims
    for ax, shifts_imgs, text in zip(axs[1:], imgs, texts):
        ax.axis('off')
        plt.subplots_adjust(left=0.5)
        ax.imshow(
            to_image(
                make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=1),
                True))
        ax.text(-20, 21, str(text), fontsize=10)

    if deformator is not None and deformator_is_training:
        deformator.train()

    return fig
    discovered_annotation += '{}: {}\n'.format(d[0], d[1])
print('human-annotated directions:\n' + discovered_annotation)

rows = 8
plt.figure(figsize=(5, rows), dpi=250)

# set desired class for conditional GAN
# if is_conditional(G):
#     G.set_classes(12)

annotated = list(deformator.annotation.values())
inspection_dim = annotated[0]
# zs = torch.randn([rows, G.dim_z] if type(G.dim_z) == int else [rows] + G.dim_z, device='cuda')
zs = torch.randn([rows, G.dim_z] if type(G.dim_z) == int else [rows] + G.dim_z)

for z, i in zip(zs, range(rows)):
    interpolation_deformed = interpolate(G,
                                         z.unsqueeze(0),
                                         shifts_r=16,
                                         shifts_count=3,
                                         dim=inspection_dim,
                                         deformator=deformator,
                                         with_central_border=True)

    plt.subplot(rows, 1, i + 1)
    plt.axis('off')
    grid = make_grid(interpolation_deformed, nrow=11, padding=1, pad_value=0.0)
    grid = torch.clamp(grid, -1, 1)

    plt.imshow(to_image(grid))
示例#6
0
def inspect_path_generator_freeze(generator, samples_to_take=7, out_file=None):
    # inspect layers freeze
    buckets = list(generator.buckets())

    if generator.noise_is_discrete:
        noises_count = generator.const_noise.shape[0]
        noise_to_take = random.randint(0, noises_count - 1)
        generator.freeze_noise(noise_to_take)
    else:
        noises_count = 1
        generator.freeze_noise(True)

    blocks_to_take = \
        [random.randint(0, len(bucket.blocks) - 1) for bucket in buckets]

    def reset_model():
        for i, bucket in enumerate(generator.buckets()):
            bucket.freeze(blocks_to_take[i])

    def samples_to_grid(samples):
        return torchvision.utils.make_grid(torch.cat(samples), samples_to_take, pad_value=1)

    reset_model()
    original = generator(1).detach()

    grids_with_varying = []
    for i_layer, varying_bucket in enumerate(buckets):
        varying_images = []

        all_indices = list(range(len(varying_bucket.blocks)))
        if len(all_indices) > 1:
            all_indices.remove(blocks_to_take[i_layer])
        for i in all_indices[:samples_to_take]:
            varying_bucket.freeze(i)
            varying_images.append(generator(1).detach())

        grids_with_varying.append(samples_to_grid(varying_images))
        reset_model()

    # add noise variation images
    if noises_count > 1 or not generator.noise_is_discrete:
        varying_images = []

        if generator.noise_is_discrete:
            noise_indices = list(range(noises_count))
            noise_indices.remove(noise_to_take)
            for i in noise_indices[:samples_to_take]:
                generator.freeze_noise(i)
                varying_images.append(generator(1).detach())
        else:
            generator.freeze_noise(False)
            for i in range(samples_to_take):
                varying_images.append(generator(1).detach())

        grids_with_varying.insert(0, samples_to_grid(varying_images))
    generator.unfreeze_all()

    plt.subplot(len(grids_with_varying) + 1, 1, 1)
    plt.axis('off')
    plt.imshow(to_image(original, False))
    for i, grid_with_varying in enumerate(grids_with_varying):
        plt.subplot(len(grids_with_varying) + 1, 1, i + 2)
        plt.axis('off')
        plt.imshow(to_image(grid_with_varying, False))
    if out_file is not None:
        plt.savefig(out_file, dpi=200)