コード例 #1
0
def make_upfn(args, dataset, model, layername):
    '''Creates an upsampling function.'''
    convs, data_shape = None, None
    if args.model == 'alexnet':
        convs = [layer for name, layer in model.model.named_children()
                if name.startswith('conv') or name.startswith('pool')]
    elif args.model == 'progan':
        # Probe the data shape
        out = model(dataset[0][0][None,...].cuda())
        data_shape = model.retained_layer(layername).shape[2:]
        upfn = upsample.upsampler(
                (64, 64),
                data_shape=data_shape,
                image_size=out.shape[2:])
        return upfn
    else:
        # Probe the data shape
        _ = model(dataset[0][0][None,...].cuda())
        data_shape = model.retained_layer(layername).shape[2:]
        pbar.print('upsampling from data_shape', tuple(data_shape))
    upfn = upsample.upsampler(
            (56, 56),
            data_shape=data_shape,
            source=dataset,
            convolutions=convs)
    return upfn
コード例 #2
0
def make_upfn_without_hooks(args, dataset, layername, layers_output):
    convs = None
    data_HW_size = layers_output[layername].shape[2:]
    pbar.print('upsampling from data_shape', tuple(data_HW_size))
    upfn = upsample.upsampler(
        (56, 56),
        data_shape=data_HW_size,
        source=dataset,
        convolutions=convs)
    return upfn
コード例 #3
0
def load_cached_state(cachefile, args):
    if cachefile is None:
        return None
    try:
        dat = numpy.load(cachefile, allow_pickle=True)
        for a, v in args.items():
            if a not in dat or dat[a] != v:
                pbar.print('%s %s changed from %s to %s' %
                           (cachefile, a, dat[a], v))
                return None
    except:
        return None
    else:
        pbar.print('Loading cached %s' % cachefile)
        return dat
コード例 #4
0
 def printstat(s):
     with open(os.path.join(experiment_dir, 'log.txt'), 'a') as f:
         f.write(str(s) + '\n')
     pbar.print(s)
