Esempio n. 1
0
    def test_number_of_nodes_for_module_in_loop__not_input_node(self):
        num_iter = 5
        patch_torch_operators()

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def forward(self, x):
                    s = F.sigmoid(x)
                    t = F.tanh(x)
                    result = F.sigmoid(x) * t + F.tanh(x) * s
                    return result

                @staticmethod
                def nodes_number():
                    return 7

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(F.relu(x))
                return x

            def nodes_number(self):
                return self.inner.nodes_number() + num_iter

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Esempio n. 2
0
    def test_number_of_nodes_for_module_with_nested_loops(self):
        num_iter = 5
        patch_torch_operators()

        class TestIterModule(nn.Module):
            @ITERATION_MODULES.register()
            class TestIterModule_ResetPoint(nn.Module):
                def __init__(self, loop_module):
                    super().__init__()
                    self.loop_module = loop_module

                def forward(self, x):
                    return self.loop_module(F.relu(x))

            def __init__(self):
                super().__init__()
                self.loop_module = self.LoopModule2()
                self.reset_point = self.TestIterModule_ResetPoint(
                    self.loop_module)

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.reset_point(x)
                return x

            class LoopModule2(nn.Module):
                @ITERATION_MODULES.register()
                class LoopModule2_ResetPoint(nn.Module):
                    def __init__(self, inner):
                        super().__init__()
                        self.inner = inner

                    def forward(self, x):
                        return self.inner(F.relu(x))

                def __init__(self):
                    super().__init__()
                    self.inner = self.Inner()
                    self.reset_helper = self.LoopModule2_ResetPoint(self.inner)

                def forward(self, x):
                    for _ in range(num_iter):
                        self.reset_helper(x)
                    return x

                class Inner(nn.Module):
                    def forward(self, x):
                        s = F.sigmoid(x)
                        t = F.tanh(x)
                        result = t + s
                        return result

        test_module = TestIterModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == num_iter
    def test_build_graph(self, model_name, model_builder, input_size):
        net = model_builder()
        ctx = reset_context('test')
        with context('test') as c:
            _ = net(torch.zeros(input_size))
            c.reset_scope_operator_call_counters()
            _ = net(torch.zeros(input_size))

        check_graph(to_networkx(ctx), model_name, 'original')
Esempio n. 4
0
    def test_build_graph(self, model_name, model_builder, forward_fn_):
        net = model_builder()
        net.to(self.device)
        ctx = reset_context('test')
        with context('test') as c:
            forward_fn_(net)
            c.reset_scope_operator_call_counters()
            forward_fn_(net)

        check_graph(ctx.graph, model_name, 'original')
    def build_graph(self, model: torch.nn.Module, context_name: str) -> NNCFGraph:
        logger.info("Building graph with context: {}".format(context_name))
        sd = deepcopy(model.state_dict())

        ctx = get_context(context_name)
        with context(context_name):
            self.custom_forward_fn(model)
        model.load_state_dict(sd)

        if isinstance(model, PostGraphBuildActing):
            model.post_build_graph_actions()
        return ctx.graph
Esempio n. 6
0
def test_iterate_module_list():
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.ml = nn.ModuleList([nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)])

        def forward(self, x):
            return [self.ml[0](x), self.ml[1](x)]

    net = Net()

    ctx = reset_context('test')
    with context('test'):
        _ = net(torch.zeros(1, 1, 1, 1))

    check_graph(ctx.graph, 'case_iterate_module_list.dot', 'original')
 def test_sparse_network(self, model_name, model_builder, input_size, algo, params):
     model = model_builder()
     from nncf.layers import NNCF_MODULES_MAP
     sparsifiable_modules = list(NNCF_MODULES_MAP.values())
     ref_num_sparsed = len(get_all_modules_by_type(model, sparsifiable_modules))
     ctx = reset_context('test')
     config = get_empty_config(input_sample_size=input_size)
     config["compression"] = {"algorithm": algo, "params": params}
     compression_algo = create_compression_algorithm(model, config)
     assert ref_num_sparsed == len(compression_algo.sparsified_module_info)
     model = compression_algo.model
     with context('test') as c:
         _ = model(torch.zeros(input_size))
         c.reset_scope_operator_call_counters()
         _ = model(torch.zeros(input_size))
     check_graph(to_networkx(ctx), model_name, algo)
Esempio n. 8
0
 def test_sparse_network(self, model_name, model_builder, forward_fn_, algo, params):
     model = model_builder()
     from nncf.layers import NNCF_MODULES_MAP
     sparsifiable_modules = list(NNCF_MODULES_MAP.values())
     ref_num_sparsed = len(get_all_modules_by_type(model, sparsifiable_modules))
     ctx = reset_context('test')
     config = get_empty_config()
     config["compression"] = {"algorithm": algo, "params": params}
     compression_algo = create_compression_algorithm(model, config, dummy_forward_fn=forward_fn_)
     assert ref_num_sparsed == len(compression_algo.sparsified_module_info)
     model = compression_algo.model
     model.to(self.device)
     with context('test') as c:
         forward_fn_(model)
         c.reset_scope_operator_call_counters()
         forward_fn_(model)
     check_graph(ctx.graph, model_name, algo)
