Пример #1
0
def create_instrumented_model(args, **kwargs):
    '''
    Creates an instrumented model out of a namespace of arguments that
    correspond to ArgumentParser command-line args:
      model: a string to evaluate as a constructor for the model.
      pthfile: (optional) filename of .pth file for the model.
      layers: a list of layers to instrument, defaulted if not provided.
      edit: True to instrument the layers for editing.
      gen: True for a generator model.  One-pixel input assumed.
      imgsize: For non-generator models, (y, x) dimensions for RGB input.
      cuda: True to use CUDA.
  
    The constructed model will be decorated with the following attributes:
      input_shape: (usually 4d) tensor shape for single-image input.
      output_shape: 4d tensor shape for output.
      feature_shape: map of layer names to 4d tensor shape for featuremaps.
      retained: map of layernames to tensors, filled after every evaluation.
      ablation: if editing, map of layernames to [0..1] alpha values to fill.
      replacement: if editing, map of layernames to values to fill.

    When editing, the feature value x will be replaced by:
        `x = (replacement * ablation) + (x * (1 - ablation))`
    '''

    args = EasyDict(vars(args), **kwargs)

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        return None
    if isinstance(args.model, torch.nn.Module):
        model = args.model
    else:
        model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if getattr(args, 'pthfile', None) is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)
    model.meta = meta

    # Decide which layers to instrument.
    if getattr(args, 'layer', None) is not None:
        args.layers = [args.layer]
    if getattr(args, 'layers', None) is None:
        # Skip wrappers with only one named model
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))

    # Instrument the layers.
    retain_layers(model, args.layers)
    if getattr(args, 'edit', False):
        edit_layers(model, args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Annotate input, output, and feature shapes
    annotate_model_shapes(model,
                          gen=getattr(args, 'gen', False),
                          imgsize=getattr(args, 'imgsize', None))
    return model
Пример #2
0
def main():
    parser = argparse.ArgumentParser(description='GAN sample making utility')
    parser.add_argument('--model', type=str, default=None,
            help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
            help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='images',
            help='directory for image output')
    parser.add_argument('--size', type=int, default=100,
            help='number of images to output')
    parser.add_argument('--test_size', type=int, default=None,
            help='number of images to test')
    parser.add_argument('--layer', type=str, default=None,
            help='layer to inspect')
    parser.add_argument('--seed', type=int, default=1,
            help='seed')
    parser.add_argument('--maximize_units', type=int, nargs='+', default=None,
            help='units to maximize')
    parser.add_argument('--ablate_units', type=int, nargs='+', default=None,
            help='units to ablate')
    parser.add_argument('--quiet', action='store_true', default=False,
            help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    verbose_progress(not args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)
    if args.pthfile is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())
    # Examine first conv in model to determine input feature size.
    first_layer = [c for c in model.modules()
            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                torch.nn.Linear))][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        z_channels = first_layer.in_channels
        spatialdims = (1, 1)
    else:
        z_channels = first_layer.in_features
        spatialdims = ()
    # Instrument the model if needed
    if args.maximize_units is not None:
        retain_layers(model, [args.layer])
    model.cuda()

    # Get the sample of z vectors
    if args.maximize_units is None:
        indexes = torch.arange(args.size)
        z_sample = standard_z_sample(args.size, z_channels, seed=args.seed)
        z_sample = z_sample.view(tuple(z_sample.shape) + spatialdims)
    else:
        # By default, if maximizing units, get a 'top 5%' sample.
        if args.test_size is None:
            args.test_size = args.size * 20
        z_universe = standard_z_sample(args.test_size, z_channels,
                seed=args.seed)
        z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims)
        indexes = get_highest_znums(model, z_universe, args.maximize_units,
                args.size, seed=args.seed)
        z_sample = z_universe[indexes]

    if args.ablate_units:
        edit_layers(model, [args.layer])
        dims = max(2, max(args.ablate_units) + 1) # >=2 to avoid broadcast
        model.ablation[args.layer] = torch.zeros(dims)
        model.ablation[args.layer][args.ablate_units] = 1

    save_znum_images(args.outdir, model, z_sample, indexes,
            args.layer, args.ablate_units)
    copy_lightbox_to(args.outdir)
