Beispiel #1
0
def create_compressed_model(model: Module,
                            config: Config,
                            dummy_forward_fn: Callable[[Module], Any] = None):
    """dummy_forward_fn will be used instead of a *forward* function call to build
    the graph - useful when the original training pipeline has special formats of
    data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model will be made with a single Tensor with
    a shape and type specified in config."""

    if dummy_forward_fn is None:
        input_info_list = create_input_infos(config)
        graph_builder = GraphBuilder(
            custom_forward_fn=create_dummy_forward_fn(input_info_list))
    else:
        graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

    if is_main_process():
        print(*get_all_modules(model).keys(), sep="\n")
        reset_context('create_model')
        graph = graph_builder.build_graph(model, 'create_model')
        graph.dump_graph(osp.join(config.log_dir, "original_graph.dot"))

    compression_algo = create_compression_algorithm(model, config,
                                                    dummy_forward_fn)

    compressed_model = compression_algo.model
    if is_main_process() and not isinstance(compression_algo,
                                            NoCompressionAlgorithm):
        context_name = 'create_compressed_graph'
        if isinstance(compressed_model, QuantizedNetwork):
            context_name = compressed_model.get_context_name()
        graph = graph_builder.build_graph(compression_algo.model, context_name)
        graph.dump_graph(osp.join(config.log_dir, "compressed_graph.dot"))

    return compression_algo, compressed_model
def test_activation_shape_tracing(input_shape: Tuple):
    model = ModelForTest()
    input_info = ModelInputInfo(input_shape)
    graph_builder = GraphBuilder(create_dummy_forward_fn([
        input_info,
    ]))
    graph = graph_builder.build_graph(model)

    shape1 = (input_shape[0], ModelForTest.CONV1_OUT_CHANNELS, input_shape[2],
              input_shape[3])
    ref_node_ids_and_output_shapes = [
        # TODO: extend with checking input tensor size once proper input node marking is implemented
        ("0 ModelForTest/Conv2d[conv1]/conv2d", [shape1]),
        ("1 ModelForTest/BatchNorm2d[bn1]/batch_norm", [shape1]),
        ("2 ModelForTest/ReLU[relu1]/RELU", [shape1, shape1]),
        ("3 ModelForTest/max_pool2d",
         [(shape1[0], shape1[1], shape1[2] // ModelForTest.MAXPOOL_SIZE,
           shape1[3] // ModelForTest.MAXPOOL_SIZE)]),
        ("4 ModelForTest/ConvTranspose2d[convt1]/conv_transpose2d",
         [input_shape]),
        ("5 ModelForTest/cat",
         [(input_shape[0], ModelForTest.CONV2_IN_CHANNELS, input_shape[2],
           input_shape[3])])

        # TODO: extend with checking output tensor size once proper output node marking is implemented
    ]
    for node_id, ref_output_shapes in ref_node_ids_and_output_shapes:
        # pylint:disable=protected-access
        output_edges = graph._get_nncf_graph_pattern_input_output([
            node_id,
        ]).output_edges
        output_shapes = [x.tensor_shape for x in output_edges]
        assert output_shapes == ref_output_shapes, "Failed for {}".format(
            node_id)
Beispiel #3
0
def get_all_node_names(model, input_sample_size, builder=None):
    if not builder:
        builder = GraphBuilder(
            create_dummy_forward_fn([
                ModelInputInfo(input_sample_size),
            ]))
    graph = builder.build_graph(model)
    return [
        node_name.split(' ', 1)[1] for node_name in graph.get_all_node_keys()
    ]
def get_all_node_names(model,
                       input_sample_size,
                       graph_scope=None,
                       builder=None):
    if graph_scope is None:
        graph_scope = 'utils'
    reset_context(graph_scope)
    if not builder:
        builder = GraphBuilder(
            create_dummy_forward_fn([
                ModelInputInfo(input_sample_size),
            ]))
    graph = builder.build_graph(model, graph_scope)
    return [
        node_name.split(' ', 1)[1] for node_name in graph.get_all_node_keys()
    ]
 def test_build_graph(self, desc: ModelDesc):
     net = desc.model_builder()
     input_sample_sizes = desc.input_sample_sizes
     if isinstance(input_sample_sizes, tuple):
         input_info_list = [
             ModelInputInfo(sample_size)
             for sample_size in input_sample_sizes
         ]
     else:
         input_info_list = [ModelInputInfo(input_sample_sizes)]
     dummy_forward_fn = desc.dummy_forward_fn
     if not dummy_forward_fn:
         dummy_forward_fn = create_dummy_forward_fn(input_info_list)
     graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)
     graph = graph_builder.build_graph(net)
     check_graph(graph, desc.dot_filename, 'original')
def test_ambiguous_function():
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList(
                [nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)])

        def forward(self, x):
            for layer in self.layers:
                x = F.relu(layer(x))

    mod = Model()
    input_info = ModelInputInfo([1, 1, 1, 1])

    graph_builder = GraphBuilder(custom_forward_fn=create_dummy_forward_fn([
        input_info,
    ]))
    graph = graph_builder.build_graph(mod)

    unique_op_exec_contexts = set()
    # pylint:disable=protected-access
    for _, node in graph._nx_graph.nodes.items():
        node_op_exec_context = node[NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]
        assert node_op_exec_context not in unique_op_exec_contexts
def create_compressed_model(model: Module, config: NNCFConfig,
                            resuming_state_dict: dict = None,
                            dummy_forward_fn: Callable[[Module], Any] = None,
                            dump_graphs=True,) \
    -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
    """
    The main function used to produce a model ready for compression fine-tuning from an original PyTorch
    model and a configuration object.
    dummy_forward_fn
    :param model: The original model. Should have its parameters already loaded from a checkpoint or another
    source.
    :param config: A configuration object used to determine the exact compression modifications to be applied
    to the model
    :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after
    building.
    :param dummy_forward_fn: will be used instead of a *forward* function call to build
    the internal graph representation via tracing. Specifying this is useful when the original training pipeline
    has special formats of data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according
    to the shape specified in the config object.
    :param dump_graphs: Whether or not should also dump the internal graph representation of the
    original and compressed models in the .dot format into the log directory.
    :return: A controller for the compression algorithm (or algorithms, in which case the controller
    is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped
    as an object of NNCFNetwork."""

    if dump_graphs:
        if dummy_forward_fn is None:
            input_info_list = create_input_infos(config)
            graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=True))
        else:
            graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

        if is_main_process():
            graph = graph_builder.build_graph(model)
            graph.dump_graph(osp.join(config.get("log_dir", "."),
                                      "original_graph.dot"),
                             extended=True)

    if is_debug():
        set_debug_log_dir(config.get("log_dir", "."))

    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    should_init = resuming_state_dict is None
    compression_algo_builder_list = create_compression_algorithm_builders(
        config, should_init=should_init)

    for builder in compression_algo_builder_list:
        compressed_model = builder.apply_to(compressed_model)
    compression_ctrl = compressed_model.commit_compression_changes()

    if dump_graphs and is_main_process() and compression_algo_builder_list:
        if dummy_forward_fn is None:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=False))
        else:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=dummy_forward_fn)

        graph = compressed_graph_builder.build_graph(
            compressed_model, compressed_model.get_tracing_context())
        graph.dump_graph(osp.join(config.get("log_dir", "."),
                                  "compressed_graph.dot"),
                         extended=True)

    if resuming_state_dict is not None:
        load_state(compressed_model, resuming_state_dict, is_resume=True)

    return compression_ctrl, compressed_model