Esempio n. 9
0
    def test_number_of_nodes_for_module_in_loop(self):
        num_iter = 5
        patch_torch_operators()

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.operator1 = torch.sigmoid
                    self.operator2 = torch.tanh

                def forward(self, x):
                    s = self.operator1(x)
                    t = self.operator2(x)
                    result = t + s
                    return result

                @staticmethod
                def nodes_number():
                    return 3

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(x)
                return x

            def nodes_number(self):
                return self.inner.nodes_number()

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Esempio n. 10
0
    def test_number_of_nodes_for_repeated_module(self):
        patch_torch_operators()

        class LoopModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.operator = F.relu
                self.layers = nn.ModuleList(
                    [nn.Conv2d(1, 1, 1),
                     nn.Conv2d(1, 1, 1)])

            def forward(self, x):
                for layer in self.layers:
                    x = F.relu(layer(x))
                return x

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            x = test_module(torch.zeros(1, 1, 1, 1))
            assert ctx.graph.get_nodes_count() == 4
            _ = test_module(x)
            assert ctx.graph.get_nodes_count() == 8
Esempio n. 11
0
    def test_number_of_calling_fq_for_gnmt(self, tmp_path):
        torch.cuda.set_device(0)
        device = torch.device('cuda')
        batch_first = False
        vocab_size = 32000
        model_config = {
            'hidden_size': 100,
            'vocab_size': vocab_size,
            'num_layers': 4,
            'dropout': 0.2,
            'batch_first': batch_first,
            'share_embedding': True,
        }
        batch_size = 128
        sequence_size = 50
        input_sample_size = (batch_size,
                             sequence_size) if batch_first else (sequence_size,
                                                                 batch_size)
        patch_torch_operators()
        config = get_empty_config(input_sample_size=input_sample_size)
        config['compression'] = \
            {'algorithm': 'quantization',
             'quantize_inputs': True,
             'quantizable_subgraph_patterns': [["linear", "__add__"],
                                               ["sigmoid", "__mul__", "__add__"],
                                               ["__add__", "tanh", "__mul__"],
                                               ["sigmoid", "__mul__"]],
             'scopes_without_shape_matching':
                 ['GNMT/ResidualRecurrentDecoder[decoder]/RecurrentAttention[att_rnn]/BahdanauAttention[attn]'],
             'disable_function_quantization_hooks': True}

        config.log_dir = str(tmp_path)
        reset_context('orig')
        reset_context('quantized_graphs')

        model = GNMT(**model_config)
        model = replace_lstm(model)
        model.to(device)

        def dummy_forward_fn(model, seq_len=sequence_size):
            def gen_packed_sequence():
                seq_list = []
                seq_lens = torch.LongTensor(batch_size).random_(1, seq_len + 1)
                seq_lens = torch.sort(seq_lens, descending=True).values
                for seq_size in seq_lens:
                    seq_list.append(
                        torch.LongTensor(seq_size.item()).random_(
                            1, vocab_size).to(device))
                padded_seq_batch = torch.nn.utils.rnn.pad_sequence(
                    seq_list, batch_first=batch_first)
                return padded_seq_batch, seq_lens

            x_data, seq_lens = gen_packed_sequence()
            input_encoder = x_data
            input_enc_len = seq_lens.to(device)
            input_decoder = gen_packed_sequence()[0]
            model.forward(input_encoder, input_enc_len, input_decoder)

        _, model = create_compressed_model(model, config, dummy_forward_fn)
        model.to(device)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        for name, quantizer in model.all_quantizations.items():
            counter = Counter()
            counters[name] = counter
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
        with context('quantized_graphs') as ctx:
            dummy_forward_fn(model)
            assert ctx.graph.get_nodes_count() == 239
            assert len(counters) == 68
            for name, counter in counters.items():
                if 'cell' in name or "LSTMCellForwardNNCF" in name:
                    assert counter.count == sequence_size, name
                else:
                    assert counter.count == 1, name
            new_seq_len = int(sequence_size / 2)
            dummy_forward_fn(model, new_seq_len)
            assert ctx.graph.get_nodes_count() == 239
            assert len(counters) == 68
            for name, counter in counters.items():
                if 'cell' in name or "LSTMCellForwardNNCF" in name:
                    assert counter.count == sequence_size + new_seq_len, name
                else:
                    assert counter.count == 2, name
Esempio n. 12
0
    def test_number_of_calling_fq_for_lstm(self, tmp_path):
        p = LSTMTestSizes(1, 1, 1, 5)
        num_layers = 2
        bidirectional = True
        num_directions = 2 if bidirectional else 1
        bias = True
        batch_first = False
        patch_torch_operators()
        config = get_empty_config(input_sample_size=(p.seq_length, p.batch,
                                                     p.input_size))
        config['compression'] = {
            'algorithm': 'quantization',
            'quantize_inputs': True
        }

        config.log_dir = str(tmp_path)
        reset_context('orig')
        reset_context('quantized_graphs')
        test_data = TestLSTMCell.generate_lstm_data(p,
                                                    num_layers,
                                                    num_directions,
                                                    bias=bias,
                                                    batch_first=batch_first)

        test_rnn = NNCF_RNN('LSTM',
                            input_size=p.input_size,
                            hidden_size=p.hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            bias=bias,
                            batch_first=batch_first)
        TestLSTM.set_ref_lstm_weights(test_data, test_rnn, num_layers,
                                      num_directions, bias)
        test_hidden = TestLSTM.get_test_lstm_hidden(test_data)

        _ = reset_context('orig')
        _ = reset_context('quantized_graphs')
        _, model = create_compressed_model(test_rnn, config)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        for name, quantizer in model.all_quantizations.items():
            counter = Counter()
            counters[name] = counter
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
        with context('quantized_graphs') as ctx:
            _ = model(test_data.x, test_hidden)
            assert ctx.graph.get_nodes_count() == 110
            ctx.graph.dump_graph(
                os.path.join(config.log_dir, "compressed_graph_next.dot"))
        assert len(counters) == 54
        for counter in counters.values():
            assert counter.count == p.seq_length