def test_masked_max(self): # Testing the general masked 1D case. vector_1d = torch.FloatTensor([1.0, 12.0, 5.0]) mask_1d = torch.FloatTensor([1.0, 0.0, 1.0]) vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy() assert_array_almost_equal(vector_1d_maxed, 5.0) # Testing if all masks are zero, the output will be arbitrary, but it should not be nan. vector_1d = torch.FloatTensor([1.0, 12.0, 5.0]) mask_1d = torch.FloatTensor([0.0, 0.0, 0.0]) vector_1d_maxed = util.masked_max(vector_1d, mask_1d, dim=0).data.numpy() assert not numpy.isnan(vector_1d_maxed).any() # Testing batch value and batch masks matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]) matrix_maxed = util.masked_max(matrix, mask, dim=-1).data.numpy() assert_array_almost_equal(matrix_maxed, numpy.array([5.0, -1.0])) # Testing keepdim for batch value and batch masks matrix = torch.FloatTensor([[1.0, 12.0, 5.0], [-1.0, -2.0, 3.0]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]) matrix_maxed = util.masked_max(matrix, mask, dim=-1, keepdim=True).data.numpy() assert_array_almost_equal(matrix_maxed, numpy.array([[5.0], [-1.0]])) # Testing broadcast matrix = torch.FloatTensor([[[1.0, 2.0], [12.0, 3.0], [5.0, -1.0]], [[-1.0, -3.0], [-2.0, -0.5], [3.0, 8.0]]]) mask = torch.FloatTensor([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).unsqueeze(-1) matrix_maxed = util.masked_max(matrix, mask, dim=1).data.numpy() assert_array_almost_equal(matrix_maxed, numpy.array([[5.0, 2.0], [-1.0, -0.5]]))
def forward(self, **kwargs) -> torch.FloatTensor: mask = kwargs['mask'] embedded_text = kwargs['embedded_text'] encoded_output = self._architecture(embedded_text, mask) encoded_repr = [] for aggregation in self._aggregations: if aggregation == "meanpool": broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_mean(context_vectors, broadcast_mask, dim=1, keepdim=False) elif aggregation == 'maxpool': broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_max(context_vectors, broadcast_mask, dim=1) elif aggregation == 'final_state': is_bi = self._architecture.is_bidirectional() encoded_text = get_final_encoder_states(encoded_output, mask, is_bi) elif aggregation == 'attention': alpha = self._attention_layer(encoded_output) alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp() encoded_text = alpha * encoded_output encoded_text = encoded_text.sum(dim=1) else: raise ConfigurationError(f"{aggregation} aggregation not available.") encoded_repr.append(encoded_text) encoded_repr = torch.cat(encoded_repr, 1) return encoded_repr
def forward(self, document, query=None, rationale=None, metadata=None, label=None) -> Dict[str, Any]: input_ids = document["bert"] input_mask = (input_ids != 0).long() starting_offsets = document["bert-starting-offsets"] # (B, T) last_hidden_states, _ = self._bert_model( input_ids, attention_mask=input_mask, position_ids=document["bert-position-ids"]) token_embeddings, span_mask = generate_embeddings_for_pooling( last_hidden_states, starting_offsets, document["bert-ending-offsets"]) token_embeddings = util.masked_max(token_embeddings, span_mask.unsqueeze(-1), dim=2) token_embeddings = token_embeddings * document["mask"].unsqueeze(-1) logits = self._classification_layer(self._dropout(token_embeddings)) assert logits.shape[0:2] == starting_offsets.shape if self._use_crf: best_paths = self._crf.viterbi_tags(logits, mask=document["mask"]) best_paths = [b[0] for b in best_paths] best_paths = [ x + [0] * (logits.shape[1] - len(x)) for x in best_paths ] best_paths = torch.Tensor(best_paths).to( logits.device) * document["mask"] else: best_paths = (logits[:, :, 1] > 0.5).long() * document["mask"] output_dict = {} output_dict["predicted_rationales"] = best_paths output_dict["mask"] = document["mask"] output_dict["metadata"] = metadata if rationale is not None: if self._use_crf: output_dict["loss"] = -self._crf(logits, rationale, document["mask"]) else: output_dict["loss"] = ((F.cross_entropy( logits.view(-1, logits.shape[-1]), rationale.view(-1), reduction="none", weight=self._pos_weight, ) * document["mask"].view(-1)).sum(-1).mean()) best_paths = best_paths.unsqueeze(-1) best_paths = torch.cat([1 - best_paths, best_paths], dim=-1) self._token_prf(best_paths, rationale, document["mask"]) return output_dict
def forward(self, document, query=None, label=None, metadata=None, rationale=None, **kwargs) -> Dict[str, Any]: #pylint: disable=arguments-differ bert_document = self.combine_document_query(document, query) last_hidden_states, _ = self._bert_model( bert_document["bert"]["wordpiece-ids"], attention_mask=bert_document["bert"]["wordpiece-mask"], position_ids=bert_document["bert"]["position-ids"], token_type_ids=bert_document["bert"]["type-ids"], ) token_embeddings, span_mask = generate_embeddings_for_pooling( last_hidden_states, bert_document["bert"]['document-starting-offsets'], bert_document["bert"]['document-ending-offsets']) token_embeddings = util.masked_max(token_embeddings, span_mask.unsqueeze(-1), dim=2) token_embeddings = token_embeddings * bert_document['bert'][ "mask"].unsqueeze(-1) logits = self._classification_layer(self._dropout(token_embeddings)) probs = torch.sigmoid(logits)[:, :, 0] mask = bert_document['bert']['mask'] output_dict = {} output_dict["probs"] = probs * mask output_dict['mask'] = mask predicted_rationale = (probs > 0.5).long() output_dict["predicted_rationale"] = predicted_rationale * mask output_dict["prob_z"] = probs * mask if rationale is not None: rat_mask = (rationale.sum(1) > 0) if rat_mask.sum().long() == 0: output_dict['loss'] = 0.0 else: weight = torch.Tensor([1.0, self._pos_weight]).to(logits.device) loss = torch.nn.functional.cross_entropy( logits[rat_mask].transpose(1, 2), rationale[rat_mask], weight=weight) output_dict['loss'] = loss self._token_prf(logits[rat_mask], rationale[rat_mask], bert_document['bert']["mask"][rat_mask]) return output_dict
def _encode_definition( self, definition: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # [batch_size, seq_len] definition_mask = util.get_text_field_mask(definition) # [batch_size, seq_len, emb_dim] embedded_definition = self.text_embedder(definition) # either [batch_size, emb_dim] or [batch_size, seq_len, emb_dim] encoded_definition = self.definition_encoder(embedded_definition, definition_mask) # if len(encoded_definition.size()) == 3: if self.definition_pooling == 'last': # [batch_size, emb_dim] encoded_definition = util.get_final_encoder_states( encoded_definition, definition_mask) elif self.definition_pooling == 'max': # encoded_definition = F.adaptive_max_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2) encoded_definition = util.masked_max(encoded_definition, definition_mask.unsqueeze(2), dim=1) elif self.definition_pooling == 'mean': # encoded_definition = F.adaptive_avg_pool1d(encoded_definition.transpose(1, 2), 1).squeeze(2) encoded_definition = util.masked_mean(encoded_definition, definition_mask.unsqueeze(2), dim=1) elif self.definition_pooling == 'self-attentive': self_attentive_logits = self.self_attentive_pooling_projection( encoded_definition).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, definition_mask) encoded_definition = util.weighted_sum(encoded_definition, self_weights) # [batch_size, emb_dim] definition_embedding = self.definition_feedforward(encoded_definition) # [batch_size, vocab_size(num_class)] definition_logits = self.definition_decoder(definition_embedding) # [batch_size, seq_len, vocab_size] sequence_definition_logits = definition_logits.unsqueeze(1).repeat( 1, definition_mask.size(1), 1) # ``average`` can be None, "batch", or "token" # loss for ``average==None`` is a vector of shape (batch_size,); otherwise, a scalar targets = definition['tokens'].clone() if self.limited_word_vocab_size is not None: targets[targets >= self.limited_word_vocab_size] = self._oov_index cross_entropy_loss = util.sequence_cross_entropy_with_logits( sequence_definition_logits, targets, # definition['tokens'], weights=definition_mask, average='token') return { "definition_embedding": definition_embedding, "cross_entropy_loss": cross_entropy_loss }
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the text. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded = self._bert(tokens) first_token = embedded[:, 0, :] pooled_first = self._pooler(first_token) pooled_first = self._dropout(pooled_first) mask = tokens['mask'].float() encoded = self._encoder(embedded, mask) encoded = self._dropout(encoded) pooled_encoded = masked_max(encoded, mask.unsqueeze(-1), dim=1) concat = torch.cat([pooled_first, pooled_encoded], dim=-1) label_logits = self._classifier(concat) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits.view(-1, self._num_labels), label.view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def pool_graph(self, node_embs, node_emb_mask): """ Parameters: node_embs: (bsz, n_nodes, graph_dim) node_emb_mask: (bsz, n_nodes) Returns: (bsz, graph_dim (*2)) """ node_emb_mask = node_emb_mask.unsqueeze(-1) output = masked_max(node_embs, node_emb_mask, 1) output = torch.where(node_emb_mask.any(1), output, torch.zeros_like(output)) return output
def seq2vec_seq_aggregate(seq_tensor, mask, aggregate, bidirectional, dim=1): """ Takes the aggregation of sequence tensor :param seq_tensor: Batched sequence requires [batch, seq, hs] :param mask: binary mask with shape batch, seq_len, 1 :param aggregate: max, avg, sum :param dim: The dimension to take the max. for batch, seq, hs it is 1 :return: """ seq_tensor_masked = seq_tensor * mask.unsqueeze(-1) aggr_func = None if aggregate == "last": if seq_tensor.dim() > 3: seq = get_final_encoder_states_after_squashing(seq_tensor, mask, bidirectional) else: seq = get_final_encoder_states(seq_tensor, mask, bidirectional) elif aggregate == "max": seq = masked_max(seq_tensor, mask.unsqueeze(-1).expand_as(seq_tensor), dim=dim) elif aggregate == "min": seq = -masked_max(-seq_tensor, mask.unsqueeze(-1).expand_as(seq_tensor), dim=dim) elif aggregate == "sum": aggr_func = torch.sum seq = aggr_func(seq_tensor_masked, dim=dim) elif aggregate == "avg": aggr_func = torch.sum seq = aggr_func(seq_tensor_masked, dim=dim) seq_lens = torch.sum(mask, dim=dim) # this returns batch_size, .. 1 .. masked_seq_lens = replace_masked_values(seq_lens, (seq_lens != 0).float(), 1.0) masked_seq_lens = masked_seq_lens.unsqueeze(dim=dim).expand_as(seq) # print(seq.shape) # print(masked_seq_lens.shape) seq = seq / masked_seq_lens return seq
def pool(vector: torch.Tensor, mask: torch.Tensor, dim: int, pooling: str, is_bidirectional: bool) -> torch.Tensor: if pooling == "max": return masked_max(vector, mask, dim) elif pooling == "mean": return masked_mean(vector, mask, dim) elif pooling == "sum": return torch.sum(vector, dim) elif pooling == "final": return get_final_encoder_states(vector, mask, is_bidirectional) else: raise ValueError(f"'{pooling}' is not a valid pooling operation.")
def forward(self, document, query=None, label=None, metadata=None, rationale=None, **kwargs) -> Dict[str, Any]: #pylint: disable=arguments-differ bert_document = self.combine_document_query(document, query) last_hidden_states, _ = self._bert_model( bert_document["bert"]["wordpiece-ids"], attention_mask=bert_document["bert"]["wordpiece-mask"], position_ids=bert_document["bert"]["position-ids"], token_type_ids=bert_document["bert"]["type-ids"], ) token_embeddings, span_mask = generate_embeddings_for_pooling( last_hidden_states, bert_document["bert"]['document-starting-offsets'], bert_document["bert"]['document-ending-offsets']) token_embeddings = util.masked_max(token_embeddings, span_mask.unsqueeze(-1) == 1, dim=2) token_embeddings = token_embeddings * bert_document['bert'][ "mask"].unsqueeze(-1) logits = torch.nn.functional.softplus( self._classification_layer(self._dropout(token_embeddings))) a, b = logits[:, :, 0], logits[:, :, 1] mask = bert_document['bert']['mask'] output_dict = {} output_dict["a"] = a * mask output_dict["b"] = b * mask output_dict['mask'] = mask output_dict['wordpiece-to-token'] = bert_document['bert'][ 'wordpiece-to-token'] return output_dict
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() s1_layer_1_out = self._encoder1(embedded_premise, premise_mask) s2_layer_1_out = self._encoder1(embedded_hypothesis, hypothesis_mask) s1_layer_2_out = self._encoder2( torch.cat([embedded_premise, s1_layer_1_out], dim=2), premise_mask) s2_layer_2_out = self._encoder2( torch.cat([embedded_hypothesis, s2_layer_1_out], dim=2), hypothesis_mask) s1_layer_3_out = self._encoder3( torch.cat([embedded_premise, s1_layer_1_out, s1_layer_2_out], dim=2), premise_mask) s2_layer_3_out = self._encoder3( torch.cat([embedded_hypothesis, s2_layer_1_out, s2_layer_2_out], dim=2), hypothesis_mask) premise_max = masked_max(s1_layer_3_out, premise_mask.unsqueeze(-1)) hypothesis_max = masked_max(s2_layer_3_out, hypothesis_mask.unsqueeze(-1)) features = torch.cat([ premise_max, hypothesis_max, torch.abs(premise_max - hypothesis_max), premise_max * hypothesis_max ], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden output_hidden1 = self._output_feedforward1(features) output_hidden2 = self._output_feedforward2(output_hidden1) label_logits = self._output_logit(output_hidden2) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward(self, context_1: torch.Tensor, mask_1: torch.Tensor, context_2: torch.Tensor, mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # pylint: disable=arguments-differ """ Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral matching functions between them in one direction. Parameters ---------- context_1 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence. mask_1 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len1), indicating which positions in the first sentence are padding (0) and which are not (1). context_2 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence. mask_2 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len2), indicating which positions in the second sentence are padding (0) and which are not (1). Returns ------- A tuple of matching vectors for the two sentences. Each of which is a list of matching vectors of shape (batch, seq_len, num_perspectives or 1) """ assert (not mask_2.requires_grad) and (not mask_1.requires_grad) assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim # (batch,) len_1 = get_lengths_from_binary_sequence_mask(mask_1) len_2 = get_lengths_from_binary_sequence_mask(mask_2) # (batch, seq_len*) mask_1, mask_2 = mask_1.float(), mask_2.float() # explicitly set masked weights to zero # (batch_size, seq_len*, hidden_dim) context_1 = context_1 * mask_1.unsqueeze(-1) context_2 = context_2 * mask_2.unsqueeze(-1) # array to keep the matching vectors for the two sentences matching_vector_1: List[torch.Tensor] = [] matching_vector_2: List[torch.Tensor] = [] # Step 0. unweighted cosine # First calculate the cosine similarities between each forward # (or backward) contextual embedding and every forward (or backward) # contextual embedding of the other sentence. # (batch, seq_len1, seq_len2) cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3) # (batch, seq_len*, 1) cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) matching_vector_1.extend([cosine_max_1, cosine_mean_1]) matching_vector_2.extend([cosine_max_2, cosine_mean_2]) # Step 1. Full-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with the last time step of the forward (or backward) # contextual embedding of the other sentence if self.with_full_match: # (batch, 1, hidden_dim) if self.is_forward: # (batch, 1, hidden_dim) last_position_1 = (len_1 - 1).clamp(min=0) last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim) last_position_2 = (len_2 - 1).clamp(min=0) last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim) context_1_last = context_1.gather(1, last_position_1) context_2_last = context_2.gather(1, last_position_2) else: context_1_last = context_1[:, 0:1, :] context_2_last = context_2[:, 0:1, :] # (batch, seq_len*, num_perspectives) matching_vector_1_full = multi_perspective_match(context_1, context_2_last, self.full_match_weights) matching_vector_2_full = multi_perspective_match(context_2, context_1_last, self.full_match_weights_reversed) matching_vector_1.extend(matching_vector_1_full) matching_vector_2.extend(matching_vector_2_full) # Step 2. Maxpooling-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with every time step of the forward (or backward) # contextual embedding of the other sentence, and only the max value of each # dimension is retained. if self.with_maxpool_match: # (batch, seq_len1, seq_len2, num_perspectives) matching_vector_max = multi_perspective_match_pairwise(context_1, context_2, self.maxpool_match_weights) # (batch, seq_len*, num_perspectives) matching_vector_1_max = masked_max(matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1_mean = masked_mean(matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_max = masked_max(matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_mean = masked_mean(matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1.extend([matching_vector_1_max, matching_vector_1_mean]) matching_vector_2.extend([matching_vector_2_max, matching_vector_2_mean]) # Step 3. Attentive-Matching # Each forward (or backward) similarity is taken as the weight # of the forward (or backward) contextual embedding, and calculate an # attentive vector for the sentence by weighted summing all its # contextual embeddings. # Finally match each forward (or backward) contextual embedding # with its corresponding attentive vector. # (batch, seq_len1, seq_len2, hidden_dim) att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1) # (batch, seq_len1, seq_len2, hidden_dim) att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1) if self.with_attentive_match: # (batch, seq_len*, hidden_dim) att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1)) att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1)) # (batch, seq_len*, num_perspectives) matching_vector_1_att_mean = multi_perspective_match(context_1, att_mean_2, self.attentive_match_weights) matching_vector_2_att_mean = multi_perspective_match(context_2, att_mean_1, self.attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_mean) matching_vector_2.extend(matching_vector_2_att_mean) # Step 4. Max-Attentive-Matching # Pick the contextual embeddings with the highest cosine similarity as the attentive # vector, and match each forward (or backward) contextual embedding with its # corresponding attentive vector. if self.with_max_attentive_match: # (batch, seq_len*, hidden_dim) att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) # (batch, seq_len*, num_perspectives) matching_vector_1_att_max = multi_perspective_match(context_1, att_max_2, self.max_attentive_match_weights) matching_vector_2_att_max = multi_perspective_match(context_2, att_max_1, self.max_attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_max) matching_vector_2.extend(matching_vector_2_att_max) return matching_vector_1, matching_vector_2
def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], answer: torch.BoolTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector, ], dim=-1, ) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) prediction_bool_logits = util.masked_max(span_start_logits, passage_mask, dim=1) output_dict = { "passage_question_attention": passage_question_attention, "prediction_bool_logits": prediction_bool_logits } # Compute the loss for training. if answer is not None: loss = binary_cross_entropy_with_logits(prediction_bool_logits, answer) threshold = 0.5 prediction_bool_logits = torch.where( torch.sigmoid(prediction_bool_logits) > threshold, torch.ones_like(prediction_bool_logits), torch.zeros_like(prediction_bool_logits)) self._accuracy(prediction_bool_logits, answer) output_dict["loss"] = loss return output_dict
def esim_forward( # type: ignore self, encoded_premise, encoded_hypothesis, premise_mask, hypothesis_mask, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis, ], dim=-1, ) hypothesis_enhanced = torch.cat( [ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise, ], dim=-1, ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward(premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max = masked_max(v_ai, premise_mask.unsqueeze(-1), dim=1) v_b_max = masked_max(v_bi, hypothesis_mask.unsqueeze(-1), dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True ) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True ) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} return output_dict
def forward( self, context_1: torch.Tensor, mask_1: torch.Tensor, context_2: torch.Tensor, mask_2: torch.Tensor, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral matching functions between them in one direction. Parameters ---------- context_1 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence. mask_1 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len1), indicating which positions in the first sentence are padding (0) and which are not (1). context_2 : ``torch.Tensor`` Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence. mask_2 : ``torch.Tensor`` Binary Tensor of shape (batch_size, seq_len2), indicating which positions in the second sentence are padding (0) and which are not (1). Returns ------- A tuple of matching vectors for the two sentences. Each of which is a list of matching vectors of shape (batch, seq_len, num_perspectives or 1) """ assert (not mask_2.requires_grad) and (not mask_1.requires_grad) assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim # (batch,) len_1 = get_lengths_from_binary_sequence_mask(mask_1) len_2 = get_lengths_from_binary_sequence_mask(mask_2) # (batch, seq_len*) mask_1, mask_2 = mask_1.float(), mask_2.float() # explicitly set masked weights to zero # (batch_size, seq_len*, hidden_dim) context_1 = context_1 * mask_1.unsqueeze(-1) context_2 = context_2 * mask_2.unsqueeze(-1) # array to keep the matching vectors for the two sentences matching_vector_1: List[torch.Tensor] = [] matching_vector_2: List[torch.Tensor] = [] # Step 0. unweighted cosine # First calculate the cosine similarities between each forward # (or backward) contextual embedding and every forward (or backward) # contextual embedding of the other sentence. # (batch, seq_len1, seq_len2) cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3) # (batch, seq_len*, 1) cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True) cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True) matching_vector_1.extend([cosine_max_1, cosine_mean_1]) matching_vector_2.extend([cosine_max_2, cosine_mean_2]) # Step 1. Full-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with the last time step of the forward (or backward) # contextual embedding of the other sentence if self.with_full_match: # (batch, 1, hidden_dim) if self.is_forward: # (batch, 1, hidden_dim) last_position_1 = (len_1 - 1).clamp(min=0) last_position_1 = last_position_1.view(-1, 1, 1).expand( -1, 1, self.hidden_dim) last_position_2 = (len_2 - 1).clamp(min=0) last_position_2 = last_position_2.view(-1, 1, 1).expand( -1, 1, self.hidden_dim) context_1_last = context_1.gather(1, last_position_1) context_2_last = context_2.gather(1, last_position_2) else: context_1_last = context_1[:, 0:1, :] context_2_last = context_2[:, 0:1, :] # (batch, seq_len*, num_perspectives) matching_vector_1_full = multi_perspective_match( context_1, context_2_last, self.full_match_weights) matching_vector_2_full = multi_perspective_match( context_2, context_1_last, self.full_match_weights_reversed) matching_vector_1.extend(matching_vector_1_full) matching_vector_2.extend(matching_vector_2_full) # Step 2. Maxpooling-Matching # Each time step of forward (or backward) contextual embedding of one sentence # is compared with every time step of the forward (or backward) # contextual embedding of the other sentence, and only the max value of each # dimension is retained. if self.with_maxpool_match: # (batch, seq_len1, seq_len2, num_perspectives) matching_vector_max = multi_perspective_match_pairwise( context_1, context_2, self.maxpool_match_weights) # (batch, seq_len*, num_perspectives) matching_vector_1_max = masked_max( matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1_mean = masked_mean( matching_vector_max, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_max = masked_max( matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_2_mean = masked_mean( matching_vector_max.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) matching_vector_1.extend( [matching_vector_1_max, matching_vector_1_mean]) matching_vector_2.extend( [matching_vector_2_max, matching_vector_2_mean]) # Step 3. Attentive-Matching # Each forward (or backward) similarity is taken as the weight # of the forward (or backward) contextual embedding, and calculate an # attentive vector for the sentence by weighted summing all its # contextual embeddings. # Finally match each forward (or backward) contextual embedding # with its corresponding attentive vector. # (batch, seq_len1, seq_len2, hidden_dim) att_2 = context_2.unsqueeze(-3) * cosine_sim.unsqueeze(-1) # (batch, seq_len1, seq_len2, hidden_dim) att_1 = context_1.unsqueeze(-2) * cosine_sim.unsqueeze(-1) if self.with_attentive_match: # (batch, seq_len*, hidden_dim) att_mean_2 = masked_softmax(att_2.sum(dim=2), mask_1.unsqueeze(-1)) att_mean_1 = masked_softmax(att_1.sum(dim=1), mask_2.unsqueeze(-1)) # (batch, seq_len*, num_perspectives) matching_vector_1_att_mean = multi_perspective_match( context_1, att_mean_2, self.attentive_match_weights) matching_vector_2_att_mean = multi_perspective_match( context_2, att_mean_1, self.attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_mean) matching_vector_2.extend(matching_vector_2_att_mean) # Step 4. Max-Attentive-Matching # Pick the contextual embeddings with the highest cosine similarity as the attentive # vector, and match each forward (or backward) contextual embedding with its # corresponding attentive vector. if self.with_max_attentive_match: # (batch, seq_len*, hidden_dim) att_max_2 = masked_max(att_2, mask_2.unsqueeze(-2).unsqueeze(-1), dim=2) att_max_1 = masked_max(att_1.permute(0, 2, 1, 3), mask_1.unsqueeze(-2).unsqueeze(-1), dim=2) # (batch, seq_len*, num_perspectives) matching_vector_1_att_max = multi_perspective_match( context_1, att_max_2, self.max_attentive_match_weights) matching_vector_2_att_max = multi_perspective_match( context_2, att_max_1, self.max_attentive_match_weights_reversed) matching_vector_1.extend(matching_vector_1_att_max) matching_vector_2.extend(matching_vector_2_att_max) return matching_vector_1, matching_vector_2
def forward( # type: ignore self, sent1: TextFieldTensors, sent2: TextFieldTensors, label: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: with adv_utils.forward_context("sent1"): embedded_sent1 = self.word_embedders(sent1) with adv_utils.forward_context("sent2"): embedded_sent2 = self.word_embedders(sent2) sent1_mask = get_text_field_mask(sent1) sent2_mask = get_text_field_mask(sent2) # apply dropout for LSTM if self.rnn_input_dropout: embedded_sent1 = self.rnn_input_dropout(embedded_sent1) embedded_sent2 = self.rnn_input_dropout(embedded_sent2) # encode sent1 and sent2 encoded_sent1 = self._encoder(embedded_sent1, sent1_mask) encoded_sent2 = self._encoder(embedded_sent2, sent2_mask) # Shape: (batch_size, sent1_length, sent2_length) similarity_matrix = self._matrix_attention(encoded_sent1, encoded_sent2) # Shape: (batch_size, sent1_length, sent2_length) p2h_attention = masked_softmax(similarity_matrix, sent2_mask) # Shape: (batch_size, sent1_length, embedding_dim) attended_sent2 = weighted_sum(encoded_sent2, p2h_attention) # Shape: (batch_size, sent2_length, sent1_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), sent1_mask) # Shape: (batch_size, sent2_length, embedding_dim) attended_sent1 = weighted_sum(encoded_sent1, h2p_attention) # the "enhancement" layer sent1_enhanced = torch.cat( [ encoded_sent1, attended_sent2, encoded_sent1 - attended_sent2, encoded_sent1 * attended_sent2, ], dim=-1, ) sent2_enhanced = torch.cat( [ encoded_sent2, attended_sent1, encoded_sent2 - attended_sent1, encoded_sent2 * attended_sent1, ], dim=-1, ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_sent1 = self._projection_feedforward(sent1_enhanced) projected_enhanced_sent2 = self._projection_feedforward(sent2_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_sent1 = self.rnn_input_dropout( projected_enhanced_sent1) projected_enhanced_sent2 = self.rnn_input_dropout( projected_enhanced_sent2) v_ai = self._inference_encoder(projected_enhanced_sent1, sent1_mask) v_bi = self._inference_encoder(projected_enhanced_sent2, sent2_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max = masked_max(v_ai, sent1_mask.unsqueeze(-1), dim=1) v_b_max = masked_max(v_bi, sent2_mask.unsqueeze(-1), dim=1) v_a_avg = torch.sum(v_ai * sent1_mask.unsqueeze(-1), dim=1) / torch.sum(sent1_mask, 1, keepdim=True) v_b_avg = torch.sum(v_bi * sent2_mask.unsqueeze(-1), dim=1) / torch.sum(sent2_mask, 1, keepdim=True) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"logits": label_logits, "probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector( max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu( raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) masked_span_embeddings = span_embeddings * span_mask.unsqueeze(-1) batch_size, num_spans, max_batch_span_width, embedding_dim = masked_span_embeddings.size( ) # Shape: (batch_size*num_spans, embedding_dim, max_batch_span_width) masked_span_embeddings = masked_span_embeddings.view( batch_size * num_spans, max_batch_span_width, embedding_dim).transpose(1, 2) # Shape: (batch_size, embedding_dim, num_spans*max_batch_span_width) conv_span_embeddings = torch.nn.functional.relu( self._conv(masked_span_embeddings)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) conv_span_embeddings = conv_span_embeddings.transpose(1, 2).view( batch_size, num_spans, max_batch_span_width, embedding_dim) # Shape: (batch_size, num_spans, embedding_dim) span_embeddings = util.masked_max(conv_span_embeddings, span_mask.unsqueeze(-1), dim=2) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. span_width_embeddings = self._span_width_embedding( span_widths.squeeze(-1)) span_embeddings = torch.cat( [span_embeddings, span_width_embeddings], -1) return span_embeddings
def forward( # type: ignore self, premise: TextFieldTensors, hypothesis: TextFieldTensors, label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters premise : TextFieldTensors From a `TextField` hypothesis : TextFieldTensors From a `TextField` label : torch.IntTensor, optional (default = None) From a `LabelField` metadata : `List[Dict[str, Any]]`, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. # Returns An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape `(batch_size, num_labels)` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape `(batch_size, num_labels)` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise) hypothesis_mask = get_text_field_mask(hypothesis) # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [ encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis, ], dim=-1, ) hypothesis_enhanced = torch.cat( [ encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise, ], dim=-1, ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward( premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward( hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout( projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout( projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max = masked_max(v_ai, premise_mask.unsqueeze(-1), dim=1) v_b_max = masked_max(v_bi, hypothesis_mask.unsqueeze(-1), dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(premise_mask, 1, keepdim=True) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore passage: Dict[str, torch.LongTensor], all_qa: Dict[str, torch.LongTensor], candidate: Dict[str, torch.LongTensor], combined_source: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """""" if self._with_knowledge: embedded_passage = self._text_field_embedder(passage) # B * T * d passage_len = embedded_passage.size(1) embedded_all_qa = self._text_field_embedder(all_qa) # B * U * d embedded_choice = self._text_field_embedder(candidate) # B * V * d if self._with_knowledge: embedded_passage = self._variational_dropout( embedded_passage) # B * T * d embedded_all_qa = self._variational_dropout(embedded_all_qa) embedded_choice = self._variational_dropout( embedded_choice) # B * V * d all_qa_mask = util.get_text_field_mask(all_qa) # B * U choice_mask = util.get_text_field_mask(candidate) # B * V # Encoding if self._with_knowledge: # B * T * H passage_mask = util.get_text_field_mask(passage) # B * T encoded_passage = self._variational_dropout( self._pseqlevel_enc(embedded_passage, passage_mask)) # B * U * H if self._shared_rnn: encoded_allqa = self._variational_dropout( self._pseqlevel_enc(embedded_all_qa, all_qa_mask)) else: encoded_allqa = self._variational_dropout( self._qaseqlevel_enc(embedded_all_qa, all_qa_mask)) if self._with_knowledge and self._is_qdep_penc: # similarity matrix _, normalized_attn_mat = self._cart_attn(encoded_passage, encoded_allqa, all_qa_mask) # B * T * U # question dependent passage encoding q_aware_passage_rep = sequential_weighted_avg( encoded_allqa, normalized_attn_mat) # B * T * H q_dep_passage_enc_rnn_input = torch.cat( [encoded_passage, q_aware_passage_rep], 2) # B * T * 2H # gated question dependent passage encoding gated_qaware_passage_rep = self._gate_qdep_penc( q_dep_passage_enc_rnn_input) # B * T * 2H encoded_qdep_penc = self._qdep_penc_rnn(gated_qaware_passage_rep, passage_mask) # B * T * H # multi factor attentive encoding if self._with_knowledge and self._is_mfa_enc: if self._is_qdep_penc: mfa_enc = self._multifactor_attn(encoded_qdep_penc, passage_mask) # B * T * 2H else: mfa_enc = self._multifactor_attn(encoded_passage, passage_mask) # B * T * 2H encoded_passage = self._mfarnn(mfa_enc, passage_mask) # B * T * H # B * V * H if self._shared_rnn: encoded_choice = self._variational_dropout( self._pseqlevel_enc(embedded_choice, choice_mask)) # B * V * H else: encoded_choice = self._variational_dropout( self._cseqlevel_enc(embedded_choice, choice_mask)) # B * V * H if self._with_knowledge: attn_pq, _ = self._pqaattnmat(encoded_passage, encoded_allqa, all_qa_mask) # B * T * U combined_pqa_mask = passage_mask.unsqueeze(-1) * \ all_qa_mask.unsqueeze(1) # B * T * U max_attn_pqa = masked_max(attn_pq, combined_pqa_mask, dim=1) # B * U norm_attn_pqa = masked_softmax(max_attn_pqa, all_qa_mask, dim=-1) # B * U agg_prev_qa = norm_attn_pqa.unsqueeze(1).bmm( encoded_allqa).squeeze(1) # B * H attn_pc, _ = self._pcattnmat(encoded_passage, encoded_choice, choice_mask) # B * T * V combined_pc_mask = passage_mask.unsqueeze(-1) * \ choice_mask.unsqueeze(1) # B * T * V max_attn_pc = masked_max(attn_pc, combined_pc_mask, dim=1) # B * V norm_attn_pc = masked_softmax(max_attn_pc, choice_mask, dim=-1) # B * V agg_c = norm_attn_pc.unsqueeze(1).bmm(encoded_choice) # B * 1 * H choice_scores_wk = agg_c.bmm(agg_prev_qa.unsqueeze(-1)).squeeze( -1) # B * 1 if self._qac_ap: attn_qac, _ = self._cqaattnmat(encoded_allqa, encoded_choice, choice_mask) # B * U * V combined_qac_mask = all_qa_mask.unsqueeze(-1) * \ choice_mask.unsqueeze(1) # B * U * V max_attn_c = masked_max(attn_qac, combined_qac_mask, dim=1) # B * V max_attn_qa = masked_max(attn_qac, combined_qac_mask, dim=2) # B * U norm_attn_c = masked_softmax(max_attn_c, choice_mask, dim=-1) # B * V norm_attn_qa = masked_softmax(max_attn_qa, all_qa_mask, dim=-1) # B * U agg_c_qa = norm_attn_c.unsqueeze(1).bmm(encoded_choice).squeeze( 1) # B * H agg_qa_c = norm_attn_qa.unsqueeze(1).bmm(encoded_allqa).squeeze( 1) # B * H choice_scores_nk = agg_c_qa.unsqueeze(1).bmm( agg_qa_c.unsqueeze(-1)).squeeze(-1) # B * 1 if self._with_knowledge and self._qac_ap: choice_score = choice_scores_wk + choice_scores_nk elif self._qac_ap: choice_score = choice_scores_nk elif self._with_knowledge: choice_score = choice_scores_wk else: raise NotImplementedError output = torch.sigmoid(choice_score).squeeze(-1) # B output_dict = { "label_logits": choice_score.squeeze(-1), "label_probs": output, "metadata": metadata } if label is not None: label = label.long().view(-1) loss = self._loss(output, label.float()) self._auc(output, label) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore qa_pairs: Dict[str, torch.LongTensor], answer_index: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- qa_pairs : Dict[str, torch.LongTensor] From a ``ListField``. answer_index : ``torch.IntTensor``, optional From an ``IndexField``. This is what we are trying to predict. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, question and choices for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``qid``, ``question``, ``choices``, ``question_tokens`` and ``choices_tokens``. Returns ------- An output dictionary consisting of the followings. qid : List[str] A list consisting of question ids. answer_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_options=5)`` representing unnormalised log probabilities of the choices. answer_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_options=5)`` representing probabilities of the choices. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embeded = self._text_field_embedder(qa_pairs, num_wrapping_dims=1) mask = qa_pairs['mask'].float() batch_size, choice_size, seq_len, hidden_size = embeded.size() embeded = embeded.view(-1, seq_len, hidden_size) mask = mask.view(-1, seq_len) if self.dropout: embeded = self.dropout(embeded) if self._encoder: embeded = self._encoder(embeded, mask) embeded = embeded.view(batch_size, choice_size, seq_len, -1) mask = mask.view(batch_size, choice_size, seq_len) embeded = masked_max(embeded, mask.unsqueeze(-1), dim=2) # the final MLP -- apply dropout to input, and MLP applies to hidden answer_logits = self._output_logit(embeded).squeeze(-1) answer_probs = torch.nn.functional.softmax(answer_logits, dim=-1) qids = [m['qid'] for m in metadata] output_dict = { "answer_logits": answer_logits, "answer_probs": answer_probs, "qid": qids } if answer_index is not None: answer_index = answer_index.squeeze(-1) loss = self._loss(answer_logits, answer_index) self._accuracy(answer_logits, answer_index) output_dict["loss"] = loss return output_dict
def forward(self, span_embeddings, span_children, span_children_mask): batch, sequence, children_num, _ = span_children.size() # (batch, sequence, children_num) span_children = span_children.squeeze(-1) for t in range(self._tree_prop): flat_span_indices = util.flatten_and_batch_shift_indices(span_children, span_embeddings.size(1)) # (batch, sequence, children_num, span_emb_dim) children_span_embeddings = util.batched_index_select(span_embeddings, span_children, flat_span_indices) if self._tree_children == 'attention': # (batch, sequence, children_num) attention_scores = self._global_attention(children_span_embeddings).squeeze(-1) # (batch, sequence, children_num) attention_scores_softmax = util.masked_softmax(attention_scores, span_children_mask, dim=2) # attention_scores_softmax = self.antecedent_softmax(attention_scores) # debug feili # for dim1 in attention_scores_softmax: # for dim2 in dim1: # pass # (batch, sequence, span_emb_dim) children_span_embeddings_merged = util.weighted_sum(children_span_embeddings, attention_scores_softmax) elif self._tree_children == 'pooling': children_span_embeddings_merged = util.masked_max(children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2) elif self._tree_children == 'conv': masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1) masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1).transpose(1, 2) conv_children_span_embeddings = torch.nn.functional.relu(self._conv(masked_children_span_embeddings)) conv_children_span_embeddings = conv_children_span_embeddings.transpose(1, 2).view(batch, sequence, children_num, -1) children_span_embeddings_merged = util.masked_max(conv_children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2) elif self._tree_children == 'rnn': masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1) masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1) try : # if all spans don't have children in this batch, this code will report error rnn_children_span_embeddings = self._encoder(masked_children_span_embeddings, span_children_mask.view(batch * sequence, children_num)) except Exception as e: rnn_children_span_embeddings = masked_children_span_embeddings rnn_children_span_embeddings = rnn_children_span_embeddings.view(batch, sequence, children_num, -1) forward_sequence, backward_sequence = rnn_children_span_embeddings.split(int(self._span_emb_dim / 2), dim=-1) children_span_embeddings_merged = torch.cat([forward_sequence[:,:,-1,:], backward_sequence[:,:,0,:]], dim=-1) else: raise RuntimeError # for dim1 in children_span_embeddings_attentioned: # for dim2 in dim1: # pass # (batch, sequence, 2*span_emb_dim) f_network_input = torch.cat([span_embeddings, children_span_embeddings_merged], dim=-1) # (batch, sequence, span_emb_dim) f_weights = self._f_network(f_network_input) # for dim1 in f_weights: # for dim2 in dim1: # pass # (batch, sequence, 1), if f_weights_mask=1, this span has at least one child f_weights_mask, _ = span_children_mask.max(dim=-1, keepdim=True) # for dim1 in f_weights_mask: # for dim2 in dim1: # pass # (batch, sequence, span_emb_dim), let the element of f_weights becomes 1 where f_weights_mask==0 f_weights = util.replace_masked_values(f_weights, f_weights_mask, 1.0) # for dim1 in f_weights: # for dim2 in dim1: # pass # (batch, sequence, span_emb_dim) # for dim1 in span_embeddings: # for dim2 in dim1: # pass span_embeddings = f_weights * span_embeddings + (1.0 - f_weights) * children_span_embeddings_merged # for dim1 in combined_span_embeddings: # for dim2 in dim1: # pass span_embeddings = self._dropout(span_embeddings) return span_embeddings
def forward(self, document, query=None, label=None, metadata=None, rationale=None, **kwargs) -> Dict[str, Any]: # pylint: disable=arguments-differ bert_document = self.combine_document_query(document, query) last_hidden_states, _ = self._bert_model( bert_document["bert"]["wordpiece-ids"], attention_mask=bert_document["bert"]["wordpiece-mask"], position_ids=bert_document["bert"]["position-ids"], token_type_ids=bert_document["bert"]["type-ids"], ) token_embeddings, span_mask = generate_embeddings_for_pooling( last_hidden_states, bert_document["bert"]["document-starting-offsets"], bert_document["bert"]["document-ending-offsets"], ) token_embeddings = util.masked_max(token_embeddings, span_mask.unsqueeze(-1) == 1, dim=2) token_embeddings = token_embeddings * bert_document["bert"][ "mask"].unsqueeze(-1) logits = self._classification_layer(self._dropout(token_embeddings)) probs = torch.sigmoid(logits)[:, :, 0] mask = bert_document["bert"]["mask"] output_dict = {} output_dict["probs"] = probs * mask output_dict["mask"] = mask predicted_rationale = (probs > 0.5).long() output_dict["predicted_rationale"] = predicted_rationale * mask output_dict["prob_z"] = probs * mask if rationale is not None and self._supervise_rationale: rat_mask = rationale.sum(1) > 0 if rat_mask.sum().long() == 0: output_dict["loss"] = 0.0 else: rat_mask = rat_mask.bool() loss = torch.nn.functional.binary_cross_entropy_with_logits( logits[rat_mask].squeeze(-1), rationale[rat_mask], reduction="none", pos_weight=self._pos_weight.to(rationale.device), ) loss = ((loss * mask[rat_mask]).sum(-1) / mask[rat_mask].sum(-1)).mean() output_dict["loss"] = loss self._token_prf( torch.cat([ 1 - probs[rat_mask].unsqueeze(-1), probs[rat_mask].unsqueeze(-1) ], dim=-1), rationale[rat_mask].long(), mask[rat_mask] == 1, ) return output_dict
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata to persist Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens).float() encoder_output = self._encoder(embedded_text, mask) encoded_repr = [] for aggregation in self._aggregations: if aggregation == "meanpool": broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoder_output * broadcast_mask encoded_text = masked_mean(context_vectors, broadcast_mask, dim=1, keepdim=False) elif aggregation == 'maxpool': broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoder_output * broadcast_mask encoded_text = masked_max(context_vectors, broadcast_mask, dim=1) elif aggregation == 'final_state': is_bi = self._encoder.is_bidirectional() encoded_text = get_final_encoder_states( encoder_output, mask, is_bi) encoded_repr.append(encoded_text) encoded_repr = torch.cat(encoded_repr, 1) if self.dropout: encoded_repr = self.dropout(encoded_repr) output_hidden = self._output_feedforward(encoded_repr) label_logits = self._classification_layer(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict