Esempio n. 1
0
    def replace_compressed_modules(self):
        """
        Replace all the modules that have changed (weights/inputs/output) shape.
        The new module is created using the same arguments of the to-be-replaced module,
        and correctly inherits its weights.

        NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
        is that ```func``` should be not required to be replaced.
        """
        for module_name in self.inferred_masks:
            g_node = self.torch_graph.name_to_node[module_name]
            _logger.debug("replace %s, in %s type, with op_type %s",
                          module_name, g_node.type, g_node.op_type)
            if g_node.type == 'module':
                super_module, leaf_module = get_module_by_name(
                    self.bound_model, g_node.name)
                m_type = g_node.op_type
                if not m_type in replace_module:
                    raise RuntimeError(
                        "Has not supported replacing the module: `{}`".format(
                            m_type))
                _logger.info("replace module (name: %s, op_type: %s)",
                             g_node.name, m_type)
                compressed_module = replace_module[m_type](
                    leaf_module, self.inferred_masks[module_name])
                setattr(super_module,
                        g_node.name.split('.')[-1], compressed_module)
            elif g_node.type == 'func':
                _logger.info(
                    "Warning: cannot replace (name: %s, op_type: %s) which is func type",
                    module_name, g_node.op_type)
            else:
                raise RuntimeError("Unsupported node type: {}".format(
                    g_node.type))
Esempio n. 2
0
    def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
        """
        Replace the submodule according to the inferred sparsity.
        unique_name: str
            The unique_name of the submodule to replace.
        reindex_dim: int
            The dimension of the re-index operation.
        reindex: Reindex
            The index tensor. Normally this variable is None. If we want to reindex the
            output of this submodule, we can pass the index by this parameter.
        """
        class ReindexModule(nn.Module):
            """
            ReindexModule is used to resolve the mask conflict when replace the submodule.
            Basically, we can use two ways to resolve the mask conflict: (1) unmask some
            values(will introduce more computation overhead) (2) reindex and padd the output
            tensor of the target op(introduce more memory access overhad). Currently this
            method is shutdown, in the future, we will merge these two methods into a graph
            pass which is used to resolve the mask conflict.
            """
            def __init__(self, ori_module, reindex_dim, reindex):
                super(ReindexModule, self).__init__()
                self.ori_module = ori_module
                self.reindex_dim = reindex_dim
                self.reindex = reindex
                tmp_index = [slice(None, None) for i in range(reindex_dim+1)]
                # the index for the tensor
                tmp_index[reindex_dim] = reindex
                self.t_index = tuple(tmp_index)

            def forward(self, x):
                tmpout = self.ori_module(x)
                shape = list(tmpout.size())
                shape[self.reindex_dim] = self.reindex.size(0)
                out = torch.zeros(tuple(shape), device=tmpout.device,
                                  requires_grad=tmpout.requires_grad)
                out[self.t_index] = tmpout
                return out

        assert unique_name in self.auto_inferences
        g_node = self.torch_graph.name_to_node[unique_name]
        _logger.debug("replace %s, in %s type, with op_type %s",
                      unique_name, g_node.type, g_node.op_type)
        auto_infer = self.auto_inferences[unique_name]
        if g_node.type == 'module':
            if g_node.unique_name in self.torch_graph.reused_module:
                if reindex_dim is not None:
                    _logger.warning(
                        'Cannot replace a reused module with padding operator!!')
                    return None
            super_module, leaf_module = get_module_by_name(
                self.bound_model, g_node.name)
            m_type = g_node.op_type
            if not m_type in replace_module:
                raise RuntimeError(
                    "Has not supported replacing the module: `{}`".format(m_type))
            _logger.info("replace module (name: %s, op_type: %s)",
                         g_node.name, m_type)
            compressed_module = replace_module[m_type](
                leaf_module, auto_infer.get_masks())
            new_submodule = compressed_module
            if reindex_dim is None:
                setattr(super_module, g_node.name.split(
                    '.')[-1], compressed_module)
            elif reindex_dim is not None and reindex is not None:
                # reindex the output of this submodule and replace the orginal module
                new_submodule = ReindexModule(
                    compressed_module, reindex_dim, reindex)
                setattr(super_module, g_node.name.split(
                    '.')[-1], new_submodule)
            return new_submodule
        elif g_node.type == 'func':
            _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
                         unique_name, g_node.op_type)
            return None
        else:
            raise RuntimeError("Unsupported node type: {}".format(g_node.type))
