Exemplo n.º 1
0
def make_grid(latent,
              lat_mean,
              lat_comp,
              lat_stdev,
              act_mean,
              act_comp,
              act_stdev,
              scale=1,
              n_rows=10,
              n_cols=5,
              make_plots=True,
              edit_type='latent'):
    from notebooks.notebook_utils import create_strip_centered

    inst.remove_edits()
    x_range = np.linspace(-scale, scale, n_cols,
                          dtype=np.float32)  # scale in sigmas

    rows = []
    for r in range(n_rows):
        curr_row = []
        out_batch = create_strip_centered(inst, edit_type, layer_key, [latent],
                                          act_comp[r], lat_comp[r],
                                          act_stdev[r], lat_stdev[r], act_mean,
                                          lat_mean, scale, 0, -1, n_cols)[0]
        for i, img in enumerate(out_batch):
            curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img))

        rows.append(curr_row[:n_cols])

    inst.remove_edits()

    if make_plots:
        # If more rows than columns, make several blocks side by side
        n_blocks = 2 if n_rows > n_cols else 1

        for r, data in enumerate(rows):
            # Add white borders
            imgs = pad_frames([img for _, img in data])

            coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows)
            plt.subplot(n_rows // n_blocks, n_blocks, 1 + coord)
            plt.imshow(np.hstack(imgs))

            # Custom x-axis labels
            W = imgs[0].shape[1]  # image width
            P = imgs[1].shape[1]  # padding width
            locs = [(0.5 * W + i * (W + P)) for i in range(n_cols)]
            plt.xticks(locs, ["{:.2f}".format(v) for v in x_range])
            plt.yticks([])
            plt.ylabel(f'C{r}')

        plt.tight_layout()
        plt.subplots_adjust(top=0.96)  # make room for suptitle

    return [img for row in rows for img in row]
Exemplo n.º 2
0
def img_list_generator(latent,
                       lat_mean,
                       lat_comp,
                       lat_stdev,
                       act_mean,
                       act_comp,
                       act_stdev,
                       scale=1,
                       num_frames=5,
                       make_plots=True,
                       edit_type='latent',
                       english_name=None,
                       allrand=False,
                       args=None):
    """

    :param latent:
    :param lat_mean:
    :param lat_comp:
    :param lat_stdev:
    :param act_mean:
    :param act_comp:
    :param act_stdev:
    :param scale:
    :param num_frames:
    :param make_plots:
    :param edit_type:
    :return:   The image list : tuple (imgname, image_np_array)
    """
    from notebooks.notebook_utils import create_strip_centered

    x_range = np.linspace(-scale, scale, num_frames, dtype=np.float32)

    inst.remove_edits()
    r = 0
    curr_row = []

    sigs = None
    if allrand:
        sigma_range = np.linspace(-scale, scale, num_frames)
        selected = random.sample([i for i in range(num_frames)], allrand)
        sigs = sigma_range[selected]

    out_batch = create_strip_centered(
        inst,
        mode=edit_type,
        layer=layer_key,
        latents=[latent],
        x_comp=act_comp[r],
        z_comp=lat_comp[r],
        act_stdev=act_stdev[r],
        lat_stdev=lat_stdev[r],
        act_mean=act_mean,
        lat_mean=lat_mean,
        sigma=scale,
        layer_start=0,
        layer_end=-1,
        num_frames=num_frames,
        allrand=allrand,
        sigs=sigs,
        args=args)[0]  # use [0] since only one latent layer
    for i, img in enumerate(out_batch):
        if allrand:
            curr_row.append(
                ('{}_sigma_{:.2f}_{}.png'.format(i, sigs[i],
                                                 english_name), img))
        else:
            curr_row.append(
                ('{}_sigma_{:.2f}_{}.png'.format(i, x_range[i],
                                                 english_name), img))

    return curr_row  # The image list : tuple (imgname, image_np_array)