Ejemplo n.º 1
0
    def __init__(self):
        self.graph = NNCFGraph()

        self._save_context = None
        self._post_hooks = {}
        self._pre_hooks = {}  # type: Dict[PreHookId, List[Callable]]
        self._num_nested_hooks = 0

        self._thread_local = threading.local()

        self._n_instance = 0
        self._cond = threading.Condition()
        self.is_tracing = True
        self._input_comparators_per_scope = []
    def __init__(self, name):
        self.name = name
        self.graph = NNCFGraph()

        self._save_context = None
        self._post_hooks = {}
        self._pre_hooks = {}
        self._num_nested_hooks = 0

        self._thread_local = threading.local()

        self._n_instance = 0
        self._cond = threading.Condition()
        self.is_tracing = True
        self._input_comparators_per_scope = []
Ejemplo n.º 3
0
def find_first_ops_with_type(nncf_graph: NNCFGraph, nodes, required_types, forward: bool = True):
    """
    Looking for first nodes with type from pruned_ops_types that are reachable from nodes.
    :param nncf_graph: NNCFGraph to work with
    :param nodes: nodes from which search begins
    :param required_types: types of nodes for search
    :param forward: whether the search will be forward or backward
    :return:
    """
    graph = nncf_graph._nx_graph
    get_edges_fn = graph.out_edges if forward else graph.in_edges

    found_nodes = []
    visited = {n: False for n in graph.nodes}
    node_stack = deque(nodes)
    while node_stack:
        last_node = node_stack.pop()
        last_node_type = nncf_graph.node_type_fn(last_node)

        if not visited[last_node['key']]:
            visited[last_node['key']] = True
        else:
            continue

        if last_node_type not in required_types:
            edges = get_edges_fn(last_node['key'])
            for in_node_name, out_node_name in edges:
                cur_node = graph.nodes[out_node_name] if forward else graph.nodes[in_node_name]

                if not visited[cur_node['key']]:
                    node_stack.append(cur_node)
        else:
            found_nodes.append(last_node)
    return found_nodes
Ejemplo n.º 4
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        output_mask = nx_node['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
Ejemplo n.º 5
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_features = new_num_channels
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])
        node_module.running_mean = torch.nn.Parameter(
            node_module.running_mean[bool_mask], requires_grad=False)
        node_module.running_var = torch.nn.Parameter(
            node_module.running_var[bool_mask], requires_grad=False)

        nncf_logger.info(
            'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
Ejemplo n.º 6
0
def check_graph(graph: NNCFGraph, path_to_dot, graph_dir, sort_dot_graph=True):
    # pylint:disable=protected-access
    nx_graph = graph._get_graph_for_structure_analysis()

    data_dir = os.path.join(os.path.dirname(__file__), 'data/reference_graphs')
    dot_dir = os.path.join(data_dir, graph_dir)
    path_to_dot = os.path.abspath(os.path.join(dot_dir, path_to_dot))

    # validate .dot file manually!
    if not os.path.exists(path_to_dot):
        if not os.path.exists(dot_dir):
            os.makedirs(dot_dir)
        nx.drawing.nx_pydot.write_dot(nx_graph, path_to_dot)
        if sort_dot_graph:
            sort_dot(path_to_dot)

    load_graph = nx.drawing.nx_pydot.read_dot(path_to_dot)
    load_graph = get_version_agnostic_graph(load_graph)

    # nx_graph is expected to have version-agnostic operator names already
    for k, attrs in nx_graph.nodes.items():
        attrs = {k: str(v) for k, v in attrs.items()}
        load_attrs = {
            k: str(v).strip('"')
            for k, v in load_graph.nodes[k].items()
        }
        if attrs != load_attrs:
            assert attrs == load_attrs

    assert load_graph.nodes.keys() == nx_graph.nodes.keys()
    assert nx.DiGraph(load_graph).edges == nx_graph.edges
Ejemplo n.º 7
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        is_depthwise = False
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        # In case of group convs we can't prune by output filters
        if node_module.groups != 1:
            if node_module.weight.size(1) == 1:
                # Depthwise case
                is_depthwise = True
                output_mask = input_masks[0]
            else:
                accept_pruned_input = False
                output_mask = None

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input
        nx_node['is_depthwise'] = is_depthwise
Ejemplo n.º 8
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        is_depthwise = nx_node['is_depthwise']
        old_num_clannels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
Ejemplo n.º 9
0
def get_sources_of_node(nncf_node: NNCFNode, graph: NNCFGraph, sources_types):
    """
    Source is a node of sourse such that there is path from this node to nx_node and on this path
    no node has one of sources_types type.
    :param sources_types: list of sources types
    :param nncf_node: NNCFNode to get sources
    :param graph: NNCF graph to work with
    :return: list of all sources nodes
    """
    visited = {node_id: False for node_id in graph.get_all_node_idxs()}
    partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x in sources_types,
                                        visited=visited)
    nncf_nodes = [nncf_node]
    if nncf_node.op_exec_context.operator_name in sources_types:
        nncf_nodes = graph.get_previous_nodes(nncf_node)

    source_nodes = []
    for node in nncf_nodes:
        source_nodes.extend(graph.traverse_graph(node, partial_traverse_function, False))
    return source_nodes
