def replace_compressed_modules(self): """ Replace all the modules that have changed (weights/inputs/output) shape. The new module is created using the same arguments of the to-be-replaced module, and correctly inherits its weights. NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation is that ```func``` should be not required to be replaced. """ for module_name in self.inferred_masks: g_node = self.torch_graph.name_to_node[module_name] _logger.debug("replace %s, in %s type, with op_type %s", module_name, g_node.type, g_node.op_type) if g_node.type == 'module': super_module, leaf_module = get_module_by_name( self.bound_model, g_node.name) m_type = g_node.op_type if not m_type in replace_module: raise RuntimeError( "Has not supported replacing the module: `{}`".format( m_type)) _logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type) compressed_module = replace_module[m_type]( leaf_module, self.inferred_masks[module_name]) setattr(super_module, g_node.name.split('.')[-1], compressed_module) elif g_node.type == 'func': _logger.info( "Warning: cannot replace (name: %s, op_type: %s) which is func type", module_name, g_node.op_type) else: raise RuntimeError("Unsupported node type: {}".format( g_node.type))
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): """ Replace the submodule according to the inferred sparsity. unique_name: str The unique_name of the submodule to replace. reindex_dim: int The dimension of the re-index operation. reindex: Reindex The index tensor. Normally this variable is None. If we want to reindex the output of this submodule, we can pass the index by this parameter. """ class ReindexModule(nn.Module): """ ReindexModule is used to resolve the mask conflict when replace the submodule. Basically, we can use two ways to resolve the mask conflict: (1) unmask some values(will introduce more computation overhead) (2) reindex and padd the output tensor of the target op(introduce more memory access overhad). Currently this method is shutdown, in the future, we will merge these two methods into a graph pass which is used to resolve the mask conflict. """ def __init__(self, ori_module, reindex_dim, reindex): super(ReindexModule, self).__init__() self.ori_module = ori_module self.reindex_dim = reindex_dim self.reindex = reindex tmp_index = [slice(None, None) for i in range(reindex_dim+1)] # the index for the tensor tmp_index[reindex_dim] = reindex self.t_index = tuple(tmp_index) def forward(self, x): tmpout = self.ori_module(x) shape = list(tmpout.size()) shape[self.reindex_dim] = self.reindex.size(0) out = torch.zeros(tuple(shape), device=tmpout.device, requires_grad=tmpout.requires_grad) out[self.t_index] = tmpout return out assert unique_name in self.auto_inferences g_node = self.torch_graph.name_to_node[unique_name] _logger.debug("replace %s, in %s type, with op_type %s", unique_name, g_node.type, g_node.op_type) auto_infer = self.auto_inferences[unique_name] if g_node.type == 'module': if g_node.unique_name in self.torch_graph.reused_module: if reindex_dim is not None: _logger.warning( 'Cannot replace a reused module with padding operator!!') return None super_module, leaf_module = get_module_by_name( self.bound_model, g_node.name) m_type = g_node.op_type if not m_type in replace_module: raise RuntimeError( "Has not supported replacing the module: `{}`".format(m_type)) _logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type) compressed_module = replace_module[m_type]( leaf_module, auto_infer.get_masks()) new_submodule = compressed_module if reindex_dim is None: setattr(super_module, g_node.name.split( '.')[-1], compressed_module) elif reindex_dim is not None and reindex is not None: # reindex the output of this submodule and replace the orginal module new_submodule = ReindexModule( compressed_module, reindex_dim, reindex) setattr(super_module, g_node.name.split( '.')[-1], new_submodule) return new_submodule elif g_node.type == 'func': _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", unique_name, g_node.op_type) return None else: raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def update_direct_sparsity(self, node): """ Update the direct sparsity for the target node. Here the direct sparsity means that the sparsity in the output tensor that caused by the sparsity in the input tensors/weight tensors. """ # this name is consistent with the name returned by named_modules() module_name = node.name _logger.info('Update mask for %s', module_name) unique_name = node.unique_name dummy_input, input_debugname = self._prepare_dummy_input(node) # get the input mask from self.masks # Note: the input mask of the successor nodes are # already created by the predecessor node in_masks = [self.masks[debugname] for debugname in input_debugname] in_constants = [self.constant[debugname] for debugname in input_debugname] if node.type == 'func': # we cannot get the runable function directly from the jit traced # graph, so we translate it back to python function, Note: the function # is appliable to both cpu/gpu devices, the output tensors will be on the # same device of the input tensors func = jit_to_python_function(node, self) if func is None: # no need to infer the sparsity for this node self.auto_inferences[unique_name] = None return # function doesn't have weights _auto_infer = AutoMaskInference( func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim) else: weight_mask = None if module_name in self.masks: weight_mask = self.masks[module_name] _, module = get_module_by_name(self.bound_model, module_name) _auto_infer = AutoMaskInference( module, dummy_input, in_masks, weight_mask, in_constants=in_constants, state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim) self.auto_inferences[unique_name] = _auto_infer _auto_infer.name = node.unique_name _auto_infer.update_direct_sparsity() # also save the input debug names into the auto_infer _auto_infer.input_debugname = input_debugname # update the mask tensor and the internal output of the submodules # after manually unpack the tuple/list of tensors, the number of the outputs # of each node should always be one(Except for the TupleUnpack node at the end # of the whole model) assert len( node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually' out_debugname = node.outputs[0] # update the output mask into self.masks self.masks[out_debugname] = _auto_infer.output_mask self.constant[out_debugname] = _auto_infer.out_constant # update the output result into self.internal_result, so that # the successor nodes can take these output tensors as inputs. self.internal_result[out_debugname] = _auto_infer.output # update the parameter mask of the node self.masks[module_name] = _auto_infer.weight_mask
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): """ Replace the submodule according to the inferred sparsity. Parameters ---------- unique_name: str The unique_name of the submodule to replace. reindex_dim: int The dimension of the re-index operation. reindex: Reindex The index tensor. Normally this variable is None. If we want to reindex the output of this submodule, we can pass the index by this parameter. """ class ReindexModule(nn.Module): """ ReindexModule is used to resolve the mask conflict when replace the submodule. Basically, we can use two ways to resolve the mask conflict: (1) unmask some values(will introduce more computation overhead) (2) reindex and padd the output tensor of the target op(introduce more memory access overhad). Currently this method is shutdown, in the future, we will merge these two methods into a graph pass which is used to resolve the mask conflict. """ def __init__(self, ori_module, reindex_dim, reindex): super(ReindexModule, self).__init__() self.ori_module = ori_module self.reindex_dim = reindex_dim self.reindex = reindex tmp_index = [slice(None, None) for i in range(reindex_dim + 1)] # the index for the tensor tmp_index[reindex_dim] = reindex self.t_index = tuple(tmp_index) def forward(self, x): tmpout = self.ori_module(x) shape = list(tmpout.size()) shape[self.reindex_dim] = self.reindex.size(0) out = torch.zeros(tuple(shape), device=tmpout.device, requires_grad=tmpout.requires_grad) out[self.t_index] = tmpout return out assert unique_name in self.auto_inferences g_node = self.torch_graph.name_to_node[unique_name] _logger.debug("replace %s, in %s type, with op_type %s", unique_name, g_node.type, g_node.op_type) auto_infer = self.auto_inferences[unique_name] if g_node.type == 'module': if g_node.unique_name in self.torch_graph.reused_module: if reindex_dim is not None: _logger.warning( 'Cannot replace a reused module with padding operator!!' ) return None super_module, leaf_module = get_module_by_name( self.bound_model, g_node.name) m_type = g_node.op_type if (not m_type in replace_module) and ( m_type not in self.customized_replace_func): err_msg = f"Has not supported replacing module with type: {m_type}, " err_msg += f"you could report an issue at https://github.com/microsoft/nni. " err_msg += f"If you know how to replace {m_type}, " err_msg += f"you could implement module replacement by passing in" err_msg += f"`customized_replace_func` to `{self.__class__.__name__}`. " err_msg += f"You are welcome to contribute back to nni as native support if you have implemented the replacement function, " err_msg += f"so that more users can benefit from your contributions." raise RuntimeError(err_msg) _logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type) replace_function = self.customized_replace_func.get( m_type, replace_module.get(m_type, None)) compressed_module = replace_function(leaf_module, auto_infer.get_masks()) new_submodule = compressed_module if reindex_dim is None: setattr(super_module, g_node.name.split('.')[-1], compressed_module) elif reindex_dim is not None and reindex is not None: # reindex the output of this submodule and replace the orginal module new_submodule = ReindexModule(compressed_module, reindex_dim, reindex) setattr(super_module, g_node.name.split('.')[-1], new_submodule) return new_submodule elif g_node.type == 'func': _logger.info( "Warning: cannot replace (name: %s, op_type: %s) which is func type", unique_name, g_node.op_type) return None else: raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None): """ Infer input shape / output shape based on the module's weight mask / input shape / output shape. For a module: Infer its input and output shape from its weight mask Infer its output shape from its input shape Infer its input shape from its output shape If its input shape is changed, continue infering its predecessors If its output shape is changed, continue infering its successors Parameters ---------- module_name : str The name of the node last_module : str The name of last visited node mask : tensor of mask or ModuleMasks Mask of the weights in this node (i.e., module) in_shape : ModuleMasks Input shape of this node out_shape : ModuleMasks Output shape of this node """ input_cmask = output_cmask = None if module_name in self.inferred_masks: module_masks = self.inferred_masks[module_name] else: _, m = get_module_by_name(self.bound_model, module_name) module_masks = ModuleMasks(module_name, m) self.inferred_masks[module_name] = module_masks m_type = self.torch_graph.name_to_node[module_name].op_type _logger.debug("infer mask of module %s with op_type %s", module_name, m_type) if mask is not None: _logger.debug("mask is not None") if not m_type in infer_from_mask: raise RuntimeError( "Has not supported infering input/output shape from mask for module/function: `{}`, {}" .format(m_type, module_name)) if m_type in ['Linear']: input_cmask, output_cmask = infer_from_mask[m_type]( module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary) else: input_cmask, output_cmask = infer_from_mask[m_type]( module_masks, mask) if in_shape is not None: _logger.debug("in_shape is not None") if not m_type in infer_from_inshape: raise RuntimeError( "Has not supported infering output shape from input shape for module/function: `{}`, {}" .format(m_type, module_name)) if m_type in [ 'aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape' ]: output_cmask = infer_from_inshape[m_type]( module_masks, in_shape, self.torch_graph.name_to_node[module_name].auxiliary) elif m_type in ['aten::cat', 'Concat']: # To calculate the mask for concat operation, the output shape # , cat dimension, and the order of the input parameters. output_cmask = infer_from_inshape[m_type]( module_masks, in_shape, self.torch_graph.name_to_node[module_name].auxiliary, last_module) else: output_cmask = infer_from_inshape[m_type](module_masks, in_shape) if out_shape is not None: _logger.debug("out_shape is not None") if not m_type in infer_from_outshape: raise RuntimeError( "Has not supported infering input shape from output shape for module/function: `{}`, {}" .format(m_type, module_name)) if m_type in [ 'aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape' ]: input_cmask = infer_from_outshape[m_type]( module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary) else: input_cmask = infer_from_outshape[m_type](module_masks, out_shape) if input_cmask: predecessors = self.torch_graph.find_predecessors(module_name) for _module_name in predecessors: self.infer_module_mask(_module_name, module_name, out_shape=input_cmask) if output_cmask: successors = self.torch_graph.find_successors(module_name) for _module_name in successors: self.infer_module_mask(_module_name, module_name, in_shape=output_cmask)