def __init__(self): self.graph = NNCFGraph() self._save_context = None self._post_hooks = {} self._pre_hooks = {} # type: Dict[PreHookId, List[Callable]] self._num_nested_hooks = 0 self._thread_local = threading.local() self._n_instance = 0 self._cond = threading.Condition() self.is_tracing = True self._input_comparators_per_scope = []
def __init__(self, name): self.name = name self.graph = NNCFGraph() self._save_context = None self._post_hooks = {} self._pre_hooks = {} self._num_nested_hooks = 0 self._thread_local = threading.local() self._n_instance = 0 self._cond = threading.Condition() self.is_tracing = True self._input_comparators_per_scope = []
def find_first_ops_with_type(nncf_graph: NNCFGraph, nodes, required_types, forward: bool = True): """ Looking for first nodes with type from pruned_ops_types that are reachable from nodes. :param nncf_graph: NNCFGraph to work with :param nodes: nodes from which search begins :param required_types: types of nodes for search :param forward: whether the search will be forward or backward :return: """ graph = nncf_graph._nx_graph get_edges_fn = graph.out_edges if forward else graph.in_edges found_nodes = [] visited = {n: False for n in graph.nodes} node_stack = deque(nodes) while node_stack: last_node = node_stack.pop() last_node_type = nncf_graph.node_type_fn(last_node) if not visited[last_node['key']]: visited[last_node['key']] = True else: continue if last_node_type not in required_types: edges = get_edges_fn(last_node['key']) for in_node_name, out_node_name in edges: cur_node = graph.nodes[out_node_name] if forward else graph.nodes[in_node_name] if not visited[cur_node['key']]: node_stack.append(cur_node) else: found_nodes.append(last_node) return found_nodes
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = nx_node['output_mask'] if output_mask is None: return bool_mask = torch.tensor(output_mask, dtype=torch.bool) new_num_channels = int(torch.sum(bool_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(1)) in_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(in_channels).view( in_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.out_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) if node_module.bias is not None: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) bool_mask = torch.tensor(input_mask, dtype=torch.bool) old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.num_features = new_num_channels node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) node_module.running_mean = torch.nn.Parameter( node_module.running_mean[bool_mask], requires_grad=False) node_module.running_var = torch.nn.Parameter( node_module.running_var[bool_mask], requires_grad=False) nncf_logger.info( 'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def check_graph(graph: NNCFGraph, path_to_dot, graph_dir, sort_dot_graph=True): # pylint:disable=protected-access nx_graph = graph._get_graph_for_structure_analysis() data_dir = os.path.join(os.path.dirname(__file__), 'data/reference_graphs') dot_dir = os.path.join(data_dir, graph_dir) path_to_dot = os.path.abspath(os.path.join(dot_dir, path_to_dot)) # validate .dot file manually! if not os.path.exists(path_to_dot): if not os.path.exists(dot_dir): os.makedirs(dot_dir) nx.drawing.nx_pydot.write_dot(nx_graph, path_to_dot) if sort_dot_graph: sort_dot(path_to_dot) load_graph = nx.drawing.nx_pydot.read_dot(path_to_dot) load_graph = get_version_agnostic_graph(load_graph) # nx_graph is expected to have version-agnostic operator names already for k, attrs in nx_graph.nodes.items(): attrs = {k: str(v) for k, v in attrs.items()} load_attrs = { k: str(v).strip('"') for k, v in load_graph.nodes[k].items() } if attrs != load_attrs: assert attrs == load_attrs assert load_graph.nodes.keys() == nx_graph.nodes.keys() assert nx.DiGraph(load_graph).edges == nx_graph.edges
def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = None accept_pruned_input = True is_depthwise = False input_masks = get_input_masks(nx_node, nx_graph) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if node_module.pre_ops: output_mask = node_module.pre_ops[ '0'].op.binary_filter_pruning_mask # In case of group convs we can't prune by output filters if node_module.groups != 1: if node_module.weight.size(1) == 1: # Depthwise case is_depthwise = True output_mask = input_masks[0] else: accept_pruned_input = False output_mask = None nx_node['input_masks'] = input_masks nx_node['output_mask'] = output_mask nx_node['accept_pruned_input'] = accept_pruned_input nx_node['is_depthwise'] = is_depthwise
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) new_num_channels = int(torch.sum(input_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) is_depthwise = nx_node['is_depthwise'] old_num_clannels = int(node_module.weight.size(1)) if is_depthwise: # In depthwise case prune output channels by input mask, here only fix for new number of input channels node_module.groups = new_num_channels node_module.in_channels = new_num_channels else: out_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(out_channels).view( out_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.in_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) nncf_logger.info( 'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def get_sources_of_node(nncf_node: NNCFNode, graph: NNCFGraph, sources_types): """ Source is a node of sourse such that there is path from this node to nx_node and on this path no node has one of sources_types type. :param sources_types: list of sources types :param nncf_node: NNCFNode to get sources :param graph: NNCF graph to work with :return: list of all sources nodes """ visited = {node_id: False for node_id in graph.get_all_node_idxs()} partial_traverse_function = partial(traverse_function, nncf_graph=graph, type_check_fn=lambda x: x in sources_types, visited=visited) nncf_nodes = [nncf_node] if nncf_node.op_exec_context.operator_name in sources_types: nncf_nodes = graph.get_previous_nodes(nncf_node) source_nodes = [] for node in nncf_nodes: source_nodes.extend(graph.traverse_graph(node, partial_traverse_function, False)) return source_nodes
def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = None accept_pruned_input = True input_masks = get_input_masks(nx_node, nx_graph) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if node_module.pre_ops: output_mask = node_module.pre_ops[ '0'].op.binary_filter_pruning_mask nx_node['input_masks'] = input_masks nx_node['output_mask'] = output_mask nx_node['accept_pruned_input'] = accept_pruned_input
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.in_channels = int(torch.sum(bool_mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.in_channels))
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): mask = nx_node['output_mask'] if mask is None: return bool_mask = torch.tensor(mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.out_channels = int(torch.sum(mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) if node_module.bias is not None and not nx_node['is_depthwise']: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)): assert node_module.target_weight_dim_for_compression == 0,\ "Implemented only for target_weight_dim_for_compression == 0" old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.weight = torch.nn.Parameter( node_module.weight[bool_mask]) node_module.n_channels = new_num_channels nncf_logger.info( 'Pruned Elementwise {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def prepare_potential_quantizer_graph( graph: NNCFGraph, potential_activations_quantizers: Dict[ InsertionInfo, Optional[List[QuantizerConfig]]], potential_weights_modules: List[PotentialQuantizedModule] ) -> NNCFGraph: quantizers_weights_attr = {} quantizers_activations_attr = {} # pylint:disable=protected-access for _, module_scope, qconfig_list in potential_weights_modules: matching_graph_op_nodes = graph.get_op_nodes_in_scope(module_scope) assert len( matching_graph_op_nodes ) == 1 # Isn't correct when NNCF module has more than 1 graph node op_name = matching_graph_op_nodes[0][ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name ia_op_exec_context = InputAgnosticOperationExecutionContext( op_name, module_scope, 0) str_qconfig_list = '' for qconfig in qconfig_list: str_qconfig_list += '[' + str(qconfig) + '] ' quantizers_weights_attr[ia_op_exec_context] = str_qconfig_list for insertion_info, qconfig_list in potential_activations_quantizers.items( ): ia_op_exec_context = insertion_info.op_exec_context.input_agnostic str_qconfig_list = '' for qconfig in qconfig_list: str_qconfig_list += '[' + str(qconfig) + '] ' quantizers_activations_attr[ia_op_exec_context] = str_qconfig_list for linked_op_exec_context in insertion_info.linked_op_exec_contexts: quantizers_activations_attr[ linked_op_exec_context.input_agnostic] = str_qconfig_list nx_graph = graph._nx_graph nodes = deepcopy(nx_graph.nodes) for node_name, node in sorted(nodes.items()): ia_op_exec_context_for_node = nx_graph.nodes[node_name][ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic node_scope = str(ia_op_exec_context_for_node) if ia_op_exec_context_for_node in quantizers_activations_attr: label = "Quantizer: {}".format( quantizers_activations_attr[ia_op_exec_context_for_node]) nx_graph.add_node(node_scope, label=label, color="purple", id=node[NNCFGraph.ID_NODE_ATTR], op_exec_context=nx_graph.nodes[node_name][ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]) next_nodes = deepcopy(nx_graph._succ[node_name]) for next_node_name, _ in next_nodes.items(): nx_graph.add_edge(node_scope, next_node_name) nx_graph.remove_edge(node_name, next_node_name) nx_graph.add_edge(node_name, node_scope) elif ia_op_exec_context_for_node in quantizers_weights_attr: label = "Quantizer: {}".format( quantizers_weights_attr[ia_op_exec_context_for_node]) nx_graph.add_node(node_scope, label=label, color="purple", id=node[NNCFGraph.ID_NODE_ATTR], op_exec_context=nx_graph.nodes[node_name][ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR]) nx_graph.add_edge(node_scope, node_name) return graph
def _paint_activation_quantizer_node( nncf_graph: NNCFGraph, quantizer_id: NonWeightQuantizerId, quantizer_info: 'NonWeightQuantizerInfo', bits_color_map: Dict[int, str], groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers): #pylint:disable=too-many-branches affected_insertion_infos_list = quantizer_info.affected_insertions # type: List[InsertionInfo] for insertion_info in affected_insertion_infos_list: input_agnostic_op_exec_context = insertion_info.op_exec_context.input_agnostic affected_nncf_node_key = nncf_graph.get_node_key_by_iap_context( input_agnostic_op_exec_context) affected_nx_node = nncf_graph.get_nx_node_by_key( affected_nncf_node_key) operator_name = affected_nx_node[ NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].operator_name node_id = affected_nx_node[NNCFGraph.ID_NODE_ATTR] affected_nncf_node = nncf_graph.get_node_by_id(node_id) affected_nx_node['label'] = '_#'.join( [operator_name, str(node_id)]) if insertion_info.is_input: # Module UpdateInputs pre-op used for activation quantization previous_nodes = nncf_graph.get_previous_nodes( affected_nncf_node) # Relying on the _quantize_inputs behaviour of only being able to quantize 0-th input # previous_nodes are either UpdateWeights, or UpdateWeights + UpdateInputs assert len(previous_nodes) == 2 or len(previous_nodes) == 1 if len(previous_nodes) == 2: if "UpdateInputs" in str( previous_nodes[0].op_exec_context.input_agnostic): target_node = previous_nodes[0] else: target_node = previous_nodes[1] else: target_node = previous_nodes[0] target_nncf_node_id = target_node.node_id target_nncf_node_key = nncf_graph.get_node_key_by_id( target_nncf_node_id) else: in_port_id = insertion_info.in_port_id if in_port_id is None: # Post-hooking used for activation quantization # Currently only a single post-hook can immediately follow an operation succs = list( nncf_graph.get_successors(affected_nncf_node_key)) assert len(succs) == 1 target_nncf_node_key = succs[0] else: # Pre-hooking used for activation quantization previous_nodes = nncf_graph.get_previous_nodes( affected_nncf_node) target_node = None for prev_node in previous_nodes: prev_edge = nncf_graph.get_nx_edge( prev_node, affected_nncf_node) if prev_edge[NNCFGraph. IN_PORT_NAME_EDGE_ATTR] == in_port_id: target_node = prev_node break assert target_node is not None, "Could not find a pre-hook quantizer node for a specific " \ "input port!" target_nncf_node_id = target_node.node_id target_nncf_node_key = nncf_graph.get_node_key_by_id( target_nncf_node_id) activation_fq_node = nncf_graph.get_nx_node_by_key( target_nncf_node_key) bits = quantizer_info.quantizer_module_ref.num_bits activation_fq_node['color'] = bits_color_map[bits] activation_fq_node['style'] = 'filled' node_id = activation_fq_node[NNCFGraph.ID_NODE_ATTR] activation_fq_node['label'] = 'AFQ_[{}]_#{}'.format( quantizer_info.quantizer_module_ref.get_current_config(), str(node_id)) grouped_mode = bool(groups_of_adjacent_quantizers) if grouped_mode: group_id_str = 'UNDEFINED' group_id = groups_of_adjacent_quantizers.get_group_id_for_quantizer( quantizer_id) if node_id is None: nncf_logger.error( 'No group for activation quantizer: {}'.format( target_nncf_node_key)) else: group_id_str = str(group_id) activation_fq_node['label'] += "_G" + group_id_str
class TracingContext: def __init__(self, name): self.name = name self.graph = NNCFGraph() self._save_context = None self._post_hooks = {} self._pre_hooks = {} self._num_nested_hooks = 0 self._thread_local = threading.local() self._n_instance = 0 self._cond = threading.Condition() self.is_tracing = True self._input_comparators_per_scope = [] def find_operator_node(self, inputs, ia_op_exec_context: InputAgnosticOperationExecutionContext) -> NNCFNode: with self._cond: self._n_instance += 1 tensor_metas = make_input_infos(inputs) node = self.graph.find_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope) with self._cond: self._n_instance -= 1 self._cond.notify_all() if node is None: with self._cond: while self._n_instance > 0: self._cond.wait() # Another thread may have added a node inside this block, # so we need to check again if a node is already added. node = self.graph.find_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope) if node is None: node = self.graph.add_node(ia_op_exec_context, tensor_metas, self._input_comparators_per_scope, inputs) return node def get_caller_context(self, operator_type: str) -> InputAgnosticOperationExecutionContext: """ Designed to work in the following way - for each scope the context will track the number of the calls to the operators with the name operator_type (call_order). The counter values are preserved until reset by a corresponding member function of the context, which must be called after each model iteration - this is usually handled inside NNCF. This mechanism allows to discern between multiple function calls inside the same module that would each require their own instance of compression layers - for instance, multiple `relu` function calls (either on their own or inside a `for` cycle), and at the same moment allow the checkpoints to be loaded if the model had changed in the meantime in a way that does not impact the major function call order (e.g. if comments were added to the .py file with the model) """ version_agnostic_operator_type = get_version_agnostic_name(operator_type) if version_agnostic_operator_type is not None: operator_type = version_agnostic_operator_type call_order = self.get_operator_call_count_in_scope(operator_type, self.scope) ia_op_exec_context = InputAgnosticOperationExecutionContext(operator_type, self.scope, call_order) return ia_op_exec_context def reset_scope_operator_call_counters(self): """ Must be called after each "forward" operation of the model that is made within this context """ self._thread_local.operator_counters = {} @staticmethod def _get_operator_counter_key(operator_name: str, scope: Scope): return "{}_{}".format(str(scope), operator_name) def register_operator_call(self, operator_name: str, scope: Scope): key = self._get_operator_counter_key(operator_name, scope) if key in self._thread_local.operator_counters: self._thread_local.operator_counters[key] += 1 else: self._thread_local.operator_counters[key] = 1 def get_operator_call_count_in_scope(self, operator_name: str, scope: Scope): key = self._get_operator_counter_key(operator_name, scope) if key in self._thread_local.operator_counters: return self._thread_local.operator_counters[key] return 0 def reset_operator_call_count_in_scope(self, scope): scoped_op_name = str(scope) for key in self._thread_local.operator_counters.keys(): if scoped_op_name in key: self._thread_local.operator_counters[key] = 0 def enter(self): global _CURRENT_CONTEXT self._save_context = _CURRENT_CONTEXT _CURRENT_CONTEXT = self self._init_thread_local() def leave(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT = self._save_context self._save_context = None def push_scope(self, scope_module): relative_scopes_list = self._get_scope_relative_to_last_registered_module_call(scope_module) self.scope_modules.append(scope_module) self.relative_scopes_stack.append(relative_scopes_list) def pop_scope(self): self.relative_scopes_stack.pop() self.scope_modules.pop() def register_pre_hooks(self, fn_list: List[Callable], ia_op_exec_context: InputAgnosticOperationExecutionContext): if ia_op_exec_context in self._pre_hooks: raise KeyError("Pre hook for context {} is already registered".format(str(ia_op_exec_context))) self._pre_hooks[ia_op_exec_context] = fn_list def execute_pre_hooks(self, ia_op_exec_context: InputAgnosticOperationExecutionContext, op_inputs: OperatorInput) -> OperatorInput: in_op = getattr(self, 'in_operator', False) self.in_operator = False self._thread_local.num_nested_hooks += 1 if ia_op_exec_context in self._pre_hooks: for hook in self._pre_hooks[ia_op_exec_context]: op_inputs = hook(op_inputs) self._thread_local.num_nested_hooks -= 1 self.in_operator = in_op return op_inputs def register_post_hooks(self, fn_list: List[Callable], ia_op_exec_context: InputAgnosticOperationExecutionContext): if ia_op_exec_context in self._post_hooks: raise KeyError("Post hook for context {} is already registered".format(str(ia_op_exec_context))) self._post_hooks[ia_op_exec_context] = fn_list def execute_post_hooks(self, ia_op_exec_context: InputAgnosticOperationExecutionContext, outputs): in_op = getattr(self, 'in_operator', False) self.in_operator = False self._thread_local.num_nested_hooks += 1 if ia_op_exec_context in self._post_hooks: for hook in self._post_hooks[ia_op_exec_context]: outputs = hook(outputs) self._thread_local.num_nested_hooks -= 1 self.in_operator = in_op return outputs def disable_tracing(self): self.is_tracing = False def enable_tracing(self): self.is_tracing = True def add_node_comparators(self, scopes_to_apply: List[str], node_input_comparator: 'TensorMetaComparator' = None): self._input_comparators_per_scope.append((node_input_comparator, scopes_to_apply)) @property def base_module_thread_local_replica(self): self._init_thread_local() return self._thread_local.base_module_replica @base_module_thread_local_replica.setter def base_module_thread_local_replica(self, value): self._init_thread_local() self._thread_local.base_module_replica = value @property def in_operator(self): self._init_thread_local() return self._thread_local.in_operator @in_operator.setter def in_operator(self, val): self._init_thread_local() self._thread_local.in_operator = val @property def scope_modules(self): self._init_thread_local() return self._thread_local.scope_modules @property def relative_scopes_stack(self) -> List[Scope]: self._init_thread_local() return self._thread_local.scopes def _init_thread_local(self): # todo: master node part! tl = self._thread_local if getattr(tl, 'ready', False): return tl.ready = True tl.scopes = [] tl.scope_modules = [] tl.in_operator = False tl.num_nested_hooks = 0 tl.base_module_replica = None tl.operator_counters = {} tl.node_call_tracker = {} def register_node_call(self, node_key: str): if node_key in self._thread_local.node_call_tracker: self._thread_local.node_call_tracker[node_key] += 1 else: self._thread_local.node_call_tracker[node_key] = 1 def reset_node_call_counters(self): for k, _ in self._thread_local.node_call_tracker.items(): self._thread_local.node_call_tracker[k] = 0 def get_node_call_counter_dict(self): return self._thread_local.node_call_tracker def _get_scope_relative_to_last_registered_module_call(self, module) -> Scope: module_class = module.__class__.__name__ if not self.scope_modules: return Scope([ScopeElement(module_class), ]) q = deque([(tuple(), self.scope_modules[-1])]) while q: scope_parts, top = q.popleft() if module is top: return Scope(list(scope_parts)) for name, child in top.named_children(): scope_element = ScopeElement(child.__class__.__name__, name) q.append((scope_parts + (scope_element,), child)) return Scope([ScopeElement(module_class), ]) @property def scope(self) -> Scope: stack_copy = self.relative_scopes_stack.copy() scope_el_list = [] for relative_scope in stack_copy: for scope_element in relative_scope.scope_elements: scope_el_list.append(scope_element) return Scope(scope_el_list) def reset_graph(self): self.graph = NNCFGraph()
def reset_graph(self): self.graph = NNCFGraph()
def test_graph_pattern_io_building(): graph = NNCFGraph() # 1 # / \ # 2 | # | | # 3 | # \ / # 4 # / | \ # 5 6 7 # | # 8 #pylint:disable=protected-access node_keys = ['1', '2', '3', '4', '5', '6', '7', '8'] for idx, node_key in enumerate(node_keys): attrs = { NNCFGraph.ID_NODE_ATTR: idx + 1, NNCFGraph.KEY_NODE_ATTR: node_key, NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR: None, } graph._nx_graph.add_node(node_key, **attrs) edge_attr = {NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR: None} graph._nx_graph.add_edges_from([('1', '2'), ('1', '4'), ('2', '3'), ('3', '4'), ('4', '5'), ('4', '6'), ('4', '7'), ('5', '8')], **edge_attr) graph._node_id_to_key_dict.update( {k + 1: v for k, v in enumerate(node_keys)}) def make_mock_edge(from_id: int, to_id: int): return NNCFGraphEdge(NNCFNode(from_id, None), NNCFNode(to_id, None), None) def make_mock_node(id_: int): return NNCFNode(id_, None) ref_patterns_and_ios = [ (['1', '2'], NNCFGraphPatternIO( input_edges=[], input_nodes=[make_mock_node(1)], output_edges=[make_mock_edge(2, 3), make_mock_edge(1, 4)], output_nodes=[])), (['3'], NNCFGraphPatternIO(input_edges=[make_mock_edge(2, 3)], input_nodes=[], output_edges=[make_mock_edge(3, 4)], output_nodes=[])), (['1', '2', '3'], NNCFGraphPatternIO( input_edges=[], input_nodes=[make_mock_node(1)], output_edges=[make_mock_edge(3, 4), make_mock_edge(1, 4)], output_nodes=[])), (['4'], NNCFGraphPatternIO( input_edges=[make_mock_edge(3, 4), make_mock_edge(1, 4)], input_nodes=[], output_edges=[ make_mock_edge(4, 5), make_mock_edge(4, 6), make_mock_edge(4, 7) ], output_nodes=[])), (['5', '6', '8'], NNCFGraphPatternIO( input_edges=[make_mock_edge(4, 5), make_mock_edge(4, 6)], input_nodes=[], output_edges=[], output_nodes=[make_mock_node(6), make_mock_node(8)])), (['7'], NNCFGraphPatternIO(input_edges=[make_mock_edge(4, 7)], input_nodes=[], output_edges=[], output_nodes=[make_mock_node(7)])) ] for pattern, ref_pattern_io in ref_patterns_and_ios: test_pattern_io = graph._get_nncf_graph_pattern_io_list(pattern) assert Counter(test_pattern_io.input_edges) == Counter( ref_pattern_io.input_edges) assert Counter(test_pattern_io.output_edges) == Counter( ref_pattern_io.output_edges) assert Counter(test_pattern_io.input_nodes) == Counter( ref_pattern_io.input_nodes) assert Counter(test_pattern_io.output_nodes) == Counter( ref_pattern_io.output_nodes)