def test_masked_softmax(): tensor = torch.FloatTensor([ [2, 3, 1, 4, 5], [4, 1, 6, 9, 10], [1, 5, 2, 4, 1], ]) mask = torch.tensor([ [1., 1., 1., 0., 0.], [1., 1., 0., 0., 0.], [1., 1., 1., 1., 1.] ]) result = f.masked_softmax(tensor, mask) assert result.argmax(dim=-1).equal(torch.LongTensor([1, 0, 1]))
def _key2x(self, S, x, x_mask): attention = f.masked_softmax(S, x_mask) # (B, C_L) key2x = f.weighted_sum(attention=attention, matrix=x) return key2x.unsqueeze(1).expand(x.size()) # (B, C_L, 2d)
def forward(self, features, labels=None): """ * Args: features: feature dictionary like below. {"feature_name1": { "token_name1": tensor, "toekn_name2": tensor}, "feature_name2": ...} * Kwargs: label: label dictionary like below. {"label_name1": tensor, "label_name2": tensor} Do not calculate loss when there is no label. (inference/predict mode) * Returns: output_dict (dict) consisting of - start_logits: representing unnormalized log probabilities of the span start position. - end_logits: representing unnormalized log probabilities of the span end position. - best_span: the string from the original passage that the model thinks is the best answer to the question. - answer_idx: the question id, mapping with answer - loss: A scalar loss to be optimised. """ context = features["context"] question = features["question"] # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward context_seq_config = f.get_sorted_seq_config(context) query_seq_config = f.get_sorted_seq_config(question) # Embedding Layer (Char + Word -> Contextual) query_params = {"frequent_word": {"frequent_tuning": True}} context_embed, query_embed = self.token_embedder( context, question, query_params=query_params, query_align=self.aligned_query_embedding) context_mask = f.get_mask_from_tokens(context).float() query_mask = f.get_mask_from_tokens(question).float() B, C_L = context_embed.size(0), context_embed.size(1) context_embed = self.context_highway(context_embed) query_embed = self.query_highway(query_embed) context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn, context_embed, context_seq_config) context_encoded = self.dropout(context_encoded) query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn, query_embed, query_seq_config) query_encoded = self.dropout(query_encoded) # Attention Flow Layer attention_context_query = self.attention(context_encoded, context_mask, query_encoded, query_mask) # Modeling Layer modeled_context = f.forward_rnn_with_pack(self.modeling_rnn, attention_context_query, context_seq_config) modeled_context = self.dropout(modeled_context) M_D = modeled_context.size(-1) # Output Layer span_start_input = self.dropout( torch.cat([attention_context_query, modeled_context], dim=-1)) # (B, C_L, 10d) span_start_logits = self.span_start_linear(span_start_input).squeeze( -1) # (B, C_L) span_start_probs = f.masked_softmax(span_start_logits, context_mask) span_start_representation = f.weighted_sum(attention=span_start_probs, matrix=modeled_context) tiled_span_start_representation = span_start_representation.unsqueeze( 1).expand(B, C_L, M_D) span_end_representation = torch.cat( [ attention_context_query, modeled_context, tiled_span_start_representation, modeled_context * tiled_span_start_representation, ], dim=-1, ) encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn, span_end_representation, context_seq_config) encoded_span_end = self.dropout(encoded_span_end) span_end_input = self.dropout( torch.cat([attention_context_query, encoded_span_end], dim=-1)) span_end_logits = self.span_end_linear(span_end_input).squeeze(-1) # Masked Value span_start_logits = f.add_masked_value(span_start_logits, context_mask, value=-1e7) span_end_logits = f.add_masked_value(span_end_logits, context_mask, value=-1e7) # No_Answer Bias bias = self.bias.expand(B, 1) span_start_logits = torch.cat([span_start_logits, bias], dim=-1) span_end_logits = torch.cat([span_end_logits, bias], dim=-1) output_dict = { "start_logits": span_start_logits, "end_logits": span_end_logits, "best_span": self.get_best_span( span_start_logits[:, :-1], span_end_logits[:, :-1], answer_maxlen=self.answer_maxlen, # except no_answer bias ), } if labels: answer_idx = labels["answer_idx"] answer_start_idx = labels["answer_start_idx"] answer_end_idx = labels["answer_end_idx"] answerable = labels["answerable"] # No_Asnwer Case C_L = context_mask.size(1) answer_start_idx = answer_start_idx.masked_fill( answerable.eq(0), C_L) answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L) output_dict["answer_idx"] = answer_idx # Loss loss = self.criterion(span_start_logits, answer_start_idx) loss += self.criterion(span_end_logits, answer_end_idx) output_dict["loss"] = loss.unsqueeze( 0) # NOTE: DataParallel concat Error return output_dict
def _query2context(self, S, c, c_mask): attention = f.masked_softmax(S, c_mask) # (B, C_L) q2c = f.weighted_sum(attention=attention, matrix=c) return q2c.unsqueeze(1).expand(c.size()) # (B, C_L, 2d)