Пример #3
0
def run_command(args):
    verbose_progress(True)
    progress = default_progress()
    classname = args.classname # 'door'
    layer = args.layer # 'layer4'
    num_eval_units = 20

    assert os.path.isfile(os.path.join(args.outdir, 'dissect.json')), (
            "Should be a dissection directory")

    if args.variant is None:
        args.variant = 'ace'

    if args.l2_lambda != 0.005:
        args.variant = '%s_reg%g' % (args.variant, args.l2_lambda)

    cachedir = os.path.join(args.outdir, safe_dir_name(layer), args.variant,
            classname)

    if pidfile_taken(os.path.join(cachedir, 'lock.pid'), True):
        sys.exit(0)

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter
    # Default segmenter class
    if args.segmenter is None:
        args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                "segsizes=[256], segdiv='quad')")

    if (not args.no_cache and
        os.path.isfile(os.path.join(cachedir, 'snapshots', 'epoch-%d.npy' % (
            args.train_epochs - 1))) and
        os.path.isfile(os.path.join(cachedir, 'report.json'))):
        print('%s already done' % cachedir)
        sys.exit(0)

    os.makedirs(cachedir, exist_ok=True)

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True,
            layers=[args.layer])
    if model is None:
        print('No model specified')
        sys.exit(1)
    # Instantiate segmenter
    segmenter = autoimport_eval(args.segmenter)
    labelnames, catname = segmenter.get_label_and_category_names()
    classnum = [i for i, (n, c) in enumerate(labelnames) if n == classname][0]
    num_classes = len(labelnames)
    with open(os.path.join(cachedir, 'labelnames.json'), 'w') as f:
        json.dump(labelnames, f, indent=1)

    # Sample sets for training.
    full_sample = netdissect.zdataset.z_sample_for_model(model,
            args.search_size, seed=10)
    second_sample = netdissect.zdataset.z_sample_for_model(model,
            args.search_size, seed=11)
    # Load any cached data.
    cache_filename = os.path.join(cachedir, 'corpus.npz')
    corpus = EasyDict()
    try:
        if not args.no_cache:
            corpus = EasyDict({k: torch.from_numpy(v)
                for k, v in numpy.load(cache_filename).items()})
    except:
        pass

    # The steps for the computation.
    compute_present_locations(args, corpus, cache_filename,
            model, segmenter, classnum, full_sample)
    compute_mean_present_features(args, corpus, cache_filename, model)
    compute_feature_quantiles(args, corpus, cache_filename, model, full_sample)
    compute_candidate_locations(args, corpus, cache_filename, model, segmenter,
            classnum, second_sample)
    # visualize_training_locations(args, corpus, cachedir, model)
    init_ablation = initial_ablation(args, args.outdir)
    scores = train_ablation(args, corpus, cache_filename,
            model, segmenter, classnum, init_ablation)
    summarize_scores(args, corpus, cachedir, layer, classname,
            args.variant, scores)
    if args.variant == 'ace':
        add_ace_ranking_to_dissection(args.outdir, layer, classname, scores)
