コード例 #1
0
    def test_multi_input(self):
        """ Test building ConnectedGraph on a model with multiple inputs """
        # pylint: disable=protected-access
        model = test_models.MultiInput()
        model.eval()
        inp_shape_1 = (1, 3, 32, 32)
        inp_shape_2 = (1, 3, 20, 20)
        inp_tensor_list = create_rand_tensors_given_shapes(
            [inp_shape_1, inp_shape_2])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(11, len(conn_graph.ordered_ops))
        # Split count of 1 due to reshape having a split
        self.assertEqual(1, conn_graph._split_count)
        conv1 = conn_graph.get_op_from_module_name('MultiInput.conv1')
        self.assertEqual(model.conv1, conv1.get_module())
        self.assertEqual(2, len(conv1.inputs))
        conv2 = conn_graph.get_op_from_module_name('MultiInput.conv2')
        self.assertEqual(model.conv2, conv2.get_module())
        self.assertEqual(3, len(conv2.inputs))
        conv3 = conn_graph.get_op_from_module_name('MultiInput.conv3')
        self.assertEqual(model.conv3, conv3.get_module())
        self.assertEqual(3, len(conv3.inputs))

        input_ops = get_all_input_ops(conn_graph)
        input_modules = [op.get_module() for op in input_ops]
        self.assertEqual(2, len(input_ops))
        self.assertTrue(model.conv1 in input_modules)
        self.assertTrue(model.conv3 in input_modules)
        output_ops = get_all_output_ops(conn_graph)
        self.assertEqual(1, len(output_ops))
        self.assertEqual(model.fc, output_ops[0].get_module())
コード例 #2
0
 def test_module_list(self):
     """ Test building ConnectedGraph on a model with module list """
     model = test_models.ModuleListModel()
     model.eval()
     inp_data_1 = torch.rand(1, 3, 8, 8)
     conn_graph = ConnectedGraph(model, (inp_data_1, ))
     self.assertEqual(10, len(conn_graph.ordered_ops))
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.4'),
         conn_graph.ordered_ops[0])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.seq_list.2'),
         conn_graph.ordered_ops[1])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.1'),
         conn_graph.ordered_ops[2])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.0'),
         conn_graph.ordered_ops[3])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.2'),
         conn_graph.ordered_ops[4])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.seq_list.0'),
         conn_graph.ordered_ops[5])
コード例 #3
0
 def test_hierarchial_model(self):
     """ Test building ConnectedGraph on model which multi-level aggregation of nn.Modules  """
     # pylint: disable=protected-access
     model = test_models.HierarchicalModel()
     model.eval()
     conv_shape = (1, 64, 32, 32)
     inp_shape = (1, 3, 32, 32)
     seq_shape = (1, 3, 8, 8)
     inp_tensor_list = create_rand_tensors_given_shapes(
         [conv_shape, inp_shape, conv_shape, inp_shape, seq_shape])
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(95, len(conn_graph.ordered_ops))
     self.assertEqual(5, conn_graph._split_count)
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv1.conv'),
         conn_graph.ordered_ops[0])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm1.conv1'), conn_graph.ordered_ops[5])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm2.conv1'), conn_graph.ordered_ops[20])
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv2.conv'),
         conn_graph.ordered_ops[36])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.multi_conv.seq_list.0.conv'),
         conn_graph.ordered_ops[40])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm1.conv1'), conn_graph.ordered_ops[53])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm2.conv1'), conn_graph.ordered_ops[68])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.sq.seq_list.0'), conn_graph.ordered_ops[84])
