コード例 #1
0
ファイル: test_utils.py プロジェクト: Rohan-Chaudhury/aimet
    def test_create_rand_tensors_given_shapes(self):
        shape_1 = (1, 32)
        shape_2 = (3, 3)
        rand_tensors = utils.create_rand_tensors_given_shapes(
            [shape_1, shape_2], device=torch.device('cpu'))
        self.assertEqual(2, len(rand_tensors))
        self.assertEqual(shape_1, rand_tensors[0].shape)
        self.assertEqual(shape_2, rand_tensors[1].shape)
        self.assertEqual(torch.device('cpu'), rand_tensors[0].device)

        rand_tensors = utils.create_rand_tensors_given_shapes(
            [shape_1, shape_2], device=torch.device('cuda:0'))
        self.assertEqual(torch.device('cuda:0'), rand_tensors[0].device)
コード例 #2
0
 def __init__(self, model: torch.nn.Module,
              input_shapes: Union[Tuple, List[Tuple]]):
     inp_tensor_list = tuple(
         utils.create_rand_tensors_given_shapes(input_shapes))
     self._connected_graph = ConnectedGraph(model, inp_tensor_list)
     self._ordered_module_list = utils.get_ordered_list_of_conv_modules(
         model, inp_tensor_list)
コード例 #3
0
    def test_multi_input(self):
        """ Test building ConnectedGraph on a model with multiple inputs """
        # pylint: disable=protected-access
        model = test_models.MultiInput()
        model.eval()
        inp_shape_1 = (1, 3, 32, 32)
        inp_shape_2 = (1, 3, 20, 20)
        inp_tensor_list = create_rand_tensors_given_shapes(
            [inp_shape_1, inp_shape_2])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(11, len(conn_graph.ordered_ops))
        # Split count of 1 due to reshape having a split
        self.assertEqual(1, conn_graph._split_count)
        conv1 = conn_graph.get_op_from_module_name('MultiInput.conv1')
        self.assertEqual(model.conv1, conv1.get_module())
        self.assertEqual(2, len(conv1.inputs))
        conv2 = conn_graph.get_op_from_module_name('MultiInput.conv2')
        self.assertEqual(model.conv2, conv2.get_module())
        self.assertEqual(3, len(conv2.inputs))
        conv3 = conn_graph.get_op_from_module_name('MultiInput.conv3')
        self.assertEqual(model.conv3, conv3.get_module())
        self.assertEqual(3, len(conv3.inputs))

        input_ops = get_all_input_ops(conn_graph)
        input_modules = [op.get_module() for op in input_ops]
        self.assertEqual(2, len(input_ops))
        self.assertTrue(model.conv1 in input_modules)
        self.assertTrue(model.conv3 in input_modules)
        output_ops = get_all_output_ops(conn_graph)
        self.assertEqual(1, len(output_ops))
        self.assertEqual(model.fc, output_ops[0].get_module())
コード例 #4
0
 def test_passthrough_op_last_module(self):
     """ Test building a connected graph on a model where a PassThroughOp is the last module in the graph. """
     model = test_models.PassThroughOpLastLayerModel()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(1, len(conn_graph.ordered_ops))
コード例 #5
0
def create_connected_graph_with_input_shapes(model: torch.nn.Module, input_shapes: Union[Tuple, List[Tuple]]) \
        -> ConnectedGraph:
    """
    Create connected graph, using random inputs generated from given input shapes.
    :param model: torch model to create a connected graph from
    :param input_shapes: input shapes to the torch model
    :return: ConnectedGraph representation of the model
    """
    random_inputs = create_rand_tensors_given_shapes(input_shapes)
    device = get_device(model)
    random_inputs = tuple([inp.to(device) for inp in random_inputs])
    return ConnectedGraph(model, random_inputs)
コード例 #6
0
 def test_concat(self):
     """ Test building ConnectedGraph on a model with concat """
     model = test_models.ConcatModel()
     model.eval()
     inp_shape_1 = (1, 3, 8, 8)
     inp_shape_2 = (1, 3, 8, 8)
     inp_shape_3 = (1, 3, 8, 8)
     inp_tensor_list = create_rand_tensors_given_shapes(
         [inp_shape_1, inp_shape_2, inp_shape_3])
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     concat_op = conn_graph.get_all_ops()['cat_3']
     self.assertEqual(3, len(concat_op.inputs))
     self.assertEqual(14, concat_op.output_shape[1])
