Beispiel #1
0
    def test_single_sentence_multiple_steps(self):
        x, t, m = masked_tensor_from_sentences([[0, 1, 2, 3]])

        e_input = tensor([
            [0, 1, 2],
        ])
        e_target = tensor([
            [1, 2, 3],
        ])
        e_mask = tensor([
            [1, 1, 1],
        ])

        self.assertEqual(x, e_input)
        self.assertEqual(t, e_target)
        self.assertEqual(m, e_mask)
Beispiel #2
0
    def test_multiple_sentences_nonmatching(self):
        sentences = [
            [0, 1, 2, 3],
            [4, 5, 6],
        ]
        x, t, m = masked_tensor_from_sentences(sentences)

        e_input = tensor([
            [0, 1, 2],
            [4, 5, 0],
        ])
        e_target = tensor([
            [1, 2, 3],
            [5, 6, 0],
        ])
        e_mask = tensor([
            [1, 1, 1],
            [1, 1, 0],
        ])

        self.assertEqual(x, e_input)
        self.assertEqual(t, e_target)
        self.assertEqual(m, e_mask)
Beispiel #3
0
    def test_single_token_multi_token(self):
        sentences = [
            [0],
            [1, 2],
        ]
        x, t, m = masked_tensor_from_sentences(sentences, target_all=True)

        e_input = tensor([
            [0],
            [1],
        ])
        e_target = tensor([
            [0, 0],
            [1, 2],
        ])
        e_mask = tensor([
            [1, 0],
            [1, 1],
        ])

        self.assertEqual(x, e_input)
        self.assertEqual(t, e_target)
        self.assertEqual(m, e_mask)
Beispiel #4
0
    def test_target_all(self):
        sentences = [
            [4, 5, 6],
            [0, 1, 2, 3],
        ]
        x, t, m = masked_tensor_from_sentences(sentences, target_all=True)

        e_input = tensor([
            [4, 5, 0],
            [0, 1, 2],
        ])
        e_target = tensor([
            [4, 5, 6, 0],
            [0, 1, 2, 3],
        ])
        e_mask = tensor([
            [1, 1, 1, 0],
            [1, 1, 1, 1],
        ])

        self.assertEqual(x, e_input)
        self.assertEqual(t, e_target)
        self.assertEqual(m, e_mask)
Beispiel #5
0
    def batch_nll_idxs(self, idxs, h0_provider=None, return_h=False):
        '''Provides the negative log-probability of a batch of sequences of indexes
        '''
        if h0_provider is None:
            h0_provider = self.model.init_hidden

        device = self.device
        input, target, mask = masked_tensor_from_sentences(idxs,
                                                           device=device,
                                                           target_all=True)
        batch_size = input.shape[0]

        h0 = h0_provider(batch_size)
        o, new_h = self.model(input, h0)

        o0 = self.model.extract_output_from_h(h0).unsqueeze(1)
        o = torch.cat([o0, o], dim=1)

        all_nlllh = self.decoder.neg_log_prob_raw(o, target)

        if return_h:
            return all_nlllh * mask, new_h
        else:
            return all_nlllh * mask
Beispiel #6
0
    def test_cuda(self):
        x, t, m = masked_tensor_from_sentences([[0, 1]], device='cuda')

        self.assertTrue(x.is_cuda)
        self.assertTrue(t.is_cuda)
        self.assertTrue(m.is_cuda)