def forward(self, prems_indexes, prem_lens, hypos_indexes, hypo_lens): prems = self._embedding_layer(prems_indexes) prems = prems / (prems.norm(dim=-1, keepdim=True) + 1e-6) prems = self._linear_projection(prems) prems_mask = pwF.create_mask_from_length(prem_lens, prems.shape[1]) prems_att_vectors = self._att_mlp(prems) hypos = self._embedding_layer(hypos_indexes) hypos = hypos / (hypos.norm(dim=-1, keepdim=True) + 1e-6) hypos = self._linear_projection(hypos) hypos_mask = pwF.create_mask_from_length(hypo_lens, hypos.shape[1]) hypos_att_vectors = self._att_mlp(hypos) scores = torch.matmul(prems_att_vectors, hypos_att_vectors.transpose(1, 2)) scores = scores.masked_fill(prems_mask.unsqueeze(2) == 0, -1e9) scores = scores.masked_fill(hypos_mask.unsqueeze(1) == 0, -1e9) horizontal_softmaxed = F.softmax(scores, dim=2) vertical_softmaxed = F.softmax(scores, dim=1) hypos_attended = torch.matmul(horizontal_softmaxed, hypos) prems_hypos_attended = torch.cat([prems, hypos_attended], dim=-1) prems_hypos_attended_compared = self._comp_mlp(prems_hypos_attended) prems_hypos_attended_compared = prems_hypos_attended_compared.masked_fill( prems_mask.unsqueeze(2) == 0, 0) prems_hypos_attended_aggregated = torch.sum( prems_hypos_attended_compared, dim=1) prems_attended = torch.matmul(vertical_softmaxed.transpose(1, 2), prems) hypos_prems_attended = torch.cat([hypos, prems_attended], dim=-1) hypos_prems_attended_compared = self._comp_mlp(hypos_prems_attended) hypos_prems_attended_compared = hypos_prems_attended_compared.masked_fill( hypos_mask.unsqueeze(2) == 0, 0) hypos_prems_attended_aggregated = torch.sum( hypos_prems_attended_compared, dim=1) encodings = torch.cat( [prems_hypos_attended_aggregated, hypos_prems_attended_aggregated], dim=-1) return self._out_mlp(encodings)
def test_1D_end_padded(self): length_tensor = torch.tensor([1, 3]) mask_size = 5 is_end_padded = True result = pwF.create_mask_from_length(length_tensor, mask_size, is_end_padded) correct = torch.tensor([[1, 0, 0, 0, 0], [1, 1, 1, 0, 0]], dtype=torch.uint8) self.assertListEqual(result.tolist(), correct.tolist())
def test_2D_start_padded(self): length_tensor = torch.tensor([[1, 2], [3, 4]]) mask_size = 5 is_end_padded = False result = pwF.create_mask_from_length(length_tensor, mask_size, is_end_padded) correct = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 0, 1, 1]], [[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]]], dtype=torch.uint8) self.assertListEqual(result.tolist(), correct.tolist())
def forward(self, batched_char_words, batched_char_words_len, batched_char_word_index, batched_tokens, batched_tokens_len, target=None): char_tokens = self._char_embedding_layer(batched_char_words) token_encodings = self._char_token_encoder(char_tokens) token_encodings_z = torch.zeros((1, token_encodings.shape[1]), device=token_encodings.device) token_encodings = torch.cat([token_encodings_z, token_encodings], dim=0) token_encodings_indexed = torch.index_select( token_encodings, dim=0, index=batched_char_word_index.view(-1)) token_encodings_indexed = token_encodings_indexed.view( batched_char_word_index.shape[0], batched_char_word_index.shape[1], -1) texts = self._embedding_layer(batched_tokens) texts = torch.cat([texts, token_encodings_indexed], -1) texts = self._embedding_dp(texts) texts = pack_padded_sequence(texts, batched_tokens_len, batch_first=True, enforce_sorted=False) texts = self._rnn(texts)[0] texts = pad_packed_sequence(texts, batch_first=True)[0] texts = self._rnn_top_layer_dp(texts) mlp_out = self._out_mlp(texts) mask = pwF.create_mask_from_length(batched_tokens_len, mlp_out.shape[1]) if self.training: return -self._crf( mlp_out, target, mask=mask, reduction='token_mean') else: predictions = self._crf.decode(mlp_out, mask) predictions = torch.tensor(pad_to_max(predictions), dtype=torch.long).to(mlp_out.device) one_hot_pred = torch.eye(17).to(mlp_out.device)[[predictions]] return one_hot_pred
def collate_fn(batch): batch_zipped = list(zip(*batch)) input_zipped = list(zip(*batch_zipped[1])) ids = batch_zipped[0] batched_char_words = torch.tensor(pad_to_max( list(itertools.chain.from_iterable(input_zipped[0]))), dtype=torch.long) batched_char_words_len = torch.tensor(list( itertools.chain.from_iterable(input_zipped[1])), dtype=torch.int) nbs_accumulated = list( itertools.accumulate([1] + list(input_zipped[3]))) indices = [ list(range(nbs_accumulated[i], nbs_accumulated[i + 1])) for i in range(len(nbs_accumulated) - 1) ] batched_char_word_index = torch.tensor(pad_to_max(indices), dtype=torch.long) batched_tokens = torch.tensor(pad_to_max(input_zipped[2]), dtype=torch.long) batched_tokens_len = torch.tensor(input_zipped[3], dtype=torch.int) with torch.no_grad(): pred_mask = pwF.create_mask_from_length( batched_tokens_len, torch.max(batched_tokens_len).item()) target = torch.tensor(pad_to_max(batch_zipped[2], pad_value=-1), dtype=torch.long) return { 'id': ids, 'input': [ batched_char_words, batched_char_words_len, batched_char_word_index, batched_tokens, batched_tokens_len, target ], 'target': target, 'mask': pred_mask }
def forward(self, text, text_len): attention_mask = pwF.create_mask_from_length(text_len, text.shape[1]) return self._output_linear( self._dp( self._bert_model(text, attention_mask=attention_mask)[0][:, 0, :]))