class QuantizedNetwork(nn.Module, PostGraphBuildActing):
    def __init__(self,
                 module,
                 quantize_module_creator_fn,
                 input_infos=None,
                 dummy_forward_fn=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 quantize_inputs=True,
                 quantize_outputs=False,
                 quantizable_subgraph_patterns=None,
                 scopes_without_shape_matching=None,
                 disable_function_quantization_hooks=False):
        super().__init__()
        self.set_nncf_wrapped_module(module)
        self.quantize_inputs = quantize_inputs
        self.quantize_outputs = quantize_outputs
        self.input_infos = input_infos
        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self.activation_quantizers = nn.ModuleDict()
        self.function_quantizers = nn.ModuleDict()
        self.quantized_weight_modules = OrderedDict()
        self.quantized_activation_modules = OrderedDict()
        self.quantize_module_creator_fn = quantize_module_creator_fn
        self.quantizable_subgraph_patterns = quantizable_subgraph_patterns
        self._dummy_forward_fn = dummy_forward_fn
        self._nncf_module_scopes = []  # type: List[Scope]
        self.debug_interface = QuantizationDebugInterface() if is_debug(
        ) else None
        self.scopes_without_shape_matching = scopes_without_shape_matching

        device = next(module.parameters()).device

        self.all_quantizations = OrderedDict()
        self._processed_input_agnostic_op_exec_contexts = set()
        self._processed_function_quantizers = set()

        # all modules should be replaced prior to graph building
        self._replace_quantized_modules_by_nncf_modules(device)
        self._register_weight_quantization_operations(device)

        if self._dummy_forward_fn is None:
            self._dummy_forward_fn = create_dummy_forward_fn(self.input_infos)

        self._graph_builder = GraphBuilder(
            custom_forward_fn=self._dummy_forward_fn)

        self._context_name = "orig"
        if self.scopes_without_shape_matching:
            get_context(self._context_name).add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            self, self._context_name)

        self._context_name = "quantized_graphs"
        self._ctx = get_context("quantized_graphs")
        if self.scopes_without_shape_matching:
            get_context(self._context_name).add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._register_activation_quantization_hooks(device)
        if self.quantize_inputs:
            self._register_input_quantization_operations(device)

        if not disable_function_quantization_hooks:
            self._register_function_quantization_hooks(device)

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        self.all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        self.load_listener = LoadStateListener(self, self.all_quantizations)
        if self.debug_interface is not None:
            self.debug_interface.init_actual(self.all_quantizations.keys(),
                                             self.activation_quantizers.keys(),
                                             self.function_quantizers.keys())

    @debuggable_forward
    def forward(self, *args, **kwargs):
        with context(self._context_name) as ctx:  # type: TracingContext
            ctx.base_module_thread_local_replica = self
            retval = self.get_nncf_wrapped_module()(*args, **kwargs)
        return retval

    # Cannnot use property syntax here, otherwise the wrapped module will end up
    # being twice in the same checkpoint with different prefixes
    def get_nncf_wrapped_module(self):
        return getattr(self, MODULE_WRAPPED_BY_NNCF_ATTR_NAME)

    def set_nncf_wrapped_module(self, value):
        setattr(self, MODULE_WRAPPED_BY_NNCF_ATTR_NAME, value)

    def __getattr__(self, name):
        wrapped_module = super().__getattr__(MODULE_WRAPPED_BY_NNCF_ATTR_NAME)
        if hasattr(wrapped_module, name):
            return getattr(wrapped_module, name)
        return super().__getattr__(name)

    def get_quantized_graph(self) -> NNCFGraph:
        return self._ctx.graph

    def get_context_name(self) -> str:
        return self._context_name

    def _should_consider_scope(self, scope_str: str) -> bool:
        return (self.target_scopes is None or in_scope_list(scope_str, self.target_scopes)) \
               and not in_scope_list(scope_str, self.ignored_scopes)

    def _replace_quantized_modules_by_nncf_modules(self, device):
        module, self._nncf_module_scopes = replace_modules_by_nncf_modules(
            self.get_nncf_wrapped_module(),
            ignored_scopes=self.ignored_scopes,
            target_scopes=self.target_scopes,
            logger=logger)
        self.set_nncf_wrapped_module(module.to(device))

    def _register_weight_quantization_operation(self, module_name, module,
                                                device):
        logger.info(
            "Adding signed Weight quantizer in scope: {}".format(module_name))
        op = UpdateWeight(
            self.quantize_module_creator_fn(module_name,
                                            is_weights=True)).to(device)
        module.register_pre_forward_operation(op)

    def _register_input_quantization_operation(self, module_name, module,
                                               device):
        # Only use the shape of the 0-th input info specified in config. TODO: fix this
        input_shape = self.input_infos[
            0].shape if self.input_infos is not None else None
        quantizer = self.quantize_module_creator_fn(module_name,
                                                    is_weights=False,
                                                    input_shape=input_shape)

        logger.info("Adding {} input quantizer in scope: {}".format(
            "signed" if quantizer.signed else "unsigned", module_name))

        module.register_pre_forward_operation(
            UpdateInputs(quantizer).to(device))

    def _register_weight_quantization_operations(self, device):
        modules = get_all_modules_by_type(self.get_nncf_wrapped_module(),
                                          NNCF_MODULES)

        for name, module in modules.items():
            if not self._should_consider_scope(name):
                logger.info(
                    "Ignored adding Weight quantizer in scope: {}".format(
                        name))
                continue

            self.quantized_weight_modules[name] = module
            self._register_weight_quantization_operation(name, module, device)

    def _register_input_quantization_operations(self, device):
        # limitations:
        # graph is incorrectly split into subgraphs and there are no quantize layers before QuantizeMixin

        graph_roots = self._original_graph.get_graph_roots()

        def get_first_noncompression_module_node_after(
                graph_node: NNCFNode, graph: NNCFGraph) -> NNCFNode:
            """ Gets the pre-op node immediately preceding the first non-COMPRESSION_MODULES node
                after `graph_node`.
                This is required in case there are multiple compression operations applied to the actual input node;
                for instance, in case of sparsity + quantization the input convolution might be preceded
                by 2 pre-ops - binary sparsity mask application and weight quantization
                """
            curr_m = get_module_for_scope(
                self.get_nncf_wrapped_module(),
                graph_node.op_exec_context.scope_in_model)
            if not isinstance(
                    curr_m, tuple(COMPRESSION_MODULES.registry_dict.values())):
                return graph_node
            next_node_list = graph.get_next_nodes(graph_node)
            next_node = next(iter(
                next_node_list))  # not handling the branching case for now

            m = get_module_for_scope(self.get_nncf_wrapped_module(),
                                     next_node.op_exec_context.scope_in_model)
            if isinstance(m,
                          tuple(COMPRESSION_MODULES.registry_dict.values())):
                return get_first_noncompression_module_node_after(
                    next_node, graph)
            return graph_node

        for idx, node in enumerate(graph_roots):
            graph_roots[idx] = get_first_noncompression_module_node_after(
                node, self._original_graph)

        inputs = []
        for node in graph_roots:
            scope_str = str(node.op_exec_context.scope_in_model)
            # if the node is quantizer, we get its successor to get the input of original graph
            if self._should_consider_scope(scope_str):
                module = get_module_for_scope(
                    self.get_nncf_wrapped_module(),
                    node.op_exec_context.scope_in_model)
                if isinstance(
                        module,
                        tuple(QUANTIZATION_MODULES.registry_dict.values())):
                    next_node_list = self._original_graph.get_next_nodes(node)
                    if next_node_list:
                        next_node = next(
                            iter(next_node_list
                                 ))  # not handling the branching case for now
                        next_module = get_module_for_scope(
                            self.get_nncf_wrapped_module(),
                            next_node.op_exec_context.scope_in_model)
                        if next_module in self.quantized_weight_modules.values() and \
                            self._original_graph.get_inputs_count(next_node) == 1:
                            # Quantizer is the only input of the node
                            inputs.append(next_node)
                else:
                    inputs.append(node)

        def _add_input_quantizers_traverser(node: NNCFNode) -> bool:
            module = get_module_for_scope(self.get_nncf_wrapped_module(),
                                          node.op_exec_context.scope_in_model)
            if module is None:
                return True
            is_quantized_weight = module in self.quantized_weight_modules.values(
            )
            module_name = str(node.op_exec_context.scope_in_model)
            if is_quantized_weight and module not in self.quantized_activation_modules.values(
            ):
                self.quantized_activation_modules[module_name] = module
                self._register_input_quantization_operation(
                    module_name, module, device)

            if isinstance(module,
                          tuple(QUANTIZATION_MODULES.registry_dict.values())
                          ) or is_quantized_weight:
                return True
            return False

        for node in inputs:
            self._original_graph.traverse_graph(
                node, _add_input_quantizers_traverser)

    def _make_custom_quantizable_subgraph_pattern(self):
        full_pattern = _make_quantizable_subgraph_pattern()
        if self.quantizable_subgraph_patterns is not None:
            for pattern in self.quantizable_subgraph_patterns:
                if not isinstance(pattern, str):
                    custom_pattern = functools.reduce(
                        operator.add, [N(node) for node in pattern])
                else:
                    custom_pattern = N(pattern)
                full_pattern = full_pattern | custom_pattern
        return full_pattern

    class ActivationQuantizationHook:
        """Cannot simply register the quantizer module as a callable hook, since we need to call
        a thread-local version of the quantizer module during base module execution."""
        def __init__(
                self,
                context_name: str,
                ia_op_exec_context: InputAgnosticOperationExecutionContext,
                debug_interface: QuantizationDebugInterface = None):
            self.context_name = context_name
            self.ia_op_exec_context = ia_op_exec_context
            self.debug_interface = debug_interface

        def __call__(self, *args, **kwargs):
            if self.debug_interface is not None:
                self.debug_interface.register_activation_quantize_call(
                    str(self.ia_op_exec_context))
            replica = get_context(
                self.context_name).base_module_thread_local_replica
            return replica.activation_quantizers[str(self.ia_op_exec_context)](
                *args, **kwargs)

    def _register_activation_quantization_hooks(self, device):
        pattern = self._make_custom_quantizable_subgraph_pattern()
        insertion_point_nncf_nodes = self._original_graph.get_insertion_point_nodes_after_pattern(
            pattern)

        for ip_node in insertion_point_nncf_nodes:
            ia_op_exec_context = ip_node.op_exec_context.input_agnostic
            operator_scope_str = str(ia_op_exec_context)

            if not self.quantize_outputs and self._original_graph.is_output_node(
                    ip_node):
                logger.info(
                    "Ignored adding Activation Quantize "
                    "in scope (output scope, quantize_outputs=False): {}".
                    format(operator_scope_str))
                continue
            if not self._should_consider_scope(operator_scope_str):
                logger.info(
                    "Ignored adding Activation quantizer in scope: {}".format(
                        operator_scope_str))
                continue

            if ia_op_exec_context in self._processed_input_agnostic_op_exec_contexts:
                raise RuntimeError(
                    "Ambiguous call to {fn} with call order {co} in current scope. "
                    "Cannot insert quantization hooks "
                    "automatically!".format(
                        fn=ia_op_exec_context.operator_name,
                        co=ia_op_exec_context.call_order))
            self._processed_input_agnostic_op_exec_contexts.add(
                ia_op_exec_context)

            assert ia_op_exec_context not in self.activation_quantizers
            input_shape = ip_node.op_exec_context.tensor_metas[0].shape
            quantizer = self.quantize_module_creator_fn(
                operator_scope_str, is_weights=False,
                input_shape=input_shape).to(device)
            self.activation_quantizers[operator_scope_str] = quantizer

            if isinstance(quantizer, BaseQuantizer):
                logger.info(
                    "Adding {} Activation Quantize in scope: {}".format(
                        "signed" if quantizer.signed else "unsigned",
                        operator_scope_str))
            else:
                logger.info("Adding Activation Binarize in scope: {}".format(
                    operator_scope_str))

            self._ctx.register_post_hooks([
                self.ActivationQuantizationHook(self._context_name,
                                                ia_op_exec_context,
                                                self.debug_interface),
            ], ia_op_exec_context)

        # NOTE: Order of activations must be the same to correctly broadcast parameters (e.g. scales) in distributed
        # mode (see call of `_dist_broadcast_coalesced` in torch/nn/parallel/distributed.py for more details)
        # pylint: disable=protected-access
        self.activation_quantizers._modules = OrderedDict(
            sorted(self.activation_quantizers._modules.items()))

    def rebuild_graph(self, *input_args):
        ctx = get_context(self._context_name)
        ctx.reset_graph()
        _ = self._graph_builder.build_graph(self, self._context_name)

    def post_build_graph_actions(self):
        # Reset initialization flags (`initialized`) for all quantization modules
        # after dummy `load_state_dict` call.
        for module in self.all_quantizations.values():
            module.initialized = False

    class FunctionQuantizerKey:
        def __init__(
                self,
                ia_op_exec_context: InputAgnosticOperationExecutionContext,
                input_arg_idx: int):
            self.ia_op_exec_context = ia_op_exec_context
            self.input_arg_idx = input_arg_idx

        def __str__(self):
            return str(self.ia_op_exec_context) + "_input" + str(
                self.input_arg_idx)

        def __hash__(self):
            return hash((self.ia_op_exec_context, self.input_arg_idx))

    class FunctionQuantizationPreHook:
        """Cannot simply register the quantizer module as a callable hook, since we need to call
        a thread-local version of the quantizer module during base module execution."""
        def __init__(self,
                     context_name: str,
                     func_in_quant_info: 'FunctionQuantizationInfo',
                     debug_interface: QuantizationDebugInterface = None):
            self.context_name = context_name
            self.func_in_quant_info = func_in_quant_info
            self.debug_interface = debug_interface

        def __call__(self, op_inputs: OperatorInput):
            quantizer_dict_key = str(self.func_in_quant_info)
            if self.debug_interface is not None:
                self.debug_interface.register_function_quantizer_call(
                    quantizer_dict_key)
            replica = get_context(
                self.context_name).base_module_thread_local_replica
            idx = self.func_in_quant_info.input_arg_idx
            op_inputs.op_args[idx] = replica.function_quantizers[
                quantizer_dict_key](op_inputs.op_args[idx])
            return op_inputs

    def _register_function_quantization_hooks(self, device):
        if not FUNCTIONS_TO_QUANTIZE:
            return
        pattern = N(FUNCTIONS_TO_QUANTIZE[0].name)
        for i in range(1, len(FUNCTIONS_TO_QUANTIZE)):
            pattern |= N(FUNCTIONS_TO_QUANTIZE[i].name)

        insertion_points = self._original_graph.get_insertion_point_nodes_after_pattern(
            pattern)

        non_shadowed_insertion_points = []
        for ip_node in insertion_points:
            is_function_in_nncf_module = False
            for nncf_scope in self._nncf_module_scopes:
                if ip_node.op_exec_context.scope_in_model in nncf_scope:
                    is_function_in_nncf_module = True
            if is_function_in_nncf_module:
                continue
            non_shadowed_insertion_points.append(ip_node)

        for ip_node in non_shadowed_insertion_points:
            ia_op_exec_context = ip_node.op_exec_context.input_agnostic
            scope_str = str(ia_op_exec_context.scope_in_model)

            if not self._should_consider_scope(scope_str):
                logger.info(
                    "Ignored adding function input quantizer in scope: {}".
                    format(scope_str))
                continue

            function_arg_positions_to_quantize = get_arg_positions_to_quantize(
                ia_op_exec_context.operator_name)
            assert function_arg_positions_to_quantize is not None, "Function with inputs to be quantized has " \
                                                                   "no info struct registered in " \
                                                                   "QUANTIZED_INPUT_FUNCTIONS!"

            pre_hooks_to_register = []
            for input_arg_idx in function_arg_positions_to_quantize:
                ip_arg_quant_key = self.FunctionQuantizerKey(
                    ia_op_exec_context, input_arg_idx)
                if ip_arg_quant_key in self._processed_function_quantizers:
                    raise RuntimeError(
                        "Ambiguous call to {fn} with call order {co} and argname {arg} in current scope. "
                        "Cannot insert quantization hooks "
                        "automatically!".format(
                            fn=ia_op_exec_context.operator_name,
                            co=ia_op_exec_context.call_order,
                            arg=input_arg_idx))

                self._processed_function_quantizers.add(ip_arg_quant_key)

                ip_arg_quant_name = str(ip_arg_quant_key)
                assert ip_arg_quant_name not in self.function_quantizers
                input_shape = ip_node.op_exec_context.tensor_metas[0].shape
                self.function_quantizers[ip_arg_quant_name] = \
                    self.quantize_module_creator_fn(scope_str, is_weights=False,
                                                    input_shape=input_shape).to(device)

                logger.info("Adding {} Function Quantize: {}".format(
                    "signed"
                    if self.function_quantizers[ip_arg_quant_name].signed else
                    "unsigned", ip_arg_quant_name))
                pre_hooks_to_register.append(
                    self.FunctionQuantizationPreHook(self._context_name,
                                                     ip_arg_quant_key,
                                                     self.debug_interface))
            self._ctx.register_pre_hooks(pre_hooks_to_register,
                                         ia_op_exec_context)

        # NOTE: Order of input quantizers must be the same to correctly broadcast parameters (e.g. scales) in
        # distributed mode (see call of `_dist_broadcast_coalesced` in torch/nn/parallel/distributed.py for more
        # details) pylint: disable=protected-access
        self.function_quantizers._modules = OrderedDict(
            sorted(self.function_quantizers._modules.items()))
