def create_instrumented_model(args, **kwargs):
    '''
    Creates an instrumented model out of a namespace of arguments that
    correspond to ArgumentParser command-line args:
      model: a string to evaluate as a constructor for the model.
      pthfile: (optional) filename of .pth file for the model.
      layers: a list of layers to instrument, defaulted if not provided.
      edit: True to instrument the layers for editing.
      gen: True for a generator model.  One-pixel input assumed.
      imgsize: For non-generator models, (y, x) dimensions for RGB input.
      cuda: True to use CUDA.
  
    The constructed model will be decorated with the following attributes:
      input_shape: (usually 4d) tensor shape for single-image input.
      output_shape: 4d tensor shape for output.
      feature_shape: map of layer names to 4d tensor shape for featuremaps.
      retained: map of layernames to tensors, filled after every evaluation.
      ablation: if editing, map of layernames to [0..1] alpha values to fill.
      replacement: if editing, map of layernames to values to fill.

    When editing, the feature value x will be replaced by:
        `x = (replacement * ablation) + (x * (1 - ablation))`
    '''

    args = EasyDict(vars(args), **kwargs)

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        return None
    if isinstance(args.model, torch.nn.Module):
        model = args.model
    else:
        model = autoimport_eval(args.model)
    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    meta = {}
    if getattr(args, 'pthfile', None) is not None:
        data = torch.load(args.pthfile)
        if 'state_dict' in data:
            meta = {}
            for key in data:
                if isinstance(data[key], numbers.Number):
                    meta[key] = data[key]
            data = data['state_dict']
        model.load_state_dict(data)
    model.meta = meta

    # Decide which layers to instrument.
    if getattr(args, 'layer', None) is not None:
        args.layers = [args.layer]
    if getattr(args, 'layers', None) is None:
        # Skip wrappers with only one named model
        container = model
        prefix = ''
        while len(list(container.named_children())) == 1:
            name, container = next(container.named_children())
            prefix += name + '.'
        # Default to all nontrivial top-level layers except last.
        args.layers = [
            prefix + name for name, module in container.named_children()
            if type(module).__module__ not in [
                # Skip ReLU and other activations.
                'torch.nn.modules.activation',
                # Skip pooling layers.
                'torch.nn.modules.pooling'
            ]
        ][:-1]
        print_progress('Defaulting to layers: %s' % ' '.join(args.layers))

    # Instrument the layers.
    retain_layers(model, args.layers)
    if getattr(args, 'edit', False):
        edit_layers(model, args.layers)
    model.eval()
    if args.cuda:
        model.cuda()

    # Annotate input, output, and feature shapes
    annotate_model_shapes(model,
                          gen=getattr(args, 'gen', False),
                          imgsize=getattr(args, 'imgsize', None))
    return model
Beispiel #2
0
def create_instrumented_model(args, **kwargs):
    '''
    Creates an instrumented model out of a namespace of arguments that
    correspond to ArgumentParser command-line args:
      model: a string to evaluate as a constructor for the model.
      pthfile: (optional) filename of .pth file for the model.
      layers: a list of layers to instrument, defaulted if not provided.
      edit: True to instrument the layers for editing.
      gen: True for a generator model.  One-pixel input assumed.
      imgsize: For non-generator models, (y, x) dimensions for RGB input.
      cuda: True to use CUDA.
  
    The constructed model will be decorated with the following attributes:
      input_shape: (usually 4d) tensor shape for single-image input.
      output_shape: 4d tensor shape for output.
      feature_shape: map of layer names to 4d tensor shape for featuremaps.
      retained: map of layernames to tensors, filled after every evaluation.
      ablation: if editing, map of layernames to [0..1] alpha values to fill.
      replacement: if editing, map of layernames to values to fill.

    When editing, the feature value x will be replaced by:
        `x = (replacement * ablation) + (x * (1 - ablation))`
    '''

    args = EasyDict(vars(args), **kwargs)

    # Construct the network
    if args.model is None:
        print_progress('No model specified')
        return None
    #if isinstance(args.model, torch.nn.Module):
    #    model = args.model
    #else:
    #    model = autoimport_eval(args.model)
    model = UnetNormalized()
    checkpoint = torch.load('p2p_churches.pth')
    model.load_state_dict(checkpoint)
    model.cuda()
    #model.eval()
    torch.no_grad()

    # Unwrap any DataParallel-wrapped model
    if isinstance(model, torch.nn.DataParallel):
        model = next(model.children())

    # Load its state dict
    #meta = {}
    #if getattr(args, 'pthfile', None) is not None:
    #    data = torch.load(args.pthfile)
    #    if 'state_dict' in data:
    #        meta = {}
    #    for key in data:
    #            if isinstance(data[key], numbers.Number):
    #                meta[key] = data[key]
    #        data = data['state_dict']
    #    submodule = getattr(args, 'submodule', None)
    #    if submodule is not None and len(submodule):
    #        remove_prefix = submodule + '.'
    #        data = { k[len(remove_prefix):]: v for k, v in data.items()
    #                if k.startswith(remove_prefix)}
    #        if not len(data):
    #            print_progress('No submodule %s found in %s' %
    #                    (submodule, args.pthfile))
    #            return None
    #    model.load_state_dict(data, strict=not getattr(args, 'unstrict', False))

    # Decide which layers to instrument.
    #if getattr(args, 'layer', None) is not None:
    #    args.layers = [args.layer]
    args.layers = [layer5, layer9, layer12]

    #if getattr(args, 'layers', None) is None:
    #    # Skip wrappers with only one named model
    #    container = model
    #    prefix = ''
    #    while len(list(container.named_children())) == 1:
    #        name, container = next(container.named_children())
    #        prefix += name + '.'
    #    # Default to all nontrivial top-level layers except last.
    #    args.layers = [prefix + name
    #            for name, module in container.named_children()
    #            if type(module).__module__ not in [
    #                # Skip ReLU and other activations.
    #                'torch.nn.modules.activation',
    #                # Skip pooling layers.
    #                'torch.nn.modules.pooling']
    #            ][:-1]
    #    print_progress('Defaulting to layers: %s' % ' '.join(args.layers))

    # Now wrap the model for instrumentation.
    model = InstrumentedModel(model)
    #    model.meta = meta

    # Instrument the layers.
    model.retain_layers(args.layers)
    #model.eval()
    #if args.cuda:
    #  model.cuda()

    # Annotate input, output, and feature shapes
    annotate_model_shapes(model,
                          gen=getattr(args, 'gen', False),
                          imgsize=getattr(args, 'imgsize', None))
    return model