Ejemplo n.º 1
0
def get_all_modules_by_type(model,
                            module_types,
                            current_scope=None,
                            ignored_scopes=None,
                            target_scopes=None) -> Dict['Scope', Module]:
    if isinstance(module_types, str):
        module_types = [module_types]
    found = OrderedDict()
    from nncf.dynamic_graph.context import Scope
    from nncf.dynamic_graph.context import ScopeElement
    if current_scope is None:
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))
    for name, module in model.named_children():
        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)

        if in_scope_list(str(child_scope), ignored_scopes):
            continue

        if target_scopes is None or in_scope_list(str(child_scope),
                                                  target_scopes):
            if module_types.count(str(type(module).__name__)) != 0:
                found[child_scope] = module
            sub_found = get_all_modules_by_type(module,
                                                module_types,
                                                current_scope=child_scope,
                                                ignored_scopes=ignored_scopes,
                                                target_scopes=target_scopes)
            if sub_found:
                found.update(sub_found)
    return found
Ejemplo n.º 2
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules in scope: {}".format(
                            child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    if isinstance(model, nn.Sequential):
                        # pylint: disable=protected-access
                        model._modules[name] = replaced_module
                    else:
                        setattr(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
        _, affected_scopes = replace_modules(module, replace_fn,
                                             affected_scopes, ignored_scopes,
                                             target_scopes, memo, child_scope)
    return model, affected_scopes
Ejemplo n.º 3
0
def test_can_load_quant_algo__with_defaults():
    model = BasicConvTestModel()
    config = get_basic_quantization_config()
    compression_algo_builder_list = create_compression_algorithm_builders(config)
    assert len(compression_algo_builder_list) == 1
    assert isinstance(compression_algo_builder_list[0], QuantizationBuilder)

    quant_model, _ = create_compressed_model_and_algo_for_test(deepcopy(model), config)

    model_conv = get_all_modules_by_type(model, 'Conv2d')
    quant_model_conv = get_all_modules_by_type(quant_model.get_nncf_wrapped_model(), 'NNCFConv2d')
    assert len(model_conv) == len(quant_model_conv)

    for module_scope, _ in model_conv.items():
        quant_scope = deepcopy(module_scope)  # type: Scope
        quant_scope.pop()
        quant_scope.push(ScopeElement('NNCFConv2d', 'conv'))
        assert quant_scope in quant_model_conv.keys()

        store = []
        for op in quant_model_conv[quant_scope].pre_ops.values():
            if isinstance(op, (UpdateInputs, UpdateWeight)) and isinstance(op.operand, SymmetricQuantizer):
                assert op.__class__.__name__ not in store
                store.append(op.__class__.__name__)
        assert UpdateWeight.__name__ in store
Ejemplo n.º 4
0
def get_mock_quantizer_id(key: str) -> WeightQuantizerId:
    scope_element = ScopeElement(key)
    scope = Scope([scope_element])
    return WeightQuantizerId(scope)
Ejemplo n.º 5
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None,
                    eval_ops_exec_ctx_str: List[str] = None,
                    reset: bool = False):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules specified in scope: {}".
                        format(child_scope))
                    continue
                if eval_ops_exec_ctx_str is None:
                    eval_ops_exec_ctx_str = []
                is_ignored = True
                for op_ctx_str in eval_ops_exec_ctx_str:
                    full_op_scope = Scope.from_str(op_ctx_str)
                    # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode
                    # inside it
                    if full_op_scope in child_scope:
                        is_ignored = False
                        break
                if is_ignored and eval_ops_exec_ctx_str:
                    nncf_logger.info(
                        "Ignored wrapping modules not called in eval mode in scope: {}"
                        .format(child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    set_replaced_module_by_name(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
                if reset:
                    replaced_module.reset()
        _, affected_scopes = replace_modules(module,
                                             replace_fn,
                                             affected_scopes,
                                             ignored_scopes,
                                             target_scopes,
                                             memo,
                                             child_scope,
                                             eval_ops_exec_ctx_str,
                                             reset=reset)
    return model, affected_scopes