コード例 #7
0
    def test_get_all_ops_in_neighborhood(self):
        """ Test that default quantization parameters are set correctly when using json config file """
        model = SingleResidual()
        model.eval()
        input_shapes = (1, 3, 32, 32)

        random_inputs = utils.create_rand_tensors_given_shapes(input_shapes)
        conn_graph = ConnectedGraph(model, random_inputs)
        starting_op = conn_graph.get_all_ops()['convolution_7']
        add_10_op = conn_graph.get_all_ops()['add_10']
        adaptive_avg_pool2d_9_op = conn_graph.get_all_ops()['adaptive_avg_pool2d_9']
        neighborhood = _get_all_ops_in_neighborhood(starting_op, 'output')
        assert len(neighborhood) == 3
        assert starting_op in neighborhood
        assert add_10_op in neighborhood
        assert adaptive_avg_pool2d_9_op in neighborhood
コード例 #8
0
 def test_dropouts(self):
     """ Test building ConnectedGraph on a model with dropouts """
     # pylint: disable=protected-access
     model = test_models.ModelWithDropouts()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(9, len(conn_graph.ordered_ops))
     # Split count of 2 due to residual as well as reshape having a split
     self.assertEqual(1, conn_graph._split_count)
     # All ops will include 2 inserted split ops
     self.assertEqual(10, len(conn_graph.get_all_ops().keys()))
     dropout_1_op = conn_graph.get_all_ops()['dropout_3']
     dropout_2_op = conn_graph.get_all_ops()['feature_dropout_4']
     self.assertEqual(model.dropout1, dropout_1_op.get_module())
     self.assertEqual(model.dropout2, dropout_2_op.get_module())
コード例 #9
0
    def compute_and_save_weight_encodings(self, path: str,
                                          filename_prefix: str,
                                          input_shape: Union[Tuple,
                                                             List[Tuple]]):
        """
        Save the quantized model weight encodings

        :param path: path where to store model pth and encodings
        :param filename_prefix: filename to store exported weight encodings in json format
        :param input_shape: shape of the input parameter to the model
        :return: None
        """

        device = utils.get_device(self._model)
        self._model.cpu()
        inputs = utils.create_rand_tensors_given_shapes(input_shape)

        # compute weight encodings
        weight_encoding_dict = {}
        weight_encoding_dict_with_onnx_names = {}
        quantized_layers = self.__get_qc_quantized_layers(self._model)
        pytorch_onnx_names_dict = su.SaveUtils.get_name_of_op_from_graph(
            self._model, *inputs)
        for layer_name, layer in quantized_layers:
            if isinstance(layer, QcQuantizeWrapper):
                layer_wt_encoding = layer.compute_weight_encodings()
                # skip dictionary update for no weight encoding case
                if layer_wt_encoding is not None:
                    value = (layer_wt_encoding.max, layer_wt_encoding.min,
                             layer_wt_encoding.delta, layer_wt_encoding.offset,
                             layer_wt_encoding.bw)
                    weight_encoding_dict[layer_name] = value
                    if layer_name in pytorch_onnx_names_dict:
                        weight_encoding_dict_with_onnx_names[
                            pytorch_onnx_names_dict[layer_name]] = value
        # export weight encodings to output json file
        su.SaveUtils.save_weight_encodings_to_files(
            path=path,
            filename_prefix=filename_prefix,
            weight_encoding_dict=weight_encoding_dict,
            weight_encoding_dict_with_onnx_names=
            weight_encoding_dict_with_onnx_names)

        self._model.to(device)
