Exemplo n.º 1
0
def get_all_modules_by_type(model,
                            module_types=None,
                            current_scope=None,
                            ignored_scopes=None,
                            target_scopes=None) -> Dict['Scope', Module]:
    if isinstance(module_types, str):
        module_types = [module_types]
    found = OrderedDict()
    from nncf.torch.dynamic_graph.scope import Scope
    from nncf.torch.dynamic_graph.scope import ScopeElement
    if current_scope is None:
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))
    for name, module in model.named_children():
        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)

        if matches_any(str(child_scope), ignored_scopes):
            continue

        if target_scopes is None or matches_any(str(child_scope),
                                                target_scopes):
            if module_types is None or module_types.count(
                    str(type(module).__name__)) != 0:
                found[child_scope] = module
            sub_found = get_all_modules_by_type(module,
                                                module_types,
                                                current_scope=child_scope,
                                                ignored_scopes=ignored_scopes,
                                                target_scopes=target_scopes)
            if sub_found:
                found.update(sub_found)
    return found
Exemplo n.º 2
0
def get_mock_nncf_node_attrs(op_name=None, scope_str=None):
    op_name_to_set = op_name if op_name is not None else MOCK_OPERATOR_NAME
    scope_to_set = Scope() if scope_str is None else Scope.from_str(scope_str)
    return {
        NNCFGraph.NODE_NAME_ATTR: str(OperationAddress(op_name_to_set, scope_to_set, 0)),
        NNCFGraph.NODE_TYPE_ATTR: op_name_to_set
    }
Exemplo n.º 3
0
 def scope(self) -> Scope:
     stack_copy = self.relative_scopes_stack.copy()
     scope_el_list = []
     for relative_scope in stack_copy:
         for scope_element in relative_scope.scope_elements:
             scope_el_list.append(scope_element)
     return Scope(scope_el_list)
Exemplo n.º 4
0
 def _get_scope_relative_to_last_registered_module_call(self,
                                                        module) -> Scope:
     module_class = module.__class__.__name__
     if not self.module_call_stack:
         return Scope([
             ScopeElement(module_class),
         ])
     q = deque([(tuple(), self.module_call_stack[-1])])
     while q:
         scope_parts, top = q.popleft()
         if module is top:
             return Scope(list(scope_parts))
         for name, child in top.named_children():
             scope_element = ScopeElement(child.__class__.__name__, name)
             q.append((scope_parts + (scope_element, ), child))
     return Scope([
         ScopeElement(module_class),
     ])
Exemplo n.º 5
0
def get_nncf_graph_from_mock_nx_graph(nx_graph: nx.DiGraph) -> PTNNCFGraph:
    # pylint:disable=too-many-branches
    mock_graph = PTNNCFGraph()
    key_vs_id = {}
    edge_vs_output_idx_and_creator_id = {}  # type: Dict[Tuple[str, str], Tuple[int, int]]
    from networkx.algorithms.dag import lexicographical_topological_sort
    for idx, curr_node_key in enumerate(lexicographical_topological_sort(nx_graph)):
        node = nx_graph.nodes[curr_node_key]
        if NNCFGraph.NODE_NAME_ATTR in node:
            node_name = node[NNCFGraph.NODE_NAME_ATTR]
        else:
            node_name = str(OperationAddress(curr_node_key, Scope(), 0))

        if NNCFGraph.NODE_TYPE_ATTR in node:
            node_type = node[NNCFGraph.NODE_TYPE_ATTR]
        else:
            node_type = curr_node_key

        layer_attributes = node.get(NNCFGraph.LAYER_ATTRIBUTES)

        if NNCFGraph.METATYPE_ATTR in node:
            metatype = node[NNCFGraph.METATYPE_ATTR]
        else:
            metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
            if metatype is not UnknownMetatype:
                if metatype.subtypes:
                    subtype = metatype.determine_subtype(layer_attributes=layer_attributes)
                    if subtype is not None:
                        metatype = subtype

        node_id = idx
        node = mock_graph.add_nncf_node(
            node_name=node_name,
            node_type=node_type,
            node_metatype=metatype,
            layer_attributes=layer_attributes,
            node_id_override=idx)
        key_vs_id[curr_node_key] = node_id

        preds = list(nx_graph.predecessors(curr_node_key))
        for pred_idx, pred in enumerate(preds):
            in_edge = (pred, curr_node_key)
            out_idx, creator_id = edge_vs_output_idx_and_creator_id[in_edge]
            edge_data = nx_graph.edges[in_edge]
            if NNCFGraph.DTYPE_EDGE_ATTR in edge_data:
                dtype = edge_data[NNCFGraph.DTYPE_EDGE_ATTR]
            else:
                dtype = Dtype.FLOAT
            mock_graph.add_edge_between_nncf_nodes(creator_id, node_id,
                                                   [1, 1, 1, 1], input_port_id=pred_idx,
                                                   output_port_id=out_idx,
                                                   dtype=dtype)

        for out_idx, out_edge in enumerate(nx_graph.out_edges(curr_node_key)):
            edge_vs_output_idx_and_creator_id[out_edge] = (out_idx, node.node_id)
    return mock_graph
