コード例 #1
0
def draw_lang_model_to_file(model, png_fname):
    """Draw a language model graph to a PNG file.

    Caveat: the PNG that is produced has some problems, which we suspect are due to
    PyTorch issues related to RNN ONNX export.
    """
    try:
        # if dataset == 'wikitext2':
        batch_size = 1
        seq_len = 255
        dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size, 1).to(device)
        hidden = model.init_hidden(batch_size)
        dummy_input = (dummy_input, hidden)
        # else:
        #     msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset)
        #     return
        g = distiller.SummaryGraph(model, dummy_input)
        distiller.draw_model_to_file(g, png_fname)
        msglogger.info("Network PNG image generation completed")

    except FileNotFoundError as e:
        msglogger.info("An error has occured while generating the network PNG image.")
        msglogger.info("Please check that you have graphviz installed.")
        msglogger.info("\t$ sudo apt-get install graphviz")
        raise e
コード例 #2
0
    def prepare_model(self, dummy_input=None):
        """
        Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
        and overrides configuration provided to __init__(), and according to the replacement_factory as
        defined by the Quantizer sub-class being used.

        Note:
            If multiple sub-modules within the model actually reference the same module, then that module
            is replaced only once, according to the configuration (bit-width and/or overrides) of the
            first encountered reference.
            Toy Example - say a module is constructed using this bit of code:

                shared_relu = nn.ReLU
                self.relu1 = shared_relu
                self.relu2 = shared_relu

            When traversing the model, a replacement will be generated when 'self.relu1' is encountered.
            Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced
            with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
            will be ignored. A warning message will be shown.
        """
        msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__))

        self.model.quantizer_metadata["dummy_input"] = dummy_input
        if dummy_input is not None:
            summary_graph = distiller.SummaryGraph(self.model, dummy_input)
            self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)

        self._pre_prepare_model(dummy_input)

        self._pre_process_container(self.model)
        for module_name, module in self.model.named_modules():
            qbits = self.module_qbits_map[module_name]
            curr_parameters = dict(module.named_parameters())
            for param_name, param in curr_parameters.items():
                n_bits = qbits.bias if param_name.endswith('bias') else qbits.wts
                if n_bits is None:
                    continue
                fp_attr_name = param_name
                if self.train_with_fp_copy:
                    hack_float_backup_parameter(module, param_name, n_bits)
                    fp_attr_name = FP_BKP_PREFIX + param_name
                self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))

                param_full_name = '.'.join([module_name, param_name])
                msglogger.info(
                    "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))

        # If an optimizer was passed, assume we need to update it
        if self.optimizer:
            optimizer_type = type(self.optimizer)
            new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults)
            self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups})

        self._post_prepare_model()

        msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
コード例 #3
0
ファイル: policy.py プロジェクト: xcffl/distiller
    def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs):
        msglogger.debug("Pruner {} is about to prune".format(self.pruner.name))
        self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] -
                                                       1)
        if self.levels is not None:
            self.pruner.levels = self.levels

        meta['model'] = model
        is_initialized = self.is_initialized

        if self.fold_bn:
            # Cache this information (required for BN-folding) to improve performance
            self.named_modules = OrderedDict(model.named_modules())
            dummy_input = torch.randn(model.input_shape)
            self.sg = distiller.SummaryGraph(model, dummy_input)

        for param_name, param in model.named_parameters():
            if self.fold_bn:
                param = self._fold_batchnorm(model, param_name, param,
                                             self.named_modules, self.sg)
            if not is_initialized:
                # Initialize the maskers
                masker = zeros_mask_dict[param_name]
                masker.use_double_copies = self.use_double_copies
                masker.mask_on_forward_only = self.mask_on_forward_only
                # register for the backward hook of the parameters
                if self.mask_gradients:
                    masker.backward_hook_handle = param.register_hook(
                        masker.mask_gradient)

                self.is_initialized = True
                if not self.skip_first_minibatch:
                    self.pruner.set_param_mask(param, param_name,
                                               zeros_mask_dict, meta)
            else:
                self.pruner.set_param_mask(param, param_name, zeros_mask_dict,
                                           meta)
コード例 #4
0
def fuse_modules(model,
                 types_sequence,
                 fuse_fn,
                 dummy_input=None,
                 adjacency_map=None):
    """
    Scans the module for sequences of modules of the specified types and "fuses" them. As an example, consider the
    following sequence of 3 modules: 'm_1' --> 'm_2' --> 'm_3'. Assuming they match the specified sequence of types,
    they will be fused such that the "fused" module replaces 'm_1', and 'm_2' and 'm_3' are replaced with identity
    operations.

    For a sequence of modules to be fused, it must not contain splits. That is - no module in the sequence can
    have more than a single output. For example, consider the following sequence:

    m_1 --> m_2 --> m_3
        |
        |
        --> m_4

    Even if m_1, m_2 and m_3 match the types sequence, they can't be fused because m_1's output also goes to m_4.

    The fused module is generated by the user specified function 'fuse_fn'.

    To infer the order of modules it is required to perform a forward pass on the model. Hence the need to pass the
    expected input shape.

    Args:
        model (nn.Module): Model instance on which the transformation is performed
        types_sequence (list or tuple): Sequence of module types. Each item in the sequence may itself be a
          list / tuple. For example - to fuse all possible convolution types with ReLU, pass:
          [[nn.Conv1d, nn.Conv2d, nn.Conv3d], nn.ReLU]
        fuse_fn (function): Function that takes a list of models to be fused, and returns a single fused module.
          If the sequence cannot be fused, this function should return None
        dummy_input (torch.Tensor or tuple): Dummy input to the model. Required if summary_graph is None
        adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based
          on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created
          internally using the passed dummy_input.
    """
    distiller.assign_layer_fq_names(model)
    if adjacency_map is None:
        if dummy_input is None:
            raise ValueError(
                'Must pass either valid adjacency map instance or valid dummy input'
            )
        summary_graph = distiller.SummaryGraph(model, dummy_input)
        adjacency_map = summary_graph.adjacency_map(
            dedicated_modules_only=False)
    named_modules = OrderedDict(model.named_modules())
    in_sequence_idx = 0
    curr_sequence = []

    for node_name, adj_entry in adjacency_map.items():
        module = named_modules.get(node_name, None)
        if module is None:
            reset = True
        else:
            reset = False
            if isinstance(module, types_sequence[in_sequence_idx]):
                curr_sequence.append(module)
                in_sequence_idx += 1
                if in_sequence_idx == len(types_sequence):
                    _fuse_sequence(curr_sequence, named_modules, fuse_fn)
                    reset = True
                elif len(adj_entry.successors) > 1:
                    msglogger.debug(
                        node_name +
                        " is connected to multiple outputs, not fuse-able")
                    reset = True
            elif isinstance(module, types_sequence[0]):
                # Current module breaks the current sequence, check if it's the start of a new sequence
                in_sequence_idx = 1
                curr_sequence = [module]
            else:
                reset = True
        if reset:
            in_sequence_idx = 0
            curr_sequence = []
    return model
コード例 #5
0
 def create_graph(model):
     if input_tensor is not None:
         dummy_input = input_tensor
     else:
         dummy_input = torch.randn(16, 3, 32, 32)
     return distiller.SummaryGraph(model, dummy_input)