Ejemplo n.º 1
0
    def test_char_rnn_equivalent(self):
        """Ensure that the CharRNNEncoder.onnx_export_model path does not
        change computation"""
        test_args = test_utils.ModelParamsDict(
            encoder_bidirectional=True, sequence_lstm=True
        )
        lexical_dictionaries = test_utils.create_lexical_dictionaries()
        test_args.vocab_reduction_params = {
            "lexical_dictionaries": lexical_dictionaries,
            "num_top_words": 5,
            "max_translation_candidates_per_word": 1,
        }

        test_args.arch = "char_source"
        test_args.char_source_dict_size = 126
        test_args.char_embed_dim = 8
        test_args.char_rnn_units = 12
        test_args.char_rnn_layers = 2

        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(task.build_model(test_args))
        encoder_ensemble = CharSourceEncoderEnsemble(model_list)

        length = 5
        src_tokens = torch.LongTensor(
            np.random.randint(0, len(src_dict), (length, 1), dtype="int64")
        )
        src_lengths = torch.IntTensor(np.array([length], dtype="int32"))
        word_length = 3
        char_inds = torch.LongTensor(
            np.random.randint(0, 126, (1, length, word_length), dtype="int64")
        )
        word_lengths = torch.IntTensor(
            np.array([word_length] * length, dtype="int32")
        ).reshape((1, length))

        onnx_path_outputs = encoder_ensemble(
            src_tokens, src_lengths, char_inds, word_lengths
        )

        for model in encoder_ensemble.models:
            model.encoder.onnx_export_model = False

        original_path_outputs = encoder_ensemble(
            src_tokens, src_lengths, char_inds, word_lengths
        )

        for (onnx_out, original_out) in zip(onnx_path_outputs, original_path_outputs):
            onnx_array = onnx_out.detach().numpy()
            original_array = original_out.detach().numpy()
            assert onnx_array.shape == original_array.shape
            np.testing.assert_allclose(onnx_array, original_array)
Ejemplo n.º 2
0
    def _test_ensemble_encoder_export_char_source(self, test_args):
        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.DictionaryHolderTask(src_dict, tgt_dict)

        num_models = 3
        model_list = []
        for _ in range(num_models):
            model_list.append(task.build_model(test_args))
        encoder_ensemble = CharSourceEncoderEnsemble(model_list)

        tmp_dir = tempfile.mkdtemp()
        encoder_pb_path = os.path.join(tmp_dir, "char_encoder.pb")
        encoder_ensemble.onnx_export(encoder_pb_path)

        length = 5
        src_tokens = torch.LongTensor(np.ones((length, 1), dtype="int64"))
        src_lengths = torch.IntTensor(np.array([length], dtype="int32"))
        word_length = 3
        char_inds = torch.LongTensor(
            np.ones((1, length, word_length), dtype="int64"))
        word_lengths = torch.IntTensor(
            np.array([word_length] * length, dtype="int32")).reshape(
                (1, length))

        pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths,
                                                   char_inds, word_lengths)

        onnx_encoder = caffe2_backend.prepare_zip_archive(encoder_pb_path)

        caffe2_encoder_outputs = onnx_encoder.run((
            src_tokens.numpy(),
            src_lengths.numpy(),
            char_inds.numpy(),
            word_lengths.numpy(),
        ))

        for i in range(len(pytorch_encoder_outputs)):
            caffe2_out_value = caffe2_encoder_outputs[i]
            pytorch_out_value = pytorch_encoder_outputs[i].detach().numpy()
            np.testing.assert_allclose(caffe2_out_value,
                                       pytorch_out_value,
                                       rtol=1e-4,
                                       atol=1e-6)

        encoder_ensemble.save_to_db(
            os.path.join(tmp_dir, "encoder.predictor_export"))