예제 #1
0
def save_znum_images(dirname, model, z_sample, indexes, layer, ablated_units,
        name_template="image_{}.png", lightbox=False, batch_size=100, seed=1):
    progress = default_progress()
    os.makedirs(dirname, exist_ok=True)
    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
                    batch_size=batch_size, num_workers=2,
                    pin_memory=True)
        saver = WorkerPool(SaveImageWorker)
        if ablated_units is not None:
            dims = max(2, max(ablated_units) + 1) # >=2 to avoid broadcast
            mask = torch.zeros(dims)
            mask[ablated_units] = 1
            model.ablation[layer] = mask[None,:,None,None].cuda()
        for batch_num, [z] in enumerate(progress(z_loader,
                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
                    0, 2, 3, 1).cpu()
            for i in range(len(im)):
                index = i + start_index
                if indexes is not None:
                    index = indexes[index].item()
                filename = os.path.join(dirname, name_template.format(index))
                saver.add(im[i].numpy(), filename)
    saver.join()
예제 #2
0
def visualize_training_locations(args, corpus, cachedir, model):
    # Phase 2.5 Create visualizations of the corpus images.
    progress = default_progress()
    feature_shape = model.feature_shape[args.layer][2:]
    num_locations = numpy.prod(feature_shape).item()
    with torch.no_grad():
        imagedir = os.path.join(cachedir, 'image')
        os.makedirs(imagedir, exist_ok=True)
        image_saver = WorkerPool(SaveImageWorker)
        for group, group_sample, group_location, group_indices in [
            ('present', corpus.object_present_sample,
             corpus.object_present_location, corpus.present_indices),
            ('candidate', corpus.candidate_sample, corpus.candidate_location,
             corpus.candidate_indices)
        ]:
            for [zbatch, featloc,
                 indices] in progress(torch.utils.data.DataLoader(
                     TensorDataset(group_sample, group_location,
                                   group_indices),
                     batch_size=args.inference_batch_size,
                     num_workers=10,
                     pin_memory=True),
                                      desc="Visualize %s" % group):
                zbatch = zbatch.cuda()
                tensor_image = model(zbatch)
                feature_mask = torch.zeros((len(zbatch), 1) + feature_shape)
                feature_mask.view(len(zbatch),
                                  -1).scatter_(1, featloc[:, None], 1)
                feature_mask = torch.nn.functional.adaptive_max_pool2d(
                    feature_mask.float(), tensor_image.shape[-2:]).cuda()
                yellow = torch.Tensor([1.0, 1.0, -1.0])[None, :, None,
                                                        None].cuda()
                tensor_image = tensor_image * (1 - 0.5 * feature_mask) + (
                    0.5 * feature_mask * yellow)
                byte_image = (((tensor_image + 1) / 2) * 255).clamp(
                    0, 255).byte()
                numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy()
                for i, index in enumerate(indices):
                    image_saver.add(
                        numpy_image[i],
                        os.path.join(imagedir, '%s_%d.jpg' % (group, index)))
    image_saver.join()
예제 #3
0
def save_chosen_unit_images(dirname,
                            model,
                            z_universe,
                            indices,
                            shared_dir="shared_images",
                            unitdir_template="unit_{}",
                            name_template="image_{}.jpg",
                            lightbox=False,
                            batch_size=50,
                            seed=1):
    all_indices = torch.unique(indices.view(-1), sorted=True)
    z_sample = z_universe[all_indices]
    progress = default_progress()
    sdir = os.path.join(dirname, shared_dir)
    created_hashdirs = set()
    for index in range(len(z_universe)):
        hd = hashdir(index)
        if hd not in created_hashdirs:
            created_hashdirs.add(hd)
            os.makedirs(os.path.join(sdir, hd), exist_ok=True)
    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample),
                                               batch_size=batch_size,
                                               num_workers=2,
                                               pin_memory=True)
        saver = WorkerPool(SaveImageWorker)
        for batch_num, [z
                        ] in enumerate(progress(z_loader,
                                                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute(
                0, 2, 3, 1).cpu()
            for i in range(len(im)):
                index = all_indices[i + start_index].item()
                filename = os.path.join(sdir, hashdir(index),
                                        name_template.format(index))
                saver.add(im[i].numpy(), filename)
        saver.join()
    linker = WorkerPool(MakeLinkWorker)
    for u in progress(range(len(indices)), desc='Making links'):
        udir = os.path.join(dirname, unitdir_template.format(u))
        os.makedirs(udir, exist_ok=True)
        for r in range(indices.shape[1]):
            index = indices[u, r].item()
            fn = name_template.format(index)
            # sourcename = os.path.join('..', shared_dir, fn)
            sourcename = os.path.join(sdir, hashdir(index), fn)
            targname = os.path.join(udir, fn)
            linker.add(sourcename, targname)
        if lightbox:
            copy_lightbox_to(udir)
    linker.join()