Beispiel #9
0
def create_compressed_model(model: Module, config: NNCFConfig,
                            resuming_state_dict: dict = None,
                            dummy_forward_fn: Callable[[Module], Any] = None,
                            wrap_inputs_fn: Callable[[Tuple, Dict], Tuple[Tuple, Dict]] = None,
                            dump_graphs=True,) \
    -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
    """
    The main function used to produce a model ready for compression fine-tuning from an original PyTorch
    model and a configuration object.
    dummy_forward_fn
    :param model: The original model. Should have its parameters already loaded from a checkpoint or another
    source.
    :param config: A configuration object used to determine the exact compression modifications to be applied
    to the model
    :param resuming_state_dict: A PyTorch state dict object to load (strictly) into the compressed model after
    building.
    :param dummy_forward_fn: if supplied, will be used instead of a *forward* function call to build
    the internal graph representation via tracing. Specifying this is useful when the original training pipeline
    has special formats of data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according
    to the shape specified in the config object.
    :param wrap_inputs_fn: if supplied, will be used on the module's input arguments during a regular, non-dummy
    forward call before passing the inputs to the underlying compressed model. This is required if the model's input
    tensors that are important for compression are not supplied as arguments to the model's forward call directly, but
    instead are located in a container (such as list), and the model receives the container as an argument.
    wrap_inputs_fn should take as input two arguments - the tuple of positional arguments to the underlying
    model's forward call, and a dict of keyword arguments to the same. The function should wrap each tensor among the
    supplied model's args and kwargs that is important for compression (e.g. quantization) with an nncf.nncf_model_input
    function, which is a no-operation function and marks the tensors as inputs to be traced by NNCF in the internal
    graph representation. Output is the tuple of (args, kwargs), where args and kwargs are the same as were supplied in
    input, but each tensor in the original input.
    :param dump_graphs: Whether or not should also dump the internal graph representation of the
    original and compressed models in the .dot format into the log directory.
    :return: A controller for the compression algorithm (or algorithms, in which case the controller
    is an instance of CompositeCompressionController) and the model ready for compression parameter training wrapped
    as an object of NNCFNetwork."""

    # Compress model that will be deployed for the inference on target device. No need to compress parts of the
    # model that are used on training stage only (e.g. AuxLogits of Inception-v3 model) or unused modules with weights.
    # As a consequence, no need to care about spoiling BN statistics, as there're disabled in eval mode.
    model.eval()

    if dump_graphs:
        if dummy_forward_fn is None:
            input_info_list = create_input_infos(config)
            graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=True))
        else:
            graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

        if is_main_process():
            graph = graph_builder.build_graph(model)
            graph.visualize_graph(
                osp.join(config.get("log_dir", "."), "original_graph.dot"))

    set_debug_log_dir(config.get("log_dir", "."))

    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        wrap_inputs_fn=wrap_inputs_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    should_init = resuming_state_dict is None
    compression_algo_builder_list = create_compression_algorithm_builders(
        config, should_init=should_init)

    for builder in compression_algo_builder_list:
        compressed_model = builder.apply_to(compressed_model)
    compression_ctrl = compressed_model.commit_compression_changes()

    try:
        if resuming_state_dict is not None:
            load_state(compressed_model, resuming_state_dict, is_resume=True)
    finally:
        if dump_graphs and is_main_process() and compression_algo_builder_list:
            if dummy_forward_fn is None:
                compressed_graph_builder = GraphBuilder(
                    custom_forward_fn=create_dummy_forward_fn(
                        input_info_list, with_input_tracing=False))
            else:
                compressed_graph_builder = GraphBuilder(
                    custom_forward_fn=dummy_forward_fn)

            graph = compressed_graph_builder.build_graph(
                compressed_model, compressed_model.get_tracing_context())
            graph.visualize_graph(
                osp.join(config.get("log_dir", "."), "compressed_graph.dot"))
    return compression_ctrl, compressed_model
