def test_multi_input(self): """ Test building ConnectedGraph on a model with multiple inputs """ # pylint: disable=protected-access model = test_models.MultiInput() model.eval() inp_shape_1 = (1, 3, 32, 32) inp_shape_2 = (1, 3, 20, 20) inp_tensor_list = create_rand_tensors_given_shapes( [inp_shape_1, inp_shape_2]) conn_graph = ConnectedGraph(model, inp_tensor_list) self.assertEqual(11, len(conn_graph.ordered_ops)) # Split count of 1 due to reshape having a split self.assertEqual(1, conn_graph._split_count) conv1 = conn_graph.get_op_from_module_name('MultiInput.conv1') self.assertEqual(model.conv1, conv1.get_module()) self.assertEqual(2, len(conv1.inputs)) conv2 = conn_graph.get_op_from_module_name('MultiInput.conv2') self.assertEqual(model.conv2, conv2.get_module()) self.assertEqual(3, len(conv2.inputs)) conv3 = conn_graph.get_op_from_module_name('MultiInput.conv3') self.assertEqual(model.conv3, conv3.get_module()) self.assertEqual(3, len(conv3.inputs)) input_ops = get_all_input_ops(conn_graph) input_modules = [op.get_module() for op in input_ops] self.assertEqual(2, len(input_ops)) self.assertTrue(model.conv1 in input_modules) self.assertTrue(model.conv3 in input_modules) output_ops = get_all_output_ops(conn_graph) self.assertEqual(1, len(output_ops)) self.assertEqual(model.fc, output_ops[0].get_module())
def test_module_list(self): """ Test building ConnectedGraph on a model with module list """ model = test_models.ModuleListModel() model.eval() inp_data_1 = torch.rand(1, 3, 8, 8) conn_graph = ConnectedGraph(model, (inp_data_1, )) self.assertEqual(10, len(conn_graph.ordered_ops)) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.mod_list.4'), conn_graph.ordered_ops[0]) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.seq_list.2'), conn_graph.ordered_ops[1]) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.mod_list.1'), conn_graph.ordered_ops[2]) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.mod_list.0'), conn_graph.ordered_ops[3]) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.mod_list.2'), conn_graph.ordered_ops[4]) self.assertEqual( conn_graph.get_op_from_module_name('ModuleListModel.seq_list.0'), conn_graph.ordered_ops[5])
def test_hierarchial_model(self): """ Test building ConnectedGraph on model which multi-level aggregation of nn.Modules """ # pylint: disable=protected-access model = test_models.HierarchicalModel() model.eval() conv_shape = (1, 64, 32, 32) inp_shape = (1, 3, 32, 32) seq_shape = (1, 3, 8, 8) inp_tensor_list = create_rand_tensors_given_shapes( [conv_shape, inp_shape, conv_shape, inp_shape, seq_shape]) conn_graph = ConnectedGraph(model, inp_tensor_list) self.assertEqual(95, len(conn_graph.ordered_ops)) self.assertEqual(5, conn_graph._split_count) self.assertEqual( conn_graph.get_op_from_module_name('HierarchicalModel.conv1.conv'), conn_graph.ordered_ops[0]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.nm1.tm1.conv1'), conn_graph.ordered_ops[5]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.nm1.tm2.conv1'), conn_graph.ordered_ops[20]) self.assertEqual( conn_graph.get_op_from_module_name('HierarchicalModel.conv2.conv'), conn_graph.ordered_ops[36]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.multi_conv.seq_list.0.conv'), conn_graph.ordered_ops[40]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.nm2.tm1.conv1'), conn_graph.ordered_ops[53]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.nm2.tm2.conv1'), conn_graph.ordered_ops[68]) self.assertEqual( conn_graph.get_op_from_module_name( 'HierarchicalModel.sq.seq_list.0'), conn_graph.ordered_ops[84])
class MaskPropagationWinnower(AimetCommonMaskPropagationWinnower): """ The MaskPropagationWinnower class implements winnowing based on propagating masks corresponding to each module's input channels identified to be winnowed. """ def __init__(self, model: torch.nn.Module, input_shape: Tuple, list_of_modules_to_winnow: List[Tuple[torch.nn.Module, List]] = None, reshape=True, in_place=False, verbose=False): """ MaskPropagationWinnower object initialization. :param model: The model to be winnowed. :param input_shape: The input shape of the model. :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of channels to be winnowed for that module. :param reshape: If set to True a Down Sample Layer is added between modules to match the number of channels. If set to False, the modules that need a Down Sample Layer will not be winnowed. :param in_place: If set to True, the model will be winnowed in place. If set to False, a copy of the model will be winnowed. :param verbose: If set to True, logs detailed winnowing log messages. """ super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose) model.apply(has_hooks) debug_level = logger.getEffectiveLevel() logger.debug("Current log level: %s", debug_level) self._using_cuda = next(model.parameters()).is_cuda if self._in_place is False: # Do not winnow the model in place self._model = copy.deepcopy(model) logger.info("A copy of the model will be winnowed") else: # Winnow the model in place logger.info("Model will be winnowed in place") self._model = model # Construct connected graph representation of the computational graph dummy_input = torch.rand(input_shape) if self._using_cuda: dummy_input = torch.tensor(dummy_input).cuda() # pylint: disable=not-callable self._graph = ConnectedGraph(self._model, (dummy_input, )) self.list_of_modules_to_winnow_with_names = \ generate_and_add_module_winnow_list_with_names(model, self._list_of_modules_to_winnow) self._mask_propagator = MaskPropagator(self._graph, ModelApi.pytorch) self._module_reducer = ModuleReducer( self._model, self._using_cuda, self._reshape, self._mask_propagator.op_to_mask_dict) def propagate_masks_and_winnow(self): """ For the modules to be winnowed, create and propagate the masks. Once mask propagation is completed, winnow the model. """ # Propagate the masks self._propagate_masks() modified_op_list = self._mask_propagator.get_ops_with_non_default_ip_op_masks( ) for name in modified_op_list: logger.info("Modified Op: %s", name) modified_modules_dict = self._module_reducer.reduce_modules( modified_op_list) if modified_modules_dict: ordered_module_list = self._create_modified_modules_list( modified_modules_dict) else: ordered_module_list = None logger.info( "No modules were winnowed. Original model is returned.") return self._model, ordered_module_list def _propagate_masks(self): """ For the modules to be winnowed, set the channels to winnow and propagate the masks.""" for module, list_of_channels_to_winnow, name in self.list_of_modules_to_winnow_with_names: self.validate_winnow_api_parameters(module, name, list_of_channels_to_winnow) input_channels_to_winnow = list_of_channels_to_winnow output_channels_to_winnow = None if isinstance(module, (torch.nn.Linear, torch.nn.modules.conv.Conv2d)): self._mask_propagator.update_channels_to_winnow( name, self._reshape, input_channels_to_winnow, output_channels_to_winnow) # The channels to winnow have been updated # Propagate the masks. self._mask_propagator.propagate_masks() @staticmethod def _create_modified_modules_list(modified_modules: Dict[str, torch.nn.Module]): """ Creates and returns a list of tuples with each tuple containing the original module and its replacement module :param modified_modules: dictionary of modules modified during module reduction :return list of tuples of name of the original module in the model and corresponding new module """ modified_module_list = [] for orig_module_name, new_module in modified_modules.items(): # Remove prefix of the model name # E.g. the module_name maybe Net.layer1.conv1, we only want layer1.conv1 first_dot_position = orig_module_name.find('.') if first_dot_position != -1: orig_module_name = orig_module_name[first_dot_position + 1:] modified_module_list.append((orig_module_name, new_module)) return modified_module_list def validate_winnow_api_parameters(self, module, name, list_of_channels_to_winnow): """ For a given module, validate Winnow API parameters. :param module: module whose channel numbers are being validated. :param name: module's name :param list_of_channels_to_winnow: list of channels that must be winnowed. """ if not isinstance(module, torch.nn.Conv2d): logger.critical( "Winnowing is currently only supported for torch.nn.Conv2d modules. Attempting to winnow " "module of type %s", type(module)) raise NotImplementedError(type(module)) # Validate the list of channels. num_channels_to_winnow = len(list_of_channels_to_winnow) if num_channels_to_winnow == 0: raise ValueError( "The list of channels to winnow is empty for the module: %s" % name) max_channel_num = max(list_of_channels_to_winnow) max_in_channel_index = (module.in_channels - 1) if max_channel_num > max_in_channel_index: raise ValueError( "Channel number: %s exceeds module's max channel number index: %s for module: %s" % (max_channel_num, max_in_channel_index, name)) if num_channels_to_winnow == module.in_channels: raise ValueError( "Winnowing all the input channels is not allowed, module: %s" % name) module_op = self._graph.get_op_from_module_name(name) input_index = 0 # Using op index 0 to examine input to op if module_op.inputs[input_index].is_model_input: logger.critical( "Winnowing the first module of a model is NOT supported. Please ignore the first " "module and try again. First module: %s, shape %s, channels to winnow: %s", module_op.dotted_name, module_op.inputs[input_index].shape, list_of_channels_to_winnow) raise NotImplementedError(module_op.dotted_name)