def show_seg_results(self):
        if self.unit_images is None:
            self.compute_top_unit_imgs()
        level_at_99 = self.rq.quantiles(0.99).cuda()[None, :, None, None]
        sample_size = 20

        def compute_selected_segments(batch, *args):
            img, seg = batch
            #     show(iv.segmentation(seg))
            image_batch = img.cuda()
            seg_batch = seg.cuda()
            _ = self.model(image_batch)
            acts = self.model.retained_layer(self.layername)
            hacts = self.upfn(acts)
            iacts = (hacts >
                     level_at_99).float()  # indicator where > 0.99 percentile.
            return tally.conditional_samples(iacts, seg_batch)

        condi99 = tally.tally_conditional_mean(compute_selected_segments,
                                               dataset=self.ds,
                                               sample_size=sample_size,
                                               loader=self.ds_loader,
                                               pass_with_lbl=True)

        self.iou99 = tally.iou_from_conditional_indicator_mean(condi99)
        bolded_string = "\033[1m" + self.layername + "\033[0m"

        print(bolded_string)
        iou_unit_label_99 = sorted([
            (unit, concept.item(), self.seglabels[int(concept)],
             bestiou.item())
            for unit, (bestiou, concept) in enumerate(zip(*self.iou99.max(0)))
        ],
                                   key=lambda x: -x[-1])
        for unit, concept, label, score in iou_unit_label_99[:20]:
            show([
                'unit %d; iou %g; label "%s"' % (unit, score, label),
                [self.unit_images[unit]]
            ])
Beispiel #2
0
def main():
    args = parseargs()
    resdir = 'results/%s-%s-%s' % (args.model, args.dataset, args.seg)
    if args.layer is not None:
        resdir += '-' + args.layer
    if args.quantile != 0.005:
        resdir += ('-%g' % (args.quantile * 1000))
    if args.thumbsize != 100:
        resdir += ('-t%d' % (args.thumbsize))
    resfile = pidfile.exclusive_dirfn(resdir)

    model = load_model(args)
    layername = instrumented_layername(args)
    model.retain_layer(layername)
    dataset = load_dataset(args, model=model.model)
    upfn = make_upfn(args, dataset, model, layername)
    sample_size = len(dataset)
    is_generator = (args.model == 'progan')
    percent_level = 1.0 - args.quantile
    iou_threshold = args.miniou
    image_row_width = 5
    torch.set_grad_enabled(False)

    # Tally rq.np (representation quantile, unconditional).
    pbar.descnext('rq')
    def compute_samples(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
    rq = tally.tally_quantile(compute_samples, dataset,
                              sample_size=sample_size,
                              r=8192,
                              num_workers=100,
                              pin_memory=True,
                              cachefile=resfile('rq.npz'))

    # Create visualizations - first we need to know the topk
    pbar.descnext('topk')
    def compute_image_max(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        acts = acts.view(acts.shape[0], acts.shape[1], -1)
        acts = acts.max(2)[0]
        return acts
    topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size,
            batch_size=50, num_workers=30, pin_memory=True,
            cachefile=resfile('topk.npz'))

    # Visualize top-activating patches of top-activatin images.
    pbar.descnext('unit_images')
    image_size, image_source = None, None
    if is_generator:
        image_size = model(dataset[0][0].cuda()[None,...]).shape[2:]
    else:
        image_source = dataset
    iv = imgviz.ImageVisualizer((args.thumbsize, args.thumbsize),
        image_size=image_size,
        source=dataset,
        quantiles=rq,
        level=rq.quantiles(percent_level))
    def compute_acts(data_batch, *ignored_class):
        data_batch = data_batch.cuda()
        out_batch = model(data_batch)
        acts_batch = model.retained_layer(layername)
        if is_generator:
            return (acts_batch, out_batch)
        else:
            return (acts_batch, data_batch)
    unit_images = iv.masked_images_for_topk(
            compute_acts, dataset, topk,
            k=image_row_width, num_workers=30, pin_memory=True,
            cachefile=resfile('top%dimages.npz' % image_row_width))
    pbar.descnext('saving images')
    imgsave.save_image_set(unit_images, resfile('image/unit%d.jpg'),
            sourcefile=resfile('top%dimages.npz' % image_row_width))

    # Compute IoU agreement between segmentation labels and every unit
    # Grab the 99th percentile, and tally conditional means at that level.
    level_at_99 = rq.quantiles(percent_level).cuda()[None,:,None,None]

    segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg)
    renorm = renormalize.renormalizer(dataset, target='zc')
    def compute_conditional_indicator(batch, *args):
        data_batch = batch.cuda()
        out_batch = model(data_batch)
        image_batch = out_batch if is_generator else renorm(data_batch)
        seg = segmodel.segment_batch(image_batch, downsample=4)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts > level_at_99).float() # indicator
        return tally.conditional_samples(iacts, seg)

    pbar.descnext('condi99')
    condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
            dataset, sample_size=sample_size,
            num_workers=3, pin_memory=True,
            cachefile=resfile('condi99.npz'))

    # Now summarize the iou stats and graph the units
    iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
    unit_label_99 = [
            (concept.item(), seglabels[concept],
                segcatlabels[concept], bestiou.item())
            for (bestiou, concept) in zip(*iou_99.max(0))]
    labelcat_list = [labelcat
            for concept, label, labelcat, iou in unit_label_99
            if iou > iou_threshold]
    save_conceptcat_graph(resfile('concepts_99.svg'), labelcat_list)
    dump_json_file(resfile('report.json'), dict(
            header=dict(
                name='%s %s %s' % (args.model, args.dataset, args.seg),
                image='concepts_99.svg'),
            units=[
                dict(image='image/unit%d.jpg' % u,
                    unit=u, iou=iou, label=label, cat=labelcat[1])
                for u, (concept, label, labelcat, iou)
                in enumerate(unit_label_99)])
            )
    copy_static_file('report.html', resfile('+report.html'))
    resfile.done();
