def compute_topk_imgs(self, mode='mean'):
     if mode == 'mean':
         self.topk = tally.tally_topk(self._mean_activations,
                                      dataset=self.ds,
                                      sample_size=self.sample_size,
                                      batch_size=10)
     else:  # It can only be max if not mean
         self.topk = tally.tally_topk(self._max_activations,
                                      dataset=self.ds,
                                      sample_size=self.sample_size,
                                      batch_size=10)
Esempio n. 2
0
def main():
    args = parseargs()
    resdir = 'results/%s-%s-%s' % (args.model, args.dataset, args.seg)
    if args.layer is not None:
        resdir += '-' + args.layer
    if args.quantile != 0.005:
        resdir += ('-%g' % (args.quantile * 1000))
    if args.thumbsize != 100:
        resdir += ('-t%d' % (args.thumbsize))
    resfile = pidfile.exclusive_dirfn(resdir)

    model = load_model(args)
    layername = instrumented_layername(args)
    model.retain_layer(layername)
    dataset = load_dataset(args, model=model.model)
    upfn = make_upfn(args, dataset, model, layername)
    sample_size = len(dataset)
    is_generator = (args.model == 'progan')
    percent_level = 1.0 - args.quantile
    iou_threshold = args.miniou
    image_row_width = 5
    torch.set_grad_enabled(False)

    # Tally rq.np (representation quantile, unconditional).
    pbar.descnext('rq')
    def compute_samples(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
    rq = tally.tally_quantile(compute_samples, dataset,
                              sample_size=sample_size,
                              r=8192,
                              num_workers=100,
                              pin_memory=True,
                              cachefile=resfile('rq.npz'))

    # Create visualizations - first we need to know the topk
    pbar.descnext('topk')
    def compute_image_max(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        acts = acts.view(acts.shape[0], acts.shape[1], -1)
        acts = acts.max(2)[0]
        return acts
    topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size,
            batch_size=50, num_workers=30, pin_memory=True,
            cachefile=resfile('topk.npz'))

    # Visualize top-activating patches of top-activatin images.
    pbar.descnext('unit_images')
    image_size, image_source = None, None
    if is_generator:
        image_size = model(dataset[0][0].cuda()[None,...]).shape[2:]
    else:
        image_source = dataset
    iv = imgviz.ImageVisualizer((args.thumbsize, args.thumbsize),
        image_size=image_size,
        source=dataset,
        quantiles=rq,
        level=rq.quantiles(percent_level))
    def compute_acts(data_batch, *ignored_class):
        data_batch = data_batch.cuda()
        out_batch = model(data_batch)
        acts_batch = model.retained_layer(layername)
        if is_generator:
            return (acts_batch, out_batch)
        else:
            return (acts_batch, data_batch)
    unit_images = iv.masked_images_for_topk(
            compute_acts, dataset, topk,
            k=image_row_width, num_workers=30, pin_memory=True,
            cachefile=resfile('top%dimages.npz' % image_row_width))
    pbar.descnext('saving images')
    imgsave.save_image_set(unit_images, resfile('image/unit%d.jpg'),
            sourcefile=resfile('top%dimages.npz' % image_row_width))

    # Compute IoU agreement between segmentation labels and every unit
    # Grab the 99th percentile, and tally conditional means at that level.
    level_at_99 = rq.quantiles(percent_level).cuda()[None,:,None,None]

    segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg)
    renorm = renormalize.renormalizer(dataset, target='zc')
    def compute_conditional_indicator(batch, *args):
        data_batch = batch.cuda()
        out_batch = model(data_batch)
        image_batch = out_batch if is_generator else renorm(data_batch)
        seg = segmodel.segment_batch(image_batch, downsample=4)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts > level_at_99).float() # indicator
        return tally.conditional_samples(iacts, seg)

    pbar.descnext('condi99')
    condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
            dataset, sample_size=sample_size,
            num_workers=3, pin_memory=True,
            cachefile=resfile('condi99.npz'))

    # Now summarize the iou stats and graph the units
    iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
    unit_label_99 = [
            (concept.item(), seglabels[concept],
                segcatlabels[concept], bestiou.item())
            for (bestiou, concept) in zip(*iou_99.max(0))]
    labelcat_list = [labelcat
            for concept, label, labelcat, iou in unit_label_99
            if iou > iou_threshold]
    save_conceptcat_graph(resfile('concepts_99.svg'), labelcat_list)
    dump_json_file(resfile('report.json'), dict(
            header=dict(
                name='%s %s %s' % (args.model, args.dataset, args.seg),
                image='concepts_99.svg'),
            units=[
                dict(image='image/unit%d.jpg' % u,
                    unit=u, iou=iou, label=label, cat=labelcat[1])
                for u, (concept, label, labelcat, iou)
                in enumerate(unit_label_99)])
            )
    copy_static_file('report.html', resfile('+report.html'))
    resfile.done();