Esempio n. 3
0
    def update_direct_sparsity(self, node):
        """
        Update the direct sparsity for the target node. Here the direct sparsity
        means that the sparsity in the output tensor that caused by the sparsity
        in the input tensors/weight tensors.
        """
        # this name is consistent with the name returned by named_modules()
        module_name = node.name
        _logger.info('Update mask for %s', module_name)
        unique_name = node.unique_name
        dummy_input, input_debugname = self._prepare_dummy_input(node)
        # get the input mask from self.masks
        # Note: the input mask of the successor nodes are
        # already created by the predecessor node
        in_masks = [self.masks[debugname] for debugname in input_debugname]
        in_constants = [self.constant[debugname]
                        for debugname in input_debugname]
        if node.type == 'func':
            # we cannot get the runable function directly from the jit traced
            # graph, so we translate it back to python function, Note: the function
            # is appliable to both cpu/gpu devices, the output tensors will be on the
            # same device of the input tensors
            func = jit_to_python_function(node, self)
            if func is None:
                # no need to infer the sparsity for this node
                self.auto_inferences[unique_name] = None
                return
            # function doesn't have weights
            _auto_infer = AutoMaskInference(
                func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim)
        else:
            weight_mask = None
            if module_name in self.masks:
                weight_mask = self.masks[module_name]
            _, module = get_module_by_name(self.bound_model, module_name)
            _auto_infer = AutoMaskInference(
                module, dummy_input, in_masks, weight_mask, in_constants=in_constants,
                state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim)
        self.auto_inferences[unique_name] = _auto_infer
        _auto_infer.name = node.unique_name

        _auto_infer.update_direct_sparsity()
        # also save the input debug names into the auto_infer
        _auto_infer.input_debugname = input_debugname
        # update the mask tensor and the internal output of the submodules
        # after manually unpack the tuple/list of tensors, the number of the outputs
        # of each node should always be one(Except for the TupleUnpack node at the end
        # of the whole model)
        assert len(
            node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually'

        out_debugname = node.outputs[0]
        # update the output mask into self.masks
        self.masks[out_debugname] = _auto_infer.output_mask
        self.constant[out_debugname] = _auto_infer.out_constant
        # update the output result into self.internal_result, so that
        # the successor nodes can take these output tensors as inputs.
        self.internal_result[out_debugname] = _auto_infer.output
        # update the parameter mask of the node

        self.masks[module_name] = _auto_infer.weight_mask
Esempio n. 4
0
    def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
        """
        Replace the submodule according to the inferred sparsity.

        Parameters
        ----------
        unique_name: str
            The unique_name of the submodule to replace.
        reindex_dim: int
            The dimension of the re-index operation.
        reindex: Reindex
            The index tensor. Normally this variable is None. If we want to reindex the
            output of this submodule, we can pass the index by this parameter.
        """
        class ReindexModule(nn.Module):
            """
            ReindexModule is used to resolve the mask conflict when replace the submodule.
            Basically, we can use two ways to resolve the mask conflict: (1) unmask some
            values(will introduce more computation overhead) (2) reindex and padd the output
            tensor of the target op(introduce more memory access overhad). Currently this
            method is shutdown, in the future, we will merge these two methods into a graph
            pass which is used to resolve the mask conflict.
            """
            def __init__(self, ori_module, reindex_dim, reindex):
                super(ReindexModule, self).__init__()
                self.ori_module = ori_module
                self.reindex_dim = reindex_dim
                self.reindex = reindex
                tmp_index = [slice(None, None) for i in range(reindex_dim + 1)]
                # the index for the tensor
                tmp_index[reindex_dim] = reindex
                self.t_index = tuple(tmp_index)

            def forward(self, x):
                tmpout = self.ori_module(x)
                shape = list(tmpout.size())
                shape[self.reindex_dim] = self.reindex.size(0)
                out = torch.zeros(tuple(shape),
                                  device=tmpout.device,
                                  requires_grad=tmpout.requires_grad)
                out[self.t_index] = tmpout
                return out

        assert unique_name in self.auto_inferences
        g_node = self.torch_graph.name_to_node[unique_name]
        _logger.debug("replace %s, in %s type, with op_type %s", unique_name,
                      g_node.type, g_node.op_type)
        auto_infer = self.auto_inferences[unique_name]
        if g_node.type == 'module':
            if g_node.unique_name in self.torch_graph.reused_module:
                if reindex_dim is not None:
                    _logger.warning(
                        'Cannot replace a reused module with padding operator!!'
                    )
                    return None
            super_module, leaf_module = get_module_by_name(
                self.bound_model, g_node.name)
            m_type = g_node.op_type
            if (not m_type in replace_module) and (
                    m_type not in self.customized_replace_func):
                err_msg = f"Has not supported replacing module with type: {m_type}, "
                err_msg += f"you could report an issue at https://github.com/microsoft/nni. "
                err_msg += f"If you know how to replace {m_type}, "
                err_msg += f"you could implement module replacement by passing in"
                err_msg += f"`customized_replace_func` to `{self.__class__.__name__}`. "
                err_msg += f"You are welcome to contribute back to nni as native support if you have implemented the replacement function, "
                err_msg += f"so that more users can benefit from your contributions."
                raise RuntimeError(err_msg)
            _logger.info("replace module (name: %s, op_type: %s)", g_node.name,
                         m_type)
            replace_function = self.customized_replace_func.get(
                m_type, replace_module.get(m_type, None))
            compressed_module = replace_function(leaf_module,
                                                 auto_infer.get_masks())
            new_submodule = compressed_module
            if reindex_dim is None:
                setattr(super_module,
                        g_node.name.split('.')[-1], compressed_module)
            elif reindex_dim is not None and reindex is not None:
                # reindex the output of this submodule and replace the orginal module
                new_submodule = ReindexModule(compressed_module, reindex_dim,
                                              reindex)
                setattr(super_module,
                        g_node.name.split('.')[-1], new_submodule)
            return new_submodule
        elif g_node.type == 'func':
            _logger.info(
                "Warning: cannot replace (name: %s, op_type: %s) which is func type",
                unique_name, g_node.op_type)
            return None
        else:
            raise RuntimeError("Unsupported node type: {}".format(g_node.type))
Esempio n. 5
0
    def infer_module_mask(self,
                          module_name,
                          last_module,
                          mask=None,
                          in_shape=None,
                          out_shape=None):
        """
        Infer input shape / output shape based on the module's weight mask / input shape / output shape.

        For a module:
            Infer its input and output shape from its weight mask
            Infer its output shape from its input shape
            Infer its input shape from its output shape

        If its input shape is changed, continue infering its predecessors
        If its output shape is changed, continue infering its successors

        Parameters
        ----------
        module_name : str
            The name of the node
        last_module : str
            The name of last visited node
        mask : tensor of mask or ModuleMasks
            Mask of the weights in this node (i.e., module)
        in_shape : ModuleMasks
            Input shape of this node
        out_shape : ModuleMasks
            Output shape of this node
        """
        input_cmask = output_cmask = None
        if module_name in self.inferred_masks:
            module_masks = self.inferred_masks[module_name]
        else:
            _, m = get_module_by_name(self.bound_model, module_name)
            module_masks = ModuleMasks(module_name, m)
            self.inferred_masks[module_name] = module_masks

        m_type = self.torch_graph.name_to_node[module_name].op_type
        _logger.debug("infer mask of module %s with op_type %s", module_name,
                      m_type)
        if mask is not None:
            _logger.debug("mask is not None")
            if not m_type in infer_from_mask:
                raise RuntimeError(
                    "Has not supported infering input/output shape from mask for module/function: `{}`, {}"
                    .format(m_type, module_name))
            if m_type in ['Linear']:
                input_cmask, output_cmask = infer_from_mask[m_type](
                    module_masks, mask,
                    self.torch_graph.name_to_node[module_name].auxiliary)
            else:
                input_cmask, output_cmask = infer_from_mask[m_type](
                    module_masks, mask)
        if in_shape is not None:
            _logger.debug("in_shape is not None")
            if not m_type in infer_from_inshape:
                raise RuntimeError(
                    "Has not supported infering output shape from input shape for module/function: `{}`, {}"
                    .format(m_type, module_name))
            if m_type in [
                    'aten::view', 'aten::flatten', 'aten::mean',
                    'aten::reshape'
            ]:
                output_cmask = infer_from_inshape[m_type](
                    module_masks, in_shape,
                    self.torch_graph.name_to_node[module_name].auxiliary)
            elif m_type in ['aten::cat', 'Concat']:
                # To calculate the mask for concat operation, the output shape
                # , cat dimension, and the order of the input parameters.
                output_cmask = infer_from_inshape[m_type](
                    module_masks, in_shape,
                    self.torch_graph.name_to_node[module_name].auxiliary,
                    last_module)
            else:
                output_cmask = infer_from_inshape[m_type](module_masks,
                                                          in_shape)
        if out_shape is not None:
            _logger.debug("out_shape is not None")
            if not m_type in infer_from_outshape:
                raise RuntimeError(
                    "Has not supported infering input shape from output shape for module/function: `{}`, {}"
                    .format(m_type, module_name))
            if m_type in [
                    'aten::view', 'aten::flatten', 'aten::mean',
                    'aten::reshape'
            ]:
                input_cmask = infer_from_outshape[m_type](
                    module_masks, out_shape,
                    self.torch_graph.name_to_node[module_name].auxiliary)
            else:
                input_cmask = infer_from_outshape[m_type](module_masks,
                                                          out_shape)

        if input_cmask:
            predecessors = self.torch_graph.find_predecessors(module_name)
            for _module_name in predecessors:
                self.infer_module_mask(_module_name,
                                       module_name,
                                       out_shape=input_cmask)
        if output_cmask:
            successors = self.torch_graph.find_successors(module_name)
            for _module_name in successors:
                self.infer_module_mask(_module_name,
                                       module_name,
                                       in_shape=output_cmask)