Example #1
0
def train_ablation(args, corpus, cachefile, model, segmenter, classnum,
        initial_ablation=None):
    progress = default_progress()
    cachedir = os.path.dirname(cachefile)
    snapdir = os.path.join(cachedir, 'snapshots')
    os.makedirs(snapdir, exist_ok=True)

    # high_replacement = corpus.feature_99[None,:,None,None].cuda()
    if '_h99' in args.variant:
        high_replacement = corpus.feature_99[None,:,None,None].cuda()
    elif '_tcm' in args.variant:
        # variant: top-conditional-mean
        high_replacement = (
                corpus.mean_present_feature[None,:,None,None].cuda())
    else: # default: weighted mean
        high_replacement = (
                corpus.weighted_mean_present_feature[None,:,None,None].cuda())
    fullimage_measurement = False
    ablation_only = False
    fullimage_ablation = False
    if '_fim' in args.variant:
        fullimage_measurement = True
    elif '_fia' in args.variant:
        fullimage_measurement = True
        ablation_only = True
        fullimage_ablation = True
    high_replacement.requires_grad = False
    for p in model.parameters():
        p.requires_grad = False

    ablation = torch.zeros(high_replacement.shape).cuda()
    if initial_ablation is not None:
        ablation.view(-1)[...] = initial_ablation
    ablation.requires_grad = True
    optimizer = torch.optim.Adam([ablation], lr=0.01)
    start_epoch = 0
    epoch = 0

    def eval_loss_and_reg():
        discrete_experiments = dict(
           # dpixel=dict(discrete_pixels=True),
           # dunits20=dict(discrete_units=20),
           # dumix20=dict(discrete_units=20, mixed_units=True),
           # dunits10=dict(discrete_units=10),
           # abonly=dict(ablation_only=True),
           # fimabl=dict(ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           dboth20=dict(discrete_units=20, discrete_pixels=True),
           # dbothm20=dict(discrete_units=20, mixed_units=True,
           #              discrete_pixels=True),
           # abdisc20=dict(discrete_units=20, discrete_pixels=True,
           #             ablation_only=True),
           # abdiscm20=dict(discrete_units=20, mixed_units=True,
           #             discrete_pixels=True,
           #             ablation_only=True),
           # fimadp=dict(discrete_pixels=True,
           #             ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           # fimadu10=dict(discrete_units=10,
           #             ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           # fimadb10=dict(discrete_units=10, discrete_pixels=True,
           #             ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           fimadbm10=dict(discrete_units=10, mixed_units=True,
                       discrete_pixels=True,
                       ablation_only=True,
                       fullimage_ablation=True,
                       fullimage_measurement=True),
           # fimadu20=dict(discrete_units=20,
           #             ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           # fimadb20=dict(discrete_units=20, discrete_pixels=True,
           #             ablation_only=True,
           #             fullimage_ablation=True,
           #             fullimage_measurement=True),
           fimadbm20=dict(discrete_units=20, mixed_units=True,
                       discrete_pixels=True,
                       ablation_only=True,
                       fullimage_ablation=True,
                       fullimage_measurement=True)
           )
        with torch.no_grad():
            total_loss = 0
            discrete_losses = {k: 0 for k in discrete_experiments}
            for [pbatch, ploc, cbatch, cloc] in progress(
                    torch.utils.data.DataLoader(TensorDataset(
                        corpus.eval_present_sample,
                        corpus.eval_present_location,
                        corpus.eval_candidate_sample,
                        corpus.eval_candidate_location),
                    batch_size=args.inference_batch_size, num_workers=10,
                    shuffle=False, pin_memory=True),
                    desc="Eval"):
                # First, put in zeros for the selected units.
                # Loss is amount of remaining object.
                total_loss = total_loss + ace_loss(segmenter, classnum,
                        model, args.layer, high_replacement, ablation,
                        pbatch, ploc, cbatch, cloc, run_backward=False,
                        ablation_only=ablation_only,
                        fullimage_measurement=fullimage_measurement)
                for k, config in discrete_experiments.items():
                    discrete_losses[k] = discrete_losses[k] + ace_loss(
                        segmenter, classnum,
                        model, args.layer, high_replacement, ablation,
                        pbatch, ploc, cbatch, cloc, run_backward=False,
                        **config)
            avg_loss = (total_loss / args.eval_size).item()
            avg_d_losses = {k: (d / args.eval_size).item()
                    for k, d in discrete_losses.items()}
            regularizer = (args.l2_lambda * ablation.pow(2).sum())
            print_progress('Epoch %d Loss %g Regularizer %g' %
                    (epoch, avg_loss, regularizer))
            print_progress(' '.join('%s: %g' % (k, d)
                    for k, d in avg_d_losses.items()))
            print_progress(scale_summary(ablation.view(-1), 10, 3))
            return avg_loss, regularizer, avg_d_losses

    if args.eval_only:
        # For eval_only, just load each snapshot and re-run validation eval
        # pass on each one.
        for epoch in range(-1, args.train_epochs):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch)
            if not os.path.exists(snapfile):
                data = {}
                if epoch >= 0:
                    print('No epoch %d' % epoch)
                    continue
            else:
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
            avg_loss, regularizer, new_extra = eval_loss_and_reg()
            # Keep old values, and update any new ones.
            extra = {k: v for k, v in data.items()
                    if k not in ['ablation', 'optimizer', 'avg_loss']}
            extra.update(new_extra)
            torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(),
                avg_loss=avg_loss, **extra),
                os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        # Return loaded ablation.
        return ablation.view(-1).detach().cpu().numpy()

    if not args.no_cache:
        for start_epoch in reversed(range(args.train_epochs)):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch)
            if os.path.exists(snapfile):
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
                start_epoch += 1
                break

    if start_epoch < args.train_epochs:
        epoch = start_epoch - 1
        avg_loss, regularizer, extra = eval_loss_and_reg()
        if epoch == -1:
            torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(),
                avg_loss=avg_loss, **extra),
                os.path.join(snapdir, 'epoch-%d.pth' % epoch))

    update_size = args.train_update_freq * args.train_batch_size
    for epoch in range(start_epoch, args.train_epochs):
        candidate_shuffle = torch.randperm(len(corpus.candidate_sample))
        train_loss = 0
        for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(progress(
                torch.utils.data.DataLoader(TensorDataset(
                    corpus.object_present_sample,
                    corpus.object_present_location,
                    corpus.candidate_sample[candidate_shuffle],
                    corpus.candidate_location[candidate_shuffle]),
                batch_size=args.train_batch_size, num_workers=10,
                shuffle=True, pin_memory=True),
                desc="ACE opt epoch %d" % epoch)):
            if batch_num % args.train_update_freq == 0:
                optimizer.zero_grad()
            # First, put in zeros for the selected units.  Loss is amount
            # of remaining object.
            loss = ace_loss(segmenter, classnum,
                    model, args.layer, high_replacement, ablation,
                    pbatch, ploc, cbatch, cloc, run_backward=True,
                    ablation_only=ablation_only,
                    fullimage_measurement=fullimage_measurement)
            with torch.no_grad():
                train_loss = train_loss + loss
            if (batch_num + 1) % args.train_update_freq == 0:
                # Third, add some L2 loss to encourage sparsity.
                regularizer = (args.l2_lambda * update_size
                        * ablation.pow(2).sum())
                regularizer.backward()
                optimizer.step()
                with torch.no_grad():
                    ablation.clamp_(0, 1)
                    post_progress(l=(train_loss/update_size).item(),
                            r=(regularizer/update_size).item())
                    train_loss = 0

        avg_loss, regularizer, extra = eval_loss_and_reg()
        torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(),
            avg_loss=avg_loss, **extra),
            os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch),
                ablation.detach().cpu().numpy())

    # The output of this phase is this set of scores.
    return ablation.view(-1).detach().cpu().numpy()
