コード例 #1
0
    def _binarize_weights_and_module_inputs(self, target_model: NNCFNetwork) -> List[InsertionCommand]:
        device = next(target_model.parameters()).device
        modules = target_model.get_nncf_modules()

        insertion_commands = []
        for scope, module in modules.items():
            scope_str = str(scope)

            if not self._should_consider_scope(scope_str):
                nncf_logger.info("Ignored adding binarizers in scope: {}".format(scope_str))
                continue

            if isinstance(module, torch.nn.modules.Conv2d):
                nncf_logger.info("Adding Weight binarizer in scope: {}".format(scope_str))
                op_weights = UpdateWeight(
                    self.__create_binarize_module()
                ).to(device)

                nncf_logger.info("Adding Activation binarizer in scope: {}".format(scope_str))
                op_inputs = UpdateInputs(ActivationBinarizationScaleThreshold(module.weight.shape)).to(device)

                insertion_commands.append(InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext("", scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), op_weights, OperationPriority.QUANTIZATION_PRIORITY))

                insertion_commands.append(InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext("", scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), op_inputs, OperationPriority.QUANTIZATION_PRIORITY))
        return insertion_commands
コード例 #2
0
    def _sparsify_weights(self,
                          target_model: NNCFNetwork) -> List[InsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_modules = target_model.get_nncf_modules()
        insertion_commands = []
        for module_scope, module in sparsified_modules.items():
            scope_str = str(module_scope)

            if not self._should_consider_scope(scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(scope_str))
            operation = self.create_weight_sparsifying_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.SPARSIFICATION_PRIORITY))
            self._sparsified_module_info.append(
                SparseModuleInfo(scope_str, module, hook.operand))

        return insertion_commands
コード例 #3
0
def make_op_exec_context_for_coalescing_test(scope_str: str) -> OperationExecutionContext:
    ia_op_exec_context = InputAgnosticOperationExecutionContext.from_str(scope_str)
    op_exec_context = OperationExecutionContext(ia_op_exec_context.operator_name,
                                                ia_op_exec_context.scope_in_model,
                                                ia_op_exec_context.call_order,
                                                [TensorMeta(0, 0, [1])])
    return op_exec_context
コード例 #4
0
def generate_qp(scope_str: str,
                target: QuantizerGroup,
                in_port_id: int = None) -> SingleConfigQuantizationPoint:
    if target is QuantizerGroup.WEIGHTS:
        ip = InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP,
                            module_scope=Scope.from_str(scope_str))
    elif target is QuantizerGroup.ACTIVATIONS:
        ip = InsertionPoint(
            InsertionType.OPERATOR_POST_HOOK
            if in_port_id is None else InsertionType.OPERATOR_PRE_HOOK,
            ia_op_exec_context=InputAgnosticOperationExecutionContext.from_str(
                scope_str),
            input_port_id=in_port_id)
    else:
        raise RuntimeError()
    return SingleConfigQuantizationPoint(ip, QuantizerConfig())
コード例 #5
0
    def get_caller_context(self, operator_type: str) -> InputAgnosticOperationExecutionContext:
        """
        Designed to work in the following way - for each scope the context will track the number of the calls to the
        operators with the name operator_type (call_order). The counter values are preserved until reset by a
        corresponding member function of the context, which must be called after each model iteration - this is
        usually handled inside NNCF. This mechanism allows to discern between multiple function calls inside the same
        module that would each require their own instance of compression layers - for instance, multiple `relu`
        function calls (either on their own or inside a `for` cycle), and at the same moment allow the checkpoints to
        be loaded if the model had changed in the meantime in a way that does not impact the major function call
        order (e.g. if comments were added to the .py file with the model)
        """
        version_agnostic_operator_type = get_version_agnostic_name(operator_type)
        if version_agnostic_operator_type is not None:
            operator_type = version_agnostic_operator_type

        call_order = self.get_operator_call_count_in_scope(operator_type, self.scope)

        ia_op_exec_context = InputAgnosticOperationExecutionContext(operator_type,
                                                                    self.scope,
                                                                    call_order)
        return ia_op_exec_context