コード例 #4
0
class MaskPropagationWinnower(AimetCommonMaskPropagationWinnower):
    """ The MaskPropagationWinnower class implements winnowing based on propagating masks corresponding to each
    module's input channels identified to be winnowed.  """
    def __init__(self,
                 model: torch.nn.Module,
                 input_shape: Tuple,
                 list_of_modules_to_winnow: List[Tuple[torch.nn.Module,
                                                       List]] = None,
                 reshape=True,
                 in_place=False,
                 verbose=False):
        """
        MaskPropagationWinnower object initialization.
        :param model: The model to be winnowed.
        :param input_shape: The input shape of the model.
        :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of
        channels to be winnowed for that module.
        :param reshape: If set to True a Down Sample Layer is added between modules to match the number of channels.
                    If set to False, the modules that need a Down Sample Layer will not be winnowed.
        :param in_place: If set to True, the model will be winnowed in place.
                     If set to False, a copy of the model will be winnowed.
        :param verbose: If set to True, logs detailed winnowing log messages.
        """

        super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)
        model.apply(has_hooks)

        debug_level = logger.getEffectiveLevel()
        logger.debug("Current log level: %s", debug_level)

        self._using_cuda = next(model.parameters()).is_cuda

        if self._in_place is False:
            # Do not winnow the model in place
            self._model = copy.deepcopy(model)
            logger.info("A copy of the model will be winnowed")
        else:
            # Winnow the model in place
            logger.info("Model will be winnowed in place")
            self._model = model

        # Construct connected graph representation of the computational graph
        dummy_input = torch.rand(input_shape)
        if self._using_cuda:
            dummy_input = torch.tensor(dummy_input).cuda()  # pylint: disable=not-callable

        self._graph = ConnectedGraph(self._model, (dummy_input, ))
        self.list_of_modules_to_winnow_with_names = \
            generate_and_add_module_winnow_list_with_names(model, self._list_of_modules_to_winnow)
        self._mask_propagator = MaskPropagator(self._graph, ModelApi.pytorch)
        self._module_reducer = ModuleReducer(
            self._model, self._using_cuda, self._reshape,
            self._mask_propagator.op_to_mask_dict)

    def propagate_masks_and_winnow(self):
        """  For the modules to be winnowed, create and propagate the masks.
        Once mask propagation is completed, winnow the model. """

        # Propagate the masks
        self._propagate_masks()

        modified_op_list = self._mask_propagator.get_ops_with_non_default_ip_op_masks(
        )
        for name in modified_op_list:
            logger.info("Modified Op: %s", name)

        modified_modules_dict = self._module_reducer.reduce_modules(
            modified_op_list)

        if modified_modules_dict:
            ordered_module_list = self._create_modified_modules_list(
                modified_modules_dict)
        else:
            ordered_module_list = None
            logger.info(
                "No modules were winnowed. Original model is returned.")

        return self._model, ordered_module_list

    def _propagate_masks(self):
        """  For the modules to be winnowed, set the channels to winnow and propagate the masks."""
        for module, list_of_channels_to_winnow, name in self.list_of_modules_to_winnow_with_names:
            self.validate_winnow_api_parameters(module, name,
                                                list_of_channels_to_winnow)

            input_channels_to_winnow = list_of_channels_to_winnow
            output_channels_to_winnow = None
            if isinstance(module,
                          (torch.nn.Linear, torch.nn.modules.conv.Conv2d)):
                self._mask_propagator.update_channels_to_winnow(
                    name, self._reshape, input_channels_to_winnow,
                    output_channels_to_winnow)

        # The channels to winnow have been updated
        # Propagate the masks.
        self._mask_propagator.propagate_masks()

    @staticmethod
    def _create_modified_modules_list(modified_modules: Dict[str,
                                                             torch.nn.Module]):
        """ Creates and returns a list of tuples with each tuple containing
        the original module and its replacement module
        :param modified_modules: dictionary of modules modified during module reduction
        :return list of tuples of name of the original module in the model and corresponding new module
        """

        modified_module_list = []
        for orig_module_name, new_module in modified_modules.items():
            # Remove prefix of the model name
            # E.g. the module_name maybe Net.layer1.conv1, we only want layer1.conv1
            first_dot_position = orig_module_name.find('.')
            if first_dot_position != -1:
                orig_module_name = orig_module_name[first_dot_position + 1:]
            modified_module_list.append((orig_module_name, new_module))

        return modified_module_list

    def validate_winnow_api_parameters(self, module, name,
                                       list_of_channels_to_winnow):
        """
        For a given module, validate Winnow API parameters.
        :param module: module whose channel numbers are being validated.
        :param name: module's name
        :param list_of_channels_to_winnow: list of channels that must be winnowed.
        """

        if not isinstance(module, torch.nn.Conv2d):
            logger.critical(
                "Winnowing is currently only supported for torch.nn.Conv2d modules. Attempting to winnow "
                "module of type %s", type(module))
            raise NotImplementedError(type(module))

        # Validate the list of channels.
        num_channels_to_winnow = len(list_of_channels_to_winnow)
        if num_channels_to_winnow == 0:
            raise ValueError(
                "The list of channels to winnow is empty for the module: %s" %
                name)

        max_channel_num = max(list_of_channels_to_winnow)
        max_in_channel_index = (module.in_channels - 1)
        if max_channel_num > max_in_channel_index:
            raise ValueError(
                "Channel number: %s exceeds module's max channel number index: %s for module: %s"
                % (max_channel_num, max_in_channel_index, name))

        if num_channels_to_winnow == module.in_channels:
            raise ValueError(
                "Winnowing all the input channels is not allowed, module: %s" %
                name)

        module_op = self._graph.get_op_from_module_name(name)
        input_index = 0  # Using op index 0 to examine input to op
        if module_op.inputs[input_index].is_model_input:
            logger.critical(
                "Winnowing the first module of a model is NOT supported. Please ignore the first "
                "module and try again. First module: %s, shape %s, channels to winnow: %s",
                module_op.dotted_name, module_op.inputs[input_index].shape,
                list_of_channels_to_winnow)
            raise NotImplementedError(module_op.dotted_name)