def save_znum_images(dirname, model, z_sample, indexes, layer, ablated_units, name_template="image_{}.png", lightbox=False, batch_size=100, seed=1): progress = default_progress() os.makedirs(dirname, exist_ok=True) with torch.no_grad(): # Pass 2: now generate images z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), batch_size=batch_size, num_workers=2, pin_memory=True) saver = WorkerPool(SaveImageWorker) if ablated_units is not None: dims = max(2, max(ablated_units) + 1) # >=2 to avoid broadcast mask = torch.zeros(dims) mask[ablated_units] = 1 model.ablation[layer] = mask[None,:,None,None].cuda() for batch_num, [z] in enumerate(progress(z_loader, desc='Saving images')): z = z.cuda() start_index = batch_num * batch_size im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( 0, 2, 3, 1).cpu() for i in range(len(im)): index = i + start_index if indexes is not None: index = indexes[index].item() filename = os.path.join(dirname, name_template.format(index)) saver.add(im[i].numpy(), filename) saver.join()
def visualize_training_locations(args, corpus, cachedir, model): # Phase 2.5 Create visualizations of the corpus images. progress = default_progress() feature_shape = model.feature_shape[args.layer][2:] num_locations = numpy.prod(feature_shape).item() with torch.no_grad(): imagedir = os.path.join(cachedir, 'image') os.makedirs(imagedir, exist_ok=True) image_saver = WorkerPool(SaveImageWorker) for group, group_sample, group_location, group_indices in [ ('present', corpus.object_present_sample, corpus.object_present_location, corpus.present_indices), ('candidate', corpus.candidate_sample, corpus.candidate_location, corpus.candidate_indices) ]: for [zbatch, featloc, indices] in progress(torch.utils.data.DataLoader( TensorDataset(group_sample, group_location, group_indices), batch_size=args.inference_batch_size, num_workers=10, pin_memory=True), desc="Visualize %s" % group): zbatch = zbatch.cuda() tensor_image = model(zbatch) feature_mask = torch.zeros((len(zbatch), 1) + feature_shape) feature_mask.view(len(zbatch), -1).scatter_(1, featloc[:, None], 1) feature_mask = torch.nn.functional.adaptive_max_pool2d( feature_mask.float(), tensor_image.shape[-2:]).cuda() yellow = torch.Tensor([1.0, 1.0, -1.0])[None, :, None, None].cuda() tensor_image = tensor_image * (1 - 0.5 * feature_mask) + ( 0.5 * feature_mask * yellow) byte_image = (((tensor_image + 1) / 2) * 255).clamp( 0, 255).byte() numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() for i, index in enumerate(indices): image_saver.add( numpy_image[i], os.path.join(imagedir, '%s_%d.jpg' % (group, index))) image_saver.join()
def save_chosen_unit_images(dirname, model, z_universe, indices, shared_dir="shared_images", unitdir_template="unit_{}", name_template="image_{}.jpg", lightbox=False, batch_size=50, seed=1): all_indices = torch.unique(indices.view(-1), sorted=True) z_sample = z_universe[all_indices] progress = default_progress() sdir = os.path.join(dirname, shared_dir) created_hashdirs = set() for index in range(len(z_universe)): hd = hashdir(index) if hd not in created_hashdirs: created_hashdirs.add(hd) os.makedirs(os.path.join(sdir, hd), exist_ok=True) with torch.no_grad(): # Pass 2: now generate images z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), batch_size=batch_size, num_workers=2, pin_memory=True) saver = WorkerPool(SaveImageWorker) for batch_num, [z ] in enumerate(progress(z_loader, desc='Saving images')): z = z.cuda() start_index = batch_num * batch_size im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( 0, 2, 3, 1).cpu() for i in range(len(im)): index = all_indices[i + start_index].item() filename = os.path.join(sdir, hashdir(index), name_template.format(index)) saver.add(im[i].numpy(), filename) saver.join() linker = WorkerPool(MakeLinkWorker) for u in progress(range(len(indices)), desc='Making links'): udir = os.path.join(dirname, unitdir_template.format(u)) os.makedirs(udir, exist_ok=True) for r in range(indices.shape[1]): index = indices[u, r].item() fn = name_template.format(index) # sourcename = os.path.join('..', shared_dir, fn) sourcename = os.path.join(sdir, hashdir(index), fn) targname = os.path.join(udir, fn) linker.add(sourcename, targname) if lightbox: copy_lightbox_to(udir) linker.join()