def test_single_sentence_multiple_steps(self): x, t, m = masked_tensor_from_sentences([[0, 1, 2, 3]]) e_input = tensor([ [0, 1, 2], ]) e_target = tensor([ [1, 2, 3], ]) e_mask = tensor([ [1, 1, 1], ]) self.assertEqual(x, e_input) self.assertEqual(t, e_target) self.assertEqual(m, e_mask)
def test_multiple_sentences_nonmatching(self): sentences = [ [0, 1, 2, 3], [4, 5, 6], ] x, t, m = masked_tensor_from_sentences(sentences) e_input = tensor([ [0, 1, 2], [4, 5, 0], ]) e_target = tensor([ [1, 2, 3], [5, 6, 0], ]) e_mask = tensor([ [1, 1, 1], [1, 1, 0], ]) self.assertEqual(x, e_input) self.assertEqual(t, e_target) self.assertEqual(m, e_mask)
def test_single_token_multi_token(self): sentences = [ [0], [1, 2], ] x, t, m = masked_tensor_from_sentences(sentences, target_all=True) e_input = tensor([ [0], [1], ]) e_target = tensor([ [0, 0], [1, 2], ]) e_mask = tensor([ [1, 0], [1, 1], ]) self.assertEqual(x, e_input) self.assertEqual(t, e_target) self.assertEqual(m, e_mask)
def test_target_all(self): sentences = [ [4, 5, 6], [0, 1, 2, 3], ] x, t, m = masked_tensor_from_sentences(sentences, target_all=True) e_input = tensor([ [4, 5, 0], [0, 1, 2], ]) e_target = tensor([ [4, 5, 6, 0], [0, 1, 2, 3], ]) e_mask = tensor([ [1, 1, 1, 0], [1, 1, 1, 1], ]) self.assertEqual(x, e_input) self.assertEqual(t, e_target) self.assertEqual(m, e_mask)
def batch_nll_idxs(self, idxs, h0_provider=None, return_h=False): '''Provides the negative log-probability of a batch of sequences of indexes ''' if h0_provider is None: h0_provider = self.model.init_hidden device = self.device input, target, mask = masked_tensor_from_sentences(idxs, device=device, target_all=True) batch_size = input.shape[0] h0 = h0_provider(batch_size) o, new_h = self.model(input, h0) o0 = self.model.extract_output_from_h(h0).unsqueeze(1) o = torch.cat([o0, o], dim=1) all_nlllh = self.decoder.neg_log_prob_raw(o, target) if return_h: return all_nlllh * mask, new_h else: return all_nlllh * mask
def test_cuda(self): x, t, m = masked_tensor_from_sentences([[0, 1]], device='cuda') self.assertTrue(x.is_cuda) self.assertTrue(t.is_cuda) self.assertTrue(m.is_cuda)