Пример #4
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    def intpair(arg):
        p = arg.split(',')
        if len(p) == 1:
            p = p + p
        return tuple(int(v) for v in p)

    parser = argparse.ArgumentParser(
        description='Net dissect utility',
        prog='python -m netdissect',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to dissect' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--segments',
                        type=str,
                        default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--download',
                        action='store_true',
                        default=False,
                        help='downloads Broden dataset if needed')
    parser.add_argument('--imgsize',
                        type=intpair,
                        default=(227, 227),
                        help='input image size to use')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--meta',
                        type=str,
                        nargs='+',
                        help='json files of metadata to add to report')
    parser.add_argument('--examples',
                        type=int,
                        default=20,
                        help='number of image examples per unit')
    parser.add_argument('--size',
                        type=int,
                        default=10000,
                        help='dataset subset size to use')
    parser.add_argument('--broden_version',
                        type=int,
                        default=1,
                        help='broden version number')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='batch size for forward pass')
    parser.add_argument('--num_workers',
                        type=int,
                        default=24,
                        help='number of DataLoader workers')
    parser.add_argument('--quantile_threshold',
                        type=float,
                        default=None,
                        help='quantile to use for masks')
    parser.add_argument('--no-labels',
                        action='store_true',
                        default=False,
                        help='disables labeling of units')
    parser.add_argument('--ablation',
                        action='store_true',
                        default=False,
                        help='enables single unit ablation of units')
    parser.add_argument('--iqr',
                        action='store_true',
                        default=False,
                        help='enables iqr calculation')
    parser.add_argument('--maxiou',
                        action='store_true',
                        default=False,
                        help='enables maxiou calculation')
    parser.add_argument('--covariance',
                        action='store_true',
                        default=False,
                        help='enables covariance calculation')
    parser.add_argument('--no-images',
                        action='store_true',
                        default=False,
                        help='disables generation of unit images')
    parser.add_argument('--single-images',
                        action='store_true',
                        default=False,
                        help='generates single images also')
    parser.add_argument('--no-report',
                        action='store_true',
                        default=False,
                        help='disables generation report summary')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--gan',
                        type=str,
                        default=None,
                        help='netdissect.GanImageSegmenter() to probe a GAN')
    parser.add_argument('--perturbation',
                        default=None,
                        help='filename of perturbation attack to apply')
    parser.add_argument('--add_scale_offset',
                        action='store_true',
                        default=None,
                        help='offsets masks according to stride and padding')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    args.images = not args.no_images
    args.report = not args.no_report
    args.labels = not args.no_labels

    # Set up console output
    verbose_progress(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Special case: download flag without model to test.
    if args.model is None and args.download:
        from netdissect.broden import ensure_broden_downloaded
        for resolution in [224, 227, 384]:
            ensure_broden_downloaded(args.segments, resolution,
                                     args.broden_version)
        sys.exit(0)

    # Help if broden is not present
    if not os.path.isdir(args.segments):
        print_progress('Segmentation dataset not found at %s.' % args.segments)
        print_progress('Specify dataset directory using --segments [DIR]')
        print_progress('To download Broden, run: netdissect --download')
        sys.exit(1)

    # Default threshold
    if args.quantile_threshold is None:
        if args.gan:
            args.quantile_threshold = 0.01
        else:
            args.quantile_threshold = 0.005

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Default add_scale_offset only for AlexNet-looking models.
    if args.add_scale_offset is None and not args.gan:
        args.add_scale_offset = ('Alex' in model.__class__.__name__)

    # Load its state dict
    meta = {}
    if args.pthfile is None:
        print_progress('Dissecting model without pth file.')
    else:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)

    # Update any metadata from files, if any
    if args.meta:
        for mfilename in args.meta:
            with open(mfilename) as f:
                meta.update(json.load(f))

    # Instrument it and prepare it for eval
    if not args.layers:
        # Skip wrappers with only one named modele
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
    retain_layers(model, args.layers, args.add_scale_offset)
    if args.gan:
        ablate_layers(model, args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    if not args.gan:
        # Load dataset for ordinary case.
        # Load perturbation
        perturbation = numpy.load(
            args.perturbation) if args.perturbation else None

        # Load broden dataset
        dataset = try_to_load_broden(args.segments, args.imgsize,
                                     args.broden_version, perturbation,
                                     args.download, args.size)
        if dataset is None:
            ds = try_to_load_multiseg(args.segments, args.imgsize,
                                      perturbation, args.size)
        if dataset is None:
            print_progress('No segmentation dataset found in %s' %
                           args.segements)
            print_progress('use --download to download Broden.')
            sys.exit(1)

        recovery = ReverseNormalize(IMAGE_MEAN, IMAGE_STDEV)
    else:
        # Examine first conv in model to determine input feature size.
        first_layer = [
            c for c in model.modules()
            if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                              torch.nn.Linear))
        ][0]
        # 4d input if convolutional, 2d input if first layer is linear.
        if isinstance(first_layer,
                      (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
            sample = standard_z_sample(args.size,
                                       first_layer.in_channels)[:, :, None,
                                                                None]
            train_sample = standard_z_sample(args.size,
                                             first_layer.in_channels,
                                             seed=2)[:, :, None, None]
        else:
            sample = standard_z_sample(args.size, first_layer.in_features)
            train_sample = standard_z_sample(args.size,
                                             first_layer.in_features,
                                             seed=2)
        dataset = TensorDataset(sample)
        train_dataset = TensorDataset(train_sample)
        recovery = autoimport_eval(args.gan)

    # Run dissect
    dissect(args.outdir,
            model,
            dataset,
            train_dataset=train_dataset,
            recover_image=recovery,
            examples_per_unit=args.examples,
            netname=args.netname,
            quantile_threshold=args.quantile_threshold,
            meta=meta,
            make_images=args.images or args.single_images,
            make_labels=args.labels,
            make_ablation=args.ablation,
            make_iqr=args.iqr,
            make_maxiou=args.maxiou,
            make_covariance=args.covariance,
            make_report=args.report,
            make_row_images=args.images,
            make_single_images=args.single_images,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            settings=vars(args))
Пример #5
0
from torchvision import transforms
from netdissect.nethook import InstrumentedModel
from netdissect.autoeval import autoimport_eval
from netdissect.modelconfig import annotate_model_shapes
from netdissect.segviz import segment_visualization
import PIL

batch_size = 1

data = get_segments_dataset('dataset/Adissect_toy')
#model = autoimport_eval("p2pgan.from_pth_file('models/pix2pix/p2p_churches.pth')")
model = from_pth_file('models/pix2pix/p2p_churches.pth')
segmenter = (
    "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')"
)
segrunner = GeneratorSegRunner(autoimport_eval(segmenter))

layer5 = ('model.model.1.model.3.model.3.model.3.model.1', 'layer5')
layer9 = (
    'model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3',
    'layer9')
layer12 = ('model.model.1.model.3.model.3.model.3.model.5', 'layer12')

model = InstrumentedModel(model)
model.retain_layers([layer5, layer9, layer12])
annotate_model_shapes(model, gen=True, imgsize=None)

segloader = torch.utils.data.DataLoader(data,
                                        batch_size=batch_size,
                                        pin_memory=True)
Пример #6
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(
        description='Ablation eval',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        required=True,
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to edit' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes',
                        type=str,
                        nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric',
                        type=str,
                        default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--unitcount',
                        type=int,
                        default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter',
                        type=str,
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size',
                        type=int,
                        default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size',
                        type=int,
                        default=200,
                        help='number of images to test')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    pbar.verbose(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True)
    if model is None:
        print('No model specified')
        sys.exit(1)

    # Instantiate model
    device = next(model.parameters()).device
    input_shape = model.input_shape

    # 4d input if convolutional, 2d input if first layer is linear.
    raw_sample = standard_z_sample(args.size, input_shape[1],
                                   seed=2).view((args.size, ) +
                                                input_shape[1:])
    dataset = TensorDataset(raw_sample)

    # Create the segmenter
    segmenter = autoimport_eval(args.segmenter)

    # Now do the actual work.
    labelnames, catnames = (segmenter.get_label_and_category_names(dataset))
    label_category = [
        catnames.index(c) if c in catnames else 0 for l, c in labelnames
    ]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=args.batch_size,
                                            num_workers=10,
                                            pin_memory=(device.type == 'cuda'))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None

    # For each sort-order, do an ablation
    for classname in pbar(args.classes):
        pbar.post(c=classname)
        for layername in pbar(model.ablation):
            pbar.post(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                               if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'pixablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)):
                with open(os.path.join(ablationdir,
                                       '%s.json' % rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                           for a, o in zip(data.ablation_units, ordering)):
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue  # file already done.
                measurements = data.ablation_effects
            measurements = measure_ablation(segmenter, segloader, model,
                                            classnum, layername,
                                            ordering[:args.unitcount])
            measurements = measurements.cpu().numpy().tolist()
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json' % rankname),
                      'w') as f:
                json.dump(
                    dict(classname=classname,
                         classnum=classnum,
                         baseline=measurements[0],
                         layer=layername,
                         metric=args.metric,
                         ablation_units=ordering.tolist(),
                         ablation_effects=measurements[1:]), f)
Пример #7
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    def intpair(arg):
        p = arg.split(',')
        if len(p) == 1:
            p = p + p
        return tuple(int(v) for v in p)

    parser = argparse.ArgumentParser(
        description='Net dissect utility',
        prog='python -m netdissect',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--unstrict',
                        action='store_true',
                        default=False,
                        help='ignore unexpected pth parameters')
    parser.add_argument('--submodule',
                        type=str,
                        default=None,
                        help='submodule to load from pthfile')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to dissect' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--segments',
                        type=str,
                        default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--segmenter',
                        type=str,
                        default=None,
                        help='constructor for asegmenter class')
    parser.add_argument('--download',
                        action='store_true',
                        default=False,
                        help='downloads Broden dataset if needed')
    parser.add_argument('--imagedir',
                        type=str,
                        default=None,
                        help='directory containing image-only dataset')
    parser.add_argument('--imgsize',
                        type=intpair,
                        default=(227, 227),
                        help='input image size to use')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--meta',
                        type=str,
                        nargs='+',
                        help='json files of metadata to add to report')
    parser.add_argument('--merge',
                        type=str,
                        help='json file of unit data to merge in report')
    parser.add_argument('--examples',
                        type=int,
                        default=20,
                        help='number of image examples per unit')
    parser.add_argument('--size',
                        type=int,
                        default=10000,
                        help='dataset subset size to use')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='batch size for forward pass')
    parser.add_argument('--num_workers',
                        type=int,
                        default=24,
                        help='number of DataLoader workers')
    parser.add_argument('--quantile_threshold',
                        type=strfloat,
                        default=None,
                        choices=[FloatRange(0.0, 1.0), 'iqr'],
                        help='quantile to use for masks')
    parser.add_argument('--no-labels',
                        action='store_true',
                        default=False,
                        help='disables labeling of units')
    parser.add_argument('--maxiou',
                        action='store_true',
                        default=False,
                        help='enables maxiou calculation')
    parser.add_argument('--covariance',
                        action='store_true',
                        default=False,
                        help='enables covariance calculation')
    parser.add_argument('--rank_all_labels',
                        action='store_true',
                        default=False,
                        help='include low-information labels in rankings')
    parser.add_argument('--no-images',
                        action='store_true',
                        default=False,
                        help='disables generation of unit images')
    parser.add_argument('--no-report',
                        action='store_true',
                        default=False,
                        help='disables generation report summary')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--gen',
                        action='store_true',
                        default=False,
                        help='test a generator model (e.g., a GAN)')
    parser.add_argument('--gan',
                        action='store_true',
                        default=False,
                        help='synonym for --gen')
    parser.add_argument('--perturbation',
                        default=None,
                        help='filename of perturbation attack to apply')
    parser.add_argument('--add_scale_offset',
                        action='store_true',
                        default=None,
                        help='offsets masks according to stride and padding')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    args.images = not args.no_images
    args.report = not args.no_report
    args.labels = not args.no_labels
    if args.gan:
        args.gen = args.gan

    # Set up console output
    verbose_progress(not args.quiet)

    # Exit right away if job is already done or being done.
    if args.outdir is not None:
        exit_if_job_done(args.outdir)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Special case: download flag without model to test.
    if args.model is None and args.download:
        from netdissect.broden import ensure_broden_downloaded
        for resolution in [224, 227, 384]:
            ensure_broden_downloaded(args.segments, resolution, 1)
        from netdissect.segmenter import ensure_upp_segmenter_downloaded
        ensure_upp_segmenter_downloaded('dataset/segmodel')
        sys.exit(0)

    # Help if broden is not present
    if not args.gen and not args.imagedir and not os.path.isdir(args.segments):
        print_progress('Segmentation dataset not found at %s.' % args.segments)
        print_progress('Specify dataset directory using --segments [DIR]')
        print_progress('To download Broden, run: netdissect --download')
        sys.exit(1)

    # Default segmenter class
    if args.gen and args.segmenter is None:
        args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                          "segsizes=[256], segdiv=None)")

    # Default threshold
    if args.quantile_threshold is None:
        if args.gen:
            args.quantile_threshold = 'iqr'
        else:
            args.quantile_threshold = 0.005

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Construct the network with specified layers instrumented
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)
    model = create_instrumented_model(args)

    # Update any metadata from files, if any
    meta = getattr(model, 'meta', {})
    if args.meta:
        for mfilename in args.meta:
            with open(mfilename) as f:
                meta.update(json.load(f))

    # Load any merge data from files
    mergedata = None
    if args.merge:
        with open(args.merge) as f:
            mergedata = json.load(f)

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        exit_if_job_done(args.outdir)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    if not args.gen:
        # Load dataset for classifier case.
        # Load perturbation
        perturbation = numpy.load(
            args.perturbation) if args.perturbation else None
        segrunner = None

        # Load broden dataset
        if args.imagedir is not None:
            dataset = try_to_load_images(args.imagedir, args.imgsize,
                                         perturbation, args.size)
            segrunner = ImageOnlySegRunner(dataset)
        else:
            dataset = try_to_load_broden(args.segments, args.imgsize, 1,
                                         perturbation, args.download,
                                         args.size)
        if dataset is None:
            dataset = try_to_load_multiseg(args.segments, args.imgsize,
                                           perturbation, args.size)
        if dataset is None:
            print_progress('No segmentation dataset found in %s',
                           args.segments)
            print_progress('use --download to download Broden.')
            sys.exit(1)
    else:
        # For segmenter case the dataset is just a random z
        #dataset = z_dataset_for_model(model, args.size)
        #train_dataset = z_dataset_for_model(model, args.size, seed=2)
        dataset = datasets.ImageFolder('dataset/Adissect',
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize(
                                               (0.5, 0.5, 0.5),
                                               (0.5, 0.5, 0.5))
                                       ]))
        train_dataset = dataset
        segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter))
    torch.no_grad()
    # Run dissect
    dissect(args.outdir,
            model,
            dataset,
            train_dataset=train_dataset,
            segrunner=segrunner,
            examples_per_unit=args.examples,
            netname=args.netname,
            quantile_threshold=args.quantile_threshold,
            meta=meta,
            merge=mergedata,
            make_images=args.images,
            make_labels=args.labels,
            make_maxiou=args.maxiou,
            make_covariance=args.covariance,
            make_report=args.report,
            make_row_images=args.images,
            make_single_images=True,
            rank_all_labels=args.rank_all_labels,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            settings=vars(args))

    # Mark the directory so that it's not done again.
    mark_job_done(args.outdir)