コード例 #5
0
def main():
    args = parseargs()

    model = setting.load_classifier(args.model)
    model = nethook.InstrumentedModel(model).cuda().eval()
    layername = args.layer
    model.retain_layer(layername)
    dataset = setting.load_dataset(args.dataset, crop_size=224)
    train_dataset = setting.load_dataset(args.dataset,
                                         crop_size=224,
                                         split='train')
    sample_size = len(dataset)

    # Probe layer to get sizes
    model(dataset[0][0][None].cuda())
    num_units = model.retained_layer(layername).shape[1]
    classlabels = dataset.classes

    # Measure baseline classification accuracy on val set, and cache.
    pbar.descnext('baseline_pra')
    baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = (
        test_perclass_pra(model,
                          dataset,
                          cachefile=sharedfile('pra-%s-%s/pra_baseline.npz' %
                                               (args.model, args.dataset))))
    pbar.print('baseline acc', baseline_ba.mean().item())

    # Now erase each unit, one at a time, and retest accuracy.
    unit_list = random.sample(list(range(num_units)), num_units)
    val_single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
    for unit in pbar(unit_list):
        pbar.descnext('test unit %d' % unit)
        # Get binary accuracy if the model after ablating the unit.
        _, _, _, ablation_ba = test_perclass_pra(
            model,
            dataset,
            layername=layername,
            ablated_units=[unit],
            cachefile=sharedfile('pra-%s-%s/pra_ablate_unit_%d.npz' %
                                 (args.model, args.dataset, unit)))
        val_single_unit_ablation_ba[unit] = ablation_ba

    # For the purpose of ranking units by importance to a class, we
    # measure using the training set (to avoid training unit ordering
    # on the test set).
    sample_size = None
    # Measure baseline classification accuracy, and cache.
    pbar.descnext('train_baseline_pra')
    baseline_precision, baseline_recall, baseline_accuracy, baseline_ba = (
        test_perclass_pra(
            model,
            train_dataset,
            sample_size=sample_size,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_train_baseline.npz' %
                                 (args.model, args.dataset))))
    pbar.print('baseline acc', baseline_ba.mean().item())

    # Measure accuracy on the val set.
    pbar.descnext('val_baseline_pra')
    _, _, _, val_baseline_ba = (test_perclass_pra(
        model,
        dataset,
        cachefile=sharedfile('ttv-pra-%s-%s/pra_val_baseline.npz' %
                             (args.model, args.dataset))))
    pbar.print('val baseline acc', val_baseline_ba.mean().item())

    # Do in shuffled order to allow multiprocessing.
    single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
    for unit in pbar(unit_list):
        pbar.descnext('test unit %d' % unit)
        _, _, _, ablation_ba = test_perclass_pra(
            model,
            train_dataset,
            layername=layername,
            ablated_units=[unit],
            sample_size=sample_size,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_train_ablate_unit_%d.npz' %
                                 (args.model, args.dataset, unit)))
        single_unit_ablation_ba[unit] = ablation_ba

    # Now for every class, remove a set of the N most-important
    # and N least-important units for that class, and measure accuracy.
    for classnum in pbar(
            random.sample(range(len(classlabels)), len(classlabels))):
        # For a few classes, let's chart the whole range of ablations.
        if classnum in [100, 169, 351, 304]:
            num_best_list = range(1, num_units)
        else:
            num_best_list = [1, 2, 3, 4, 5, 20, 64, 128, 256]
        pbar.descnext('numbest')
        for num_best in pbar(random.sample(num_best_list, len(num_best_list))):
            num_worst = num_units - num_best
            unitlist = single_unit_ablation_ba[:,
                                               classnum].sort(0)[1][:num_best]
            _, _, _, testba = test_perclass_pra(
                model,
                dataset,
                layername=layername,
                ablated_units=unitlist,
                cachefile=sharedfile(
                    'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' %
                    (args.model, args.dataset, classlabels[classnum],
                     len(unitlist))))
            unitlist = (
                single_unit_ablation_ba[:, classnum].sort(0)[1][-num_worst:])
            _, _, _, testba2 = test_perclass_pra(
                model,
                dataset,
                layername=layername,
                ablated_units=unitlist,
                cachefile=sharedfile(
                    'ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz'
                    % (args.model, args.dataset, classlabels[classnum],
                       len(unitlist))))
            pbar.print('%s: best %d %.3f vs worst N %.3f' %
                       (classlabels[classnum], num_best,
                        testba[classnum] - val_baseline_ba[classnum],
                        testba2[classnum] - val_baseline_ba[classnum]))
