def test_multi_input(self):
        """ Test building ConnectedGraph on a model with multiple inputs """
        # pylint: disable=protected-access
        model = test_models.MultiInput()
        model.eval()
        inp_shape_1 = (1, 3, 32, 32)
        inp_shape_2 = (1, 3, 20, 20)
        inp_tensor_list = create_rand_tensors_given_shapes(
            [inp_shape_1, inp_shape_2])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(11, len(conn_graph.ordered_ops))
        # Split count of 1 due to reshape having a split
        self.assertEqual(1, conn_graph._split_count)
        conv1 = conn_graph.get_op_from_module_name('MultiInput.conv1')
        self.assertEqual(model.conv1, conv1.get_module())
        self.assertEqual(2, len(conv1.inputs))
        conv2 = conn_graph.get_op_from_module_name('MultiInput.conv2')
        self.assertEqual(model.conv2, conv2.get_module())
        self.assertEqual(3, len(conv2.inputs))
        conv3 = conn_graph.get_op_from_module_name('MultiInput.conv3')
        self.assertEqual(model.conv3, conv3.get_module())
        self.assertEqual(3, len(conv3.inputs))

        input_ops = get_all_input_ops(conn_graph)
        input_modules = [op.get_module() for op in input_ops]
        self.assertEqual(2, len(input_ops))
        self.assertTrue(model.conv1 in input_modules)
        self.assertTrue(model.conv3 in input_modules)
        output_ops = get_all_output_ops(conn_graph)
        self.assertEqual(1, len(output_ops))
        self.assertEqual(model.fc, output_ops[0].get_module())
 def __init__(self, model: torch.nn.Module,
              input_shapes: Union[Tuple, List[Tuple]]):
     inp_tensor_list = tuple(
         utils.create_rand_tensors_given_shapes(input_shapes))
     self._connected_graph = ConnectedGraph(model, inp_tensor_list)
     self._ordered_module_list = utils.get_ordered_list_of_conv_modules(
         model, inp_tensor_list)
Esempio n. 3
0
def _parse_graph(graph: torch._C.Graph,
                 model: torch.nn.Module) -> List[IrNode]:
    """
    Implements a depth-first graph extraction to obtain connectivity information in the form of an IrNodes list.
    Depth-first extraction is realized using recursion.

    :param trace: Pytorch JIT trace for model or a submodule
    :param model: Pytorch model to create connected graph from
    :return List of IrNodes created from traversing the trace graph
    """
    ir_nodes_list = []
    curr_inputs = [inp for inp in graph.inputs()]

    # A map of sub-graph models and node name that requires recursive parsing
    # modules that are being referenced within the sub-graph
    node_name_to_module = {curr_inputs[0].debugName(): model}
    for node in graph.nodes():
        outputs = [output for output in node.outputs()]

        # retrieving a module reference
        if 'GetAttr' in node.kind():
            # For GetAttr lines, the output name will be referring to the module, and not the module's output(s)
            assert len(outputs) == 1
            node_name = outputs[0].debugName()
            assert node_name not in node_name_to_module
            module = _get_module_instance(node, node_name_to_module)
            node_name_to_module[node_name] = module
        else:
            op_type: str = ConnectedGraph._parse_op_type(node)
            if "Constant" not in op_type:
                outputs = [output for output in node.outputs()]
                ir_node = IrNode(node_type=op_type,
                                 inputs=[
                                     inp for inp in node.inputs()
                                     if "Constant" not in
                                     ConnectedGraph._parse_op_type(inp.node())
                                 ],
                                 outputs=outputs,
                                 module=None)
                ir_nodes_list.append(ir_node)

    for ir_node in ir_nodes_list:
        inputs = []
        for inp in ir_node.inputs:
            if "GetAttr" in inp.node().kind():
                if ir_node.node_type in ConnectedGraph.op_type_map.values():
                    module = node_name_to_module[
                        inp.node().input().debugName()]
                    assert is_leaf_module(module)
                    if ir_node.module is None:
                        ir_node.module = module
                    else:
                        assert ir_node.module == module
            else:
                inputs.append(inp)
        ir_node.inputs = inputs

    return ir_nodes_list
 def test_concat(self):
     """ Test building ConnectedGraph on a model with concat """
     model = test_models.ConcatModel()
     model.eval()
     inp_shape_1 = (1, 3, 8, 8)
     inp_shape_2 = (1, 3, 8, 8)
     inp_shape_3 = (1, 3, 8, 8)
     inp_tensor_list = create_rand_tensors_given_shapes(
         [inp_shape_1, inp_shape_2, inp_shape_3])
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     concat_op = conn_graph.get_all_ops()['cat_3']
     self.assertEqual(3, len(concat_op.inputs))
     self.assertEqual(14, concat_op.output_shape[1])
 def test_module_list(self):
     """ Test building ConnectedGraph on a model with module list """
     model = test_models.ModuleListModel()
     model.eval()
     inp_data_1 = torch.rand(1, 3, 8, 8)
     conn_graph = ConnectedGraph(model, (inp_data_1, ))
     self.assertEqual(10, len(conn_graph.ordered_ops))
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.4'),
         conn_graph.ordered_ops[0])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.seq_list.2'),
         conn_graph.ordered_ops[1])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.1'),
         conn_graph.ordered_ops[2])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.0'),
         conn_graph.ordered_ops[3])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.mod_list.2'),
         conn_graph.ordered_ops[4])
     self.assertEqual(
         conn_graph.get_op_from_module_name('ModuleListModel.seq_list.0'),
         conn_graph.ordered_ops[5])
    def __init__(self,
                 model: torch.nn.Module,
                 input_shape: Tuple,
                 list_of_modules_to_winnow: List[Tuple[torch.nn.Module,
                                                       List]] = None,
                 reshape=True,
                 in_place=False,
                 verbose=False):
        """
        MaskPropagationWinnower object initialization.
        :param model: The model to be winnowed.
        :param input_shape: The input shape of the model.
        :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of
        channels to be winnowed for that module.
        :param reshape: If set to True a Down Sample Layer is added between modules to match the number of channels.
                    If set to False, the modules that need a Down Sample Layer will not be winnowed.
        :param in_place: If set to True, the model will be winnowed in place.
                     If set to False, a copy of the model will be winnowed.
        :param verbose: If set to True, logs detailed winnowing log messages.
        """

        super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)
        model.apply(has_hooks)

        debug_level = logger.getEffectiveLevel()
        logger.debug("Current log level: %s", debug_level)

        self._using_cuda = next(model.parameters()).is_cuda

        if self._in_place is False:
            # Do not winnow the model in place
            self._model = copy.deepcopy(model)
            logger.info("A copy of the model will be winnowed")
        else:
            # Winnow the model in place
            logger.info("Model will be winnowed in place")
            self._model = model

        # Construct connected graph representation of the computational graph
        dummy_input = torch.rand(input_shape)
        if self._using_cuda:
            dummy_input = torch.tensor(dummy_input).cuda()  # pylint: disable=not-callable

        self._graph = ConnectedGraph(self._model, (dummy_input, ))
        self.list_of_modules_to_winnow_with_names = \
            generate_and_add_module_winnow_list_with_names(model, self._list_of_modules_to_winnow)
        self._mask_propagator = MaskPropagator(self._graph, ModelApi.pytorch)
        self._module_reducer = ModuleReducer(
            self._model, self._using_cuda, self._reshape,
            self._mask_propagator.op_to_mask_dict)