Beispiel #10
0
 def rebuild_graph(self, *input_args):
     self._compressed_context.reset_graph()
     dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
         with_input_tracing=False)
     builder = GraphBuilder(dummy_forward_fn)
     _ = builder.build_graph(self, self._compressed_context)
Beispiel #11
0
class NNCFNetwork(nn.Module, PostGraphBuildActing):
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo] = None,
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._dummy_forward_fn = dummy_forward_fn

        device = next(module.parameters()).device

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[CompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[InsertionPoint, List[Tuple[Callable, OperationPriority]]]

        # all modules should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device)

        _orig_context = TracingContext()
        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True)

        self._graph_builder = GraphBuilder(_orig_graph_build_forward_fn)

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            self.get_nncf_wrapped_model(), _orig_context)

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None

        self._builders = []  # type: List['CompressionAlgorithmBuilder']

    @debuggable_forward
    def forward(self, *args, **kwargs):
        with self._compressed_context as ctx:  # type: TracingContext
            ctx.base_module_thread_local_replica = self
            args, kwargs = self._wrap_inputs_fn(args, kwargs)
            retval = self.get_nncf_wrapped_model()(*args, **kwargs)
        return retval

    def register_algorithm(self, builder: 'CompressionAlgorithmBuilder'):
        """Should be called during *builder*'s *apply_to* method, otherwise there will be no corresponding
        controller returned by the network on the *commit_compression_changes* stage"""
        self._builders.append(builder)

    # Cannnot use property syntax here, otherwise the wrapped module will end up
    # being twice in the same checkpoint with different prefixes
    def get_nncf_wrapped_model(self):
        return getattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME)

    def _set_nncf_wrapped_model(self, value):
        setattr(self, MODEL_WRAPPED_BY_NNCF_ATTR_NAME, value)

    def get_modules_in_nncf_modules_by_type(self,
                                            types) -> Dict['Scope', nn.Module]:
        nncf_modules = self.get_nncf_modules()
        retval = {}
        for nncf_module_scope, nncf_module in nncf_modules.items():
            nncf_module_scope.pop()
            for relative_scope, target_module in get_all_modules_by_type(
                    nncf_module, types).items():
                retval[nncf_module_scope + relative_scope] = target_module
        return retval

    def register_insertion_command(self, command: InsertionCommand):
        point = command.insertion_point
        if point not in self._insertions_into_original_graph:
            self._insertions_into_original_graph[point] = [(command.fn,
                                                            command.priority)]
        else:
            self._insertions_into_original_graph[point].append(
                (command.fn, command.priority))

    def commit_compression_changes(self) -> 'CompressionAlgorithmController':
        for insertion_point, fn_list_with_priority in self._insertions_into_original_graph.items(
        ):
            fn_list_with_priority = sorted(fn_list_with_priority,
                                           key=lambda x: x[1])
            self._insertions_into_original_graph[
                insertion_point] = fn_list_with_priority
            self._insert_at_point(insertion_point,
                                  [x[0] for x in fn_list_with_priority])

        if self.debug_interface is not None:
            self.debug_interface.init_actual(self)

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        self._load_listener = LoadStateListener(self, all_quantizations)

        if not self._builders:
            from nncf.algo_selector import NoCompressionAlgorithmController
            return NoCompressionAlgorithmController(self)

        if len(self._builders) == 1:
            return self._builders[0].build_controller(self)

        from nncf.composite_compression import CompositeCompressionAlgorithmController
        composite_controller = CompositeCompressionAlgorithmController(self)
        for algo_builder in self._builders:
            composite_controller.add(algo_builder.build_controller(self))
        return composite_controller

    def _insert_at_point(self, point: InsertionPoint, fn_list: List[Callable]):
        if point.insertion_type == InsertionType.OPERATOR_PRE_HOOK:
            self._compressed_context.register_pre_hooks(
                fn_list, point.ia_op_exec_context)
        elif point.insertion_type == InsertionType.OPERATOR_POST_HOOK:
            self._compressed_context.register_post_hooks(
                fn_list, point.ia_op_exec_context)
        else:
            norm_target_scope = self._normalize_variable_recurrent_scope(
                point.ia_op_exec_context.scope_in_model)
            norm_nncf_scopes = [
                self._normalize_variable_recurrent_scope(x)
                for x in self._nncf_module_scopes
            ]
            assert norm_target_scope in norm_nncf_scopes  # Required for proper Recurrent/VariableRecurrent addressing
            nncf_module = self.get_module_by_scope(
                point.ia_op_exec_context.scope_in_model)
            if point.insertion_type == InsertionType.NNCF_MODULE_PRE_OP:
                for fn in fn_list:
                    nncf_module.register_pre_forward_operation(fn)
            elif point.insertion_type == InsertionType.NNCF_MODULE_POST_OP:
                for fn in fn_list:
                    nncf_module.register_post_forward_operation(fn)

    def __getattr__(self, name):
        wrapped_module = super().__getattr__(MODEL_WRAPPED_BY_NNCF_ATTR_NAME)
        if hasattr(wrapped_module, name):
            return getattr(wrapped_module, name)
        return super().__getattr__(name)

    def get_graph(self) -> NNCFGraph:
        return self._compressed_context.graph

    def get_original_graph(self) -> NNCFGraph:
        return self._original_graph

    def get_tracing_context(self) -> TracingContext:
        return self._compressed_context

    def _get_dummy_forward_fn_for_graph_building(self, with_input_tracing):
        if self._dummy_forward_fn is None:
            return create_dummy_forward_fn(
                self.input_infos,
                with_input_tracing=with_input_tracing,
                wrap_inputs_fn=self._wrap_inputs_fn)
        return self._dummy_forward_fn

    def _replace_modules_by_nncf_modules(self, device):
        module, self._nncf_module_scopes = replace_modules_by_nncf_modules(
            self.get_nncf_wrapped_model(),
            ignored_scopes=self.ignored_scopes,
            target_scopes=self.target_scopes)
        self._set_nncf_wrapped_model(module.to(device))

    def get_nncf_module_scopes(self) -> List['Scope']:
        return self._nncf_module_scopes

    def get_nncf_modules(self) -> Dict['Scope', torch.nn.Module]:
        nncf_module_names_list = NNCF_MODULES + [
            x.__name__ for x in NNCF_WRAPPED_USER_MODULES_DICT.values()
        ]
        return get_all_modules_by_type(self.get_nncf_wrapped_model(),
                                       nncf_module_names_list)

    def rebuild_graph(self, *input_args):
        self._compressed_context.reset_graph()
        dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False)
        builder = GraphBuilder(dummy_forward_fn)
        _ = builder.build_graph(self, self._compressed_context)

    def post_build_graph_actions(self):
        # Reset initialization flags (`initialized`) for all quantization modules
        # after dummy `load_state_dict` call.
        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        for module in all_quantizations.values():
            module.initialized = False

    def get_post_pattern_insertion_points(
            self,
            pattern: 'NNCFNodeExpression',
            omit_nodes_in_nncf_modules=False) -> List[InsertionInfo]:
        io_infos = self._original_graph.get_matching_nncf_graph_pattern_io_list(
            pattern)

        insertion_infos = []
        for io_info in io_infos:
            # The input/output is given in terms of edges, but the post-hooks are currently applied to
            # nodes. Multiple output edges in a pattern I/O info may originate from one and the same
            # node, and we have to ensure that these resolve into just one insertion point - thus the usage of "set".
            pattern_insertion_info_set = set()
            if len(io_info.output_edges) > 1:
                nncf_logger.debug(
                    "WARNING: pattern has more than one activation output")

            for nncf_node in io_info.output_nodes:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_node.op_exec_context,
                                  is_output=True,
                                  shape_to_operate_on=None))
                # TODO: determine output shapes for output nodes to enable per-channel quantization

            # Ignore input nodes in the pattern for now, rely on the _quantize_inputs functions.
            # TODO: handle input quantization here as well

            # Since this function is currently only used for activation quantization purposes via operator
            # post-hook mechanism, we may take any edge and it will point from the same node where we will have to
            # insert a quantizer later. However, in the future the output edges may refer to activation tensors
            # with different sizes, in which case we have to insert different per-channel quantizers to
            # accomodate different trainable params if there is a difference in the channel dimension.
            # Furthermore, currently there is no distinction for single tensor output to multiple nodes and
            # multiple tensor output to multiple nodes ("chunk" operation is an example of the latter).
            # The pattern may also have unexpected outputs from a node in the middle of the pattern (see
            # "densenet121.dot" for an example of this) - need to decide what to do with that in terms
            # of quantization.
            # TODO: address the issues above.

            for nncf_edge in io_info.output_edges:
                pattern_insertion_info_set.add(
                    InsertionInfo(nncf_edge.from_node.op_exec_context,
                                  is_output=False,
                                  shape_to_operate_on=nncf_edge.tensor_shape))
            insertion_infos += list(pattern_insertion_info_set)

        insertion_infos = list(
            set(insertion_infos)
        )  # Filter the overlapping insertion points from different matches (happens for GNMT)
        insertion_infos_filtered = []

        for info in insertion_infos:
            if omit_nodes_in_nncf_modules and self.is_scope_in_nncf_module_scope(
                    info.op_exec_context.scope_in_model):
                continue
            insertion_infos_filtered.append(info)

        return insertion_infos_filtered

    def is_scope_in_nncf_module_scope(self, scope: 'Scope'):
        # TODO: optimize
        norm_nncf_scopes = [
            self._normalize_variable_recurrent_scope(x)
            for x in self._nncf_module_scopes
        ]
        norm_op_scope = self._normalize_variable_recurrent_scope(scope)
        for nncf_scope in norm_nncf_scopes:
            if norm_op_scope in nncf_scope:
                return True
        return False

    def register_compression_module_type(
            self, compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type in self._extra_module_types:
            raise RuntimeError("Module type {} is already registered".format(
                compression_module_type))
        self.__setattr__(attr_name, nn.ModuleDict())
        self._extra_module_types.append(compression_module_type)

    def add_compression_module(self, module_key: str, module: nn.Module,
                               compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        self.__getattr__(attr_name)[module_key] = module

    def get_compression_modules_by_type(
            self,
            compression_module_type: CompressionModuleType) -> nn.ModuleDict:
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        return self.__getattr__(attr_name)

    @staticmethod
    def _compression_module_type_to_attr_name(
            compression_module_type: CompressionModuleType):
        """Required for backward compatibility with checkpoints that store function and activation
        quantizers directly under corresponding attributes of NNCFNetwork."""
        if compression_module_type == CompressionModuleType.FUNCTION_QUANTIZER:
            return "function_quantizers"
        if compression_module_type == CompressionModuleType.ACTIVATION_QUANTIZER:
            return "activation_quantizers"
        raise RuntimeError("Unknown extra module type")

    def sort_compression_modules(
            self, compression_module_type: CompressionModuleType):
        attr_name = self._compression_module_type_to_attr_name(
            compression_module_type)
        if compression_module_type not in self._extra_module_types:
            raise RuntimeError("Module type {} was not registered".format(
                compression_module_type))
        module_dict = self.__getattr__(attr_name)
        # pylint: disable=protected-access
        module_dict._modules = OrderedDict(sorted(
            module_dict._modules.items()))
        self.__setattr__(attr_name, module_dict)

    @staticmethod
    def _normalize_variable_recurrent_scope(scope: 'Scope'):
        """
        Two scopes pointing to an NNCF module that only differ in a Recurrent/VariableRecurrent/VariableRecurrentReverse
        scope element actually point to one and the same module.
        """
        ret_scope = scope.copy()
        for scope_element in ret_scope:
            if scope_element.calling_module_class_name in [
                    "Recurrent", "VariableRecurrent",
                    "VariableRecurrentReverse"
            ]:
                scope_element.calling_module_class_name = "NormalizedName_Recurrent"
        return ret_scope

    def do_dummy_forward(self, force_eval=False):
        """Attention: If run with force_eval=False, this may spoil the batchnorm statistics,
        and an eval run of the model will perform much worse than the train run. """
        if force_eval:
            train_mode = self.training
            self.eval()
        with torch.no_grad():
            self._dummy_forward_fn(self)
        if force_eval:
            if train_mode:
                self.train()

    def get_insertion_point_graph(self) -> InsertionPointGraph:
        ip_graph = InsertionPointGraph(
            self._original_graph.get_nx_graph_copy())

        # Mark IP graph operator nodes with associated op metatypes
        # Determining operator metatypes is more suited to occur at wrap_operator
        # stage, because it might be influenced by specific non-tensor function paramters,
        # but we have to inspect the containing module parameters as well, so the
        # TracingContext in wrap_operator would have to retain a reference to
        # the model that uses it. Since currently we do not need to inspect the
        # function arguments to determine the metatype, we can do this here, but
        # once we need to inspect the arguments, the code will have to be moved to
        # wrap_operator.

        for node_key in ip_graph.nodes:
            ip_graph_node = ip_graph.nodes[node_key]
            ip_graph_node_type = ip_graph_node[
                InsertionPointGraph.NODE_TYPE_NODE_ATTR]
            if ip_graph_node_type == InsertionPointGraphNodeType.OPERATOR:
                nncf_graph_node_ref = ip_graph_node[
                    InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR]
                op_exec_context = nncf_graph_node_ref[
                    NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]
                op_name = op_exec_context.operator_name
                scope = op_exec_context.scope_in_model
                op_arch = OPERATOR_METATYPES.get_operator_metatype_by_op_name(
                    op_name)
                module = self.get_module_by_scope(scope)
                if module is not None:
                    subtype = op_arch.determine_subtype(
                        containing_module=module)
                    if subtype is not None:
                        op_arch = subtype
                ip_graph_node[
                    InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR] = op_arch
        return ip_graph

    def get_module_by_scope(self, scope: 'Scope') -> torch.nn.Module:
        curr_module = self.get_nncf_wrapped_model()
        for scope_element in scope[
                1:]:  # omit first scope element which corresponds to base module
            if scope_element.calling_field_name is None:
                # The module used is being created in-place every time and never stored in the model,
                # happens for nn.Softmax in BERT implementations.
                return None
            # pylint: disable=protected-access
            next_module = curr_module._modules.get(
                scope_element.calling_field_name)
            if next_module is None:
                raise RuntimeError(
                    "Could not find a {} module member in {} module of scope {} during node search"
                    .format(scope_element.calling_field_name,
                            scope_element.calling_module_class_name,
                            str(scope)))
            curr_module = next_module
        return curr_module

    def get_parameters_count_in_model(self):
        """
        Return total amount of model parameters.
        """
        count = 0
        for param in self.parameters():
            count = count + param.numel()
        return count

    def get_flops_per_module(self):
        """
        Calculates FLOPS count for modules.
        """
        model = self
        flops_count_dict = {}

        def get_hook(name):
            def compute_MACs_hook(module, input_, output):
                if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
                    ks = module.weight.data.shape
                    mac_count = ks[0] * ks[1] * ks[2] * ks[3] * output.shape[
                        3] * output.shape[2]
                elif isinstance(module, nn.Linear):
                    mac_count = input_[0].shape[1] * output.shape[-1]
                elif isinstance(module, nn.BatchNorm2d):
                    mac_count = np.prod(list(input_[0].shape))
                else:
                    return
                flops_count_dict[name] = 2 * mac_count

            return compute_MACs_hook

        hook_list = [
            m.register_forward_hook(get_hook(n))
            for n, m in model.named_modules()
        ]

        model.do_dummy_forward(force_eval=True)

        for h in hook_list:
            h.remove()
        return flops_count_dict

    def get_MACs_in_model(self):
        """
            Calculates MAC units count for model.
        """
        flops_count_dict = self.get_flops_per_module()
        total_MACs_count = sum(v // 2 for v in flops_count_dict.values())
        return total_MACs_count

    def get_input_infos(self) -> List[ModelInputInfo]:
        return deepcopy(self.input_infos)
 def test_build_graph(self, model_name, model_builder, input_size):
     net = model_builder()
     graph_builder = GraphBuilder(create_dummy_forward_fn([ModelInputInfo(input_size), ]))
     graph = graph_builder.build_graph(net)
     check_graph(graph, model_name, 'original')