Exemple #1
0
 def forward(self, bth, lengths):
     T = bth.shape[1]
     src_mask = sequence_mask(lengths,
                              T).type_as(lengths.data).to(bth.device)
     bth = self.proj(bth)
     output = self.transformer((bth, src_mask.unsqueeze(1).unsqueeze(1)))
     return TransformerEncoderOutput(output=output, src_mask=src_mask)
Exemple #2
0
def viterbi(unary, trans, lengths, start_idx, end_idx, norm=lambda x, y: x):
    """Do Viterbi decode on a batch.

    :param unary: torch.FloatTensor: [T, B, N]
    :param trans: torch.FloatTensor: [1, N, N]
    :param lengths: torch.LongTensor: [B]
    :param start_idx: int: The index of the go token
    :param end_idx: int: The index of the eos token
    :param norm: Callable: This function should take the initial and a dim to
        normalize along.

    :return: torch.LongTensor: [T, B] the padded paths
    :return: torch.FloatTensor: [B] the path scores
    """
    seq_len, batch_size, tag_size = unary.size()
    min_length = torch.min(lengths)
    backpointers = []

    # Alphas: [B, 1, N]
    alphas = torch.Tensor(batch_size, 1, tag_size).fill_(-1e4).to(unary.device)
    alphas[:, 0, start_idx] = 0
    alphas = norm(alphas, -1)

    for i, unary_t in enumerate(unary):
        next_tag_var = alphas + trans
        viterbi, best_tag_ids = torch.max(next_tag_var, 2)
        backpointers.append(best_tag_ids.data)
        new_alphas = viterbi + unary_t
        new_alphas.unsqueeze_(1)
        if i >= min_length:
            mask = (i < lengths).view(-1, 1, 1)
            alphas = alphas.masked_fill(mask, 0) + new_alphas.masked_fill(mask == 0, 0)
        else:
            alphas = new_alphas

    # Add end tag
    terminal_var = alphas.squeeze(1) + trans[:, end_idx, :]
    path_score, best_tag_id = torch.max(terminal_var, 1)

    # Flip lengths
    rev_len = seq_len - lengths - 1

    best_path = [best_tag_id]
    for i, backpointer_t in enumerate(reversed(backpointers)):
        # Get new best tag candidate
        new_best_tag_id = backpointer_t.gather(1, best_tag_id.unsqueeze(1)).squeeze(1)
        # We are going backwards now, if you haven't passed your flipped length
        # then you aren't in your real results yet so we propagate best tag
        # from the argmax on the terminal_var
        mask = (i > rev_len)
        best_tag_id = best_tag_id.masked_fill(mask, 0) + new_best_tag_id.masked_fill(mask == 0, 0)
        best_path.append(best_tag_id)
    _ = best_path.pop()
    best_path.reverse()
    best_path = torch.stack(best_path)
    # Mask out the extra tags (This might be pointless given that anything that
    # will use this as a dense tensor downstream will mask it itself?)
    seq_mask = sequence_mask(lengths).to(best_path.device).transpose(0, 1)
    best_path = best_path.masked_fill(seq_mask == 0, 0)
    return best_path, path_score
Exemple #3
0
    def score_sentence(self, unary, tags, lengths, batch_size):
        """Score a batch of sentences.

        :param unary: torch.FloatTensor: [T, B, N]
        :param tags: torch.LongTensor: [T, B]
        :param lengths: torch.LongTensor: [B]
        :param batch_size: int: B

        :return: torch.FloatTensor: [B]
        """
        trans = self.transitions.squeeze(0)  # [N, N]
        start = torch.full((1, batch_size), self.start_idx, dtype=tags.dtype, device=tags.device)  # [1, B]
        tags = torch.cat([start, tags], 0)  # [T + 1, B]

        # Unfold gives me all slices of size 2 (this tag next tag) from dimension T
        tag_pairs = tags.unfold(0, 2, 1)
        # Move the pair dim to the front and split it into two
        indices = tag_pairs.permute(2, 0, 1).chunk(2)
        trans_score = trans[[indices[1], indices[0]]].squeeze(0)
        # Pull out the values of the tags from the unary scores.
        unary_score = unary.gather(2, tags[1:].unsqueeze(-1)).squeeze(-1)

        mask = sequence_mask(lengths).transpose(0, 1).to(tags.device)
        scores = unary_score + trans_score
        scores = scores.masked_fill(mask == 0, 0)
        scores = scores.sum(0)

        # Add stop tag
        eos_scores = trans[self.end_idx, tags.gather(0, lengths.unsqueeze(0)).squeeze(0)]
        scores = scores + eos_scores
        return scores
def test_attention_masked_valid_probs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    scores = torch.rand(bsz, seq_len)
    score_mask = scores.masked_fill(mask, -1e9)
    attention_weights = F.softmax(score_mask, dim=1)
    for row in attention_weights:
        np.testing.assert_allclose(torch.sum(row).numpy(), 1.0, rtol=1e-5)
Exemple #5
0
def test_attention_masked_valid_probs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    scores = torch.rand(bsz, seq_len)
    score_mask = scores.masked_fill(mask, -1e9)
    attention_weights = F.softmax(score_mask, dim=1)
    for row in attention_weights:
        np.testing.assert_allclose(torch.sum(row).numpy(), 1.0, rtol=1e-5)
Exemple #6
0
 def test_mask(self):
     mask = sequence_mask(self.lengths)
     np_mask = np.zeros((self.batch_size, self.seq_len))
     for i in range(self.batch_size):
         for j in range(self.seq_len):
             if j < self.lengths.data[i]:
                 np_mask[i, j] = 1
     np.testing.assert_allclose(mask.data.numpy(), np_mask)