Esempio n. 7
0
    def test_get_all_ops_in_neighborhood(self):
        """ Test that default quantization parameters are set correctly when using json config file """
        model = SingleResidual()
        model.eval()
        input_shapes = (1, 3, 32, 32)

        random_inputs = utils.create_rand_tensors_given_shapes(input_shapes)
        conn_graph = ConnectedGraph(model, random_inputs)
        starting_op = conn_graph.get_all_ops()['convolution_7']
        add_10_op = conn_graph.get_all_ops()['add_10']
        adaptive_avg_pool2d_9_op = conn_graph.get_all_ops()['adaptive_avg_pool2d_9']
        neighborhood = _get_all_ops_in_neighborhood(starting_op, 'output')
        assert len(neighborhood) == 3
        assert starting_op in neighborhood
        assert add_10_op in neighborhood
        assert adaptive_avg_pool2d_9_op in neighborhood
Esempio n. 8
0
    def __init__(self,
                 model: torch.nn.Module,
                 dummy_input: Union[torch.Tensor, Tuple],
                 quant_scheme: Union[
                     str, QuantScheme] = QuantScheme.post_training_tf_enhanced,
                 rounding_mode: str = 'nearest',
                 default_output_bw: int = 8,
                 default_param_bw: int = 8,
                 in_place: bool = False,
                 config_file: str = None):
        """
        Constructor

        :param model: Model to add simulation ops to
        :param dummy_input: Dummy input to the model. Used to parse model graph. If the model has more than one input,
                            pass a tuple. User is expected to place the tensors on the appropriate device.
        :param quant_scheme: Quantization scheme. Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum
                             QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced
        :param rounding_mode: Rounding mode. Supported options are 'nearest' or 'stochastic'
        :param default_output_bw: Default bitwidth (4-31) to use for quantizing layer inputs and outputs
        :param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters
        :param in_place: If True, then the given 'model' is modified in-place to add quant-sim nodes.
                Only suggested use of this option is when the user wants to avoid creating a copy of the model
        :param config_file: Path to Configuration file for model quantizers
        """
        # Perform sanity checks on inputs
        QuantizationSimModel._validate_quantsim_inputs(quant_scheme,
                                                       rounding_mode,
                                                       default_output_bw,
                                                       default_param_bw)
        # save some parameters
        if in_place:
            self.model = model
        else:
            self.model = copy.deepcopy(model)

        try:
            self.connected_graph = ConnectedGraph(self.model, dummy_input)
        except (torch.jit.TracingCheckError, AssertionError):
            self.connected_graph = None

        if isinstance(quant_scheme, str):
            if quant_scheme == 'tf':
                quant_scheme = QuantScheme.post_training_tf
            elif quant_scheme == 'tf_enhanced':
                quant_scheme = QuantScheme.post_training_tf_enhanced
        self._quant_scheme = quant_scheme
        self._rounding_mode = rounding_mode
        self._default_output_bw = default_output_bw
        self._default_param_bw = default_param_bw

        # Add quantization layers
        num_inout_tensors = utils.find_num_inout_tensors_per_module(
            self.model, dummy_input)
        self._add_quantization_wrappers(self.model, num_inout_tensors)

        # Disable bias quantization
        self.exclude_param_from_quantization("bias")

        self.configure_quantization_ops(config_file)
