def test_backward_lstm_cell(self, sizes, _seed): p = sizes ref_data = TestLSTMCell.generate_lstm_data(p, batch_first=False, is_backward=True) with torch.no_grad(): test_data = LSTMTestData(*clone_test_data(ref_data)) ref_rnn = nn.LSTMCell(p.input_size, p.hidden_size) TestLSTMCell.set_weights(ref_rnn, ref_data) test_rnn = LSTMCellNNCF(p.input_size, p.hidden_size) TestLSTMCell.set_weights(test_rnn, test_data) for i in range(p.seq_length): ref_result = ref_rnn(ref_data.x[i], (ref_data.h0[0], ref_data.c0[0])) test_result = test_rnn(test_data.x[i], (test_data.h0[0], test_data.c0[0])) ref_result[0].sum().backward() test_result[0].sum().backward() ref_grads = get_grads([ref_data.h0[0], ref_data.c0[0]]) ref_grads += get_grads([ ref_rnn.weight_ih, ref_rnn.weight_hh, ref_rnn.bias_ih, ref_rnn.bias_hh ]) test_grads = get_grads([ref_data.h0[0], ref_data.c0[0]]) test_grads += get_grads([ test_rnn.weight_ih, test_rnn.weight_hh, test_rnn.bias_ih, test_rnn.bias_hh ]) for (ref, test) in list(zip(test_grads, ref_grads)): torch.testing.assert_allclose(test, ref)
def test_forward_lstm_cell(self, sizes, _seed): p = sizes ref_data = TestLSTMCell.generate_lstm_data(p, batch_first=False) test_data = LSTMTestData(*clone_test_data(ref_data)) ref_rnn = nn.LSTMCell(p.input_size, p.hidden_size) TestLSTMCell.set_weights(ref_rnn, ref_data) test_rnn = LSTMCellNNCF(p.input_size, p.hidden_size) TestLSTMCell.set_weights(test_rnn, test_data) for i in range(p.seq_length): ref_result = ref_rnn(ref_data.x[i], (ref_data.h0[0], ref_data.c0[0])) test_result = test_rnn(test_data.x[i], (test_data.h0[0], test_data.c0[0])) for (ref, test) in list(zip(ref_result, test_result)): torch.testing.assert_allclose(test, ref)
def test_export_lstm_cell(tmp_path): config = get_empty_config(model_size=1, input_sample_size=(1, 1)) config['compression'] = {'algorithm': 'quantization'} model, algo = create_compressed_model_and_algo_for_test( LSTMCellNNCF(1, 1), config) test_path = str(tmp_path.joinpath('test.onnx')) algo.export_model(test_path) assert os.path.exists(test_path) onnx_num = 0 model = onnx.load(test_path) # pylint: disable=no-member for node in model.graph.node: if node.op_type == 'FakeQuantize': onnx_num += 1 assert onnx_num == 12
def test_export_lstm_cell(tmp_path): patch_torch_operators() config = get_empty_config(model_size=1, input_sample_size=(1, 1)) config['compression'] = {'algorithm': 'quantization'} config.log_dir = str(tmp_path) reset_context('orig') reset_context('quantized_graphs') algo, model = create_compressed_model(LSTMCellNNCF(1, 1), config) test_path = str(tmp_path.joinpath('test.onnx')) algo.export_model(test_path) assert os.path.exists(test_path) onnx_num = 0 model = onnx.load(test_path) # pylint: disable=no-member for node in model.graph.node: if node.op_type == 'FakeQuantize': onnx_num += 1 assert onnx_num == 13