Exemplo n.º 1
0
def save_frames(title, model_name, rootdir, frames, strip_width=10):
    test_name = prettify_name(title)
    outdir = f'{rootdir}/{model_name}/{test_name}'
    makedirs(outdir, exist_ok=True)

    # Limit maximum resolution
    max_H = 512
    real_H = frames[0][0].shape[0]
    ratio = min(1.0, max_H / real_H)

    # Combined with first 10
    strips = [np.hstack(frames) for frames in frames[:strip_width]]
    if len(strips) >= strip_width:
        left_col = np.vstack(strips[0:strip_width // 2])
        right_col = np.vstack(strips[5:10])
        grid = np.hstack([left_col, np.ones_like(left_col[:, :30]), right_col])
        im = Image.fromarray((255 * grid).astype(np.uint8))
        im = im.resize((int(ratio * im.size[0]), int(ratio * im.size[1])),
                       Image.ANTIALIAS)
        im.save(f'{outdir}/{test_name}_all.png')
    else:
        print('Too few strips to create grid, creating just strips!')

    for ex_num, strip in enumerate(frames[:strip_width]):
        im = Image.fromarray(np.uint8(255 * np.hstack(pad_frames(strip))))
        im = im.resize((int(ratio * im.size[0]), int(ratio * im.size[1])),
                       Image.ANTIALIAS)
        im.save(f'{outdir}/{test_name}_{ex_num}.png')
Exemplo n.º 2
0
def make_grid(latent,
              lat_mean,
              lat_comp,
              lat_stdev,
              act_mean,
              act_comp,
              act_stdev,
              scale=1,
              n_rows=10,
              n_cols=5,
              make_plots=True,
              edit_type='latent'):
    from notebooks.notebook_utils import create_strip_centered

    inst.remove_edits()
    x_range = np.linspace(-scale, scale, n_cols,
                          dtype=np.float32)  # scale in sigmas

    rows = []
    for r in range(n_rows):
        curr_row = []
        out_batch = create_strip_centered(inst, edit_type, layer_key, [latent],
                                          act_comp[r], lat_comp[r],
                                          act_stdev[r], lat_stdev[r], act_mean,
                                          lat_mean, scale, 0, -1, n_cols)[0]
        for i, img in enumerate(out_batch):
            curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img))

        rows.append(curr_row[:n_cols])

    inst.remove_edits()

    if make_plots:
        # If more rows than columns, make several blocks side by side
        n_blocks = 2 if n_rows > n_cols else 1

        for r, data in enumerate(rows):
            # Add white borders
            imgs = pad_frames([img for _, img in data])

            coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows)
            plt.subplot(n_rows // n_blocks, n_blocks, 1 + coord)
            plt.imshow(np.hstack(imgs))

            # Custom x-axis labels
            W = imgs[0].shape[1]  # image width
            P = imgs[1].shape[1]  # padding width
            locs = [(0.5 * W + i * (W + P)) for i in range(n_cols)]
            plt.xticks(locs, ["{:.2f}".format(v) for v in x_range])
            plt.yticks([])
            plt.ylabel(f'C{r}')

        plt.tight_layout()
        plt.subplots_adjust(top=0.96)  # make room for suptitle

    return [img for row in rows for img in row]
Exemplo n.º 3
0
def export_direction(idx, button_frame):
    name = tk.StringVar(value='')
    num_strips = tk.IntVar(value=0)
    strip_width = tk.IntVar(value=5)

    slider_values = np.array([s.get() for s in ui_state.sliders])
    slider_value = slider_values[idx]
    if (slider_values != 0).sum() > 1:
        print('Please modify only one slider')
        return
    elif slider_value == 0:
        print('Modify selected slider to set usable range (currently 0)')
        return

    popup = tk.Toplevel(root)
    popup.geometry("200x200+0+0")
    tk.Label(popup, text="Edit name").pack()
    tk.Entry(popup, textvariable=name).pack(pady=5)
    # tk.Scale(popup, from_=0, to=30, variable=num_strips,
    #    resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strips to export').pack()
    # tk.Scale(popup, from_=3, to=15, variable=strip_width,
    #    resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strip width').pack()
    tk.Button(popup, text='OK', command=popup.quit).pack()

    canceled = False

    def on_close():
        nonlocal canceled
        canceled = True
        popup.quit()

    popup.protocol("WM_DELETE_WINDOW", on_close)
    x = button_frame.winfo_rootx()
    y = button_frame.winfo_rooty()
    w = int(button_frame.winfo_geometry().split('x')[0])
    popup.geometry('%dx%d+%d+%d' % (180, 90, x + w, y))
    popup.mainloop()
    popup.destroy()

    # Update slider name
    label = get_edit_name(idx, ui_state.edit_layer_start.get(),
                          ui_state.edit_layer_end.get(), name.get())
    ui_state.scales[idx].config(label=label)

    if canceled:
        return

    params = {
        'name': name.get(),
        'sigma_range': slider_value,
        'component_index': idx,
        'act_comp': components.X_comp[idx].detach().cpu().numpy(),
        'lat_comp':
        components.Z_comp[idx].detach().cpu().numpy(),  # either Z or W
        'latent_space': model.latent_space_name(),
        'act_stdev': components.X_stdev[idx].item(),
        'lat_stdev': components.Z_stdev[idx].item(),
        'model_name': model_name,
        'output_class': ui_state.outclass.get(),  # applied onto
        'decomposition': {
            'name': args.estimator,
            'components': args.components,
            'samples': args.n,
            'layer': args.layer,
            'class_name': state.component_class  # computed from
        },
        'edit_type': ui_state.mode.get(),
        'truncation': ui_state.truncation.get(),
        'edit_start': ui_state.edit_layer_start.get(),
        'edit_end': ui_state.edit_layer_end.get() +
        1,  # show as inclusive, save as exclusive
        'example_seed': state.seed,
    }

    edit_mode_str = params['edit_type']
    if edit_mode_str == 'latent':
        edit_mode_str = model.latent_space_name().lower()

    comp_class = state.component_class
    appl_class = params['output_class']
    if comp_class != appl_class:
        comp_class = f'{comp_class}_onto_{appl_class}'

    file_ident = "{model}-{name}-{cls}-{est}-{mode}-{layer}-comp{idx}-range{start}-{end}".format(
        model=model_name,
        name=prettify_name(params['name']),
        cls=comp_class,
        est=args.estimator,
        mode=edit_mode_str,
        layer=args.layer,
        idx=idx,
        start=params['edit_start'],
        end=params['edit_end'],
    )

    out_dir = Path(__file__).parent / 'out' / 'directions'
    makedirs(out_dir / file_ident, exist_ok=True)

    with open(out_dir / f"{file_ident}.pkl", 'wb') as outfile:
        pickle.dump(params, outfile)

    print(f'Direction "{name.get()}" saved as "{file_ident}.pkl"')

    batch_size = ui_state.batch_size.get()
    len_padded = ((num_strips.get() - 1) // batch_size + 1) * batch_size
    orig_seed = state.seed

    reset_sliders()

    # Limit max resolution
    max_H = 512
    ratio = min(1.0, max_H / inst.output_shape[2])

    strips = [[] for _ in range(len_padded)]
    for b in range(0, len_padded, batch_size):
        # Resample
        resample_latent((orig_seed + b) % np.iinfo(np.int32).max)

        sigmas = np.linspace(slider_value,
                             -slider_value,
                             strip_width.get(),
                             dtype=np.float32)
        for sid, sigma in enumerate(sigmas):
            ui_state.sliders[idx].set(sigma)

            # Advance and show results on screen
            on_draw()
            root.update()
            app.update()

            batch_res = (255 * img).byte().permute(0, 2, 3,
                                                   1).detach().cpu().numpy()

            for i, data in enumerate(batch_res):
                # Save individual
                name_nodots = file_ident.replace('.', '_')
                outname = out_dir / file_ident / f"{name_nodots}_ex{b+i}_{sid}.png"
                im = Image.fromarray(data)
                im = im.resize(
                    (int(ratio * im.size[0]), int(ratio * im.size[1])),
                    Image.ANTIALIAS)
                im.save(outname)
                strips[b + i].append(data)

    for i, strip in enumerate(strips[:num_strips.get()]):
        print(f'Saving strip {i + 1}/{num_strips.get()}', end='\r', flush=True)
        data = np.hstack(pad_frames(strip))
        im = Image.fromarray(data)
        im = im.resize((int(ratio * im.size[0]), int(ratio * im.size[1])),
                       Image.ANTIALIAS)
        im.save(out_dir / file_ident / f"{file_ident}_ex{i}.png")

    # Reset to original state
    resample_latent(orig_seed)
    ui_state.sliders[idx].set(slider_value)