Example #2
0
def train_ablation(args,
                   corpus,
                   cachefile,
                   model,
                   segmenter,
                   classnum,
                   initial_ablation=None):
    progress = default_progress()
    cachedir = os.path.dirname(cachefile)
    snapdir = os.path.join(cachedir, 'snapshots')
    os.makedirs(snapdir, exist_ok=True)

    # high_replacement = corpus.feature_99[None,:,None,None].cuda()
    high_replacement = (corpus.weighted_mean_present_feature[None, :, None,
                                                             None].cuda())
    high_replacement.requires_grad = False
    for p in model.parameters():
        p.requires_grad = False

    ablation = torch.zeros(high_replacement.shape).cuda()
    if initial_ablation is not None:
        ablation.view(-1)[...] = initial_ablation
    ablation.requires_grad = True
    optimizer = torch.optim.Adam([ablation], lr=0.01)
    start_epoch = 0

    if not args.no_cache:
        for start_epoch in reversed(range(args.train_epochs)):
            snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch)
            if os.path.exists(snapfile):
                data = torch.load(snapfile)
                with torch.no_grad():
                    ablation[...] = data['ablation'].to(ablation.device)
                    optimizer.load_state_dict(data['optimizer'])
                start_epoch += 1
                break

    update_size = args.train_update_freq * args.train_batch_size
    for epoch in range(start_epoch, args.train_epochs):
        candidate_shuffle = torch.randperm(len(corpus.candidate_sample))
        train_loss = 0
        for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(
                progress(torch.utils.data.DataLoader(
                    TensorDataset(
                        corpus.object_present_sample,
                        corpus.object_present_location,
                        corpus.candidate_sample[candidate_shuffle],
                        corpus.candidate_location[candidate_shuffle]),
                    batch_size=args.train_batch_size,
                    num_workers=10,
                    shuffle=True,
                    pin_memory=True),
                         desc="ACE opt epoch %d" % epoch)):
            if batch_num % args.train_update_freq == 0:
                optimizer.zero_grad()
            # First, put in zeros for the selected units.  Loss is amount
            # of remaining object.
            loss = ace_loss(segmenter,
                            classnum,
                            model,
                            args.layer,
                            high_replacement,
                            ablation,
                            pbatch,
                            ploc,
                            cbatch,
                            cloc,
                            run_backward=True)
            with torch.no_grad():
                train_loss = train_loss + loss
            if (batch_num + 1) % args.train_update_freq == 0:
                # Third, add some L2 loss to encourage sparsity.
                regularizer = (args.l2_lambda * update_size *
                               ablation.pow(2).sum())
                regularizer.backward()
                optimizer.step()
                with torch.no_grad():
                    ablation.clamp_(0, 1)
                    post_progress(l=(train_loss / update_size).item(),
                                  r=(regularizer / update_size).item())
                    train_loss = 0
        with torch.no_grad():
            total_loss = 0
            for [pbatch, ploc, cbatch,
                 cloc] in progress(torch.utils.data.DataLoader(
                     TensorDataset(corpus.eval_present_sample,
                                   corpus.eval_present_location,
                                   corpus.eval_candidate_sample,
                                   corpus.eval_candidate_location),
                     batch_size=args.inference_batch_size,
                     num_workers=10,
                     shuffle=False,
                     pin_memory=True),
                                   desc="Eval"):
                # First, put in zeros for the selected units.  Loss is amount
                # of remaining object.
                total_loss = total_loss + ace_loss(segmenter,
                                                   classnum,
                                                   model,
                                                   args.layer,
                                                   high_replacement,
                                                   ablation,
                                                   pbatch,
                                                   ploc,
                                                   cbatch,
                                                   cloc,
                                                   run_backward=False)
            avg_loss = (total_loss / args.eval_size).item()
            regularizer = (args.l2_lambda * ablation.pow(2).sum())
            print_progress('Epoch %d Loss %g, Regularizer %g' %
                           (epoch, avg_loss, regularizer))
            print_progress(scale_summary(ablation.view(-1), 10, 3))
        torch.save(
            dict(ablation=ablation,
                 optimizer=optimizer.state_dict(),
                 avg_loss=avg_loss),
            os.path.join(snapdir, 'epoch-%d.pth' % epoch))
        numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch),
                   ablation.detach().cpu().numpy())

    # The output of this phase is this set of scores.
    return ablation.view(-1).detach().cpu().numpy()