Exemplo n.º 6
0
    def test_inplace_apply_filter_binary_mask(mask, reference_weight,
                                              reference_bias):
        """
        Test that inplace_apply_filter_binary_mask changes the input weight and returns valid result.
        """
        nncf_module = NNCFConv2d(1, 2, 2)
        fill_conv_weight(nncf_module, 1)
        fill_bias(nncf_module, 1)

        result_weight = inplace_apply_filter_binary_mask(
            mask, nncf_module.weight.data, Scope())
        assert torch.allclose(result_weight, reference_weight)
        assert torch.allclose(nncf_module.weight, reference_weight)

        result_bias = inplace_apply_filter_binary_mask(mask,
                                                       nncf_module.bias.data,
                                                       Scope())
        assert torch.allclose(result_bias, reference_bias)
        assert torch.allclose(nncf_module.bias, reference_bias)
Exemplo n.º 7
0
def test_assert_broadcastable_mask_and_weight_shape():
    nncf_module = NNCFConv2d(1, 2, 2)
    fill_conv_weight(nncf_module, 1)
    fill_bias(nncf_module, 1)

    mask = torch.zeros(10)

    with pytest.raises(RuntimeError):
        inplace_apply_filter_binary_mask(mask, nncf_module.weight.data,
                                         Scope())

    with pytest.raises(RuntimeError):
        apply_filter_binary_mask(mask, nncf_module.weight.data)
Exemplo n.º 8
0
 def _normalize_variable_recurrent_scope(scope: Scope):
     """
     Two scopes pointing to an NNCF module that only differ in a Recurrent/VariableRecurrent/VariableRecurrentReverse
     scope node actually point to one and the same module.
     """
     ret_scope = scope.copy()
     for scope_element in ret_scope:
         if scope_element.calling_module_class_name in [
                 "Recurrent", "VariableRecurrent",
                 "VariableRecurrentReverse"
         ]:
             scope_element.calling_module_class_name = "NormalizedName_Recurrent"
     return ret_scope
Exemplo n.º 9
0
def replace_modules(model: nn.Module, replace_fn, affected_scopes, ignored_scopes=None, target_scopes=None, memo=None,
                    current_scope=None, eval_op_scopes: List[Scope] = None, reset: bool = False):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if matches_any(str(child_scope), ignored_scopes):
                    nncf_logger.info("Ignored wrapping modules specified in scope: {}".format(child_scope))
                    continue
                if eval_op_scopes is None:
                    eval_op_scopes = []
                is_ignored = True
                for eval_op_scope in eval_op_scopes:
                    # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode
                    # inside it
                    if eval_op_scope in child_scope:
                        is_ignored = False
                        break
                if is_ignored and eval_op_scopes:
                    nncf_logger.info(
                        "Ignored wrapping modules not called in eval mode in scope: {}".format(child_scope))
                    continue

                if target_scopes is None or matches_any(str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(str(child_scope),
                                                                       str(replaced_scope)))
                    set_replaced_module_by_name(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
                if reset:
                    replaced_module.reset()
        _, affected_scopes = replace_modules(module, replace_fn, affected_scopes, ignored_scopes, target_scopes,
                                             memo, child_scope, eval_op_scopes, reset=reset)
    return model, affected_scopes
Exemplo n.º 10
0
 def from_str(s: str):
     scope_and_op, _, call_order_str = s.rpartition('_')
     scope_str, _, op_name = scope_and_op.rpartition('/')
     return OperationAddress(op_name, Scope.from_str(scope_str),
                             int(call_order_str))
Exemplo n.º 11
0
def module_scope_from_node_name(name):
    module_name = name.rsplit('/', 1)[0].split(' ', 1)[1]
    return Scope.from_str(module_name)