コード例 #10
0
ファイル: test_utils.py プロジェクト: Rohan-Chaudhury/aimet
    def _collect_inp_out_data_multi_input(self, device):
        model = MultiInput().to(device=device)
        inp_shape_1 = (1, 3, 32, 32)
        inp_shape_2 = (1, 3, 20, 20)
        model_input = utils.create_rand_tensors_given_shapes(
            [inp_shape_1, inp_shape_2])

        module_data = utils.ModuleData(model, model.conv1)
        inp, out = module_data.collect_inp_out_data(model_input,
                                                    collect_input=True,
                                                    collect_output=False)
        self.assertTrue(
            np.array_equal(utils.to_numpy(inp),
                           utils.to_numpy(model_input[0])))
        self.assertEqual(out, None)

        module_data = utils.ModuleData(model, model.conv1)
        inp, out = module_data.collect_inp_out_data(model_input,
                                                    collect_input=False,
                                                    collect_output=True)
        conv1_out = model.conv1(model_input[0])
        self.assertTrue(
            np.array_equal(utils.to_numpy(out), utils.to_numpy(conv1_out)))
        self.assertEqual(inp, None)

        module_data = utils.ModuleData(model, model.conv3)
        inp, out = module_data.collect_inp_out_data(model_input,
                                                    collect_input=True,
                                                    collect_output=True)
        conv3_out = model.conv3(model_input[1])
        self.assertTrue(
            np.array_equal(utils.to_numpy(out), utils.to_numpy(conv3_out)))
        self.assertTrue(
            np.array_equal(utils.to_numpy(inp),
                           utils.to_numpy(model_input[1])))

        module_data = utils.ModuleData(model, model.fc)
        inp, out = module_data.collect_inp_out_data(model_input,
                                                    collect_input=False,
                                                    collect_output=True)
        fc_out = model(*model_input)
        self.assertTrue(
            np.array_equal(utils.to_numpy(out), utils.to_numpy(fc_out)))
        self.assertEqual(inp, None)
コード例 #11
0
 def test_single_residual(self):
     """ Test building ConnectedGraph on single residual model """
     # pylint: disable=protected-access
     model = test_models.SingleResidual()
     model.eval()
     inp_shape = (1, 3, 32, 32)
     inp_tensor_list = create_rand_tensors_given_shapes(inp_shape)
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(17, len(conn_graph.ordered_ops))
     # Split count of 2 due to residual as well as reshape having a split
     self.assertEqual(2, conn_graph._split_count)
     # All ops will include 2 inserted split ops
     self.assertEqual(19, len(conn_graph.get_all_ops().keys()))
     input_ops = get_all_input_ops(conn_graph)
     self.assertEqual(1, len(input_ops))
     self.assertEqual(model.conv1, input_ops[0].get_module())
     output_ops = get_all_output_ops(conn_graph)
     self.assertEqual(1, len(output_ops))
     self.assertEqual(model.fc, output_ops[0].get_module())
コード例 #12
0
def find_all_conv_bn_with_activation(model: torch.nn.Module,
                                     input_shape: Tuple) -> Dict:
    """
    Uses searcher to find preceding and next bn layers for a conv/linear layer
    :param model: PyTorch model
    :param input_shape: shape of input to the model
    :return: dictionary of conv/linear layers with associated bn op / activation info
    """

    # initialize all patterns to be matched and associated call back functions
    patterns_with_callbacks = []
    layer_select_handler = ConvBnPatternHandler()

    patterns_with_callbacks.append(
        PatternType(pattern=['batch_norm', 'convolution'],
                    action=layer_select_handler))
    patterns_with_callbacks.append(
        PatternType(pattern=['convolution', 'batch_norm'],
                    action=layer_select_handler))
    linear_types = ['addmm', 'matmul']
    for linear_type in linear_types:
        patterns_with_callbacks.append(
            PatternType(pattern=['batch_norm', linear_type],
                        action=layer_select_handler))
        patterns_with_callbacks.append(
            PatternType(pattern=[linear_type, 'batch_norm'],
                        action=layer_select_handler))

    inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shape)
    connected_graph = ConnectedGraph(model, inp_tensor_list)

    # create graph searcher instance with connected graph and patterns to search
    graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)

    # get all conv/linear and bn info
    graph_searcher.find_all_patterns_in_graph_apply_actions()
    convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict(
    )

    return convs_bn_activation_dict
コード例 #13
0
 def test_hierarchial_model(self):
     """ Test building ConnectedGraph on model which multi-level aggregation of nn.Modules  """
     # pylint: disable=protected-access
     model = test_models.HierarchicalModel()
     model.eval()
     conv_shape = (1, 64, 32, 32)
     inp_shape = (1, 3, 32, 32)
     seq_shape = (1, 3, 8, 8)
     inp_tensor_list = create_rand_tensors_given_shapes(
         [conv_shape, inp_shape, conv_shape, inp_shape, seq_shape])
     conn_graph = ConnectedGraph(model, inp_tensor_list)
     self.assertEqual(95, len(conn_graph.ordered_ops))
     self.assertEqual(5, conn_graph._split_count)
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv1.conv'),
         conn_graph.ordered_ops[0])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm1.conv1'), conn_graph.ordered_ops[5])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm1.tm2.conv1'), conn_graph.ordered_ops[20])
     self.assertEqual(
         conn_graph.get_op_from_module_name('HierarchicalModel.conv2.conv'),
         conn_graph.ordered_ops[36])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.multi_conv.seq_list.0.conv'),
         conn_graph.ordered_ops[40])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm1.conv1'), conn_graph.ordered_ops[53])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.nm2.tm2.conv1'), conn_graph.ordered_ops[68])
     self.assertEqual(
         conn_graph.get_op_from_module_name(
             'HierarchicalModel.sq.seq_list.0'), conn_graph.ordered_ops[84])