Ejemplo n.º 10
0
    def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                         nx_graph: nx.DiGraph):
        output_mask = None
        accept_pruned_input = True
        input_masks = get_input_masks(nx_node, nx_graph)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if node_module.pre_ops:
            output_mask = node_module.pre_ops[
                '0'].op.binary_filter_pruning_mask

        nx_node['input_masks'] = input_masks
        nx_node['output_mask'] = output_mask
        nx_node['accept_pruned_input'] = accept_pruned_input
Ejemplo n.º 11
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.in_channels))
Ejemplo n.º 12
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        mask = nx_node['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None and not nx_node['is_depthwise']:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
Ejemplo n.º 13
0
    def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0,\
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(nx_node['key'], old_num_clannels,
                              new_num_channels))
Ejemplo n.º 14
0
def prepare_potential_quantizer_graph(
        graph: NNCFGraph, potential_activations_quantizers: Dict[
            InsertionInfo, Optional[List[QuantizerConfig]]],
        potential_weights_modules: List[PotentialQuantizedModule]
) -> NNCFGraph:
    quantizers_weights_attr = {}
    quantizers_activations_attr = {}
    # pylint:disable=protected-access
    for _, module_scope, qconfig_list in potential_weights_modules:
        matching_graph_op_nodes = graph.get_op_nodes_in_scope(module_scope)

        assert len(
            matching_graph_op_nodes
        ) == 1  # Isn't correct when NNCF module has more than 1 graph node

        op_name = matching_graph_op_nodes[0][
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name
        ia_op_exec_context = InputAgnosticOperationExecutionContext(
            op_name, module_scope, 0)
        str_qconfig_list = ''

        for qconfig in qconfig_list:
            str_qconfig_list += '[' + str(qconfig) + '] '
        quantizers_weights_attr[ia_op_exec_context] = str_qconfig_list

    for insertion_info, qconfig_list in potential_activations_quantizers.items(
    ):
        ia_op_exec_context = insertion_info.op_exec_context.input_agnostic
        str_qconfig_list = ''
        for qconfig in qconfig_list:
            str_qconfig_list += '[' + str(qconfig) + '] '
        quantizers_activations_attr[ia_op_exec_context] = str_qconfig_list
        for linked_op_exec_context in insertion_info.linked_op_exec_contexts:
            quantizers_activations_attr[
                linked_op_exec_context.input_agnostic] = str_qconfig_list

    nx_graph = graph._nx_graph
    nodes = deepcopy(nx_graph.nodes)
    for node_name, node in sorted(nodes.items()):
        ia_op_exec_context_for_node = nx_graph.nodes[node_name][
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic
        node_scope = str(ia_op_exec_context_for_node)
        if ia_op_exec_context_for_node in quantizers_activations_attr:
            label = "Quantizer: {}".format(
                quantizers_activations_attr[ia_op_exec_context_for_node])
            nx_graph.add_node(node_scope,
                              label=label,
                              color="purple",
                              id=node[NNCFGraph.ID_NODE_ATTR],
                              op_exec_context=nx_graph.nodes[node_name][
                                  NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR])
            next_nodes = deepcopy(nx_graph._succ[node_name])
            for next_node_name, _ in next_nodes.items():
                nx_graph.add_edge(node_scope, next_node_name)
                nx_graph.remove_edge(node_name, next_node_name)
            nx_graph.add_edge(node_name, node_scope)
        elif ia_op_exec_context_for_node in quantizers_weights_attr:
            label = "Quantizer: {}".format(
                quantizers_weights_attr[ia_op_exec_context_for_node])
            nx_graph.add_node(node_scope,
                              label=label,
                              color="purple",
                              id=node[NNCFGraph.ID_NODE_ATTR],
                              op_exec_context=nx_graph.nodes[node_name][
                                  NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR])
            nx_graph.add_edge(node_scope, node_name)

    return graph
Ejemplo n.º 15
0
    def _paint_activation_quantizer_node(
            nncf_graph: NNCFGraph, quantizer_id: NonWeightQuantizerId,
            quantizer_info: 'NonWeightQuantizerInfo',
            bits_color_map: Dict[int, str],
            groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers):
        #pylint:disable=too-many-branches
        affected_insertion_infos_list = quantizer_info.affected_insertions  # type: List[InsertionInfo]

        for insertion_info in affected_insertion_infos_list:
            input_agnostic_op_exec_context = insertion_info.op_exec_context.input_agnostic
            affected_nncf_node_key = nncf_graph.get_node_key_by_iap_context(
                input_agnostic_op_exec_context)
            affected_nx_node = nncf_graph.get_nx_node_by_key(
                affected_nncf_node_key)
            operator_name = affected_nx_node[
                NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name
            node_id = affected_nx_node[NNCFGraph.ID_NODE_ATTR]

            affected_nncf_node = nncf_graph.get_node_by_id(node_id)
            affected_nx_node['label'] = '_#'.join(
                [operator_name, str(node_id)])

            if insertion_info.is_input:
                # Module UpdateInputs pre-op used for activation quantization
                previous_nodes = nncf_graph.get_previous_nodes(
                    affected_nncf_node)

                # Relying on the _quantize_inputs behaviour of only being able to quantize 0-th input

                # previous_nodes are either UpdateWeights, or UpdateWeights + UpdateInputs
                assert len(previous_nodes) == 2 or len(previous_nodes) == 1

                if len(previous_nodes) == 2:
                    if "UpdateInputs" in str(
                            previous_nodes[0].op_exec_context.input_agnostic):
                        target_node = previous_nodes[0]
                    else:
                        target_node = previous_nodes[1]
                else:
                    target_node = previous_nodes[0]
                target_nncf_node_id = target_node.node_id
                target_nncf_node_key = nncf_graph.get_node_key_by_id(
                    target_nncf_node_id)
            else:
                in_port_id = insertion_info.in_port_id

                if in_port_id is None:
                    # Post-hooking used for activation quantization
                    # Currently only a single post-hook can immediately follow an operation
                    succs = list(
                        nncf_graph.get_successors(affected_nncf_node_key))
                    assert len(succs) == 1
                    target_nncf_node_key = succs[0]
                else:
                    # Pre-hooking used for activation quantization
                    previous_nodes = nncf_graph.get_previous_nodes(
                        affected_nncf_node)
                    target_node = None
                    for prev_node in previous_nodes:
                        prev_edge = nncf_graph.get_nx_edge(
                            prev_node, affected_nncf_node)
                        if prev_edge[NNCFGraph.
                                     IN_PORT_NAME_EDGE_ATTR] == in_port_id:
                            target_node = prev_node
                            break

                    assert target_node is not None, "Could not find a pre-hook quantizer node for a specific " \
                                                    "input port!"
                    target_nncf_node_id = target_node.node_id
                    target_nncf_node_key = nncf_graph.get_node_key_by_id(
                        target_nncf_node_id)

            activation_fq_node = nncf_graph.get_nx_node_by_key(
                target_nncf_node_key)
            bits = quantizer_info.quantizer_module_ref.num_bits
            activation_fq_node['color'] = bits_color_map[bits]
            activation_fq_node['style'] = 'filled'
            node_id = activation_fq_node[NNCFGraph.ID_NODE_ATTR]

            activation_fq_node['label'] = 'AFQ_[{}]_#{}'.format(
                quantizer_info.quantizer_module_ref.get_current_config(),
                str(node_id))
            grouped_mode = bool(groups_of_adjacent_quantizers)
            if grouped_mode:
                group_id_str = 'UNDEFINED'
                group_id = groups_of_adjacent_quantizers.get_group_id_for_quantizer(
                    quantizer_id)
                if node_id is None:
                    nncf_logger.error(
                        'No group for activation quantizer: {}'.format(
                            target_nncf_node_key))
                else:
                    group_id_str = str(group_id)
                activation_fq_node['label'] += "_G" + group_id_str
class TracingContext:
    def __init__(self, name):
        self.name = name
        self.graph = NNCFGraph()

        self._save_context = None
        self._post_hooks = {}
        self._pre_hooks = {}
        self._num_nested_hooks = 0

        self._thread_local = threading.local()

        self._n_instance = 0
        self._cond = threading.Condition()
        self.is_tracing = True
        self._input_comparators_per_scope = []

    def find_operator_node(self, inputs,
                           ia_op_exec_context: InputAgnosticOperationExecutionContext) -> NNCFNode:
        with self._cond:
            self._n_instance += 1
        tensor_metas = make_input_infos(inputs)

        node = self.graph.find_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope)

        with self._cond:
            self._n_instance -= 1
            self._cond.notify_all()

        if node is None:
            with self._cond:
                while self._n_instance > 0:
                    self._cond.wait()
                # Another thread may have added a node inside this block,
                # so we need to check again if a node is already added.
                node = self.graph.find_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope)
                if node is None:
                    node = self.graph.add_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope,
                                               inputs)
        return node

    def get_caller_context(self, operator_type: str) -> InputAgnosticOperationExecutionContext:
        """
        Designed to work in the following way - for each scope the context will track the number of the calls to the
        operators with the name operator_type (call_order). The counter values are preserved until reset by a
        corresponding member function of the context, which must be called after each model iteration - this is
        usually handled inside NNCF. This mechanism allows to discern between multiple function calls inside the same
        module that would each require their own instance of compression layers - for instance, multiple `relu`
        function calls (either on their own or inside a `for` cycle), and at the same moment allow the checkpoints to
        be loaded if the model had changed in the meantime in a way that does not impact the major function call
        order (e.g. if comments were added to the .py file with the model)
        """
        version_agnostic_operator_type = get_version_agnostic_name(operator_type)
        if version_agnostic_operator_type is not None:
            operator_type = version_agnostic_operator_type

        call_order = self.get_operator_call_count_in_scope(operator_type, self.scope)

        ia_op_exec_context = InputAgnosticOperationExecutionContext(operator_type,
                                                                    self.scope,
                                                                    call_order)
        return ia_op_exec_context

    def reset_scope_operator_call_counters(self):
        """
        Must be called after each "forward" operation of the model that is made
        within this context
        """
        self._thread_local.operator_counters = {}

    @staticmethod
    def _get_operator_counter_key(operator_name: str, scope: Scope):
        return "{}_{}".format(str(scope), operator_name)

    def register_operator_call(self, operator_name: str, scope: Scope):
        key = self._get_operator_counter_key(operator_name, scope)
        if key in self._thread_local.operator_counters:
            self._thread_local.operator_counters[key] += 1
        else:
            self._thread_local.operator_counters[key] = 1

    def get_operator_call_count_in_scope(self, operator_name: str, scope: Scope):
        key = self._get_operator_counter_key(operator_name, scope)
        if key in self._thread_local.operator_counters:
            return self._thread_local.operator_counters[key]
        return 0

    def reset_operator_call_count_in_scope(self, scope):
        scoped_op_name = str(scope)
        for key in self._thread_local.operator_counters.keys():
            if scoped_op_name in key:
                self._thread_local.operator_counters[key] = 0

    def enter(self):
        global _CURRENT_CONTEXT
        self._save_context = _CURRENT_CONTEXT
        _CURRENT_CONTEXT = self
        self._init_thread_local()

    def leave(self):
        global _CURRENT_CONTEXT
        _CURRENT_CONTEXT = self._save_context
        self._save_context = None

    def push_scope(self, scope_module):
        relative_scopes_list = self._get_scope_relative_to_last_registered_module_call(scope_module)
        self.scope_modules.append(scope_module)
        self.relative_scopes_stack.append(relative_scopes_list)

    def pop_scope(self):
        self.relative_scopes_stack.pop()
        self.scope_modules.pop()

    def register_pre_hooks(self, fn_list: List[Callable], ia_op_exec_context: InputAgnosticOperationExecutionContext):
        if ia_op_exec_context in self._pre_hooks:
            raise KeyError("Pre hook for context {} is already registered".format(str(ia_op_exec_context)))
        self._pre_hooks[ia_op_exec_context] = fn_list

    def execute_pre_hooks(self, ia_op_exec_context: InputAgnosticOperationExecutionContext,
                          op_inputs: OperatorInput) -> OperatorInput:
        in_op = getattr(self, 'in_operator', False)
        self.in_operator = False
        self._thread_local.num_nested_hooks += 1
        if ia_op_exec_context in self._pre_hooks:
            for hook in self._pre_hooks[ia_op_exec_context]:
                op_inputs = hook(op_inputs)
        self._thread_local.num_nested_hooks -= 1
        self.in_operator = in_op
        return op_inputs

    def register_post_hooks(self, fn_list: List[Callable], ia_op_exec_context: InputAgnosticOperationExecutionContext):
        if ia_op_exec_context in self._post_hooks:
            raise KeyError("Post hook for context {} is already registered".format(str(ia_op_exec_context)))
        self._post_hooks[ia_op_exec_context] = fn_list

    def execute_post_hooks(self, ia_op_exec_context: InputAgnosticOperationExecutionContext, outputs):
        in_op = getattr(self, 'in_operator', False)
        self.in_operator = False
        self._thread_local.num_nested_hooks += 1
        if ia_op_exec_context in self._post_hooks:
            for hook in self._post_hooks[ia_op_exec_context]:
                outputs = hook(outputs)
        self._thread_local.num_nested_hooks -= 1
        self.in_operator = in_op
        return outputs

    def disable_tracing(self):
        self.is_tracing = False

    def enable_tracing(self):
        self.is_tracing = True

    def add_node_comparators(self, scopes_to_apply: List[str],
                             node_input_comparator: 'TensorMetaComparator' = None):
        self._input_comparators_per_scope.append((node_input_comparator, scopes_to_apply))

    @property
    def base_module_thread_local_replica(self):
        self._init_thread_local()
        return self._thread_local.base_module_replica

    @base_module_thread_local_replica.setter
    def base_module_thread_local_replica(self, value):
        self._init_thread_local()
        self._thread_local.base_module_replica = value

    @property
    def in_operator(self):
        self._init_thread_local()
        return self._thread_local.in_operator

    @in_operator.setter
    def in_operator(self, val):
        self._init_thread_local()
        self._thread_local.in_operator = val

    @property
    def scope_modules(self):
        self._init_thread_local()
        return self._thread_local.scope_modules

    @property
    def relative_scopes_stack(self) -> List[Scope]:
        self._init_thread_local()
        return self._thread_local.scopes

    def _init_thread_local(self):
        # todo: master node part!
        tl = self._thread_local
        if getattr(tl, 'ready', False):
            return
        tl.ready = True
        tl.scopes = []
        tl.scope_modules = []
        tl.in_operator = False
        tl.num_nested_hooks = 0
        tl.base_module_replica = None
        tl.operator_counters = {}
        tl.node_call_tracker = {}

    def register_node_call(self, node_key: str):
        if node_key in self._thread_local.node_call_tracker:
            self._thread_local.node_call_tracker[node_key] += 1
        else:
            self._thread_local.node_call_tracker[node_key] = 1

    def reset_node_call_counters(self):
        for k, _ in self._thread_local.node_call_tracker.items():
            self._thread_local.node_call_tracker[k] = 0

    def get_node_call_counter_dict(self):
        return self._thread_local.node_call_tracker

    def _get_scope_relative_to_last_registered_module_call(self, module) -> Scope:
        module_class = module.__class__.__name__
        if not self.scope_modules:
            return Scope([ScopeElement(module_class), ])
        q = deque([(tuple(), self.scope_modules[-1])])
        while q:
            scope_parts, top = q.popleft()
            if module is top:
                return Scope(list(scope_parts))
            for name, child in top.named_children():
                scope_element = ScopeElement(child.__class__.__name__, name)
                q.append((scope_parts + (scope_element,), child))
        return Scope([ScopeElement(module_class), ])

    @property
    def scope(self) -> Scope:
        stack_copy = self.relative_scopes_stack.copy()
        scope_el_list = []
        for relative_scope in stack_copy:
            for scope_element in relative_scope.scope_elements:
                scope_el_list.append(scope_element)
        return Scope(scope_el_list)

    def reset_graph(self):
        self.graph = NNCFGraph()
 def reset_graph(self):
     self.graph = NNCFGraph()
