예제 #1
0
    def forward(self, prems_indexes, prem_lens, hypos_indexes, hypo_lens):
        prems = self._embedding_layer(prems_indexes)
        prems = prems / (prems.norm(dim=-1, keepdim=True) + 1e-6)
        prems = self._linear_projection(prems)
        prems_mask = pwF.create_mask_from_length(prem_lens, prems.shape[1])
        prems_att_vectors = self._att_mlp(prems)

        hypos = self._embedding_layer(hypos_indexes)
        hypos = hypos / (hypos.norm(dim=-1, keepdim=True) + 1e-6)
        hypos = self._linear_projection(hypos)
        hypos_mask = pwF.create_mask_from_length(hypo_lens, hypos.shape[1])
        hypos_att_vectors = self._att_mlp(hypos)

        scores = torch.matmul(prems_att_vectors,
                              hypos_att_vectors.transpose(1, 2))
        scores = scores.masked_fill(prems_mask.unsqueeze(2) == 0, -1e9)
        scores = scores.masked_fill(hypos_mask.unsqueeze(1) == 0, -1e9)

        horizontal_softmaxed = F.softmax(scores, dim=2)
        vertical_softmaxed = F.softmax(scores, dim=1)

        hypos_attended = torch.matmul(horizontal_softmaxed, hypos)
        prems_hypos_attended = torch.cat([prems, hypos_attended], dim=-1)
        prems_hypos_attended_compared = self._comp_mlp(prems_hypos_attended)
        prems_hypos_attended_compared = prems_hypos_attended_compared.masked_fill(
            prems_mask.unsqueeze(2) == 0, 0)
        prems_hypos_attended_aggregated = torch.sum(
            prems_hypos_attended_compared, dim=1)

        prems_attended = torch.matmul(vertical_softmaxed.transpose(1, 2),
                                      prems)
        hypos_prems_attended = torch.cat([hypos, prems_attended], dim=-1)
        hypos_prems_attended_compared = self._comp_mlp(hypos_prems_attended)
        hypos_prems_attended_compared = hypos_prems_attended_compared.masked_fill(
            hypos_mask.unsqueeze(2) == 0, 0)
        hypos_prems_attended_aggregated = torch.sum(
            hypos_prems_attended_compared, dim=1)

        encodings = torch.cat(
            [prems_hypos_attended_aggregated, hypos_prems_attended_aggregated],
            dim=-1)

        return self._out_mlp(encodings)
예제 #2
0
    def test_1D_end_padded(self):
        length_tensor = torch.tensor([1, 3])
        mask_size = 5
        is_end_padded = True

        result = pwF.create_mask_from_length(length_tensor, mask_size,
                                             is_end_padded)
        correct = torch.tensor([[1, 0, 0, 0, 0], [1, 1, 1, 0, 0]],
                               dtype=torch.uint8)

        self.assertListEqual(result.tolist(), correct.tolist())
예제 #3
0
    def test_2D_start_padded(self):
        length_tensor = torch.tensor([[1, 2], [3, 4]])
        mask_size = 5
        is_end_padded = False

        result = pwF.create_mask_from_length(length_tensor, mask_size,
                                             is_end_padded)
        correct = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 0, 1, 1]],
                                [[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]]],
                               dtype=torch.uint8)

        self.assertListEqual(result.tolist(), correct.tolist())
예제 #4
0
파일: model.py 프로젝트: nlpaueb/greek-bert
    def forward(self,
                batched_char_words,
                batched_char_words_len,
                batched_char_word_index,
                batched_tokens,
                batched_tokens_len,
                target=None):

        char_tokens = self._char_embedding_layer(batched_char_words)
        token_encodings = self._char_token_encoder(char_tokens)

        token_encodings_z = torch.zeros((1, token_encodings.shape[1]),
                                        device=token_encodings.device)
        token_encodings = torch.cat([token_encodings_z, token_encodings],
                                    dim=0)

        token_encodings_indexed = torch.index_select(
            token_encodings, dim=0, index=batched_char_word_index.view(-1))
        token_encodings_indexed = token_encodings_indexed.view(
            batched_char_word_index.shape[0], batched_char_word_index.shape[1],
            -1)

        texts = self._embedding_layer(batched_tokens)
        texts = torch.cat([texts, token_encodings_indexed], -1)
        texts = self._embedding_dp(texts)

        texts = pack_padded_sequence(texts,
                                     batched_tokens_len,
                                     batch_first=True,
                                     enforce_sorted=False)
        texts = self._rnn(texts)[0]
        texts = pad_packed_sequence(texts, batch_first=True)[0]
        texts = self._rnn_top_layer_dp(texts)
        mlp_out = self._out_mlp(texts)

        mask = pwF.create_mask_from_length(batched_tokens_len,
                                           mlp_out.shape[1])

        if self.training:
            return -self._crf(
                mlp_out, target, mask=mask, reduction='token_mean')
        else:
            predictions = self._crf.decode(mlp_out, mask)
            predictions = torch.tensor(pad_to_max(predictions),
                                       dtype=torch.long).to(mlp_out.device)
            one_hot_pred = torch.eye(17).to(mlp_out.device)[[predictions]]
            return one_hot_pred
예제 #5
0
    def collate_fn(batch):
        batch_zipped = list(zip(*batch))
        input_zipped = list(zip(*batch_zipped[1]))

        ids = batch_zipped[0]

        batched_char_words = torch.tensor(pad_to_max(
            list(itertools.chain.from_iterable(input_zipped[0]))),
                                          dtype=torch.long)
        batched_char_words_len = torch.tensor(list(
            itertools.chain.from_iterable(input_zipped[1])),
                                              dtype=torch.int)

        nbs_accumulated = list(
            itertools.accumulate([1] + list(input_zipped[3])))
        indices = [
            list(range(nbs_accumulated[i], nbs_accumulated[i + 1]))
            for i in range(len(nbs_accumulated) - 1)
        ]
        batched_char_word_index = torch.tensor(pad_to_max(indices),
                                               dtype=torch.long)

        batched_tokens = torch.tensor(pad_to_max(input_zipped[2]),
                                      dtype=torch.long)
        batched_tokens_len = torch.tensor(input_zipped[3], dtype=torch.int)

        with torch.no_grad():
            pred_mask = pwF.create_mask_from_length(
                batched_tokens_len,
                torch.max(batched_tokens_len).item())

        target = torch.tensor(pad_to_max(batch_zipped[2], pad_value=-1),
                              dtype=torch.long)

        return {
            'id':
            ids,
            'input': [
                batched_char_words, batched_char_words_len,
                batched_char_word_index, batched_tokens, batched_tokens_len,
                target
            ],
            'target':
            target,
            'mask':
            pred_mask
        }
예제 #6
0
 def forward(self, text, text_len):
     attention_mask = pwF.create_mask_from_length(text_len, text.shape[1])
     return self._output_linear(
         self._dp(
             self._bert_model(text,
                              attention_mask=attention_mask)[0][:, 0, :]))