Esempio n. 3
0
            sr = inst_net.retained_layer('features.' + layername)
            return ((sr > s_99).float() * (ur > u_99).float()).permute(
                0, 2, 3, 1).reshape(-1, ur.size(1))

    intersect_99 = tally.tally_mean(intersect_99_fn,
                                    ds,
                                    cachefile=os.path.join(
                                        qd.dir(layername), 'intersect_99.npz'))

    def compute_image_max(batch):
        inst_net(batch.cuda())
        return inst_net.retained_layer('features.' +
                                       layername).max(3)[0].max(2)[0]

    s_topk = tally.tally_topk(compute_image_max,
                              sds,
                              cachefile=os.path.join(qd.dir(layername),
                                                     's_topk.npz'))

    def compute_acts(image_batch):
        inst_net(image_batch.cuda())
        acts_batch = inst_net.retained_layer('features.' + layername)
        return (acts_batch, image_batch)

    iv = imgviz.ImageVisualizer(128, quantiles=s_rq, source=sds)
    unit_images = iv.masked_images_for_topk(compute_acts, sds, s_topk, k=5)
    os.makedirs(os.path.join(qd.dir(layername), 's_imgs'), exist_ok=True)
    imgsave.save_image_set(
        unit_images, os.path.join(qd.dir(layername), 's_imgs/unit%d.jpg'))

    iv = imgviz.ImageVisualizer(128, quantiles=u_rq, source=uds)
    unit_images = iv.masked_images_for_topk(compute_acts, uds, s_topk, k=5)
