def test_backward_lstm_cell(self, sizes, _seed):
        p = sizes
        ref_data = TestLSTMCell.generate_lstm_data(p,
                                                   batch_first=False,
                                                   is_backward=True)
        with torch.no_grad():
            test_data = LSTMTestData(*clone_test_data(ref_data))

        ref_rnn = nn.LSTMCell(p.input_size, p.hidden_size)
        TestLSTMCell.set_weights(ref_rnn, ref_data)
        test_rnn = LSTMCellNNCF(p.input_size, p.hidden_size)
        TestLSTMCell.set_weights(test_rnn, test_data)

        for i in range(p.seq_length):
            ref_result = ref_rnn(ref_data.x[i],
                                 (ref_data.h0[0], ref_data.c0[0]))
            test_result = test_rnn(test_data.x[i],
                                   (test_data.h0[0], test_data.c0[0]))
            ref_result[0].sum().backward()
            test_result[0].sum().backward()
            ref_grads = get_grads([ref_data.h0[0], ref_data.c0[0]])
            ref_grads += get_grads([
                ref_rnn.weight_ih, ref_rnn.weight_hh, ref_rnn.bias_ih,
                ref_rnn.bias_hh
            ])
            test_grads = get_grads([ref_data.h0[0], ref_data.c0[0]])
            test_grads += get_grads([
                test_rnn.weight_ih, test_rnn.weight_hh, test_rnn.bias_ih,
                test_rnn.bias_hh
            ])
            for (ref, test) in list(zip(test_grads, ref_grads)):
                torch.testing.assert_allclose(test, ref)
    def test_forward_lstm_cell(self, sizes, _seed):
        p = sizes
        ref_data = TestLSTMCell.generate_lstm_data(p, batch_first=False)
        test_data = LSTMTestData(*clone_test_data(ref_data))

        ref_rnn = nn.LSTMCell(p.input_size, p.hidden_size)
        TestLSTMCell.set_weights(ref_rnn, ref_data)
        test_rnn = LSTMCellNNCF(p.input_size, p.hidden_size)
        TestLSTMCell.set_weights(test_rnn, test_data)

        for i in range(p.seq_length):
            ref_result = ref_rnn(ref_data.x[i],
                                 (ref_data.h0[0], ref_data.c0[0]))
            test_result = test_rnn(test_data.x[i],
                                   (test_data.h0[0], test_data.c0[0]))
            for (ref, test) in list(zip(ref_result, test_result)):
                torch.testing.assert_allclose(test, ref)
def test_export_lstm_cell(tmp_path):
    config = get_empty_config(model_size=1, input_sample_size=(1, 1))
    config['compression'] = {'algorithm': 'quantization'}

    model, algo = create_compressed_model_and_algo_for_test(
        LSTMCellNNCF(1, 1), config)

    test_path = str(tmp_path.joinpath('test.onnx'))
    algo.export_model(test_path)
    assert os.path.exists(test_path)

    onnx_num = 0
    model = onnx.load(test_path)
    # pylint: disable=no-member
    for node in model.graph.node:
        if node.op_type == 'FakeQuantize':
            onnx_num += 1
    assert onnx_num == 12
Exemple #4
0
def test_export_lstm_cell(tmp_path):
    patch_torch_operators()
    config = get_empty_config(model_size=1, input_sample_size=(1, 1))
    config['compression'] = {'algorithm': 'quantization'}

    config.log_dir = str(tmp_path)
    reset_context('orig')
    reset_context('quantized_graphs')
    algo, model = create_compressed_model(LSTMCellNNCF(1, 1), config)

    test_path = str(tmp_path.joinpath('test.onnx'))
    algo.export_model(test_path)
    assert os.path.exists(test_path)

    onnx_num = 0
    model = onnx.load(test_path)
    # pylint: disable=no-member
    for node in model.graph.node:
        if node.op_type == 'FakeQuantize':
            onnx_num += 1
    assert onnx_num == 13