コード例 #6
0
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []
        bn_for_depthwise = {}

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if self.ignore_frozen_layers and not module.weight.requires_grad:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " the layer appears to be frozen (requires_grad=False)".
                    format(module_scope_str))
                continue

            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if is_grouped_conv(module):
                if is_depthwise_conv(module):
                    previous_conv = get_previous_conv(target_model, module,
                                                      module_scope)
                    if previous_conv:
                        depthwise_bn = get_bn_for_module_scope(
                            target_model, module_scope)
                        bn_for_depthwise[str(previous_conv.op_exec_context.
                                             scope_in_model)] = depthwise_bn

                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is grouped convolution".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules[
                    PrunedModuleInfo.BN_MODULE_NAME] = get_bn_for_module_scope(
                        target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        if self.prune_batch_norms:
            self.update_minfo_with_depthwise_bn(bn_for_depthwise)

        return insertion_commands
コード例 #7
0
class TestInsertionCommands:
    @pytest.fixture()
    def setup(self):
        self.compressed_model = NNCFNetwork(
            InsertionPointTestModel(),
            [ModelInputInfo([1, 1, 10, 10])])  # type: NNCFNetwork

    conv1_module_scope = Scope.from_str(
        'InsertionPointTestModel/NNCFConv2d[conv1]')
    conv1_module_context = InputAgnosticOperationExecutionContext(
        '', conv1_module_scope, 0)
    point_for_conv1_weights = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv1_inputs = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv1_activations = InsertionPoint(
        ia_op_exec_context=conv1_module_context,
        insertion_type=InsertionType.NNCF_MODULE_POST_OP)

    conv2_module_scope = Scope.from_str(
        'InsertionPointTestModel/NNCFConv2d[conv2]')
    conv2_module_context = InputAgnosticOperationExecutionContext(
        '', conv2_module_scope, 0)
    point_for_conv2_weights = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv2_inputs = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_PRE_OP)
    point_for_conv2_activations = InsertionPoint(
        ia_op_exec_context=conv2_module_context,
        insertion_type=InsertionType.NNCF_MODULE_POST_OP)

    linear_op_scope = Scope.from_str('InsertionPointTestModel/linear_0')
    linear_op_context = InputAgnosticOperationExecutionContext(
        'linear', linear_op_scope, 0)
    point_for_linear_weight_input = InsertionPoint(
        ia_op_exec_context=linear_op_context,
        insertion_type=InsertionType.OPERATOR_PRE_HOOK)
    point_for_linear_activation = InsertionPoint(
        ia_op_exec_context=linear_op_context,
        insertion_type=InsertionType.OPERATOR_POST_HOOK)

    relu_op_scope = Scope.from_str('InsertionPointTestModel/ReLU[relu]/relu')
    relu_op_context = InputAgnosticOperationExecutionContext(
        'relu', relu_op_scope, 0)
    point_for_relu_inputs = InsertionPoint(
        ia_op_exec_context=relu_op_context,
        insertion_type=InsertionType.OPERATOR_PRE_HOOK)
    point_for_relu_activations = InsertionPoint(
        ia_op_exec_context=relu_op_context,
        insertion_type=InsertionType.OPERATOR_POST_HOOK)

    available_points = [
        point_for_conv1_weights, point_for_conv2_weights,
        point_for_conv1_inputs, point_for_conv2_inputs,
        point_for_conv1_activations, point_for_conv2_activations,
        point_for_linear_activation, point_for_linear_weight_input,
        point_for_relu_activations, point_for_relu_inputs
    ]

    @pytest.mark.parametrize("insertion_point", available_points)
    def test_single_insertions(self, setup, insertion_point):
        if insertion_point.insertion_type in [
                InsertionType.OPERATOR_PRE_HOOK,
                InsertionType.OPERATOR_POST_HOOK
        ]:
            hook = lambda x: x
        else:
            hook = BaseOp(lambda x: x)

        command = InsertionCommand(insertion_point, hook)
        self.compressed_model.register_insertion_command(command)
        self.compressed_model.commit_compression_changes()

        #pylint:disable=protected-access
        if insertion_point.insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            assert ctx._pre_hooks[
                command.insertion_point.ia_op_exec_context][0] is hook
        if insertion_point.insertion_type == InsertionType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            assert ctx._post_hooks[
                command.insertion_point.ia_op_exec_context][0] is hook
        if insertion_point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            module = self.compressed_model.get_module_by_scope(
                command.insertion_point.ia_op_exec_context.scope_in_model)
            assert module.pre_ops["0"] is hook

        if insertion_point.insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            module = self.compressed_model.get_module_by_scope(
                command.insertion_point.ia_op_exec_context.scope_in_model)
            assert module.post_ops["0"] is hook

    priority_types = ["same", "different"]
    insertion_types = InsertionType
    priority_test_cases = list(
        itertools.product(priority_types, insertion_types))

    @staticmethod
    def check_order(iterable1: List, iterable2: List, ordering: List):
        for idx, order in enumerate(ordering):
            assert iterable1[idx] is iterable2[order]

    # pylint:disable=undefined-variable
    @pytest.mark.parametrize(
        "case",
        priority_test_cases,
        ids=[x[1].name + '-' + x[0] for x in priority_test_cases])
    def test_priority(self, case, setup):
        #pylint:disable=too-many-branches
        priority_type = case[0]
        insertion_type = case[1]
        if insertion_type in [
                InsertionType.NNCF_MODULE_PRE_OP,
                InsertionType.NNCF_MODULE_POST_OP
        ]:
            hook1 = BaseOp(lambda x: x)
            hook2 = BaseOp(lambda x: 2 * x)
            hook3 = BaseOp(lambda x: 3 * x)
        else:
            hook1 = lambda x: x
            hook2 = lambda x: 2 * x
            hook3 = lambda x: 3 * x

        if insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            point = self.point_for_conv2_weights
        elif insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            point = self.point_for_conv1_activations
        elif insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            point = self.point_for_linear_weight_input
        elif insertion_type == InsertionType.OPERATOR_POST_HOOK:
            point = self.point_for_relu_activations

        if priority_type == "same":
            # Same-priority commands will be executed in registration order
            command1 = InsertionCommand(point, hook1,
                                        OperationPriority.DEFAULT_PRIORITY)
            command2 = InsertionCommand(point, hook2,
                                        OperationPriority.DEFAULT_PRIORITY)
            command3 = InsertionCommand(point, hook3,
                                        OperationPriority.DEFAULT_PRIORITY)
        else:
            # Prioritized commands will be executed in ascending priority order
            command1 = InsertionCommand(
                point, hook1, OperationPriority.SPARSIFICATION_PRIORITY)
            command2 = InsertionCommand(
                point, hook2, OperationPriority.QUANTIZATION_PRIORITY)
            command3 = InsertionCommand(point, hook3,
                                        OperationPriority.DEFAULT_PRIORITY)

        self.compressed_model.register_insertion_command(command1)
        self.compressed_model.register_insertion_command(command2)
        self.compressed_model.register_insertion_command(command3)
        self.compressed_model.commit_compression_changes()

        hook_list = [hook1, hook2, hook3]

        if priority_type == "same":
            order = [0, 1, 2]
        elif priority_type == "different":
            order = [2, 0, 1]

        #pylint:disable=protected-access
        if insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            self.check_order(ctx._pre_hooks[point.ia_op_exec_context],
                             hook_list, order)
        if insertion_type == InsertionType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            self.check_order(ctx._post_hooks[point.ia_op_exec_context],
                             hook_list, order)

        if insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
            module = self.compressed_model.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            # Works because Pytorch ModuleDict is ordered
            self.check_order(list(module.pre_ops.values()), hook_list, order)

        if insertion_type == InsertionType.NNCF_MODULE_POST_OP:
            module = self.compressed_model.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            # Works because Pytorch ModuleDict is ordered
            self.check_order(list(module.post_ops.values()), hook_list, order)