Beispiel #3
0
def main():

    # Load the arguments
    args = parse_option()

    dataset = args.dataset
    sample_size = args.sample_size
    layername = args.layer

    # Other values for places and imagenet MoCo model
    epoch = 240
    image_size = 224
    crop = 0.2
    crop_padding = 32
    batch_size = 1
    num_workers = 24
    train_sampler = None
    moco = True

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    # Set appropriate paths
    folder_path = "/data/vision/torralba/ganprojects/yyou/CMC_data/{}_models".format(
        dataset)
    model_name = "/{}_MoCo0.999_softmax_16384_resnet50_lr_0.03".format(dataset) \
                     + "_decay_0.0001_bsz_128_crop_0.2_aug_CJ"
    epoch_name = "/ckpt_epoch_{}.pth".format(epoch)
    my_path = folder_path + model_name + epoch_name

    data_path = "/data/vision/torralba/datasets/"
    web_path = "/data/vision/torralba/scratch/yyou/wednesday/dissection/"

    if dataset == "imagenet":
        data_path += "imagenet_pytorch"
        web_path += dataset + "/" + layername
    elif dataset == "places365":
        data_path += "places/places365_standard/places365standard_easyformat"
        web_path += dataset + "/" + layername

    # Create web path folder directory for this layer
    if not os.path.exists(web_path):
        os.makedirs(web_path)

    # Load validation data loader
    val_folder = data_path + "/val"
    val_transform = transforms.Compose([
        transforms.Resize(image_size + crop_padding),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])

    ds = QuickImageFolder(val_folder,
                          transform=val_transform,
                          shuffle=True,
                          two_crop=False)
    ds_loader = torch.utils.data.DataLoader(ds,
                                            batch_size=batch_size,
                                            shuffle=(train_sampler is None),
                                            num_workers=num_workers,
                                            pin_memory=True,
                                            sampler=train_sampler)

    # Load model from checkpoint
    checkpoint = torch.load(my_path)
    model_checkpoint = {
        key.replace(".module", ""): val
        for key, val in checkpoint['model'].items()
    }

    model = InsResNet50(parallel=False)
    model.load_state_dict(model_checkpoint)
    model = nethook.InstrumentedModel(model)
    model.cuda()

    # Renormalize RGB data from the staistical scaling in ds to [-1...1] range
    renorm = renormalize.renormalizer(source=ds, target='zc')

    # Retain desired layer with nethook
    batch = next(iter(ds_loader))[0]
    model.retain_layer(layername)
    model(batch.cuda())
    acts = model.retained_layer(layername).cpu()

    upfn = upsample.upsampler(
        target_shape=(56, 56),
        data_shape=(7, 7),
    )

    def flatten_activations(batch, *args):
        image_batch = batch
        _ = model(image_batch.cuda())
        acts = model.retained_layer(layername).cpu()
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    def tally_quantile_for_layer(layername):
        rq = tally.tally_quantile(
            flatten_activations,
            dataset=ds,
            sample_size=sample_size,
            batch_size=100,
            cachefile='results/{}/{}_rq_cache.npz'.format(dataset, layername))
        return rq

    rq = tally_quantile_for_layer(layername)

    # Visualize range of activations (statistics of each filter over the sample images)
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    axs = axs.flatten()
    quantiles = [0.5, 0.8, 0.9, 0.99]
    for i in range(4):
        axs[i].plot(rq.quantiles(quantiles[i]))
        axs[i].set_title("Rq quantiles ({})".format(quantiles[i]))
    fig.suptitle("{}  -  sample size of {}".format(dataset, sample_size))
    plt.savefig(web_path + "/rq_quantiles")

    # Set the image visualizer with the rq and percent level
    iv = imgviz.ImageVisualizer(224,
                                source=ds,
                                percent_level=0.95,
                                quantiles=rq)

    # Tally top k images that maximize the mean activation of the filter
    def max_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).max(2)[0]

    def mean_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).mean(2)

    topk = tally.tally_topk(
        mean_activations,
        dataset=ds,
        sample_size=sample_size,
        batch_size=100,
        cachefile='results/{}/{}_cache_mean_topk.npz'.format(
            dataset, layername))

    top_indexes = topk.result()[1]

    # Visualize top-activating images for a particular unit
    if not os.path.exists(web_path + "/top_activating_imgs"):
        os.makedirs(web_path + "/top_activating_imgs")

    def top_activating_imgs(unit):
        img_ids = [i for i in top_indexes[unit, :12]]
        images = [iv.masked_image(ds[i][0], \
                      model.retained_layer(layername)[0], unit) \
                      for i in img_ids]
        preds = [ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()]\
                    for i in img_ids]

        fig, axs = plt.subplots(3, 4, figsize=(16, 12))
        axs = axs.flatten()

        for i in range(12):
            axs[i].imshow(images[i])
            axs[i].tick_params(axis='both', which='both', bottom=False, \
                               left=False, labelbottom=False, labelleft=False)
            axs[i].set_title("img {} \n pred: {}".format(img_ids[i], preds[i]))
        fig.suptitle("unit {}".format(unit))

        plt.savefig(web_path + "/top_activating_imgs/unit_{}".format(unit))

    for unit in np.random.randint(len(top_indexes), size=10):
        top_activating_imgs(unit)

    def compute_activations(image_batch):
        image_batch = image_batch.cuda()
        _ = model(image_batch)
        acts_batch = model.retained_layer(layername)
        return acts_batch

    unit_images = iv.masked_images_for_topk(
        compute_activations,
        ds,
        topk,
        k=5,
        num_workers=10,
        pin_memory=True,
        cachefile='results/{}/{}_cache_top10images.npz'.format(
            dataset, layername))

    file = open("results/{}/unit_images.pkl".format(dataset, layername), 'wb')
    pickle.dump(unit_images, file)

    # Load a segmentation model
    segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')

    # Intersections between every unit's 99th activation
    # and every segmentation class identified
    level_at_99 = rq.quantiles(0.99).cuda()[None, :, None, None]

    def compute_selected_segments(batch, *args):
        image_batch = batch.cuda()
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts >
                 level_at_99).float()  # indicator where > 0.99 percentile.
        return tally.conditional_samples(iacts, seg)

    condi99 = tally.tally_conditional_mean(
        compute_selected_segments,
        dataset=ds,
        sample_size=sample_size,
        cachefile='results/{}/{}_cache_condi99.npz'.format(dataset, layername))

    iou99 = tally.iou_from_conditional_indicator_mean(condi99)
    file = open("results/{}/{}_iou99.pkl".format(dataset, layername), 'wb')
    pickle.dump(iou99, file)

    # Show units with best match to a segmentation class
    iou_unit_label_99 = sorted(
        [(unit, concept.item(), seglabels[concept], bestiou.item())
         for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
        key=lambda x: -x[-1])

    fig, axs = plt.subplots(20, 1, figsize=(20, 80))
    axs = axs.flatten()

    for i, (unit, concept, label, score) in enumerate(iou_unit_label_99[:20]):
        axs[i].imshow(unit_images[unit])
        axs[i].set_title('unit %d; iou %g; label "%s"' % (unit, score, label))
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    plt.savefig(web_path + "/best_unit_segmentation")
Beispiel #4
0
        r=8192, cachefile=resfile('rq.npz'))