コード例 #14
0
def correct_bias(model: torch.nn.Module,
                 quant_params: qsim.QuantParams,
                 num_quant_samples: int,
                 data_loader,
                 num_bias_correct_samples: int,
                 conv_bn_dict: Union[Dict[torch.nn.Module, ConvBnInfoType],
                                     None] = None,
                 perform_only_empirical_bias_corr: bool = True,
                 layers_to_ignore: List[torch.nn.Module] = None):
    """
    Corrects bias for each Conv layer of model (unless ignored). A combination of Analytical and Empirical Bias
    Correction is used i.e. all the layers which can be corrected using Analytical Bias Correction are corrected
    using Analytical Bias Correction and remaining layers are corrected using Empirical method.

    Returns an in-place corrected floating point model

    :param model: Model to be corrected
    :param quant_params: Named tuple for quantization simulation for bias correction
    :param num_quant_samples: number of samples of images to pass through quantization sim for bias correction.
    :param data_loader: data loader for the model
    :param num_bias_correct_samples: number of samples for Bias correction
    :param conv_bn_dict: Dict of conv and bn with information related to activation. If None, the function calc it
    :param perform_only_empirical_bias_corr: Default True. If true will perform only empirical Bias Corr for all layers
           irrespective of the fact that layer is eligible for Analytical Bias Corr.
    :param layers_to_ignore: list of layer names for which we need to skip bias correction.

    """

    if layers_to_ignore is None:
        layers_to_ignore = []

    # Find batch size and shape of input tensor
    batch_size, input_shape = utils.get_input_shape_batch_size(data_loader)

    # Rounding up number of samples to batch size
    n_batches_bias_correction = int(
        np.ceil(num_bias_correct_samples / batch_size))
    n_batches_quantization = int(np.ceil(num_quant_samples / batch_size))

    data_loader_n_samples_bias_corr = utils.IterFirstX(
        data_loader, n_batches_bias_correction)
    data_loader_n_samples_quant = utils.IterFirstX(data_loader,
                                                   n_batches_quantization)

    # TODO: Remove wrapper function
    # Create a wrapping function for data loader for quantization
    def pass_data_through_model(model,
                                early_stopping_iterations=None,
                                use_cuda=False):
        # pylint: disable=unused-argument
        # forward pass for given number of batches for model
        for (images_in_one_batch, _) in data_loader_n_samples_quant:
            forward_pass(model, images_in_one_batch)

    ordered_conv_linear_nodes = get_ordered_lists_of_conv_fc(
        model, input_shape)

    if conv_bn_dict is None:
        conv_bn_dict = find_all_conv_bn_with_activation(model, input_shape)

    # Create a copy of the model as reference model
    model_copy = copy.deepcopy(model)

    # Add bias for all the layers whose bias is None
    for name, module in ordered_conv_linear_nodes:
        if module.bias is None:
            if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                output_size = module.out_channels
            elif isinstance(module, torch.nn.Linear):
                output_size = module.out_features
            module.bias = torch.nn.Parameter(torch.zeros(output_size))
            module.bias.data = module.bias.data.to(device=module.weight.device)

    # Quantize full model
    dummy_tensors = utils.create_rand_tensors_given_shapes(input_shape)
    dummy_tensors = [
        tensor.to(utils.get_device(model)) for tensor in dummy_tensors
    ]
    q = qsim.QuantizationSimModel(model=model,
                                  quant_scheme=quant_params.quant_scheme,
                                  rounding_mode=quant_params.round_mode,
                                  default_output_bw=quant_params.act_bw,
                                  default_param_bw=quant_params.weight_bw,
                                  in_place=True,
                                  dummy_input=dummy_tensors,
                                  config_file=quant_params.config_file)

    # make sure  model got updated in-place before we use it for bc updates
    assert (q.model is model)

    # updates to skip_output_activation and layers_to_ignore
    for name, module in model.named_modules():
        # Skip all layer's output quantization
        if isinstance(module, QcQuantizeWrapper):
            module.output_quantizers[0].enabled = False

    q.compute_encodings(pass_data_through_model, None)

    # For first conv layer, perform analytical bc if perform_only_empirical_bias_corr is set to False
    # and layer is not marked to be ignored during bc.
    if not perform_only_empirical_bias_corr:
        module_name, module = ordered_conv_linear_nodes[0]
        if module not in layers_to_ignore:
            logger.info('Correcting layer %s using Analytical Bias Correction',
                        module_name)
            quantize_layer = utils.get_layer_by_name(model, module_name)
            call_analytical_mo_correct_bias(quantize_layer, None, None)
            logger.info('Corrected bias for the layer')
            ordered_conv_linear_nodes.pop(0)

    for module_name, module in ordered_conv_linear_nodes:
        # Ignore all layers which are skipped by user
        if module in layers_to_ignore:
            continue
        else:
            # make sure module is in the model used by qsim.
            assert (module in list(q.model.modules()))
            # Analytical Bias Correction is only done for Conv layers
            reference_layer = utils.get_layer_by_name(model_copy, module_name)
            quantize_layer = utils.get_layer_by_name(model, module_name)

            if module in conv_bn_dict.keys():

                bn_layer_info = conv_bn_dict[module]

                if perform_only_empirical_bias_corr or bn_layer_info is None or bn_layer_info.input_bn is None:
                    logger.info(
                        'Correcting layer %s using Empirical Bias Correction',
                        module_name)
                    bias_correction = libpymo.BiasCorrection()

                    # Get output from quantized model and reference model

                    for images_in_one_batch, _ in data_loader_n_samples_bias_corr:
                        reference_output_batch = get_output_data(
                            reference_layer, model_copy, images_in_one_batch)
                        quantized_model_output_batch = get_output_data(
                            quantize_layer, model, images_in_one_batch)

                        if isinstance(reference_layer, torch.nn.Linear):
                            extended_shape = np.concatenate(
                                (reference_output_batch.shape, np.array([1,
                                                                         1])))
                            reference_output_batch = reference_output_batch.reshape(
                                extended_shape)
                            quantized_model_output_batch = quantized_model_output_batch.reshape(
                                extended_shape)

                        bias_correction.storePreActivationOutput(
                            reference_output_batch)
                        bias_correction.storeQuantizedPreActivationOutput(
                            quantized_model_output_batch)

                    call_empirical_mo_correct_bias(module, bias_correction)

                else:
                    logger.info(
                        'Correcting layer %s using Analytical Bias Correction',
                        module_name)
                    call_analytical_mo_correct_bias(
                        quantize_layer, bn_layer_info.input_bn,
                        bn_layer_info.in_activation_type)

                logger.info('Corrected bias for the layer')

    SaveUtils.remove_quantization_wrappers(model)

    logger.info('Completed bias correction')