Пример #8
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(description='Ablation eval',
            epilog=textwrap.dedent(help_epilog),
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model', type=str, default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir', type=str, default='dissect', required=True,
                        help='directory for dissection output')
    parser.add_argument('--layer', type=strpair,
                        help='space-separated list of layer names to edit' + 
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classname', type=str,
                        help='class name to ablate')
    parser.add_argument('--metric', type=str, default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--unitcount', type=int, default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter', type=str,
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname', type=str, default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size', type=int, default=25,
                        help='batch size for forward pass')
    parser.add_argument('--mixed_units', action='store_true', default=False,
                        help='true to keep alpha for non-zeroed units')
    parser.add_argument('--size', type=int, default=200,
                        help='number of images to test')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    verbose_progress(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Take defaults for model constructor etc from dissect.json settings.
    with open(os.path.join(args.outdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))
    if args.model is None:
        args.model = dissection.settings.model
    if args.pthfile is None:
        args.pthfile = dissection.settings.pthfile
    if args.segmenter is None:
        args.segmenter = dissection.settings.segmenter
    if args.layer is None:
        args.layer = dissection.settings.layers[0]
    args.layers = [args.layer]

    # Also load specific analysis
    layername = args.layer[1]
    if args.metric == 'iou':
        summary = dissection
    else:
        with open(os.path.join(args.outdir, layername, args.metric,
                args.classname, 'summary.json')) as f:
            summary = EasyDict(json.load(f))

    # Instantiate generator
    model = create_instrumented_model(args, gen=True, edit=True)
    if model is None:
        print('No model specified')
        sys.exit(1)

    # Instantiate model
    device = next(model.parameters()).device
    input_shape = model.input_shape

    # 4d input if convolutional, 2d input if first layer is linear.
    raw_sample = standard_z_sample(args.size, input_shape[1], seed=3).view(
            (args.size,) + input_shape[1:])
    dataset = TensorDataset(raw_sample)

    # Create the segmenter
    segmenter = autoimport_eval(args.segmenter)

    # Now do the actual work.
    labelnames, catnames = (
                segmenter.get_label_and_category_names(dataset))
    label_category = [catnames.index(c) if c in catnames else 0
            for l, c in labelnames]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                batch_size=args.batch_size, num_workers=10,
                pin_memory=(device.type == 'cuda'))

    # Index the dissection layers by layer name.

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None

    # For each sort-order, do an ablation
    progress = default_progress()
    classname = args.classname
    classnum = labelnum_from_name[classname]

    # Get iou ranking from dissect.json
    iou_rankname = '%s-%s' % (classname, 'iou')
    dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}
    iou_ranking = next(r for r in dissect_layer[layername].rankings
                if r.name == iou_rankname)

    # Get trained ranking from summary.json
    rankname = '%s-%s' % (classname, args.metric)
    summary_layer = {lrec.layer: lrec for lrec in summary.layers}
    ranking = next(r for r in summary_layer[layername].rankings
                if r.name == rankname)

    # Get ordering, first by ranking, then break ties by iou.
    ordering = [t[2] for t in sorted([(s1, s2, i)
        for i, (s1, s2) in enumerate(zip(ranking.score, iou_ranking.score))])]
    values = (-numpy.array(ranking.score))[ordering]
    if not args.mixed_units:
        values[...] = 1

    ablationdir = os.path.join(args.outdir, layername, 'fullablation')
    measurements = measure_full_ablation(segmenter, segloader,
            model, classnum, layername,
            ordering[:args.unitcount], values[:args.unitcount])
    measurements = measurements.cpu().numpy().tolist()
    os.makedirs(ablationdir, exist_ok=True)
    with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
        json.dump(dict(
            classname=classname,
            classnum=classnum,
            baseline=measurements[0],
            layer=layername,
            metric=args.metric,
            ablation_units=ordering,
            ablation_values=values.tolist(),
            ablation_effects=measurements[1:]), f)
