コード例 #1
0
    def test_auto_regressive_seq_decoder_forward(self):
        batch_size, time_steps, decoder_inout_dim = 2, 3, 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        auto_regressive_seq_decoder = AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            10,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
        )

        encoded_state = torch.rand(batch_size, time_steps, decoder_inout_dim)
        source_mask = torch.ones(batch_size, time_steps).bool()
        target_tokens = {
            "tokens": {
                "tokens": torch.ones(batch_size, time_steps).long()
            }
        }
        source_mask[0, 1:] = False
        encoder_out = {
            "source_mask": source_mask,
            "encoder_outputs": encoded_state
        }

        assert auto_regressive_seq_decoder.forward(encoder_out) == {}
        loss = auto_regressive_seq_decoder.forward(encoder_out,
                                                   target_tokens)["loss"]
        assert loss.shape == torch.Size([]) and loss.requires_grad
        auto_regressive_seq_decoder.eval()
        assert "predictions" in auto_regressive_seq_decoder.forward(
            encoder_out)
コード例 #2
0
    def test_auto_regressive_seq_decoder_indices_to_tokens(self):
        decoder_inout_dim = 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        auto_regressive_seq_decoder = AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
            beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
        )

        predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]])

        tokens_ground_truth = [["B", "A"], ["A", "A", "B"]]
        predicted_tokens = auto_regressive_seq_decoder.indices_to_tokens(
            predictions.numpy())
        assert predicted_tokens == tokens_ground_truth
コード例 #3
0
    def test_auto_regressive_seq_decoder_init(self):
        decoder_inout_dim = 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            10,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
        )

        with pytest.raises(ConfigurationError):
            AutoRegressiveSeqDecoder(
                vocab,
                decoder_net,
                10,
                Embedding(num_embeddings=vocab.get_vocab_size(),
                          embedding_dim=decoder_inout_dim + 1),
            )
コード例 #4
0
    def test_auto_regressive_seq_decoder_post_process(self):
        decoder_inout_dim = 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        auto_regressive_seq_decoder = AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            10,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
        )

        predictions = torch.tensor([[3, 2, 5, 0, 0], [2, 2, 3, 5, 0]])

        tokens_ground_truth = [["B", "A"], ["A", "A", "B"]]

        output_dict = {"predictions": predictions}
        predicted_tokens = auto_regressive_seq_decoder.post_process(
            output_dict)["predicted_tokens"]
        assert predicted_tokens == tokens_ground_truth
コード例 #5
0
    def test_auto_regressive_seq_decoder_init(self):
        decoder_inout_dim = 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
            beam_search=Lazy(BeamSearch, constructor_extras={"max_steps": 10}),
        )

        with pytest.raises(ConfigurationError):
            AutoRegressiveSeqDecoder(
                vocab,
                decoder_net,
                Embedding(num_embeddings=vocab.get_vocab_size(),
                          embedding_dim=decoder_inout_dim + 1),
                beam_search=Lazy(BeamSearch,
                                 constructor_extras={"max_steps": 10}),
            )
コード例 #6
0
    def test_auto_regressive_seq_decoder_tensor_and_token_based_metric(self):
        # set all seeds to a fixed value (torch, numpy, etc.).
        # this enable a deterministic behavior of the `auto_regressive_seq_decoder`
        # below (i.e., parameter initialization and `encoded_state = torch.randn(..)`)
        prepare_environment(Params({}))

        batch_size, time_steps, decoder_inout_dim = 2, 3, 4
        vocab, decoder_net = create_vocab_and_decoder_net(decoder_inout_dim)

        auto_regressive_seq_decoder = AutoRegressiveSeqDecoder(
            vocab,
            decoder_net,
            Embedding(num_embeddings=vocab.get_vocab_size(),
                      embedding_dim=decoder_inout_dim),
            beam_search=Lazy(BeamSearch,
                             constructor_extras={
                                 "max_steps": 10,
                                 "beam_size": 4
                             }),
            tensor_based_metric=BLEU(),
            token_based_metric=DummyMetric(),
        ).eval()

        encoded_state = torch.randn(batch_size, time_steps, decoder_inout_dim)
        source_mask = torch.ones(batch_size, time_steps).bool()
        target_tokens = {
            "tokens": {
                "tokens": torch.ones(batch_size, time_steps).long()
            }
        }
        source_mask[0, 1:] = False
        encoder_out = {
            "source_mask": source_mask,
            "encoder_outputs": encoded_state
        }

        auto_regressive_seq_decoder.forward(encoder_out, target_tokens)
        assert auto_regressive_seq_decoder.get_metrics(
        )["BLEU"] == 1.388809517005903e-11
        assert auto_regressive_seq_decoder.get_metrics()["em"] == 0.0
        assert auto_regressive_seq_decoder.get_metrics()["f1"] == 1 / 3