if False:
    def compute_conditional_samples(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        hacts = upfn(acts)
        return tally.conditional_samples(hacts, seg)

    pbar.descnext('condq')
    condq = tally.tally_conditional_quantile(compute_conditional_samples,
            dataset, sample_size=sample_size, cachefile=resfile('condq.npz'))

level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]

def compute_conditional_indicator(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

pbar.descnext('condi99')
condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
        dataset, sample_size=sample_size, cachefile=resfile('condi99.npz'))

Beispiel #5
0
def main():
    args = parseargs()
    resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg,
                                         args.layer, int(args.quantile * 1000))

    def resfile(f):
        return os.path.join(resdir, f)

    model = load_model(args)
    layername = instrumented_layername(args)
    model.retain_layer(layername)
    dataset = load_dataset(args, model=model.model)
    upfn = make_upfn(args, dataset, model, layername)
    sample_size = len(dataset)
    is_generator = (args.model == 'progan')
    percent_level = 1.0 - args.quantile

    # Tally rq.np (representation quantile, unconditional).
    torch.set_grad_enabled(False)
    pbar.descnext('rq')

    def compute_samples(batch, *args):
        data_batch = batch.cuda()
        _ = model(data_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    rq = tally.tally_quantile(compute_samples,
                              dataset,
                              sample_size=sample_size,
                              r=8192,
                              num_workers=100,
                              pin_memory=True,
                              cachefile=resfile('rq.npz'))

    # Grab the 99th percentile, and tally conditional means at that level.
    level_at_99 = rq.quantiles(percent_level).cuda()[None, :, None, None]

    segmodel, seglabels, segcatlabels = setting.load_segmenter(args.seg)
    renorm = renormalize.renormalizer(dataset, target='zc')

    def compute_conditional_indicator(batch, *args):
        data_batch = batch.cuda()
        out_batch = model(data_batch)
        image_batch = out_batch if is_generator else renorm(data_batch)
        seg = segmodel.segment_batch(image_batch, downsample=4)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts > level_at_99).float()  # indicator
        return tally.conditional_samples(iacts, seg)

    pbar.descnext('condi99')
    condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
                                           dataset,
                                           sample_size=sample_size,
                                           num_workers=3,
                                           pin_memory=True,
                                           cachefile=resfile('condi99.npz'))

    # Now summarize the iou stats and graph the units
    iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
    unit_label_99 = [(concept.item(), seglabels[concept],
                      segcatlabels[concept], bestiou.item())
                     for (bestiou, concept) in zip(*iou_99.max(0))]

    def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100):
        model.remove_edits()

        def zero_some_units(x, *args):
            x[:, zeroed_units] = 0
            return x

        model.edit_layer(layername, rule=zero_some_units)
        num_seglabels = len(segmodel.get_label_and_category_names()[0])

        def compute_mean_seg_in_images(batch_z, *args):
            img = model(batch_z.cuda())
            seg = segmodel.segment_batch(img, downsample=4)
            seg_area = seg.shape[2] * seg.shape[3]
            seg_counts = torch.bincount(
                (seg + (num_seglabels * torch.arange(
                    seg.shape[0], dtype=seg.dtype,
                    device=seg.device)[:, None, None, None])).view(-1),
                minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1)
            seg_fracs = seg_counts.float() / seg_area
            return seg_fracs

        result = tally.tally_mean(compute_mean_seg_in_images,
                                  dataset,
                                  batch_size=30,
                                  sample_size=sample_size,
                                  pin_memory=True)
        model.remove_edits()
        return result

    # Intervention experiment here:
    # segs_baseline = measure_segclasses_with_zeroed_units([])
    # segs_without_treeunits = measure_segclasses_with_zeroed_units(tree_units)
    num_units = len(unit_label_99)
    baseline_segmean = test_generator_segclass_stats(
        model,
        dataset,
        segmodel,
        layername=layername,
        cachefile=resfile('segstats/baseline.npz')).mean()

    pbar.descnext('unit ablation')
    unit_ablation_segmean = torch.zeros(num_units, len(baseline_segmean))
    for unit in pbar(random.sample(range(num_units), num_units)):
        stats = test_generator_segclass_stats(
            model,
            dataset,
            segmodel,
            layername=layername,
            zeroed_units=[unit],
            cachefile=resfile('segstats/ablated_unit_%d.npz' % unit))
        unit_ablation_segmean[unit] = stats.mean()

    ablate_segclass_name = 'tree'
    ablate_segclass = seglabels.index(ablate_segclass_name)
    best_iou_units = iou_99[ablate_segclass, :].sort(0)[1].flip(0)
    byiou_unit_ablation_seg = torch.zeros(30)
    for unitcount in pbar(random.sample(range(0, 30), 30)):
        zero_units = best_iou_units[:unitcount].tolist()
        stats = test_generator_segclass_delta_stats(
            model,
            dataset,
            segmodel,
            layername=layername,
            zeroed_units=zero_units,
            cachefile=resfile('deltasegstats/ablated_best_%d_iou_%s.npz' %
                              (unitcount, ablate_segclass_name)))
        byiou_unit_ablation_seg[unitcount] = stats.mean()[ablate_segclass]

    # Generator context experiment.
    num_segclass = len(seglabels)
    door_segclass = seglabels.index('door')
    door_units = iou_99[door_segclass].sort(0)[1].flip(0)[:20]
    door_high_values = rq.quantiles(0.995)[door_units].cuda()

    def compute_seg_impact(zbatch, *args):
        zbatch = zbatch.cuda()
        model.remove_edits()
        orig_img = model(zbatch)
        orig_seg = segmodel.segment_batch(orig_img, downsample=4)
        orig_segcount = tally.batch_bincount(orig_seg, num_segclass)
        rep = model.retained_layer(layername).clone()
        ysize = orig_seg.shape[2] // rep.shape[2]
        xsize = orig_seg.shape[3] // rep.shape[3]

        def gen_conditions():
            for y in range(rep.shape[2]):
                for x in range(rep.shape[3]):
                    # Take as the context location the segmentation
                    # labels at the center of the square.
                    selsegs = orig_seg[:, :, y * ysize + ysize // 2,
                                       x * xsize + xsize // 2]
                    changed_rep = rep.clone()
                    changed_rep[:, door_units, y,
                                x] = (door_high_values[None, :])
                    model.edit_layer(layername,
                                     ablation=1.0,
                                     replacement=changed_rep)
                    changed_img = model(zbatch)
                    changed_seg = segmodel.segment_batch(changed_img,
                                                         downsample=4)
                    changed_segcount = tally.batch_bincount(
                        changed_seg, num_segclass)
                    delta_segcount = (changed_segcount - orig_segcount).float()
                    for sel, delta in zip(selsegs, delta_segcount):
                        for cond in torch.bincount(sel).nonzero()[:, 0]:
                            if cond == 0:
                                continue
                            yield (cond.item(), delta)

        return gen_conditions()

    cond_changes = tally.tally_conditional_mean(
        compute_seg_impact,
        dataset,
        sample_size=10000,
        batch_size=20,
        cachefile=resfile('big_door_cond_changes.npz'))