Пример #9
0
def main():
    parser = argparse.ArgumentParser(
        description='GAN output segmentation util')
    parser.add_argument('--model',
                        type=str,
                        default='netdissect.proggan.from_pth_file("' +
                        'models/karras/churchoutdoor_lsun.pth")',
                        help='constructor for the model to test')
    parser.add_argument('--outdir',
                        type=str,
                        default='images',
                        help='directory for image output')
    parser.add_argument('--size',
                        type=int,
                        default=100,
                        help='number of images to output')
    parser.add_argument('--seed', type=int, default=1, help='seed')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    #if len(sys.argv) == 1:
    #    parser.print_usage(sys.stderr)
    #    sys.exit(1)
    args = parser.parse_args()
    verbose_progress(not args.quiet)

    # Instantiate the model
    model = autoimport_eval(args.model)

    # Make the standard z
    z_dataset = z_dataset_for_model(model, size=args.size)

    # Make the segmenter
    segmenter = UnifiedParsingSegmenter()

    # Write out text labels
    labels, cats = segmenter.get_label_and_category_names()
    with open(os.path.join(args.outdir, 'labels.txt'), 'w') as f:
        for i, (label, cat) in enumerate(labels):
            f.write('%s %s\n' % (label, cat))

    # Move models to cuda
    model.cuda()

    batch_size = 10
    progress = default_progress()
    dirname = args.outdir

    with torch.no_grad():
        # Pass 2: now generate images
        z_loader = torch.utils.data.DataLoader(z_dataset,
                                               batch_size=batch_size,
                                               num_workers=2,
                                               pin_memory=True)
        for batch_num, [z
                        ] in enumerate(progress(z_loader,
                                                desc='Saving images')):
            z = z.cuda()
            start_index = batch_num * batch_size
            tensor_im = model(z)
            byte_im = ((tensor_im + 1) / 2 * 255).clamp(0, 255).byte().permute(
                0, 2, 3, 1).cpu()
            seg = segmenter.segment_batch(tensor_im)
            for i in range(len(tensor_im)):
                index = i + start_index
                filename = os.path.join(dirname, '%d_img.jpg' % index)
                Image.fromarray(byte_im[i].numpy()).save(filename,
                                                         optimize=True,
                                                         quality=100)
                filename = os.path.join(dirname, '%d_seg.mat' % index)
                savemat(filename, dict(seg=seg[i].cpu().numpy()))
                filename = os.path.join(dirname, '%d_seg.png' % index)
                Image.fromarray(
                    segment_visualization(seg[i].cpu().numpy(),
                                          tensor_im.shape[2:])).save(filename)
    srcdir = os.path.realpath(
        os.path.join(os.getcwd(), os.path.dirname(__file__)))
    shutil.copy(os.path.join(srcdir, 'lightbox.html'),
                os.path.join(dirname, '+lightbox.html'))