コード例 #6
0
def test_perclass_pra(model,
                      dataset,
                      layername=None,
                      ablated_units=None,
                      sample_size=None,
                      cachefile=None):
    '''Classifier precision/recall/accuracy measurement.
    Disables a set of units in the specified layer, and then
    measures per-class precision, recall, accuracy and
    balanced (binary classification) accuracy for each class,
    compared to the ground truth in the given dataset.'''
    try:
        if cachefile is not None:
            data = numpy.load(cachefile)
            # verify that this is computed.
            data['true_negative_rate']
            result = tuple(
                torch.tensor(data[key]) for key in
                ['precision', 'recall', 'accuracy', 'balanced_accuracy'])
            pbar.print('Loading cached %s' % cachefile)
            return result
    except:
        pass
    model.remove_edits()
    if ablated_units is not None:

        def ablate_the_units(x, *args):
            x[:, ablated_units] = 0
            return x

        model.edit_layer(layername, rule=ablate_the_units)
    with torch.no_grad():
        num_classes = len(dataset.classes)
        true_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        pred_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        correct_counts = torch.zeros(num_classes, dtype=torch.int64).cuda()
        total_count = 0
        sampler = None if sample_size is None else (FixedSubsetSampler(
            list(range(sample_size))))
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=100,
                                             num_workers=20,
                                             sampler=sampler,
                                             pin_memory=True)
        for image_batch, class_batch in pbar(loader):
            total_count += len(image_batch)
            image_batch, class_batch = [
                d.cuda() for d in [image_batch, class_batch]
            ]
            scores = model(image_batch)
            preds = scores.max(1)[1]
            correct = (preds == class_batch)
            true_counts.add_(class_batch.bincount(minlength=num_classes))
            pred_counts.add_(preds.bincount(minlength=num_classes))
            correct_counts.add_(
                class_batch.bincount(correct, minlength=num_classes).long())
    model.remove_edits()
    true_neg_counts = ((total_count - true_counts) -
                       (pred_counts - correct_counts))
    precision = (correct_counts.float() / pred_counts.float()).cpu()
    recall = (correct_counts.float() / true_counts.float()).cpu()
    accuracy = (correct_counts + true_neg_counts).float().cpu() / total_count
    true_neg_rate = (true_neg_counts.float() /
                     (total_count - true_counts).float()).cpu()
    balanced_accuracy = (recall + true_neg_rate) / 2
    if cachefile is not None:
        numpy.savez(cachefile,
                    precision=precision.numpy(),
                    recall=recall.numpy(),
                    accuracy=accuracy.numpy(),
                    true_negative_rate=true_neg_rate.numpy(),
                    balanced_accuracy=balanced_accuracy.numpy())
    return precision, recall, accuracy, balanced_accuracy