コード例 #15
0
    def save_encodings_to_files(self, model, path, filename_prefix, input_shape):
        """
        Save quantization encodings for the given model in json format
        :param model: Model to save
        :param path: Directory path to save
        :param filename_prefix: Filename of the file to save
        :param input_shape: shape of the input parameter to the model
        :return: None
        """
        # pylint: disable=too-many-locals
        device = utils.get_device(model)
        model.cpu()

        encodings_path_onnx_names = os.path.join(path, filename_prefix + '_onnx_names' + '.encodings')
        encodings_path_python_names = os.path.join(path, filename_prefix + '_pytorch_names' + '.encodings')

        encoding_dict_with_pytorch_names = {}
        encoding_dict_with_onnx_names = {}

        inputs = utils.create_rand_tensors_given_shapes(input_shape)

        pytorch_onnx_names_dict = self.get_name_of_op_from_graph(model, *inputs)

        for layer_name, layer in model.named_modules():

            if isinstance(layer, QcQuantizeStandalone):
                value = (layer.output_quantizers[0].encoding.max,
                         layer.output_quantizers[0].encoding.min,
                         layer.output_quantizers[0].encoding.delta,
                         layer.output_quantizers[0].encoding.offset,
                         layer.output_quantizers[0].bitwidth,  # hack - standalone layers have no parameters
                         layer.output_quantizers[0].bitwidth)
                encoding_dict_with_onnx_names[layer_name] = value
                encoding_dict_with_pytorch_names[layer_name] = value

            elif isinstance(layer, QcQuantizeWrapper):

                # This is a hack to keep this working for now.. Need to create new json definitions
                # The reality is that layers may have more than one parameters, or even 0 parameters,
                # this code does not handle that currently
                if layer.param_quantizers:
                    param_bw = next(iter(layer.param_quantizers.values())).bitwidth
                else:
                    param_bw = layer.output_quantizers[0].bitwidth

                value = (layer.output_quantizers[0].encoding.max,
                         layer.output_quantizers[0].encoding.min,
                         layer.output_quantizers[0].encoding.delta,
                         layer.output_quantizers[0].encoding.offset,
                         param_bw,
                         layer.output_quantizers[0].encoding.bw)
                if layer_name in pytorch_onnx_names_dict:
                    encoding_dict_with_onnx_names[pytorch_onnx_names_dict[layer_name]] = value
                    encoding_dict_with_pytorch_names[layer_name] = value

        if not encoding_dict_with_onnx_names:
            raise RuntimeError('Could not find any QcQuantizeOps in the model for saving encodings!')

        save_json_yaml(encodings_path_onnx_names, encoding_dict_with_onnx_names)
        save_json_yaml(encodings_path_python_names, encoding_dict_with_pytorch_names)

        model.to(device)