Esempio n. 9
0
def get_module_act_func_pair(model: torch.nn.Module, model_input: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> \
        Dict[torch.nn.Module, Union[torch.nn.Module, None]]:
    """
    For given model, returns dictionary of module to immediate following activation function else maps
    module to None.

    Activation functions should be defined as nn.Modules in model and not as functional in the forward pass.

    :param model: Pytorch model
    :param model_input:  Model input, Can be a list/tuple of input tensor(s)
    :return: Dictionary of module to activation function
    """
    # Keep model in evaluation mode
    model.eval()

    # Create ConnectedGraph
    graph = ConnectedGraph(model, model_input)

    # Maps module to next following activation function else None
    module_act_func_pair = {}

    # Get all the ops
    all_ops = graph.get_all_ops()

    for op in all_ops.values():

        # Get module associated with op
        cur_module = op.get_module()

        if cur_module:
            module_act_func_pair[cur_module] = None

            if op.output:
                assert op.output.consumers, 'op output should have at least one consumer op.'
                # Get the next op
                next_op = op.output.consumers[0]
                # Get module associated with next op
                next_module = next_op.get_module()

                # Get the appropriate activation function
                if isinstance(next_module, ActivationTypes):
                    module_act_func_pair[cur_module] = next_module
                    logger.debug(
                        "Module: %s is followed by activation function: %s",
                        op.dotted_name, next_op.dotted_name)

    return module_act_func_pair
Esempio n. 10
0
 def test_passthrough_op_last_module(self):
     """ Test building a connected graph on a model where a PassThroughOp is the last module in the graph. """
     model = test_models.PassThroughOpLastLayerModel()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(1, len(conn_graph.ordered_ops))
Esempio n. 11
0
 def test_dropouts(self):
     """ Test building ConnectedGraph on a model with dropouts """
     # pylint: disable=protected-access
     model = test_models.ModelWithDropouts()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(9, len(conn_graph.ordered_ops))
     # Split count of 2 due to residual as well as reshape having a split
     self.assertEqual(1, conn_graph._split_count)
     # All ops will include 2 inserted split ops
     self.assertEqual(10, len(conn_graph.get_all_ops().keys()))
     dropout_1_op = conn_graph.get_all_ops()['dropout_3']
     dropout_2_op = conn_graph.get_all_ops()['feature_dropout_4']
     self.assertEqual(model.dropout1, dropout_1_op.get_module())
     self.assertEqual(model.dropout2, dropout_2_op.get_module())
Esempio n. 12
0
 def test_nested_sequential(self):
     # pylint: disable=protected-access
     """ Test building ConnectedGraph on a model constructed with nested nn.Sequential Module """
     model = test_models.NestedSequentialModel()
     model.eval()
     inp_data_1 = torch.rand(1, 3, 8, 8)
     conn_graph = ConnectedGraph(model, (inp_data_1, ))
     self.assertEqual(10, len(conn_graph.ordered_ops))
     # Expect 1 split for the reshape operation
     self.assertEqual(1, conn_graph._split_count)
Esempio n. 13
0
 def test_single_residual(self):
     """ Test building ConnectedGraph on single residual model """
     # pylint: disable=protected-access
     model = test_models.SingleResidual()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(17, len(conn_graph.ordered_ops))
     # Split count of 2 due to residual as well as reshape having a split
     self.assertEqual(2, conn_graph._split_count)
     # All ops will include 2 inserted split ops
     self.assertEqual(19, len(conn_graph.get_all_ops().keys()))
     input_ops = get_all_input_ops(conn_graph)
     self.assertEqual(1, len(input_ops))
     self.assertEqual(model.conv1, input_ops[0].get_module())
     output_ops = get_all_output_ops(conn_graph)
     self.assertEqual(1, len(output_ops))
     self.assertEqual(model.fc, output_ops[0].get_module())
Esempio n. 14
0
def create_connected_graph_with_input_shapes(model: torch.nn.Module, input_shapes: Union[Tuple, List[Tuple]]) \
        -> ConnectedGraph:
    """
    Create connected graph, using random inputs generated from given input shapes.
    :param model: torch model to create a connected graph from
    :param input_shapes: input shapes to the torch model
    :return: ConnectedGraph representation of the model
    """
    random_inputs = create_rand_tensors_given_shapes(input_shapes)
    device = get_device(model)
    random_inputs = tuple([inp.to(device) for inp in random_inputs])
    return ConnectedGraph(model, random_inputs)
Esempio n. 15
0
 def test_hierarchial_model(self):
     """ Test building ConnectedGraph on model which multi-level aggregation of nn.Modules  """
     # pylint: disable=protected-access
     model = test_models.HierarchicalModel()
     model.eval()
     conv_shape = (1, 64, 32, 32)
     inp_shape = (1, 3, 32, 32)
     seq_shape = (1, 3, 8, 8)
     inp_tensor_list = create_rand_tensors_given_shapes(
         [conv_shape, inp_shape, conv_shape, inp_shape, seq_shape])
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(95, len(conn_graph.ordered_ops))
     self.assertEqual(5, conn_graph._split_count)
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv1.conv'),
         conn_graph.ordered_ops[0])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm1.conv1'), conn_graph.ordered_ops[5])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm2.conv1'), conn_graph.ordered_ops[20])
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv2.conv'),
         conn_graph.ordered_ops[36])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.multi_conv.seq_list.0.conv'),
         conn_graph.ordered_ops[40])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm1.conv1'), conn_graph.ordered_ops[53])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm2.conv1'), conn_graph.ordered_ops[68])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.sq.seq_list.0'), conn_graph.ordered_ops[84])