Ejemplo n.º 18
0
def test_graph_pattern_io_building():
    graph = NNCFGraph()
    #   1
    # /   \
    # 2   |
    # |   |
    # 3   |
    # \   /
    #   4
    # / | \
    # 5 6 7
    # |
    # 8

    #pylint:disable=protected-access
    node_keys = ['1', '2', '3', '4', '5', '6', '7', '8']
    for idx, node_key in enumerate(node_keys):
        attrs = {
            NNCFGraph.ID_NODE_ATTR: idx + 1,
            NNCFGraph.KEY_NODE_ATTR: node_key,
            NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: None,
        }
        graph._nx_graph.add_node(node_key, **attrs)

    edge_attr = {NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR: None}
    graph._nx_graph.add_edges_from([('1', '2'), ('1', '4'), ('2', '3'),
                                    ('3', '4'), ('4', '5'), ('4', '6'),
                                    ('4', '7'), ('5', '8')], **edge_attr)
    graph._node_id_to_key_dict.update(
        {k + 1: v
         for k, v in enumerate(node_keys)})

    def make_mock_edge(from_id: int, to_id: int):
        return NNCFGraphEdge(NNCFNode(from_id, None), NNCFNode(to_id, None),
                             None)

    def make_mock_node(id_: int):
        return NNCFNode(id_, None)

    ref_patterns_and_ios = [
        (['1', '2'],
         NNCFGraphPatternIO(
             input_edges=[],
             input_nodes=[make_mock_node(1)],
             output_edges=[make_mock_edge(2, 3),
                           make_mock_edge(1, 4)],
             output_nodes=[])),
        (['3'],
         NNCFGraphPatternIO(input_edges=[make_mock_edge(2, 3)],
                            input_nodes=[],
                            output_edges=[make_mock_edge(3, 4)],
                            output_nodes=[])),
        (['1', '2', '3'],
         NNCFGraphPatternIO(
             input_edges=[],
             input_nodes=[make_mock_node(1)],
             output_edges=[make_mock_edge(3, 4),
                           make_mock_edge(1, 4)],
             output_nodes=[])),
        (['4'],
         NNCFGraphPatternIO(
             input_edges=[make_mock_edge(3, 4),
                          make_mock_edge(1, 4)],
             input_nodes=[],
             output_edges=[
                 make_mock_edge(4, 5),
                 make_mock_edge(4, 6),
                 make_mock_edge(4, 7)
             ],
             output_nodes=[])),
        (['5', '6', '8'],
         NNCFGraphPatternIO(
             input_edges=[make_mock_edge(4, 5),
                          make_mock_edge(4, 6)],
             input_nodes=[],
             output_edges=[],
             output_nodes=[make_mock_node(6),
                           make_mock_node(8)])),
        (['7'],
         NNCFGraphPatternIO(input_edges=[make_mock_edge(4, 7)],
                            input_nodes=[],
                            output_edges=[],
                            output_nodes=[make_mock_node(7)]))
    ]

    for pattern, ref_pattern_io in ref_patterns_and_ios:
        test_pattern_io = graph._get_nncf_graph_pattern_io_list(pattern)
        assert Counter(test_pattern_io.input_edges) == Counter(
            ref_pattern_io.input_edges)
        assert Counter(test_pattern_io.output_edges) == Counter(
            ref_pattern_io.output_edges)
        assert Counter(test_pattern_io.input_nodes) == Counter(
            ref_pattern_io.input_nodes)
        assert Counter(test_pattern_io.output_nodes) == Counter(
            ref_pattern_io.output_nodes)