コード例 #7
0
def main():
    # Training settings
    def strpair(arg):
        p = tuple(arg.split(':'))
        if len(p) == 1:
            p = p + p
        return p
    def intpair(arg):
        p = arg.split(',')
        if len(p) == 1:
            p = p + p
        return tuple(int(v) for v in p)

    parser = argparse.ArgumentParser(description='Net dissect utility',
            prog='python -m netdissect',
            epilog=textwrap.dedent(help_epilog),
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--model', type=str, default=None,
                        help='constructor for the model to test')
    parser.add_argument('--pthfile', type=str, default=None,
                        help='filename of .pth file for the model')
    parser.add_argument('--unstrict', action='store_true', default=False,
                        help='ignore unexpected pth parameters')
    parser.add_argument('--modelkey', type=str, default=None,
                        help='key within pthfile containing state_dict')
    parser.add_argument('--submodule', type=str, default=None,
                        help='submodule to load from pthfile state dict')
    parser.add_argument('--outdir', type=str, default='dissect',
                        help='directory for dissection output')
    parser.add_argument('--layers', type=strpair, nargs='+',
                        help='space-separated list of layer names to dissect' +
                        ', in the form layername[:reportedname]')
    parser.add_argument('--segments', type=str, default='datasets/broden',
                        help='directory containing segmentation dataset')
    parser.add_argument('--segmenter', type=str, default=None,
                        help='constructor for asegmenter class')
    parser.add_argument('--normalizer', type=str, default=None,
                        help='Normalize rgb with imagenet, zc, or pt ranges')
    parser.add_argument('--download', action='store_true', default=False,
                        help='downloads Broden dataset if needed')
    parser.add_argument('--imagedir', type=str, default=None,
                        help='directory containing image-only dataset')
    parser.add_argument('--imgsize', type=intpair, default=(227, 227),
                        help='input image size to use')
    parser.add_argument('--netname', type=str, default=None,
                        help='name for network in generated reports')
    parser.add_argument('--meta', type=str, nargs='+',
                        help='json files of metadata to add to report')
    parser.add_argument('--merge', type=str,
                        help='json file of unit data to merge in report')
    parser.add_argument('--examples', type=int, default=20,
                        help='number of image examples per unit')
    parser.add_argument('--size', type=int, default=10000,
                        help='dataset subset size to use')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='batch size for forward pass')
    parser.add_argument('--num_workers', type=int, default=24,
                        help='number of DataLoader workers')
    parser.add_argument('--quantile_threshold', type=strfloat, default=None,
                        choices=[FloatRange(0.0, 1.0), 'iqr'],
                        help='quantile to use for masks')
    parser.add_argument('--whiten', default=None,
                        help='set to pca to whiten units')
    parser.add_argument('--no-labels', action='store_true', default=False,
                        help='disables labeling of units')
    parser.add_argument('--maxiou', action='store_true', default=False,
                        help='enables maxiou calculation')
    parser.add_argument('--covariance', action='store_true', default=False,
                        help='enables covariance calculation')
    parser.add_argument('--rank_all_labels', action='store_true', default=False,
                        help='include low-information labels in rankings')
    parser.add_argument('--no-images', action='store_true', default=False,
                        help='disables generation of unit images')
    parser.add_argument('--no-report', action='store_true', default=False,
                        help='disables generation report summary')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA usage')
    parser.add_argument('--gen', action='store_true', default=False,
                        help='test a generator model (e.g., a GAN)')
    parser.add_argument('--gan', action='store_true', default=False,
                        help='synonym for --gen')
    parser.add_argument('--perturbation', default=None,
                        help='filename of perturbation attack to apply')
    parser.add_argument('--add_scale_offset', action='store_true', default=None,
                        help='offsets masks according to stride and padding')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='silences console output')
    if len(sys.argv) == 1:
        parser.print_usage(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()
    args.images = not args.no_images
    args.report = not args.no_report
    args.labels = not args.no_labels
    if args.gan:
        args.gen = args.gan

    # Set up console output
    pbar.verbose(not args.quiet)

    # Exit right away if job is already done or being done.
    if args.outdir is not None:
        exit_if_job_done(args.outdir)

    # Speed up pytorch
    torch.backends.cudnn.benchmark = True

    # Special case: download flag without model to test.
    if args.model is None and args.download:
        from netdissect.broden import ensure_broden_downloaded
        for resolution in [224, 227, 384]:
            ensure_broden_downloaded(args.segments, resolution, 1)
        from netdissect.segmenter import ensure_upp_segmenter_downloaded
        ensure_upp_segmenter_downloaded('datasets/segmodel')
        sys.exit(0)

    # Help if broden is not present
    if not args.gen and not args.imagedir and not os.path.isdir(args.segments):
        pbar.print('Segmentation dataset not found at %s.' % args.segments)
        pbar.print('Specify dataset directory using --segments [DIR]')
        pbar.print('To download Broden, run: netdissect --download')
        sys.exit(1)

    # Default segmenter class
    if args.gen and args.segmenter is None:
        args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" +
                "segsizes=[256], segdiv='quad')")

    # Default threshold
    if args.quantile_threshold is None:
        if args.gen:
            args.quantile_threshold = 'iqr'
        else:
            args.quantile_threshold = 0.005

    # Set up CUDA
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.backends.cudnn.benchmark = True

    # Construct the network with specified layers instrumented
    if args.model is None:
        pbar.print('No model specified')
        sys.exit(1)
    model = create_instrumented_model(args)

    # Update any metadata from files, if any
    meta = getattr(model, 'meta', {})
    if args.meta:
        for mfilename in args.meta:
            with open(mfilename) as f:
                meta.update(json.load(f))

    # Load any merge data from files
    mergedata = None
    if args.merge:
        with open(args.merge) as f:
            mergedata = json.load(f)

    # Set up the output directory, verify write access
    if args.outdir is None:
        args.outdir = os.path.join('dissect', type(model).__name__)
        exit_if_job_done(args.outdir)
        pbar.print('Writing output into %s.' % args.outdir)
    os.makedirs(args.outdir, exist_ok=True)
    train_dataset = None

    if not args.gen:
        # Load dataset for classifier case.
        # Load perturbation
        perturbation = numpy.load(args.perturbation
                ) if args.perturbation else None
        segrunner = None

        # Load broden dataset
        if args.imagedir is not None:
            dataset = try_to_load_images(args.imagedir, args.imgsize,
                    perturbation, args.size)
            segrunner = ImageOnlySegRunner(dataset)
        else:
            dataset = try_to_load_broden(args.segments, args.imgsize, 1,
                perturbation, args.download, args.size,
                normalizer_named(args.normalizer))
        if dataset is None:
            dataset = try_to_load_multiseg(args.segments, args.imgsize,
                    perturbation, args.size)
        if dataset is None:
            pbar.print('No segmentation dataset found in %s',
                    args.segments)
            pbar.print('use --download to download Broden.')
            sys.exit(1)
    else:
        # For segmenter case the dataset is just a random z
        dataset = z_dataset_for_model(model, args.size)
        train_dataset = z_dataset_for_model(model, args.size, seed=2)
        segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter))

    # Run dissect
    dissect(args.outdir, model, dataset,
            train_dataset=train_dataset,
            segrunner=segrunner,
            examples_per_unit=args.examples,
            netname=args.netname,
            quantile_threshold=args.quantile_threshold,
            meta=meta,
            merge=mergedata,
            pca_units=(args.whiten == 'pca'),
            make_images=args.images,
            make_labels=args.labels,
            make_maxiou=args.maxiou,
            make_covariance=args.covariance,
            make_report=args.report,
            make_row_images=args.images,
            make_single_images=True,
            rank_all_labels=args.rank_all_labels,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            settings=vars(args))

    # Mark the directory so that it's not done again.
    mark_job_done(args.outdir)