Esempio n. 16
0
    def test_multi_output_with_unuse_model(self):
        """ Test multi-output model with Tuple Tensor as intermediate output and with one of tuple tensor not used """
        class MultiOutputWithUnuseModel(torch.nn.Module):
            """
            Model with Tuple of Tensors as output with one output tensor unused
            """
            def __init__(self):
                super(MultiOutputWithUnuseModel, self).__init__()
                self.layer = test_models.TupleOutputModel()
                self.conv1 = torch.nn.Conv2d(2, 4, kernel_size=3, padding=1)
                self.conv2 = torch.nn.Conv2d(6, 4, kernel_size=3, padding=1)

            def forward(self, *inputs):
                x, _, z = self.layer(inputs[0])
                x1 = self.conv1(x)
                z1 = self.conv2(z)
                return torch.cat([x1, z1], 1)

        inp_data = torch.rand(1, 3, 8, 8)
        model = MultiOutputWithUnuseModel()
        conn_graph = ConnectedGraph(model, (inp_data, ))
        self.assertEqual(6, len(conn_graph.ordered_ops))
        self.assertEqual(
            5,
            len([
                op for op in conn_graph.get_all_ops().keys()
                if 'convolution' in op
            ]))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
        self.assertEqual('cat', conn_graph.ordered_ops[-1].type)

        product_names = conn_graph.get_all_products().keys()
        self.assertEqual(
            0,
            len([product for product in product_names if 'Tuple' in product]))

        expected_products = [
            # layer #1 to conv1,conv2
            'convolution_0_to_convolution_3',
            'convolution_2_to_convolution_4',

            # conv1,conv2 to cat
            'convolution_3_to_cat_5',
            'convolution_4_to_cat_5'
        ]

        products = conn_graph.get_all_products()
        for product_name in product_names:
            if product_name in expected_products:
                product = products[product_name]
                self.assertEqual(product.shape, product.producer.output_shape)
                expected_products.remove(product_name)
        self.assertEqual(0, len(expected_products))
Esempio n. 17
0
    def test_module_reuse_model(self):
        class ReuseReluLeafModel(torch.nn.Module):
            """ A model with Relu instance used multiple times
            Expected one input of size (1, 64, 8, 8) """
            def __init__(self):
                super(ReuseReluLeafModel, self).__init__()
                self.conv1 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
                self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
                self.relu = torch.nn.ReLU(inplace=True)

            def forward(self, *inputs):
                x = self.conv1(inputs[0])
                x = self.relu(x)
                x = self.conv2(x)
                return self.relu(x)

        inp_data = torch.rand(1, 64, 8, 8)
        model = ReuseReluLeafModel()
        conn_graph = ConnectedGraph(model, (inp_data, ))
        self.assertEqual(4, len(conn_graph.ordered_ops))
        self.assertEqual(
            2,
            len([
                op for name, op in conn_graph.get_all_ops().items()
                if 'relu' in name and op.get_module() == model.relu
            ]))

        class ReluModel(torch.nn.Module):
            def __init__(self):
                super(ReluModel, self).__init__()
                self.relu = torch.nn.ReLU(inplace=True)

            def forward(self, *inputs):
                return self.relu(inputs[0])

        class ReuseReluLayerModel(torch.nn.Module):
            """ A model with Relu Layer instance used multiple times
            Expected one input of size (1, 64, 8, 8) """
            def __init__(self):
                super(ReuseReluLayerModel, self).__init__()
                self.conv = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
                self.layer = ReluModel()

            def forward(self, *inputs):
                x = self.layer(inputs[0])
                x = self.conv(x)
                return self.layer(x)

        layer_model = ReuseReluLayerModel()
        conn_graph = ConnectedGraph(layer_model, (inp_data, ))
        self.assertEqual(3, len(conn_graph.ordered_ops))
        self.assertEqual(
            2,
            len([
                op for name, op in conn_graph.get_all_ops().items()
                if 'relu' in name and op.get_module() == layer_model.layer.relu
            ]))
