示例#1
0
def create_instrumented_model(args, **kwargs):
    '''
    Creates an instrumented model out of a namespace of arguments that
    correspond to ArgumentParser command-line args:
      model: a string to evaluate as a constructor for the model.
      pthfile: (optional) filename of .pth file for the model.
      layers: a list of layers to instrument, defaulted if not provided.
      edit: True to instrument the layers for editing.
      gen: True for a generator model.  One-pixel input assumed.
      imgsize: For non-generator models, (y, x) dimensions for RGB input.
      cuda: True to use CUDA.
  
    The constructed model will be decorated with the following attributes:
      input_shape: (usually 4d) tensor shape for single-image input.
      output_shape: 4d tensor shape for output.
      feature_shape: map of layer names to 4d tensor shape for featuremaps.
      retained: map of layernames to tensors, filled after every evaluation.
      ablation: if editing, map of layernames to [0..1] alpha values to fill.
      replacement: if editing, map of layernames to values to fill.

    When editing, the feature value x will be replaced by:
        `x = (replacement * ablation) + (x * (1 - ablation))`
    '''

    args = EasyDict(vars(args), **kwargs)

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        return None
    if isinstance(args.model, torch.nn.Module):
        model = args.model
    else:
        model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if getattr(args, 'pthfile', None) is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        submodule = getattr(args, 'submodule', None)
        if submodule is not None and len(submodule):
            remove_prefix = submodule + '.'
            data = {
                k[len(remove_prefix):]: v
                for k, v in data.items() if k.startswith(remove_prefix)
            }
            if not len(data):
                print_progress('No submodule %s found in %s' %
                               (submodule, args.pthfile))
                return None
        model.load_state_dict(data,
                              strict=not getattr(args, 'unstrict', False))

    # Decide which layers to instrument.
    if getattr(args, 'layer', None) is not None:
        args.layers = [args.layer]
    if getattr(args, 'layers', None) is None:
        # Skip wrappers with only one named model
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))

    # Now wrap the model for instrumentation.
    model = InstrumentedModel(model)
    model.meta = meta

    # Instrument the layers.
    model.retain_layers(args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Annotate input, output, and feature shapes
    annotate_model_shapes(model,
                          gen=getattr(args, 'gen', False),
                          imgsize=getattr(args, 'imgsize', None))
    return model
示例#2
0
data = get_segments_dataset('dataset/Adissect_toy')
#model = autoimport_eval("p2pgan.from_pth_file('models/pix2pix/p2p_churches.pth')")
model = from_pth_file('models/pix2pix/p2p_churches.pth')
segmenter = (
    "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')"
)
segrunner = GeneratorSegRunner(autoimport_eval(segmenter))

layer5 = ('model.model.1.model.3.model.3.model.3.model.1', 'layer5')
layer9 = (
    'model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3',
    'layer9')
layer12 = ('model.model.1.model.3.model.3.model.3.model.5', 'layer12')

model = InstrumentedModel(model)
model.retain_layers([layer5, layer9, layer12])
annotate_model_shapes(model, gen=True, imgsize=None)

segloader = torch.utils.data.DataLoader(data,
                                        batch_size=batch_size,
                                        pin_memory=True)

to_img = transforms.ToPILImage()
for i, batch in enumerate(segloader):
    to_img(batch[0].squeeze()).save('test/input%d.png' % i, 'PNG')
    outseg, bc, rgb, shape = segrunner.run_and_segment_batch(
        batch, model, want_bincount=False, want_rgb=True)
    PIL.Image.fromarray(
        segment_visualization(outseg.squeeze().cpu().numpy(),
                              size=256)).save('test/outseg%d.png' % i, 'PNG')