コード例 #8
0
def prepare_potential_quantizer_graph(
        graph: NNCFGraph, potential_activations_quantizers: Dict[
            InsertionInfo, Optional[List[QuantizerConfig]]],
        potential_weights_modules: List[PotentialQuantizedModule]
) -> NNCFGraph:
    quantizers_weights_attr = {}
    quantizers_activations_attr = {}
    # pylint:disable=protected-access
    for _, module_scope, qconfig_list in potential_weights_modules:
        matching_graph_op_nodes = graph.get_op_nodes_in_scope(module_scope)

        assert len(
            matching_graph_op_nodes
        ) == 1  # Isn't correct when NNCF module has more than 1 graph node

        op_name = matching_graph_op_nodes[0][
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name
        ia_op_exec_context = InputAgnosticOperationExecutionContext(
            op_name, module_scope, 0)
        str_qconfig_list = ''

        for qconfig in qconfig_list:
            str_qconfig_list += '[' + str(qconfig) + '] '
        quantizers_weights_attr[ia_op_exec_context] = str_qconfig_list

    for insertion_info, qconfig_list in potential_activations_quantizers.items(
    ):
        ia_op_exec_context = insertion_info.op_exec_context.input_agnostic
        str_qconfig_list = ''
        for qconfig in qconfig_list:
            str_qconfig_list += '[' + str(qconfig) + '] '
        quantizers_activations_attr[ia_op_exec_context] = str_qconfig_list
        for linked_op_exec_context in insertion_info.linked_op_exec_contexts:
            quantizers_activations_attr[
                linked_op_exec_context.input_agnostic] = str_qconfig_list

    nx_graph = graph._nx_graph
    nodes = deepcopy(nx_graph.nodes)
    for node_name, node in sorted(nodes.items()):
        ia_op_exec_context_for_node = nx_graph.nodes[node_name][
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic
        node_scope = str(ia_op_exec_context_for_node)
        if ia_op_exec_context_for_node in quantizers_activations_attr:
            label = "Quantizer: {}".format(
                quantizers_activations_attr[ia_op_exec_context_for_node])
            nx_graph.add_node(node_scope,
                              label=label,
                              color="purple",
                              id=node[NNCFGraph.ID_NODE_ATTR],
                              op_exec_context=nx_graph.nodes[node_name][
                                  NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR])
            next_nodes = deepcopy(nx_graph._succ[node_name])
            for next_node_name, _ in next_nodes.items():
                nx_graph.add_edge(node_scope, next_node_name)
                nx_graph.remove_edge(node_name, next_node_name)
            nx_graph.add_edge(node_name, node_scope)
        elif ia_op_exec_context_for_node in quantizers_weights_attr:
            label = "Quantizer: {}".format(
                quantizers_weights_attr[ia_op_exec_context_for_node])
            nx_graph.add_node(node_scope,
                              label=label,
                              color="purple",
                              id=node[NNCFGraph.ID_NODE_ATTR],
                              op_exec_context=nx_graph.nodes[node_name][
                                  NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR])
            nx_graph.add_edge(node_scope, node_name)

    return graph
コード例 #9
0
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules['bn_module'] = get_bn_for_module_scope(
                    target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        return insertion_commands
コード例 #10
0
def test_quantizer_scale_linking():
    nncf_config = get_quantization_config_without_range_init(model_size=1)
    nncf_config['quantizer_setup_type'] = 'pattern_based'
    nncf_config["compression"]["quantize_outputs"] = True
    nncf_config["compression"]["quantize_inputs"] = False
    nncf_config["input_info"] = [{
        "sample_size": [1, 1, 1, 1],
    }, {
        "sample_size": [1, 1, 1, 1],
    }]
    nncf_config["compression"]["activations"] = {
        "linked_quantizer_scopes": [[
            # Note: Assuming that quantizers are attached as a post-op to the specified operation
            "QuantizerLinkingTestModel/Path[path2]/__mul___0",
            "QuantizerLinkingTestModel/Path[path2]/__add___0",
        ]],
        "ignored_scopes": [
            # Ignore path output averaging operations
            "QuantizerLinkingTestModel/__add___0",
            "QuantizerLinkingTestModel/__add___1",
            "QuantizerLinkingTestModel/__add___2",
        ]
    }

    compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(
        QuantizerLinkingTestModel(), nncf_config)

    # 2 paths x 3 quantizers - 1 because two are shared in one path
    assert len(compression_ctrl.non_weight_quantizers) == 5

    test_input1 = torch.ones([1, 1, 1, 1])
    test_input2 = 2 * test_input1

    non_shared_mul_quantizer_id = NonWeightQuantizerId(
        InputAgnosticOperationExecutionContext.from_str(
            "QuantizerLinkingTestModel/Path[path1]/__mul___0"))

    non_shared_add_quantizer_id = NonWeightQuantizerId(
        InputAgnosticOperationExecutionContext.from_str(
            "QuantizerLinkingTestModel/Path[path1]/__add___0"))

    shared_quantizer_id = NonWeightQuantizerId(
        InputAgnosticOperationExecutionContext.from_str(
            "QuantizerLinkingTestModel/Path[path2]/__add___0"))

    non_shared_mul_quantizer = compression_ctrl.non_weight_quantizers[
        non_shared_mul_quantizer_id].quantizer_module_ref
    non_shared_add_quantizer = compression_ctrl.non_weight_quantizers[
        non_shared_add_quantizer_id].quantizer_module_ref
    shared_quantizer = compression_ctrl.non_weight_quantizers[
        shared_quantizer_id].quantizer_module_ref

    old_scale = 765.0  # so that the quantum is equal to 3
    with torch.no_grad():
        for quantizer in compression_ctrl.all_quantizations.values():
            quantizer.scale.fill_(old_scale)

    # Expected outputs without compression - 6, 12, 8. Scale deliberately set to preserve the values
    uncompressed_expected_outputs = (6.0 * torch.ones([1]),
                                     12.0 * torch.ones([1]),
                                     18.0 * torch.ones([1]))
    outputs_with_shared_scale_1 = compressed_model(test_input1, test_input2)

    for uncomp_out, comp_out_1 in zip(uncompressed_expected_outputs,
                                      outputs_with_shared_scale_1):
        assert torch.allclose(uncomp_out, comp_out_1)

    # Specifically clip the shared quantizer's outputs by setting scale to 1.0
    new_shared_scale = 1.0
    with torch.no_grad():
        shared_quantizer.scale.fill_(new_shared_scale)
    outputs_with_shared_scale_2 = compressed_model(test_input1, test_input2)

    # __add___0 outputs
    assert torch.allclose(outputs_with_shared_scale_2[0],
                          4.0 * torch.ones([1]))
    # __mul___0 outputs
    assert torch.allclose(outputs_with_shared_scale_2[1],
                          7.0 * torch.ones([1]))
    # __add___1 outputs
    assert torch.allclose(outputs_with_shared_scale_2[2],
                          12.0 * torch.ones([1]))

    # Clipping the non-shared quantizers at the same position in the path as the two shared ones
    # in the same manner is required to simulate the same grad input for both the shared quantizers
    # and the unshared ones
    with torch.no_grad():
        non_shared_mul_quantizer.scale.fill_(new_shared_scale)
        non_shared_add_quantizer.scale.fill_(new_shared_scale)
    final_output = compressed_model(test_input1, test_input2)[2]
    final_output.backward()

    assert torch.allclose(
        shared_quantizer.scale.grad, non_shared_mul_quantizer.scale.grad +
        non_shared_add_quantizer.scale.grad)