Esempio n. 18
0
    def test_submodules_with_sequence_and_module_list(self):
        """ Test building ConnectedGraph on a model with sequence and module list """
        class ModuleListAndSequentialModel(torch.nn.Module):
            def __init__(self):
                super(ModuleListAndSequentialModel, self).__init__()
                self.mod_list = torch.nn.ModuleList([
                    torch.nn.Sequential(
                        test_models.BasicConv2d(kernel_size=3),
                        test_models.BasicConv2d(kernel_size=3)),
                    torch.nn.Sequential(
                        torch.nn.Sequential(
                            test_models.BasicConv2d(kernel_size=3),
                            test_models.BasicConv2d(kernel_size=3)), ),
                    torch.nn.ModuleList([
                        torch.nn.ModuleList(
                            [test_models.BasicConv2d(kernel_size=3)])
                    ]),
                    test_models.ModuleListModel()
                ])

            def forward(self, *inputs):
                s1 = self.mod_list[0](inputs[0])
                s2 = self.mod_list[1](inputs[0])
                m1 = self.mod_list[2][0][0](inputs[0])
                m2 = self.mod_list[3](inputs[1])
                return s1, s2, m1, m2

        inp_data_1 = torch.rand(1, 64, 8, 8)
        inp_data_2 = torch.rand(1, 3, 8, 8)
        conn_graph = ConnectedGraph(ModuleListAndSequentialModel(),
                                    (inp_data_1, inp_data_2))
        self.assertEqual(30, len(conn_graph.ordered_ops))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
Esempio n. 19
0
def get_ops_with_missing_modules(
        model: torch.nn.Module, model_input: Union[torch.Tensor,
                                                   Tuple]) -> List[str]:
    """
    Utility function to ensure that all connected graph ops of a certain type have associated modules
    :param model: Pytorch model to create connected graph from
    :param model_input: Example input to model.  Can be a single tensor or a list/tuple of input tensors
    :return: List of op names with missing modules
    """
    try:
        conn_graph = ConnectedGraph(model, model_input)
    except:
        logger.error(
            'A connected graph failed to be built. This may prevent from AIMET features from being able to '
            'run on the model. Please address the errors shown.')
        raise AssertionError

    missing_modules = []
    for op_name, op in conn_graph.get_all_ops().items():
        if not op.get_module(
        ) and op.type not in ConnectedGraph.functional_ops:
            missing_modules.append(op_name)

    return missing_modules
Esempio n. 20
0
 def test_multi_output_model(self):
     """ Test multi-output model with Tuple Tensor as intermediate  output. """
     model = test_models.MultiOutputModel()
     inp_data = torch.rand(1, 3, 8, 8)
     conn_graph = ConnectedGraph(model, (inp_data, ))
     self.assertEqual(7, len(conn_graph.ordered_ops))
     self.assertEqual(
         6,
         len([
             op for op in conn_graph.get_all_ops().keys()
             if 'convolution' in op
         ]))
     self.assertEqual(
         0,
         len([
             op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
         ]))
     self.assertEqual(
         0,
         len([
             product for product in conn_graph.get_all_products().keys()
             if 'Tuple' in product
         ]))
     self.assertEqual('cat', conn_graph.ordered_ops[-1].type)
    def does_module_have_relu_activation(connected_graph: ConnectedGraph,
                                         module: torch.nn.Module) -> bool:
        """
        Finds if a given module has a ReLU activation
        :param connected_graph: Reference to ConnectedGraph instance
        :param module: PyTorch module to find activation for
        :return: True if module has a relu activation
        """

        for op in connected_graph.get_all_ops().values():

            if op.model_module and op.model_module.get_module() is module:
                assert len(op.output.consumers) == 1
                is_relu_activation = isinstance(
                    op.output.consumers[0].model_module.get_module(),
                    (torch.nn.ReLU, torch.nn.PReLU))
                return is_relu_activation

        return False
Esempio n. 22
0
def find_all_conv_bn_with_activation(model: torch.nn.Module,
                                     input_shape: Tuple) -> Dict:
    """
    Uses searcher to find preceding and next bn layers for a conv/linear layer
    :param model: PyTorch model
    :param input_shape: shape of input to the model
    :return: dictionary of conv/linear layers with associated bn op / activation info
    """

    activation_types = ['relu', 'hardtanh']

    # initialize all patterns to be matched and associated call back functions
    patterns_with_callbacks = []
    layer_select_handler = ConvBnPatternHandler()
    patterns_with_callbacks.append(
        PatternType(pattern=['batch_norm', 'convolution'],
                    action=layer_select_handler))

    patterns_with_callbacks.append(
        PatternType(pattern=['convolution'], action=layer_select_handler))

    patterns_with_callbacks.append(
        PatternType(pattern=['addmm'], action=layer_select_handler))

    for activation in activation_types:
        patterns_with_callbacks.append(
            PatternType(pattern=['batch_norm', activation, 'convolution'],
                        action=layer_select_handler))

    device = utils.get_device(model)
    connected_graph = ConnectedGraph(model,
                                     (torch.rand(input_shape).to(device), ))

    # create graph searcher instance with connected graph and patterns to search
    graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)

    # get all conv/linear and bn info
    graph_searcher.find_all_patterns_in_graph_apply_actions()
    convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict(
    )

    return convs_bn_activation_dict
