Exemplo n.º 1
0
def test_sentence_pattern():
    pad = 0

    sentences = torch.tensor([
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 2, 3, 4, 5, 6, 0, 0],
        [1, 2, 3, 4, 1, 2, 3, 4],
        [1, 2, 3, 4, 1, 2, 0, 0],
        [1, 1, 2, 3, 2, 3, 5, 0],
        [1, 1, 2, 3, 2, 3, 0, 0],
    ])

    max_len = sentences.size(1)
    pad_mask = sentences.eq(pad)

    # pattern: (batch_size, max_len, max_len)
    pattern = utils.sentence_pattern(sentences, pad_mask=pad_mask)

    golden_pattern_1 = torch.eye(max_len, dtype=torch.bool)

    golden_pattern_2 = torch.zeros(max_len, max_len, dtype=torch.bool)
    golden_pattern_2[range(6), range(6)] = 1

    golden_pattern_3 = torch.eye(max_len // 2, dtype=torch.bool)
    golden_pattern_3 = torch.cat([golden_pattern_3] * 2, dim=1)
    golden_pattern_3 = torch.cat([golden_pattern_3] * 2)

    golden_pattern_4 = torch.cat(
        [golden_pattern_3[:-2],
         torch.zeros_like(golden_pattern_3[-2:])])
    golden_pattern_4 = torch.cat(
        [golden_pattern_4[:, :-2],
         torch.zeros_like(golden_pattern_4[:, -2:])],
        dim=1)

    golden_pattern_5 = torch.tensor([
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
    ],
                                    dtype=torch.bool)

    golden_pattern_6 = golden_pattern_5.clone()
    golden_pattern_6[max_len - 2, max_len - 2] = 0

    assert golden_pattern_1.equal(pattern[0])
    assert golden_pattern_2.equal(pattern[1])
    assert golden_pattern_3.equal(pattern[2])
    assert golden_pattern_4.equal(pattern[3])
    assert golden_pattern_5.equal(pattern[4])
    assert golden_pattern_6.equal(pattern[5])
Exemplo n.º 2
0
    def test_predict_prob(self):
        self.model.eval()
        test_predictor = GoldenSeq2seqPredictor(self.model,
                                                self.predictor.vocab_size,
                                                self.tokenizer)

        sentence, length = sample_case1()
        eps = 1e-5

        pad_mask = sentence[:, 1:] == 0
        golden_pattern = utils.sentence_pattern(sentence[:, 1:],
                                                pad_mask=pad_mask)

        golden_output, golden_scores = test_predictor.predict_prob(sentence,
                                                                   length,
                                                                   beam_size=1)

        with torch.no_grad():
            output, scores = self.predictor(sentence, length, beam_size=1)

        assert torch.norm(scores - golden_scores) < eps
        assert torch.equal(output, golden_output)
        assert torch.equal(
            utils.sentence_pattern(output[:, 0, :], pad_mask=pad_mask),
            golden_pattern)

        golden_output, golden_scores = test_predictor.predict_prob(sentence,
                                                                   length,
                                                                   beam_size=3)

        with torch.no_grad():
            output, scores = self.predictor(sentence, length, beam_size=3)

        assert torch.norm(scores - golden_scores, p=inf) < eps
        assert torch.equal(output, golden_output)
        for i in range(3):
            assert torch.equal(
                utils.sentence_pattern(output[:, i, :], pad_mask=pad_mask),
                golden_pattern)
Exemplo n.º 3
0
    def _beam_search(self, source, length, beam_size=1):
        # pad_mask: (batch_size, total_len)
        # sent_pattern: (batch_size, total_len, total_len)
        pad_mask = source.eq(self.pad)
        sub_pad_mask = source[:, 1:].eq(self.pad)
        sent_pattern = utils.sentence_pattern(source[:, 1:],
                                              pad_mask=sub_pad_mask)

        # encode source
        # context: (batch_size, max_len, 2 * enc_rnn_units)
        # state: (layers, batch_size, 2 * enc_rnn_units)
        context, state = self.encoder(source,
                                      length,
                                      mask=pad_mask.unsqueeze(1))

        batch_size, fix_len = source.size()

        output = []
        scores = []

        for i in range(batch_size):
            context_i = context[i]
            state_i = state[:, i, :]
            length_i = length[i].item()
            sent_pattern_i = sent_pattern[i]
            pad_mask_i = pad_mask[i]

            output_i, scores_i = self._beam_search_single(
                context=context_i,
                state=state_i,
                src_len=length_i,
                fix_len=fix_len,
                beam_size=beam_size,
                sent_pattern=sent_pattern_i,
                attention_mask=pad_mask_i,
            )

            output += [output_i]
            scores += [scores_i]

        output = torch.stack(output, dim=0)
        scores = torch.stack(scores, dim=0)

        return output, scores
Exemplo n.º 4
0
    def test_known_token(self):
        sentence = torch.tensor([
            [1, 2, 3, 4, 5, 0],
            [1, 2, 3, 3, 2, 0],
            [1, 2, 3, 2, 0, 0],
            [1, 1, 2, 2, 0, 0],
            [1, 2, 3, 0, 0, 0],
        ])

        length = torch.tensor([6, 6, 5, 5, 4])

        known_token = self.predictor._known_token(
            utils.sentence_pattern(sentence), length)
        golden_known_token = torch.tensor(
            [[0, 0, 0, 0, 0, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 0],
             [0, 1, 0, 1, 1, 0], [0, 0, 0, 1, 0, 0]],
            dtype=torch.bool)

        for i, l in enumerate(length):
            assert torch.equal(known_token[i, :l], golden_known_token[i, :l])
Exemplo n.º 5
0
    def forward(self,
                source: torch.Tensor,
                src_len: torch.Tensor,
                beam_size: int = 1,
                enforce_sorted: bool = False):
        assert beam_size > 0
        k = beam_size

        batch_size, fix_len = source.size()
        src_without_bos = source[:, 1:]
        fix_len_less_one = fix_len - 1
        tgt_len = src_len - 1

        # pad_mask: (batch_size, fix_len)
        pad_mask: Optional[torch.Tensor] = source.eq(self.pad_token_id) \
            if src_len.min().item() != fix_len else None
        sub_pad_mask: Optional[torch.Tensor] = src_without_bos.eq(self.pad_token_id) \
            if src_len.min().item() != fix_len else None

        # Initialize the scores; for the first step,
        # scores: (batch_size * k, 1)
        scores = torch.full([batch_size * k, 1],
                            fill_value=float('-inf')).to(source.device)
        scores.index_fill_(0,
                           torch.tensor([i * k for i in range(0, batch_size)]),
                           0.0)

        # output: (batch_size, k, fix_len-1)
        output = torch.full([batch_size, k, fix_len_less_one],
                            fill_value=self.pad_token_id).long().to(
                                source.device)
        output[torch.arange(batch_size).long(), :,
               tgt_len - 1] = torch.tensor(self.eos_token_id)

        # Initialize input variable of decoder
        # input_var: (batch_size * k, 1)
        input_var = torch.full([batch_size * k, 1],
                               self.bos_token_id).long().to(source.device)

        # ban_token_mask: (batch_size * k, vocab_size)
        ban_token_mask = self.gen_token_mask(
            batch_size, k, torch.tensor(self.special_token_ids))

        fh: Optional[torch.Tensor] = None
        cnn_mem: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        # (batch_size, fix_len-1, fix_len-1)
        sent_pattern = utils.sentence_pattern(src_without_bos,
                                              pad_mask=sub_pad_mask)

        # known_token: (batch_size, fix_len-1)
        # known: (batch_size, k, fix_len-1)
        known_token = self._known_token(sent_pattern,
                                        tgt_len).to(source.device)
        known = torch.stack([known_token] * k, dim=1)
        # multi: (batch_size, fix_len-1)
        multi = sent_pattern.sum(dim=2) > 1

        # encode source
        # context: (batch_size, fix_len, hidden_size)
        # state: (layers, batch_size, hidden_size)
        context, state = self.encode(source,
                                     src_len,
                                     mask=source.eq(
                                         self.pad_token_id, ).unsqueeze(1),
                                     enforce_sorted=enforce_sorted)

        context = torch.stack([context] * k, dim=1).flatten(0, 1)
        state = torch.stack([state] * k, dim=2).flatten(1, 2)

        # attention_mask: (batch_size * k, 1, fix_len)
        attention_mask: Optional[torch.Tensor] = torch.stack([pad_mask] * k, dim=1).flatten(0, 1).unsqueeze(1) \
            if pad_mask is not None else None

        end_indices_list: Optional[List[torch.Tensor]] = attention_mask.chunk(fix_len, dim=-1)[1:] \
            if attention_mask is not None else None

        for i in range(fix_len_less_one):
            log_prob, (state, fh, cnn_mem), attn_weights = self.decode(
                input_var,
                context,
                state,
                fh,
                attention_mask,
                cnn_mem,
                offset=i,
            )

            # update scores
            # scores: (batch_size * k, vocab_size)
            last_scores = scores
            scores = scores + log_prob.squeeze(1)

            # ban tokens
            if known_token[:, i].sum() == 0:
                token_mask = ban_token_mask
            else:
                token_mask = self._token_mask(ban_token_mask, known[:, :, i],
                                              output[:, :,
                                                     i]).to(source.device)
            scores.masked_fill_(token_mask, float('-inf'))

            # top-k
            # scores: (batch_size, k)
            # candidates: (batch_size, k)
            scores, candidates = scores.view(batch_size, -1).topk(k, dim=1)

            scores = scores.view(batch_size * k, 1)
            if end_indices_list is not None:
                scores = torch.where(
                    end_indices_list[i].view(batch_size * k, -1), last_scores,
                    scores)

            # compute rank indices
            # candidates are k * vocab_size + offset
            batch_indices = torch.arange(batch_size).view(batch_size, 1).long()
            k_indices = candidates / self.vocab_size
            if end_indices_list is not None:
                k_indices = torch.where(
                    end_indices_list[i].view(batch_size, k),
                    torch.arange(k).view(1, k).long(), k_indices)
            # combine_indices: (batch_size * k,)
            combine_indices = (batch_indices * k + k_indices).view(-1)

            # re-rank
            output = output[batch_indices, k_indices, :]
            state = state[:, combine_indices, :]
            fh = fh[combine_indices, :]
            # noinspection PyTypeChecker
            cnn_mem = (cnn_mem[0][combine_indices, :, :],
                       cnn_mem[1][combine_indices, :, :])
            ban_token_mask = ban_token_mask[combine_indices, :]

            # decode symbol; update output; update input_var;
            # update ban_token_mask
            # symbol: (batch_size, k)
            symbol = candidates % self.vocab_size
            if end_indices_list is not None:
                symbol = torch.where(end_indices_list[i].view(batch_size, k),
                                     torch.tensor(self.pad_token_id), symbol)
            input_var = symbol.view(batch_size * k, 1)
            # multi: (batch_size, max_len)
            if multi[:, i].sum() > 0:
                output = self._update_output(output, sent_pattern[:, i],
                                             symbol)
            else:
                output[:, :, i] = symbol
            ban_token_mask[torch.arange(batch_size * k).long(),
                           symbol.view(-1)] = torch.tensor(1, dtype=torch.bool)

        return output.view(batch_size, k,
                           fix_len_less_one), scores.view(batch_size, k)