コード例 #16
0
    def test_multi_output_with_shuffled_layers(self):
        """ Test a multiple layer multi-output model with intermediate Tuple Tensors shuffled """
        class MultiOutputShuffledModel(torch.nn.Module):
            """
            Model with Tuple of Tensors as output shuffled between layers
            """
            def __init__(self):
                super(MultiOutputShuffledModel, self).__init__()
                self.layer1 = test_models.ConfigurableTupleOutputModel(
                    channels=(1, 2, 3))
                self.layer2 = test_models.ConfigurableTupleOutputModel(
                    channels=(2, 3, 1))
                self.layer3 = test_models.ConfigurableTupleOutputModel(
                    channels=(3, 1, 2))

            def forward(self, *inputs):
                x1, x2, x3 = self.layer1(inputs[0], inputs[1], inputs[2])
                y2, y3, y1 = self.layer2(x2, x3, x1)
                z3, z1, z2 = self.layer3(y3, y1, y2)
                return torch.cat([z1, z2, z3, x1], 1)

        model = MultiOutputShuffledModel()
        inp_tensor_list = create_rand_tensors_given_shapes([(1, 1, 8, 8),
                                                            (1, 2, 8, 8),
                                                            (1, 3, 8, 8)])
        conn_graph = ConnectedGraph(model, inp_tensor_list)
        self.assertEqual(10, len(conn_graph.ordered_ops))
        self.assertEqual(
            9,
            len([
                op for op in conn_graph.get_all_ops().keys()
                if 'convolution' in op
            ]))
        self.assertEqual(
            0,
            len([
                op for op in conn_graph.get_all_ops().keys() if 'Tuple' in op
            ]))
        self.assertEqual('cat', conn_graph.ordered_ops[-1].type)

        product_names = conn_graph.get_all_products().keys()
        self.assertEqual(
            0,
            len([product for product in product_names if 'Tuple' in product]))

        expected_products = [
            # TODO fix order of products

            # layer #1 to layer #2
            'convolution_0__to__Split_0',
            'convolution_1_to_convolution_3',
            'convolution_2_to_convolution_4',

            # layer #2 to layer #3
            'convolution_3_to_convolution_8',
            'convolution_4_to_convolution_6',
            'convolution_5_to_convolution_7',

            # layer #3, layer#1.conv1 to cat
            'convolution_6_to_cat_9',
            'convolution_7_to_cat_9',
            'convolution_8_to_cat_9'
        ]

        products = conn_graph.get_all_products()
        for product_name in product_names:
            if product_name in expected_products:
                product = products[product_name]
                self.assertEqual(product.shape, product.producer.output_shape)
                expected_products.remove(product_name)
        self.assertEqual(0, len(expected_products))
        split_product = conn_graph.get_all_products(
        )['Split_0__to__multiple_ops']
        self.assertTrue(conn_graph.get_all_ops()['convolution_5'] in
                        split_product.consumers)
        self.assertTrue(
            conn_graph.get_all_ops()['cat_9'] in split_product.consumers)