Пример #10
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p

    parser = argparse.ArgumentParser(
        description='Net dissect utility',
        epilog=textwrap.dedent(help_epilog),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile',
                        type=str,
                        default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--outdir',
                        type=str,
                        default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers',
                        type=strpair,
                        nargs='+',
                        help='space-separated list of layer names to edit' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--classes',
                        type=str,
                        nargs='+',
                        help='space-separated list of class names to ablate')
    parser.add_argument('--metric',
                        type=str,
                        default='iou',
                        help='ordering metric for selecting units')
    parser.add_argument('--startcount',
                        type=int,
                        default=1,
                        help='number of units to ablate')
    parser.add_argument('--unitcount',
                        type=int,
                        default=30,
                        help='number of units to ablate')
    parser.add_argument('--segmenter',
                        type=str,
                        default='dataset/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--netname',
                        type=str,
                        default=None,
                        help='name for network in generated reports')
    parser.add_argument('--batch_size',
                        type=int,
                        default=5,
                        help='batch size for forward pass')
    parser.add_argument('--size',
                        type=int,
                        default=1000,
                        help='number of images to test')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA usage')
    parser.add_argument('--quiet',
                        action='store_true',
                        default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    # Set up console output
    verbose_progress(not args.quiet)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        sys.exit(1)

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if args.pthfile is None:
        print_progress('Dissecting model without pth file.')
    else:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)

    # Instrument it and prepare it for eval
    if not args.layers:
        # Skip wrappers with only one named modele
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
    edit_layers(model, args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        print_progress('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    # Examine first conv in model to determine input feature size.
    first_layer = [
        c for c in model.modules()
        if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
                          torch.nn.Linear))
    ][0]
    # 4d input if convolutional, 2d input if first layer is linear.
    if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        sample = standard_z_sample(args.size,
                                   first_layer.in_channels)[:, :, None, None]
        train_sample = standard_z_sample(args.size,
                                         first_layer.in_channels,
                                         seed=2)[:, :, None, None]
    else:
        sample = standard_z_sample(args.size, first_layer.in_features)
        train_sample = standard_z_sample(args.size,
                                         first_layer.in_features,
                                         seed=2)
    dataset = TensorDataset(sample)
    train_dataset = TensorDataset(train_sample)
    recovery = autoimport_eval(args.segmenter)

    # Now do the actual work.
    device = next(model.parameters()).device
    labelnames, catnames = (recovery.get_label_and_category_names(dataset))
    label_category = [catnames.index(c) for l, c in labelnames]
    labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}

    segloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=args.batch_size,
                                            num_workers=10,
                                            pin_memory=(device.type == 'cuda'))

    with open(os.path.join(args.outdir, 'dissect.json'), 'r') as f:
        dissect = EasyDict(json.load(f))

    # Index the dissection layers by layer name.
    dissect_layer = {lrec.layer: lrec for lrec in dissect.layers}

    # First, collect a baseline
    for l in model.ablation:
        model.ablation[l] = None
    baseline = count_segments(recovery, segloader, model)

    # For each sort-order, do an ablation
    progress = default_progress()
    for classname in progress(args.classes):
        post_progress(c=classname)
        for layername in progress(model.ablation):
            post_progress(l=layername)
            rankname = '%s-%s' % (classname, args.metric)
            measurements = {}
            classnum = labelnum_from_name[classname]
            try:
                ranking = next(r for r in dissect_layer[layername].rankings
                               if r.name == rankname)
            except:
                print('%s not found' % rankname)
                sys.exit(1)
            ordering = numpy.argsort(ranking.score)
            # Check if already done
            ablationdir = os.path.join(args.outdir, layername, 'ablation')
            if os.path.isfile(os.path.join(ablationdir, '%s.json' % rankname)):
                with open(os.path.join(ablationdir,
                                       '%s.json' % rankname)) as f:
                    data = EasyDict(json.load(f))
                # If the unit ordering is not the same, something is wrong
                if not all(a == o
                           for a, o in zip(data.ablation_units, ordering)):
                    import pdb
                    pdb.set_trace()
                    continue
                if len(data.ablation_effects) >= args.unitcount:
                    continue  # file already done.
                measurements = data.ablation_effects
            for count in progress(range(args.startcount,
                                        min(args.unitcount, len(ordering)) +
                                        1),
                                  desc='units'):
                if str(count) in measurements:
                    continue
                ablation = numpy.zeros(len(ranking.score), dtype='float32')
                ablation[ordering[:count]] = 1
                for l in model.ablation:
                    model.ablation[l] = ablation if layername == l else None
                m = count_segments(recovery, segloader, model)[classnum].item()
                print_progress(
                    '%s %s %d units (#%d), %g -> %g' %
                    (layername, rankname, count, ordering[count - 1].item(),
                     baseline[classnum].item(), m))
                measurements[str(count)] = m
            os.makedirs(ablationdir, exist_ok=True)
            with open(os.path.join(ablationdir, '%s.json' % rankname),
                      'w') as f:
                json.dump(
                    dict(classname=classname,
                         classnum=classnum,
                         baseline=baseline[classnum].item(),
                         layer=layername,
                         metric=args.metric,
                         ablation_units=ordering.tolist(),
                         ablation_effects=measurements), f)