コード例 #8
0
 def eval_loss_and_reg():
     discrete_experiments = dict(
         # dpixel=dict(discrete_pixels=True),
         # dunits20=dict(discrete_units=20),
         # dumix20=dict(discrete_units=20, mixed_units=True),
         # dunits10=dict(discrete_units=10),
         # abonly=dict(ablation_only=True),
         # fimabl=dict(ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         dboth20=dict(discrete_units=20, discrete_pixels=True),
         # dbothm20=dict(discrete_units=20, mixed_units=True,
         #              discrete_pixels=True),
         # abdisc20=dict(discrete_units=20, discrete_pixels=True,
         #             ablation_only=True),
         # abdiscm20=dict(discrete_units=20, mixed_units=True,
         #             discrete_pixels=True,
         #             ablation_only=True),
         # fimadp=dict(discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadu10=dict(discrete_units=10,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadb10=dict(discrete_units=10, discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         fimadbm10=dict(discrete_units=10,
                        mixed_units=True,
                        discrete_pixels=True,
                        ablation_only=True,
                        fullimage_ablation=True,
                        fullimage_measurement=True),
         # fimadu20=dict(discrete_units=20,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         # fimadb20=dict(discrete_units=20, discrete_pixels=True,
         #             ablation_only=True,
         #             fullimage_ablation=True,
         #             fullimage_measurement=True),
         fimadbm20=dict(discrete_units=20,
                        mixed_units=True,
                        discrete_pixels=True,
                        ablation_only=True,
                        fullimage_ablation=True,
                        fullimage_measurement=True))
     with torch.no_grad():
         total_loss = 0
         discrete_losses = {k: 0 for k in discrete_experiments}
         for [pbatch, ploc, cbatch,
              cloc] in pbar(torch.utils.data.DataLoader(
                  TensorDataset(corpus.eval_present_sample,
                                corpus.eval_present_location,
                                corpus.eval_candidate_sample,
                                corpus.eval_candidate_location),
                  batch_size=args.inference_batch_size,
                  num_workers=10,
                  shuffle=False,
                  pin_memory=True),
                            desc="Eval"):
             # First, put in zeros for the selected units.
             # Loss is amount of remaining object.
             total_loss = total_loss + ace_loss(
                 segmenter,
                 classnum,
                 model,
                 args.layer,
                 high_replacement,
                 ablation,
                 pbatch,
                 ploc,
                 cbatch,
                 cloc,
                 run_backward=False,
                 ablation_only=ablation_only,
                 fullimage_measurement=fullimage_measurement)
             for k, config in discrete_experiments.items():
                 discrete_losses[k] = discrete_losses[k] + ace_loss(
                     segmenter,
                     classnum,
                     model,
                     args.layer,
                     high_replacement,
                     ablation,
                     pbatch,
                     ploc,
                     cbatch,
                     cloc,
                     run_backward=False,
                     **config)
         avg_loss = (total_loss / args.eval_size).item()
         avg_d_losses = {
             k: (d / args.eval_size).item()
             for k, d in discrete_losses.items()
         }
         regularizer = (args.l2_lambda * ablation.pow(2).sum())
         pbar.print('Epoch %d Loss %g Regularizer %g' %
                    (epoch, avg_loss, regularizer))
         pbar.print(' '.join('%s: %g' % (k, d)
                             for k, d in avg_d_losses.items()))
         pbar.print(scale_summary(ablation.view(-1), 10, 3))
         return avg_loss, regularizer, avg_d_losses