def test_mask_valid_locs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    np_mask = np.zeros((bsz, seq_len))
    for i in range(bsz):
        for j in range(seq_len):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
Exemple #8
0
def test_mask_valid_locs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    np_mask = np.zeros((bsz, seq_len))
    for i in range(bsz):
        for j in range(seq_len):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
Exemple #9
0
 def test_attention_masked_ignores_pad(self):
     mask = sequence_mask(self.lengths)
     score_mask = self.scores.masked_fill(mask, -1e9)
     attention_weights = F.softmax(score_mask, dim=1)
     for row, length in zip(attention_weights, self.lengths):
         if length.item() == self.seq_len:
             continue
         masked = row[:length.item()]
         np.testing.assert_allclose(masked.data.numpy(), 0.0)
 def test_attention_mask_values(self):
     value = 100000
     mask = sequence_mask(self.lengths)
     score_mask = attention_mask(self.scores, mask, value=value)
     for row, length in zip(score_mask, self.lengths):
         if length.data[0] == self.seq_len:
             continue
         masked = row[length.data[0]:]
         np.testing.assert_allclose(masked.data.numpy(), -value)
Exemple #11
0
def test_mask_mxlen(lengths):
    bsz, lengths, seq_len = lengths
    extra = np.random.randint(2, 11)
    mask = sequence_mask(lengths, seq_len + extra)
    np_mask = np.zeros((bsz, seq_len + extra))
    for i in range(bsz):
        for j in range(seq_len + extra):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
def test_mask_mxlen(lengths):
    bsz, lengths, seq_len = lengths
    extra = np.random.randint(2, 11)
    mask = sequence_mask(lengths, seq_len + extra)
    np_mask = np.zeros((bsz, seq_len + extra))
    for i in range(bsz):
        for j in range(seq_len + extra):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
def test_attention_masked_ignores_pad(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    scores = torch.rand(bsz, seq_len)
    score_mask = scores.masked_fill(mask, -1e9)
    attention_weights = F.softmax(score_mask, dim=1)
    for row, length in zip(attention_weights, lengths):
        if length.item() == seq_len:
            continue
        masked = row[:length.item()]
        np.testing.assert_allclose(masked.data.numpy(), 0.0)
def attn_values_seq_mask(attn, qkv):
    q, k, v = qkv
    B, H, T, _ = q.shape
    q = q.zero_()
    lens = torch.from_numpy(np.random.randint(1, T, size=B))
    mask = sequence_mask(lens, T).unsqueeze(1).unsqueeze(1)
    res, _ = attn(q, k, v, mask=mask)
    res = res.numpy()
    gold = v.numpy()
    for b in range(B):
        for h in range(H):
            for t in range(T):
                np.testing.assert_allclose(res[b, h, t, :],
                                           np.mean(gold[:, :, :lens[b], :],
                                                   axis=2)[b, h, :],
                                           atol=1e-5)
Exemple #15
0
def test_seq_mask_valid_count(lengths):
    bsz, lengths, _ = lengths
    mask = sequence_mask(lengths)
    gold = lengths.sum()
    assert mask.sum() == gold.sum()
Exemple #16
0
def _make_src_mask(output, lengths):
    T = output.shape[1]
    src_mask = sequence_mask(lengths, T).type_as(lengths.data)
    return src_mask
Exemple #17
0
def test_mask_shape(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    assert mask.size(0) == bsz
    assert mask.size(1) == seq_len
Exemple #18
0
 def forward(self, bth, lengths):
     T = bth.shape[1]
     src_mask = sequence_mask(lengths, T).type_as(lengths.data)
     bth = self.proj(bth)
     output = self.transformer(bth, src_mask.unsqueeze(1).unsqueeze(1))
     return TransformerEncoderOutput(output=output, src_mask=src_mask)
def test_mask_shape(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    assert mask.size(0) == bsz
    assert mask.size(1) == seq_len
Exemple #20
0
def _make_src_mask(output, lengths):
    T = output.shape[1]
    src_mask = sequence_mask(lengths, T).type_as(lengths.data).to(device=output.device)
    return src_mask
Exemple #21
0
 def test_mask_shape(self):
     mask = sequence_mask(self.lengths)
     self.assertEqual(mask.size(0), self.batch_size)
     self.assertEqual(mask.size(1), self.seq_len)
 def test_attention_masked_valid_probs(self):
     mask = sequence_mask(self.lengths)
     score_mask = attention_mask(self.scores, mask)
     attention_weights = F.softmax(score_mask, dim=1)
     for row in attention_weights:
         np.testing.assert_allclose(torch.sum(row).data[0], 1)
def test_seq_mask_valid_count(lengths):
    bsz, lengths, _ = lengths
    mask = sequence_mask(lengths)
    gold = lengths.sum()
    assert mask.sum() == gold.sum()
 def test_attention_mask_shape(self):
     mask = sequence_mask(self.lengths)
     score_mask = attention_mask(self.scores, mask)
     self.assertEqual(mask.size(), score_mask.size())
Exemple #25
0
 def test_attention_masked_valid_probs(self):
     mask = sequence_mask(self.lengths)
     score_mask = self.scores.masked_fill(mask, -1e9)
     attention_weights = F.softmax(score_mask, dim=1)
     for row in attention_weights:
         np.testing.assert_allclose(torch.sum(row).numpy(), 1.0, rtol=1e-5)