def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_stdev, scale=1, n_rows=10, n_cols=5, make_plots=True, edit_type='latent'): from notebooks.notebook_utils import create_strip_centered inst.remove_edits() x_range = np.linspace(-scale, scale, n_cols, dtype=np.float32) # scale in sigmas rows = [] for r in range(n_rows): curr_row = [] out_batch = create_strip_centered(inst, edit_type, layer_key, [latent], act_comp[r], lat_comp[r], act_stdev[r], lat_stdev[r], act_mean, lat_mean, scale, 0, -1, n_cols)[0] for i, img in enumerate(out_batch): curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img)) rows.append(curr_row[:n_cols]) inst.remove_edits() if make_plots: # If more rows than columns, make several blocks side by side n_blocks = 2 if n_rows > n_cols else 1 for r, data in enumerate(rows): # Add white borders imgs = pad_frames([img for _, img in data]) coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows) plt.subplot(n_rows // n_blocks, n_blocks, 1 + coord) plt.imshow(np.hstack(imgs)) # Custom x-axis labels W = imgs[0].shape[1] # image width P = imgs[1].shape[1] # padding width locs = [(0.5 * W + i * (W + P)) for i in range(n_cols)] plt.xticks(locs, ["{:.2f}".format(v) for v in x_range]) plt.yticks([]) plt.ylabel(f'C{r}') plt.tight_layout() plt.subplots_adjust(top=0.96) # make room for suptitle return [img for row in rows for img in row]
def img_list_generator(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_stdev, scale=1, num_frames=5, make_plots=True, edit_type='latent', english_name=None, allrand=False, args=None): """ :param latent: :param lat_mean: :param lat_comp: :param lat_stdev: :param act_mean: :param act_comp: :param act_stdev: :param scale: :param num_frames: :param make_plots: :param edit_type: :return: The image list : tuple (imgname, image_np_array) """ from notebooks.notebook_utils import create_strip_centered x_range = np.linspace(-scale, scale, num_frames, dtype=np.float32) inst.remove_edits() r = 0 curr_row = [] sigs = None if allrand: sigma_range = np.linspace(-scale, scale, num_frames) selected = random.sample([i for i in range(num_frames)], allrand) sigs = sigma_range[selected] out_batch = create_strip_centered( inst, mode=edit_type, layer=layer_key, latents=[latent], x_comp=act_comp[r], z_comp=lat_comp[r], act_stdev=act_stdev[r], lat_stdev=lat_stdev[r], act_mean=act_mean, lat_mean=lat_mean, sigma=scale, layer_start=0, layer_end=-1, num_frames=num_frames, allrand=allrand, sigs=sigs, args=args)[0] # use [0] since only one latent layer for i, img in enumerate(out_batch): if allrand: curr_row.append( ('{}_sigma_{:.2f}_{}.png'.format(i, sigs[i], english_name), img)) else: curr_row.append( ('{}_sigma_{:.2f}_{}.png'.format(i, x_range[i], english_name), img)) return curr_row # The image list : tuple (imgname, image_np_array)