コード例 #9
0
def create_instrumented_model(args, **kwargs):
    '''
    Creates an instrumented model out of a namespace of arguments that
    correspond to ArgumentParser command-line args:
      model: a string to evaluate as a constructor for the model.
      pthfile: (optional) filename of .pth file for the model.
      layers: a list of layers to instrument, defaulted if not provided.
      edit: True to instrument the layers for editing.
      gen: True for a generator model.  One-pixel input assumed.
      imgsize: For non-generator models, (y, x) dimensions for RGB input.
      cuda: True to use CUDA.
  
    The constructed model will be decorated with the following attributes:
      input_shape: (usually 4d) tensor shape for single-image input.
      output_shape: 4d tensor shape for output.
      feature_shape: map of layer names to 4d tensor shape for featuremaps.
      retained: map of layernames to tensors, filled after every evaluation.
      ablation: if editing, map of layernames to [0..1] alpha values to fill.
      replacement: if editing, map of layernames to values to fill.

    When editing, the feature value x will be replaced by:
        `x = (replacement * ablation) + (x * (1 - ablation))`
    '''

    args = EasyDict(vars(args), **kwargs)

    # Construct the network
    if args.model is None:
        pbar.print('No model specified')
        return None
    if isinstance(args.model, torch.nn.Module):
        model = args.model
    else:
        model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if getattr(args, 'pthfile', None) is not None:
        data = torch.load(args.pthfile)
        modelkey = getattr(args, 'modelkey', 'state_dict')
        if modelkey in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data[modelkey]
        submodule = getattr(args, 'submodule', None)
        if submodule is not None and len(submodule):
            remove_prefix = submodule + '.'
            data = {
                k[len(remove_prefix):]: v
                for k, v in data.items() if k.startswith(remove_prefix)
            }
            if not len(data):
                pbar.print('No submodule %s found in %s' %
                           (submodule, args.pthfile))
                return None
        model.load_state_dict(data,
                              strict=not getattr(args, 'unstrict', False))

    # Decide which layers to instrument.
    if getattr(args, 'layer', None) is not None:
        args.layers = [args.layer]
    # If the layer '?' is the only specified, just print out all layers.
    if getattr(args, 'layers', None) is not None:
        if len(args.layers) == 1 and args.layers[0] == ('?', '?'):
            for name, layer in model.named_modules():
                pbar.print(name)
            import sys
            sys.exit(0)
    if getattr(args, 'layers', None) is None:
        # Skip wrappers with only one named model
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        pbar.print('Defaulting to layers: %s' % ' '.join(args.layers))

    # Now wrap the model for instrumentation.
    model = InstrumentedModel(model)
    model.meta = meta

    # Instrument the layers.
    model.retain_layers(args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Annotate input, output, and feature shapes
    annotate_model_shapes(model,
                          gen=getattr(args, 'gen', False),
                          imgsize=getattr(args, 'imgsize', None))
    return model