Example #3
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(description='Ablation eval',
            epilog=textwrap.dedent(help_epilog),
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model', type=str, default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='dissect', required=True,
                        help='directory for dissection output')
    parser.add_argument('--layers', type=strpair, nargs='+',
                        help='space-separated list of layer names to edit' + 
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes', type=str, nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric', type=str, default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--unitcount', type=int, default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter', type=str,
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname', type=str, default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size', type=int, default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size', type=int, default=200,
                        help='number of images to test')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    verbose_progress(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True)
    if model is None:
        print('No model specified')
        sys.exit(1)

    # Instantiate model
    device = next(model.parameters()).device
    input_shape = model.input_shape

    # 4d input if convolutional, 2d input if first layer is linear.
    raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view(
            (args.size,) + input_shape[1:])
    dataset = TensorDataset(raw_sample)

    # Create the segmenter
    segmenter = autoimport_eval(args.segmenter)

    # Now do the actual work.
    labelnames, catnames = (
                segmenter.get_label_and_category_names(dataset))
    label_category = [catnames.index(c) if c in catnames else 0
            for l, c in labelnames]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                batch_size=args.batch_size, num_workers=10,
                pin_memory=(device.type == 'cuda'))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None

    # For each sort-order, do an ablation
    progress = default_progress()
    for classname in progress(args.classes):
        post_progress(c=classname)
        for layername in progress(model.ablation):
            post_progress(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                        if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'pixablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)):
                with open(os.path.join(ablationdir, '%s.json'%rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                        for a, o in zip(data.ablation_units, ordering)):
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue # file already done.
                measurements = data.ablation_effects
            measurements = measure_ablation(segmenter, segloader,
                    model, classnum, layername, ordering[:args.unitcount])
            measurements = measurements.cpu().numpy().tolist()
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
                json.dump(dict(
                    classname=classname,
                    classnum=classnum,
                    baseline=measurements[0],
                    layer=layername,
                    metric=args.metric,
                    ablation_units=ordering.tolist(),
                    ablation_effects=measurements[1:]), f)
Example #4
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(
        description='Net dissect utility',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to edit' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes',
                        type=str,
                        nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric',
                        type=str,
                        default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--startcount',
                        type=int,
                        default=1,
                        help='number of units to ablate')
    parser.add_argument('--unitcount',
                        type=int,
                        default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter',
                        type=str,
                        default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size',
                        type=int,
                        default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size',
                        type=int,
                        default=1000,
                        help='number of images to test')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    verbose_progress(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if args.pthfile is None:
        print_progress('Dissecting model without pth file.')
    else:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)

    # Instrument it and prepare it for eval
    if not args.layers:
        # Skip wrappers with only one named modele
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
    edit_layers(model, args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    # Examine first conv in model to determine input feature size.
    first_layer = [
        c for c in model.modules()
        if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                          torch.nn.Linear))
    ][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        sample = standard_z_sample(args.size,
                                   first_layer.in_channels)[:, :, None, None]
        train_sample = standard_z_sample(args.size,
                                         first_layer.in_channels,
                                         seed=2)[:, :, None, None]
    else:
        sample = standard_z_sample(args.size, first_layer.in_features)
        train_sample = standard_z_sample(args.size,
                                         first_layer.in_features,
                                         seed=2)
    dataset = TensorDataset(sample)
    train_dataset = TensorDataset(train_sample)
    recovery = autoimport_eval(args.segmenter)

    # Now do the actual work.
    device = next(model.parameters()).device
    labelnames, catnames = (recovery.get_label_and_category_names(dataset))
    label_category = [catnames.index(c) for l, c in labelnames]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=args.batch_size,
                                            num_workers=10,
                                            pin_memory=(device.type == 'cuda'))

    with open(os.path.join(args.outdir, 'dissect.json'), 'r') as f:
        dissect = EasyDict(json.load(f))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissect.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None
    baseline = count_segments(recovery, segloader, model)

    # For each sort-order, do an ablation
    progress = default_progress()
    for classname in progress(args.classes):
        post_progress(c=classname)
        for layername in progress(model.ablation):
            post_progress(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            measurements = {}
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                               if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'ablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)):
                with open(os.path.join(ablationdir,
                                       '%s.json' % rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                           for a, o in zip(data.ablation_units, ordering)):
                    import pdb
                    pdb.set_trace()
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue  # file already done.
                measurements = data.ablation_effects
            for count in progress(range(args.startcount,
                                        min(args.unitcount, len(ordering)) +
                                        1),
                                  desc='units'):
                if str(count) in measurements:
                    continue
                ablation = numpy.zeros(len(ranking.score), dtype='float32')
                ablation[ordering[:count]] = 1
                for l in model.ablation:
                    model.ablation[l] = ablation if layername == l else None
                m = count_segments(recovery, segloader, model)[classnum].item()
                print_progress(
                    '%s %s %d units (#%d), %g -> %g' %
                    (layername, rankname, count, ordering[count - 1].item(),
                     baseline[classnum].item(), m))
                measurements[str(count)] = m
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json' % rankname),
                      'w') as f:
                json.dump(
                    dict(classname=classname,
                         classnum=classnum,
                         baseline=baseline[classnum].item(),
                         layer=layername,
                         metric=args.metric,
                         ablation_units=ordering.tolist(),
                         ablation_effects=measurements), f)