Esempio n. 23
0
def find_all_conv_bn_with_activation(model: torch.nn.Module,
                                     input_shape: Tuple) -> Dict:
    """
    Uses searcher to find preceding and next bn layers for a conv/linear layer
    :param model: PyTorch model
    :param input_shape: shape of input to the model
    :return: dictionary of conv/linear layers with associated bn op / activation info
    """

    # initialize all patterns to be matched and associated call back functions
    patterns_with_callbacks = []
    layer_select_handler = ConvBnPatternHandler()

    patterns_with_callbacks.append(
        PatternType(pattern=['batch_norm', 'convolution'],
                    action=layer_select_handler))
    patterns_with_callbacks.append(
        PatternType(pattern=['convolution', 'batch_norm'],
                    action=layer_select_handler))
    linear_types = ['addmm', 'matmul']
    for linear_type in linear_types:
        patterns_with_callbacks.append(
            PatternType(pattern=['batch_norm', linear_type],
                        action=layer_select_handler))
        patterns_with_callbacks.append(
            PatternType(pattern=[linear_type, 'batch_norm'],
                        action=layer_select_handler))

    inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shape)
    connected_graph = ConnectedGraph(model, inp_tensor_list)

    # create graph searcher instance with connected graph and patterns to search
    graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)

    # get all conv/linear and bn info
    graph_searcher.find_all_patterns_in_graph_apply_actions()
    convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict(
    )

    return convs_bn_activation_dict
Esempio n. 24
0
    def test_multi_output_with_shuffled_layers(self):
        """ Test a multiple layer multi-output model with intermediate Tuple Tensors shuffled """
        class MultiOutputShuffledModel(torch.nn.Module):
            """
            Model with Tuple of Tensors as output shuffled between layers
            """
            def __init__(self):
                super(MultiOutputShuffledModel, self).__init__()
                self.layer1 = test_models.ConfigurableTupleOutputModel(
                    channels=(1, 2, 3))
                self.layer2 = test_models.ConfigurableTupleOutputModel(
                    channels=(2, 3, 1))
                self.layer3 = test_models.ConfigurableTupleOutputModel(
                    channels=(3, 1, 2))

            def forward(self, *inputs):
                x1, x2, x3 = self.layer1(inputs[0], inputs[1], inputs[2])
                y2, y3, y1 = self.layer2(x2, x3, x1)
                z3, z1, z2 = self.layer3(y3, y1, y2)
                return torch.cat([z1, z2, z3, x1], 1)

        model = MultiOutputShuffledModel()
        inp_tensor_list = create_rand_tensors_given_shapes([(1, 1, 8, 8),
                                                            (1, 2, 8, 8),
                                                            (1, 3, 8, 8)])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(10, len(conn_graph.ordered_ops))
        self.assertEqual(
            9,
            len([
                op for op in conn_graph.get_all_ops().keys()
                if 'convolution' in op
            ]))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
        self.assertEqual('cat', conn_graph.ordered_ops[-1].type)

        product_names = conn_graph.get_all_products().keys()
        self.assertEqual(
            0,
            len([product for product in product_names if 'Tuple' in product]))

        expected_products = [
            # TODO fix order of products

            # layer #1 to layer #2
            'convolution_0__to__Split_0',
            'convolution_1_to_convolution_3',
            'convolution_2_to_convolution_4',

            # layer #2 to layer #3
            'convolution_3_to_convolution_8',
            'convolution_4_to_convolution_6',
            'convolution_5_to_convolution_7',

            # layer #3, layer#1.conv1 to cat
            'convolution_6_to_cat_9',
            'convolution_7_to_cat_9',
            'convolution_8_to_cat_9'
        ]

        products = conn_graph.get_all_products()
        for product_name in product_names:
            if product_name in expected_products:
                product = products[product_name]
                self.assertEqual(product.shape, product.producer.output_shape)
                expected_products.remove(product_name)
        self.assertEqual(0, len(expected_products))
        split_product = conn_graph.get_all_products(
        )['Split_0__to__multiple_ops']
        self.assertTrue(conn_graph.get_all_ops()['convolution_5'] in
                        split_product.consumers)
        self.assertTrue(
            conn_graph.get_all_ops()['cat_9'] in split_product.consumers)
