Exemplo n.º 1
0
def get_all_modules_in_blocks(
        compressed_model: NNCFNetwork,
        op_adresses_in_blocks: List[OperationAddress]
) -> List[torch.nn.Module]:
    """
    Returns set of all modules included in the block.

    :param compressed_model: Target model.
    :param op_adresses_in_blocks: Set of operation addresses for building block.
    :return: List of module for building block.
    """
    modules = []
    for op_address in op_adresses_in_blocks:
        if op_address.operator_name in NNCF_MODULES_OP_NAMES:
            modules.append(
                compressed_model.get_module_by_scope(
                    op_address.scope_in_model))
    return modules
Exemplo n.º 2
0
class TestInsertionCommands:
    @pytest.fixture()
    def setup(self):
        self.compressed_model = NNCFNetwork(InsertionPointTestModel(),
                                            [ModelInputInfo([1, 1, 10, 10])])  # type: NNCFNetwork

    conv1_node_name = 'InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0'
    point_for_conv1_weights = PTTargetPoint(target_type=TargetType.OPERATION_WITH_WEIGHTS,
                                            target_node_name=conv1_node_name)
    point_for_conv1_inputs = PTTargetPoint(target_type=TargetType.OPERATOR_PRE_HOOK,
                                           target_node_name=conv1_node_name)
    point_for_conv1_activations = PTTargetPoint(target_type=TargetType.POST_LAYER_OPERATION,
                                                target_node_name=conv1_node_name)

    conv2_node_name = 'InsertionPointTestModel/NNCFConv2d[conv2]/conv2d_0'
    point_for_conv2_weights = PTTargetPoint(target_type=TargetType.OPERATION_WITH_WEIGHTS,
                                            target_node_name=conv2_node_name)
    point_for_conv2_inputs = PTTargetPoint(target_type=TargetType.OPERATOR_PRE_HOOK,
                                           target_node_name=conv2_node_name)
    point_for_conv2_activations = PTTargetPoint(target_type=TargetType.POST_LAYER_OPERATION,
                                                target_node_name=conv2_node_name)

    linear_node_name = 'InsertionPointTestModel/linear_0'
    point_for_linear_weight_input = PTTargetPoint(target_type=TargetType.OPERATOR_PRE_HOOK,
                                                  target_node_name=linear_node_name, input_port_id=0)
    point_for_linear_activation = PTTargetPoint(target_type=TargetType.OPERATOR_POST_HOOK,
                                                target_node_name=linear_node_name)

    relu_node_name = 'InsertionPointTestModel/ReLU[relu]/relu_0'
    point_for_relu_inputs = PTTargetPoint(target_type=TargetType.OPERATOR_PRE_HOOK,
                                          target_node_name=relu_node_name, input_port_id=0)
    point_for_relu_activations = PTTargetPoint(target_type=TargetType.OPERATOR_POST_HOOK,
                                               target_node_name=relu_node_name)

    available_points = [point_for_conv1_weights,
                        point_for_conv2_weights,
                        point_for_conv1_inputs,
                        point_for_conv2_inputs,
                        point_for_conv1_activations,
                        point_for_conv2_activations,
                        point_for_linear_activation,
                        point_for_linear_weight_input,
                        point_for_relu_activations,
                        point_for_relu_inputs]

    @pytest.mark.parametrize("target_point", available_points)
    def test_single_insertions(self, setup, target_point: PTTargetPoint):
        insertion_point = PTInsertionPoint(target_point.target_type,
                                           OperationAddress.from_str(target_point.target_node_name),
                                           target_point.input_port_id)
        if insertion_point.insertion_type in [PTInsertionType.OPERATOR_PRE_HOOK, PTInsertionType.OPERATOR_POST_HOOK]:
            hook = lambda x: x
        else:
            hook = BaseOp(lambda x: x)

        self.compressed_model.insert_at_point(insertion_point, [hook])

        # pylint:disable=protected-access
        if insertion_point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            pre_hook_id = PreHookId(insertion_point.op_address, input_port_id=insertion_point.input_port_id)
            assert ctx._pre_hooks[pre_hook_id][0] is hook
        if insertion_point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            assert ctx._post_hooks[insertion_point.op_address][0] is hook
        if insertion_point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP:
            module = self.compressed_model.get_module_by_scope(insertion_point.module_scope)
            assert module.pre_ops["0"] is hook

        if insertion_point.insertion_type == PTInsertionType.NNCF_MODULE_POST_OP:
            module = self.compressed_model.get_module_by_scope(insertion_point.module_scope)
            assert module.post_ops["0"] is hook

    priority_types = ["same", "different"]
    insertion_types = TargetType
    priority_test_cases = list(itertools.product(priority_types, insertion_types))

    @staticmethod
    def check_order(iterable1: List, iterable2: List, ordering: List):
        for idx, order in enumerate(ordering):
            assert iterable1[idx] is iterable2[order]

    # pylint:disable=undefined-variable
    @pytest.mark.parametrize("case", priority_test_cases, ids=[x[1].name + '-' + x[0] for x in priority_test_cases])
    def test_priority(self, case, setup):
        # pylint:disable=too-many-branches
        priority_type = case[0]
        insertion_type = case[1]
        if insertion_type in [TargetType.OPERATION_WITH_WEIGHTS, TargetType.POST_LAYER_OPERATION]:
            hook1 = BaseOp(lambda x: x)
            hook2 = BaseOp(lambda x: 2 * x)
            hook3 = BaseOp(lambda x: 3 * x)
        else:
            hook1 = lambda x: x
            hook2 = lambda x: 2 * x
            hook3 = lambda x: 3 * x

        if insertion_type == TargetType.OPERATION_WITH_WEIGHTS:
            point = self.point_for_conv2_weights
        elif insertion_type == TargetType.POST_LAYER_OPERATION:
            point = self.point_for_conv1_activations
        elif insertion_type == TargetType.OPERATOR_PRE_HOOK:
            point = self.point_for_linear_weight_input
        elif insertion_type == TargetType.OPERATOR_POST_HOOK:
            point = self.point_for_relu_activations
        else:
            pytest.skip("Insertion type {} currently unsupported in PT".format(insertion_type))

        if priority_type == "same":
            # Same-priority commands will be executed in registration order
            command1 = PTInsertionCommand(point, hook1, TransformationPriority.DEFAULT_PRIORITY)
            command2 = PTInsertionCommand(point, hook2, TransformationPriority.DEFAULT_PRIORITY)
            command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY)
        else:
            # Prioritized commands will be executed in ascending priority order
            command1 = PTInsertionCommand(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY)
            command2 = PTInsertionCommand(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY)
            command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY)

        layout = PTTransformationLayout()
        layout.register(command1)
        layout.register(command2)
        layout.register(command3)
        self.compressed_model = PTModelTransformer(self.compressed_model).transform(layout)

        hook_list = [hook1, hook2, hook3]

        if priority_type == "same":
            order = [0, 1, 2]
        elif priority_type == "different":
            order = [2, 0, 1]

        # pylint:disable=protected-access
        if insertion_type == TargetType.OPERATOR_PRE_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            pre_hook_id = PreHookId(OperationAddress.from_str(point.target_node_name),
                                    input_port_id=point.input_port_id)
            self.check_order(ctx._pre_hooks[pre_hook_id], hook_list, order)
        if insertion_type == TargetType.OPERATOR_POST_HOOK:
            ctx = self.compressed_model.get_tracing_context()
            self.check_order(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)],
                             hook_list, order)

        if insertion_type == TargetType.OPERATION_WITH_WEIGHTS:
            module = self.compressed_model.get_containing_module(point.target_node_name)
            # Works because Pytorch ModuleDict is ordered
            self.check_order([x.operand for x in module.pre_ops.values()], hook_list, order)

        if insertion_type == TargetType.POST_LAYER_OPERATION:
            module = self.compressed_model.get_containing_module(point.target_node_name)
            # Works because Pytorch ModuleDict is ordered
            self.check_order(list(module.post_ops.values()), hook_list, order)