Esempio n. 4
0
def main():

    # Load the arguments
    args = parse_option()

    dataset = args.dataset
    sample_size = args.sample_size
    layername = args.layer

    # Other values for places and imagenet MoCo model
    epoch = 240
    image_size = 224
    crop = 0.2
    crop_padding = 32
    batch_size = 1
    num_workers = 24
    train_sampler = None
    moco = True

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    # Set appropriate paths
    folder_path = "/data/vision/torralba/ganprojects/yyou/CMC_data/{}_models".format(
        dataset)
    model_name = "/{}_MoCo0.999_softmax_16384_resnet50_lr_0.03".format(dataset) \
                     + "_decay_0.0001_bsz_128_crop_0.2_aug_CJ"
    epoch_name = "/ckpt_epoch_{}.pth".format(epoch)
    my_path = folder_path + model_name + epoch_name

    data_path = "/data/vision/torralba/datasets/"
    web_path = "/data/vision/torralba/scratch/yyou/wednesday/dissection/"

    if dataset == "imagenet":
        data_path += "imagenet_pytorch"
        web_path += dataset + "/" + layername
    elif dataset == "places365":
        data_path += "places/places365_standard/places365standard_easyformat"
        web_path += dataset + "/" + layername

    # Create web path folder directory for this layer
    if not os.path.exists(web_path):
        os.makedirs(web_path)

    # Load validation data loader
    val_folder = data_path + "/val"
    val_transform = transforms.Compose([
        transforms.Resize(image_size + crop_padding),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])

    ds = QuickImageFolder(val_folder,
                          transform=val_transform,
                          shuffle=True,
                          two_crop=False)
    ds_loader = torch.utils.data.DataLoader(ds,
                                            batch_size=batch_size,
                                            shuffle=(train_sampler is None),
                                            num_workers=num_workers,
                                            pin_memory=True,
                                            sampler=train_sampler)

    # Load model from checkpoint
    checkpoint = torch.load(my_path)
    model_checkpoint = {
        key.replace(".module", ""): val
        for key, val in checkpoint['model'].items()
    }

    model = InsResNet50(parallel=False)
    model.load_state_dict(model_checkpoint)
    model = nethook.InstrumentedModel(model)
    model.cuda()

    # Renormalize RGB data from the staistical scaling in ds to [-1...1] range
    renorm = renormalize.renormalizer(source=ds, target='zc')

    # Retain desired layer with nethook
    batch = next(iter(ds_loader))[0]
    model.retain_layer(layername)
    model(batch.cuda())
    acts = model.retained_layer(layername).cpu()

    upfn = upsample.upsampler(
        target_shape=(56, 56),
        data_shape=(7, 7),
    )

    def flatten_activations(batch, *args):
        image_batch = batch
        _ = model(image_batch.cuda())
        acts = model.retained_layer(layername).cpu()
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    def tally_quantile_for_layer(layername):
        rq = tally.tally_quantile(
            flatten_activations,
            dataset=ds,
            sample_size=sample_size,
            batch_size=100,
            cachefile='results/{}/{}_rq_cache.npz'.format(dataset, layername))
        return rq

    rq = tally_quantile_for_layer(layername)

    # Visualize range of activations (statistics of each filter over the sample images)
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    axs = axs.flatten()
    quantiles = [0.5, 0.8, 0.9, 0.99]
    for i in range(4):
        axs[i].plot(rq.quantiles(quantiles[i]))
        axs[i].set_title("Rq quantiles ({})".format(quantiles[i]))
    fig.suptitle("{}  -  sample size of {}".format(dataset, sample_size))
    plt.savefig(web_path + "/rq_quantiles")

    # Set the image visualizer with the rq and percent level
    iv = imgviz.ImageVisualizer(224,
                                source=ds,
                                percent_level=0.95,
                                quantiles=rq)

    # Tally top k images that maximize the mean activation of the filter
    def max_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).max(2)[0]

    def mean_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).mean(2)

    topk = tally.tally_topk(
        mean_activations,
        dataset=ds,
        sample_size=sample_size,
        batch_size=100,
        cachefile='results/{}/{}_cache_mean_topk.npz'.format(
            dataset, layername))

    top_indexes = topk.result()[1]

    # Visualize top-activating images for a particular unit
    if not os.path.exists(web_path + "/top_activating_imgs"):
        os.makedirs(web_path + "/top_activating_imgs")

    def top_activating_imgs(unit):
        img_ids = [i for i in top_indexes[unit, :12]]
        images = [iv.masked_image(ds[i][0], \
                      model.retained_layer(layername)[0], unit) \
                      for i in img_ids]
        preds = [ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()]\
                    for i in img_ids]

        fig, axs = plt.subplots(3, 4, figsize=(16, 12))
        axs = axs.flatten()

        for i in range(12):
            axs[i].imshow(images[i])
            axs[i].tick_params(axis='both', which='both', bottom=False, \
                               left=False, labelbottom=False, labelleft=False)
            axs[i].set_title("img {} \n pred: {}".format(img_ids[i], preds[i]))
        fig.suptitle("unit {}".format(unit))

        plt.savefig(web_path + "/top_activating_imgs/unit_{}".format(unit))

    for unit in np.random.randint(len(top_indexes), size=10):
        top_activating_imgs(unit)

    def compute_activations(image_batch):
        image_batch = image_batch.cuda()
        _ = model(image_batch)
        acts_batch = model.retained_layer(layername)
        return acts_batch

    unit_images = iv.masked_images_for_topk(
        compute_activations,
        ds,
        topk,
        k=5,
        num_workers=10,
        pin_memory=True,
        cachefile='results/{}/{}_cache_top10images.npz'.format(
            dataset, layername))

    file = open("results/{}/unit_images.pkl".format(dataset, layername), 'wb')
    pickle.dump(unit_images, file)

    # Load a segmentation model
    segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')

    # Intersections between every unit's 99th activation
    # and every segmentation class identified
    level_at_99 = rq.quantiles(0.99).cuda()[None, :, None, None]

    def compute_selected_segments(batch, *args):
        image_batch = batch.cuda()
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts >
                 level_at_99).float()  # indicator where > 0.99 percentile.
        return tally.conditional_samples(iacts, seg)

    condi99 = tally.tally_conditional_mean(
        compute_selected_segments,
        dataset=ds,
        sample_size=sample_size,
        cachefile='results/{}/{}_cache_condi99.npz'.format(dataset, layername))

    iou99 = tally.iou_from_conditional_indicator_mean(condi99)
    file = open("results/{}/{}_iou99.pkl".format(dataset, layername), 'wb')
    pickle.dump(iou99, file)

    # Show units with best match to a segmentation class
    iou_unit_label_99 = sorted(
        [(unit, concept.item(), seglabels[concept], bestiou.item())
         for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
        key=lambda x: -x[-1])

    fig, axs = plt.subplots(20, 1, figsize=(20, 80))
    axs = axs.flatten()

    for i, (unit, concept, label, score) in enumerate(iou_unit_label_99[:20]):
        axs[i].imshow(unit_images[unit])
        axs[i].set_title('unit %d; iou %g; label "%s"' % (unit, score, label))
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    plt.savefig(web_path + "/best_unit_segmentation")