class MaskPropagationWinnower(AimetCommonMaskPropagationWinnower):
    """ The MaskPropagationWinnower class implements winnowing based on propagating masks corresponding to each
    module's input channels identified to be winnowed.  """
    def __init__(self,
                 model: torch.nn.Module,
                 input_shape: Tuple,
                 list_of_modules_to_winnow: List[Tuple[torch.nn.Module,
                                                       List]] = None,
                 reshape=True,
                 in_place=False,
                 verbose=False):
        """
        MaskPropagationWinnower object initialization.
        :param model: The model to be winnowed.
        :param input_shape: The input shape of the model.
        :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of
        channels to be winnowed for that module.
        :param reshape: If set to True a Down Sample Layer is added between modules to match the number of channels.
                    If set to False, the modules that need a Down Sample Layer will not be winnowed.
        :param in_place: If set to True, the model will be winnowed in place.
                     If set to False, a copy of the model will be winnowed.
        :param verbose: If set to True, logs detailed winnowing log messages.
        """

        super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)
        model.apply(has_hooks)

        debug_level = logger.getEffectiveLevel()
        logger.debug("Current log level: %s", debug_level)

        self._using_cuda = next(model.parameters()).is_cuda

        if self._in_place is False:
            # Do not winnow the model in place
            self._model = copy.deepcopy(model)
            logger.info("A copy of the model will be winnowed")
        else:
            # Winnow the model in place
            logger.info("Model will be winnowed in place")
            self._model = model

        # Construct connected graph representation of the computational graph
        dummy_input = torch.rand(input_shape)
        if self._using_cuda:
            dummy_input = torch.tensor(dummy_input).cuda()  # pylint: disable=not-callable

        self._graph = ConnectedGraph(self._model, (dummy_input, ))
        self.list_of_modules_to_winnow_with_names = \
            generate_and_add_module_winnow_list_with_names(model, self._list_of_modules_to_winnow)
        self._mask_propagator = MaskPropagator(self._graph, ModelApi.pytorch)
        self._module_reducer = ModuleReducer(
            self._model, self._using_cuda, self._reshape,
            self._mask_propagator.op_to_mask_dict)

    def propagate_masks_and_winnow(self):
        """  For the modules to be winnowed, create and propagate the masks.
        Once mask propagation is completed, winnow the model. """

        # Propagate the masks
        self._propagate_masks()

        modified_op_list = self._mask_propagator.get_ops_with_non_default_ip_op_masks(
        )
        for name in modified_op_list:
            logger.info("Modified Op: %s", name)

        modified_modules_dict = self._module_reducer.reduce_modules(
            modified_op_list)

        if modified_modules_dict:
            ordered_module_list = self._create_modified_modules_list(
                modified_modules_dict)
        else:
            ordered_module_list = None
            logger.info(
                "No modules were winnowed. Original model is returned.")

        return self._model, ordered_module_list

    def _propagate_masks(self):
        """  For the modules to be winnowed, set the channels to winnow and propagate the masks."""
        for module, list_of_channels_to_winnow, name in self.list_of_modules_to_winnow_with_names:
            self.validate_winnow_api_parameters(module, name,
                                                list_of_channels_to_winnow)

            input_channels_to_winnow = list_of_channels_to_winnow
            output_channels_to_winnow = None
            if isinstance(module,
                          (torch.nn.Linear, torch.nn.modules.conv.Conv2d)):
                self._mask_propagator.update_channels_to_winnow(
                    name, self._reshape, input_channels_to_winnow,
                    output_channels_to_winnow)

        # The channels to winnow have been updated
        # Propagate the masks.
        self._mask_propagator.propagate_masks()

    @staticmethod
    def _create_modified_modules_list(modified_modules: Dict[str,
                                                             torch.nn.Module]):
        """ Creates and returns a list of tuples with each tuple containing
        the original module and its replacement module
        :param modified_modules: dictionary of modules modified during module reduction
        :return list of tuples of name of the original module in the model and corresponding new module
        """

        modified_module_list = []
        for orig_module_name, new_module in modified_modules.items():
            # Remove prefix of the model name
            # E.g. the module_name maybe Net.layer1.conv1, we only want layer1.conv1
            first_dot_position = orig_module_name.find('.')
            if first_dot_position != -1:
                orig_module_name = orig_module_name[first_dot_position + 1:]
            modified_module_list.append((orig_module_name, new_module))

        return modified_module_list

    def validate_winnow_api_parameters(self, module, name,
                                       list_of_channels_to_winnow):
        """
        For a given module, validate Winnow API parameters.
        :param module: module whose channel numbers are being validated.
        :param name: module's name
        :param list_of_channels_to_winnow: list of channels that must be winnowed.
        """

        if not isinstance(module, torch.nn.Conv2d):
            logger.critical(
                "Winnowing is currently only supported for torch.nn.Conv2d modules. Attempting to winnow "
                "module of type %s", type(module))
            raise NotImplementedError(type(module))

        # Validate the list of channels.
        num_channels_to_winnow = len(list_of_channels_to_winnow)
        if num_channels_to_winnow == 0:
            raise ValueError(
                "The list of channels to winnow is empty for the module: %s" %
                name)

        max_channel_num = max(list_of_channels_to_winnow)
        max_in_channel_index = (module.in_channels - 1)
        if max_channel_num > max_in_channel_index:
            raise ValueError(
                "Channel number: %s exceeds module's max channel number index: %s for module: %s"
                % (max_channel_num, max_in_channel_index, name))

        if num_channels_to_winnow == module.in_channels:
            raise ValueError(
                "Winnowing all the input channels is not allowed, module: %s" %
                name)

        module_op = self._graph.get_op_from_module_name(name)
        input_index = 0  # Using op index 0 to examine input to op
        if module_op.inputs[input_index].is_model_input:
            logger.critical(
                "Winnowing the first module of a model is NOT supported. Please ignore the first "
                "module and try again. First module: %s, shape %s, channels to winnow: %s",
                module_op.dotted_name, module_op.inputs[input_index].shape,
                list_of_channels_to_winnow)
            raise NotImplementedError(module_op.dotted_name)
