def replace_fn(module_):
        if not isinstance(module_, nn.LSTM):
            return module_
        device = next(module_.parameters()).device
        custom_lstm = NNCF_RNN('LSTM',
                               input_size=module_.input_size,
                               hidden_size=module_.hidden_size,
                               num_layers=module_.num_layers,
                               bidirectional=module_.bidirectional,
                               batch_first=module_.batch_first,
                               dropout=module_.dropout,
                               bias=module_.bias)

        def get_param_names(bias):
            # type: (bool) -> List[str]
            suffixes = ['ih', 'hh']
            names = ['weight_' + suffix for suffix in suffixes]
            if bias:
                names += ['bias_' + suffix for suffix in suffixes]
            return names

        for l in range(custom_lstm.num_layers):
            for d in range(custom_lstm.num_directions):
                for name in get_param_names(custom_lstm.bias):
                    suffix = '_reverse' if d == 1 else ''
                    param_name = name + '_l{}{}'.format(l, suffix)
                    param = getattr(module_, param_name)
                    getattr(custom_lstm, param_name).data.copy_(param.data)
        custom_lstm.to(device)
        return custom_lstm
Esempio n. 2
0
    def test_number_of_calling_fq_for_lstm(self):
        p = LSTMTestSizes(1, 1, 1, 5)
        num_layers = 2
        bidirectional = True
        num_directions = 2 if bidirectional else 1
        bias = True
        batch_first = False
        config = get_empty_config(
            input_sample_sizes=[p.seq_length, p.batch, p.input_size])
        config['compression'] = {
            'algorithm': 'quantization',
            'quantize_inputs': True
        }

        test_data = TestLSTMCell.generate_lstm_data(p,
                                                    num_layers,
                                                    num_directions,
                                                    bias=bias,
                                                    batch_first=batch_first)

        test_rnn = NNCF_RNN('LSTM',
                            input_size=p.input_size,
                            hidden_size=p.hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            bias=bias,
                            batch_first=batch_first)
        TestLSTM.set_ref_lstm_weights(test_data, test_rnn, num_layers,
                                      num_directions, bias)
        test_hidden = TestLSTM.get_test_lstm_hidden(test_data)

        model, algo = create_compressed_model_and_algo_for_test(
            test_rnn, config)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        counter_for_input_quantizer = None
        for name, quantizer in algo.all_quantizations.items():
            counter = Counter()
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
            if str(name) == '/nncf_model_input_0|OUTPUT':
                counter_for_input_quantizer = counter
                continue
            counters[name] = counter
        _ = model(test_data.x, test_hidden)
        assert model.get_graph().get_nodes_count(
        ) == 112  # NB: may always fail in debug due to superfluous 'cat' nodes
        assert len(counters) + 1 == 55  # 8 WQ + 46 AQ + 1 input AQ
        for counter in counters.values():
            assert counter.count == p.seq_length
        assert counter_for_input_quantizer.count == 1
def test_export_stacked_bi_lstm(tmp_path):
    p = LSTMTestSizes(3, 3, 3, 3)
    config = get_empty_config(input_sample_size=(1, p.hidden_size,
                                                 p.input_size))
    config['compression'] = {'algorithm': 'quantization'}

    # TODO: batch_first=True fails with building graph: ambiguous call to mul or sigmoid
    test_rnn = NNCF_RNN('LSTM',
                        input_size=p.input_size,
                        hidden_size=p.hidden_size,
                        num_layers=2,
                        bidirectional=True,
                        batch_first=False)
    model, algo = create_compressed_model_and_algo_for_test(test_rnn, config)

    test_path = str(tmp_path.joinpath('test.onnx'))
    algo.export_model(test_path)
    assert os.path.exists(test_path)

    onnx_num = 0
    model = onnx.load(test_path)
    # pylint: disable=no-member
    for node in model.graph.node:
        if node.op_type == 'FakeQuantize':
            onnx_num += 1
    assert onnx_num == 50
Esempio n. 4
0
    def test_number_of_calling_fq_for_lstm(self, tmp_path):
        p = LSTMTestSizes(1, 1, 1, 5)
        num_layers = 2
        bidirectional = True
        num_directions = 2 if bidirectional else 1
        bias = True
        batch_first = False
        patch_torch_operators()
        config = get_empty_config(input_sample_size=(p.seq_length, p.batch,
                                                     p.input_size))
        config['compression'] = {
            'algorithm': 'quantization',
            'quantize_inputs': True
        }

        config.log_dir = str(tmp_path)
        reset_context('orig')
        reset_context('quantized_graphs')
        test_data = TestLSTMCell.generate_lstm_data(p,
                                                    num_layers,
                                                    num_directions,
                                                    bias=bias,
                                                    batch_first=batch_first)

        test_rnn = NNCF_RNN('LSTM',
                            input_size=p.input_size,
                            hidden_size=p.hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            bias=bias,
                            batch_first=batch_first)
        TestLSTM.set_ref_lstm_weights(test_data, test_rnn, num_layers,
                                      num_directions, bias)
        test_hidden = TestLSTM.get_test_lstm_hidden(test_data)

        _ = reset_context('orig')
        _ = reset_context('quantized_graphs')
        _, model = create_compressed_model(test_rnn, config)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        for name, quantizer in model.all_quantizations.items():
            counter = Counter()
            counters[name] = counter
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
        with context('quantized_graphs') as ctx:
            _ = model(test_data.x, test_hidden)
            assert ctx.graph.get_nodes_count() == 110
            ctx.graph.dump_graph(
                os.path.join(config.log_dir, "compressed_graph_next.dot"))
        assert len(counters) == 54
        for counter in counters.values():
            assert counter.count == p.seq_length