class GraphSearchUtils:
    """
    Code to search a model graph to find nodes to use for cross-layer-scaling and high-bias-fold
    """
    def __init__(self, model: torch.nn.Module,
                 input_shapes: Union[Tuple, List[Tuple]]):
        inp_tensor_list = tuple(
            utils.create_rand_tensors_given_shapes(input_shapes))
        self._connected_graph = ConnectedGraph(model, inp_tensor_list)
        self._ordered_module_list = utils.get_ordered_list_of_conv_modules(
            model, inp_tensor_list)

    @staticmethod
    def find_downstream_layer_groups_to_scale(op,
                                              layer_groups,
                                              current_group=None,
                                              visited_nodes=None):
        """
        Recursive function to find cls layer groups downstream from a given op
        :param op: Starting op to search from
        :param layer_groups: Running list of layer groups
        :param current_group: Running current layer group
        :param visited_nodes: Running list of visited nodes (to short-circuit recursion)
        :return: None
        """

        if not visited_nodes:
            visited_nodes = []
        if not current_group:
            current_group = []

        if op in visited_nodes:
            return
        visited_nodes.append(op)
        # print("Visiting node: {}".format(op.dotted_name))

        # If current node is Conv2D, add to the current group
        if op.model_module and isinstance(
                op.model_module.get_module(),
            (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
            current_group.append(op.model_module.get_module())

        # Terminating condition for current group
        if not op.model_module or not isinstance(
                op.model_module.get_module(),
            (torch.nn.Conv2d, torch.nn.ReLU, torch.nn.PReLU,
             torch.nn.ConvTranspose2d)):
            if (len(current_group) > 1) and (current_group
                                             not in layer_groups):
                layer_groups.append(current_group)
            current_group = []

        if op.output:
            for consumer in op.output.consumers:
                GraphSearchUtils.find_downstream_layer_groups_to_scale(
                    consumer, layer_groups, current_group, visited_nodes)

        # Reached a leaf.. See if the current group has something to grab
        if (len(current_group) > 1) and (current_group not in layer_groups):
            layer_groups.append(current_group)

    @staticmethod
    def convert_layer_group_to_cls_sets(layer_group):
        """
        Helper function to convert a layer group to a list of cls sets
        :param layer_group: Given layer group to conver
        :return: List of cls sets
        """
        cls_sets = []

        prev_layer_to_scale = layer_group.pop(0)
        while layer_group:
            next_layer_to_scale = layer_group.pop(0)

            if next_layer_to_scale.groups > 1:

                if layer_group:
                    next_non_depthwise_conv_layer = layer_group.pop(0)
                    cls_sets.append((prev_layer_to_scale, next_layer_to_scale,
                                     next_non_depthwise_conv_layer))
                    prev_layer_to_scale = next_non_depthwise_conv_layer

            else:
                cls_sets.append((prev_layer_to_scale, next_layer_to_scale))
                prev_layer_to_scale = next_layer_to_scale

        return cls_sets

    def find_layer_groups_to_scale(self) -> List[List[torch.nn.Conv2d]]:
        """
        :return: List of groups of layers. Each group can be independently equalized
        """

        # Find the input node(s) in the graph
        input_nodes = []
        for op in self._connected_graph.get_all_ops().values():
            if op.inputs and op.inputs[0].is_model_input:
                input_nodes.append(op)

        layer_groups = []
        for op in input_nodes:
            self.find_downstream_layer_groups_to_scale(op, layer_groups)

        # Sort the layer groups in order of occurrence in the model
        ordered_layer_groups = []
        for _, module in self._ordered_module_list:
            for layer_group in layer_groups:
                if layer_group[0] is module:
                    ordered_layer_groups.append(layer_group)

        return ordered_layer_groups

    @staticmethod
    def does_module_have_relu_activation(connected_graph: ConnectedGraph,
                                         module: torch.nn.Module) -> bool:
        """
        Finds if a given module has a ReLU activation
        :param connected_graph: Reference to ConnectedGraph instance
        :param module: PyTorch module to find activation for
        :return: True if module has a relu activation
        """

        for op in connected_graph.get_all_ops().values():

            if op.model_module and op.model_module.get_module() is module:
                assert len(op.output.consumers) == 1
                is_relu_activation = isinstance(
                    op.output.consumers[0].model_module.get_module(),
                    (torch.nn.ReLU, torch.nn.PReLU))
                return is_relu_activation

        return False

    def is_relu_activation_present_in_cls_sets(self, cls_sets: List[ClsSet]):
        """
        :param cls_sets: CLS sets to find relu activations in
        :return: List of groups of layers. Each group can be independently equalized
        """

        is_relu_activation_in_cls_sets = []
        for cls_set in cls_sets:

            # We need to check activation functions for all layers but the last one in the set
            # Because we are only interested in checking activation functions between the layers we will scale
            cls_set = cls_set[:-1]

            is_relu_activation_in_cls_set = ()
            for module in cls_set:
                is_relu_activation_in_cls_set += (
                    self.does_module_have_relu_activation(
                        self._connected_graph, module), )

            if len(is_relu_activation_in_cls_set) == 1:
                is_relu_activation_in_cls_set = is_relu_activation_in_cls_set[
                    0]

            is_relu_activation_in_cls_sets.append(
                is_relu_activation_in_cls_set)

        return is_relu_activation_in_cls_sets