def _passage_span_module(self, passage_out, passage_mask): # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( passage_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( passage_span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
def forward(self, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]: input, mask = self.get_input_and_mask(kwargs) # Shape: (batch_size, passage_length) start_logits = self._start_output_layer(input).squeeze(-1) # Shape: (batch_size, passage_length) end_logits = self._end_output_layer(input).squeeze(-1) start_log_probs = masked_log_softmax(start_logits, mask) end_log_probs = masked_log_softmax(end_logits, mask) # Info about the best span prediction start_logits = replace_masked_values(start_logits, mask, -1e7) end_logits = replace_masked_values(end_logits, mask, -1e7) # Shape: (batch_size, 2) best_span = get_best_span(start_logits, end_logits) output_dict = { 'start_log_probs': start_log_probs, 'end_log_probs': end_log_probs, 'best_span': best_span } return output_dict
def _question_span_module(self, passage_vector, question_out, question_mask): # Shape: (batch_size, question_length) encoded_question_for_span_prediction = \ torch.cat([question_out, passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1) question_span_start_logits = \ self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = \ self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax( question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = \ util.replace_masked_values(question_span_start_logits, question_mask, -1e7) question_span_end_logits = \ util.replace_masked_values(question_span_end_logits, question_mask, -1e7) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) return question_span_start_log_probs, question_span_end_log_probs, best_question_span
def test_masked_log_softmax_masked(self): # Tests replicated from test_softmax_masked - we test that exponentiated, # the log softmax contains the correct elements (masked elements should be == 1). # Testing the general masked 1D case. vector_1d = torch.FloatTensor([[1.0, 2.0, 5.0]]) mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0]]) vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy() assert_array_almost_equal(numpy.exp(vector_1d_softmaxed), numpy.array([[0.01798621, 0.0, 0.98201382]])) vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]]) mask_1d = torch.FloatTensor([[1.0, 0.0, 1.0, 1.0]]) vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy() assert_array_almost_equal(numpy.exp(vector_1d_softmaxed), numpy.array([[0.01321289, 0.0, 0.26538793, 0.72139918]])) # Testing the masked 1D case where the input is all 0s and the mask # is not all 0s. vector_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]]) mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 1.0]]) vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy() assert_array_almost_equal(numpy.exp(vector_1d_softmaxed), numpy.array([[0., 0., 0., 1.]])) # Testing the masked 1D case where the input is not all 0s # and the mask is all 0s. The output here will be arbitrary, but it should not be nan. vector_1d = torch.FloatTensor([[0.0, 2.0, 3.0, 4.0]]) mask_1d = torch.FloatTensor([[0.0, 0.0, 0.0, 0.0]]) vector_1d_softmaxed = util.masked_log_softmax(vector_1d, mask_1d).data.numpy() assert not numpy.isnan(vector_1d_softmaxed).any()
def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor, head_tag_temperature: Optional[float] = None, head_temperature: Optional[float] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: float_mask = mask.float() tag_mask = self._get_unknown_tag_mask(mask, head_tags) batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) if head_temperature: attended_arcs /= head_temperature normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) if head_tag_temperature: head_tag_logits /= head_tag_temperature normalised_head_tag_logits = masked_log_softmax( head_tag_logits, tag_mask.unsqueeze(-1)) * tag_mask.float().unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] tag_loss *= (head_tags > 1).float() # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() loss = arc_nll + tag_nll return loss, normalised_arc_logits, normalised_head_tag_logits
def forward(self, query: torch.FloatTensor, key: torch.FloatTensor, edge_head_mask: torch.ByteTensor = None, gold_edge_heads: torch.Tensor = None) -> Dict: """ :param query: [batch_size, query_length, query_vector_dim] :param key: [batch_size, key_length, key_vector_dim] :param edge_head_mask: [batch_size, query_length, key_length] 1 indicates a valid position; otherwise, 0. :param gold_edge_heads: None or [batch_size, query_length]. head indices start from 1. :return: edge_heads: [batch_size, query_length]. edge_types: [batch_size, query_length]. edge_head_ll: [batch_size, query_length, key_length + 1(sentinel)]. edge_type_ll: [batch_size, query_length, num_labels] (based on gold_edge_head) or None. """ if gold_edge_heads is not None: gold_edge_heads[gold_edge_heads == -1] = 0 key, edge_head_mask = self._add_sentinel(query, key, edge_head_mask) edge_head_query, edge_head_key, edge_type_query, edge_type_key = self._mlp( query, key) # [batch_size, query_length, key_length + 1] edge_head_score = self._get_edge_head_score(edge_head_query, edge_head_key) edge_heads, edge_types = self._greedy_search(edge_type_query, edge_type_key, edge_head_score, edge_head_mask) if gold_edge_heads is None: gold_edge_heads = edge_heads # [batch_size, query_length, num_labels] edge_type_score = self._get_edge_type_score(edge_type_query, edge_type_key, edge_heads) return dict( # Note: head indices start from 1. edge_heads=edge_heads, edge_types=edge_types, # Log-Likelihood. edge_head_ll=masked_log_softmax(edge_head_score, edge_head_mask, dim=2), edge_type_ll=masked_log_softmax(edge_type_score, None, dim=2), edge_head_query=edge_head_query, edge_head_key=edge_head_key, edge_type_query=edge_type_query, edge_type_key=edge_type_key)
def bidaf(curr_q: torch.Tensor, past_ctx_enc: torch.Tensor, past_ctx_emb: torch.Tensor, masks: torch.Tensor, att: MatrixAttention, men_scorer): """ :param q: batch of queries (B, M) :param c: batch of contexts (B, N) :masks c: batch of masks (B, N) :param att: attention layer :return: tuple of q_hat - context aware query representation entropy - entropy of attention scores """ # (B, N, M) q_c_att = att(curr_q, past_ctx_enc) # (B, M, N) if men_scorer is not None: men_scores = men_scorer(past_ctx_enc) # (B, N) men_scores_ex = men_scores.squeeze(-1).unsqueeze(1).expand_as(q_c_att) q_c_att = q_c_att + men_scores_ex else: men_scores = None sm_att = util.masked_softmax(q_c_att, masks) # (B, M, N) log_sm_att = util.masked_log_softmax(q_c_att, masks) # (B, M, N) entropy = torch.sum(-sm_att * log_sm_att, 2) # (B, M) # print(past_ctx_emb.size(), sm_att.size()) q_hat = util.weighted_sum(past_ctx_emb, sm_att) # (B, M) return q_hat, entropy, sm_att, men_scores
def forward(self, **kwargs) -> torch.FloatTensor: mask = kwargs['mask'] embedded_text = kwargs['embedded_text'] encoded_output = self._architecture(embedded_text, mask) encoded_repr = [] for aggregation in self._aggregations: if aggregation == "meanpool": broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_mean(context_vectors, broadcast_mask, dim=1, keepdim=False) elif aggregation == 'maxpool': broadcast_mask = mask.unsqueeze(-1).float() context_vectors = encoded_output * broadcast_mask encoded_text = masked_max(context_vectors, broadcast_mask, dim=1) elif aggregation == 'final_state': is_bi = self._architecture.is_bidirectional() encoded_text = get_final_encoder_states(encoded_output, mask, is_bi) elif aggregation == 'attention': alpha = self._attention_layer(encoded_output) alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp() encoded_text = alpha * encoded_output encoded_text = encoded_text.sum(dim=1) else: raise ConfigurationError(f"{aggregation} aggregation not available.") encoded_repr.append(encoded_text) encoded_repr = torch.cat(encoded_repr, 1) return encoded_repr
def forward(self, x, mask): """ x assumed to be logits""" # Softmax will convert logits to probabilities, this version is numerically stable b = util.masked_softmax(x, mask) * util.masked_log_softmax(x, mask) non_batch_dimensions = tuple(range(-len(b.shape) + 1, 0)) b = -1.0 * b.sum(dim=non_batch_dimensions) return b.mean()
def _get_next_state_info_without_agenda( state: NlvrDecoderState, considered_actions: List[List[int]], action_logits: torch.Tensor, action_mask: torch.Tensor ) -> List[List[Tuple[int, torch.LongTensor]]]: """ We return a list of log probabilities corresponding to actions that are not padding. This method is related to the training scenario where we have target action sequences for training. """ considered_action_logprobs = nn_util.masked_log_softmax( action_logits, action_mask) all_action_logprobs: List[List[Tuple[int, torch.LongTensor]]] = [] for group_index, (score, considered_logprobs) in enumerate( zip(state.score, considered_action_logprobs)): instance_action_logprobs: List[Tuple[int, torch.Tensor]] = [] for action_index, logprob in enumerate(considered_logprobs): # This is the actual index of the action from the original list of actions. action = considered_actions[group_index][action_index] if action == -1: # Ignoring padding. continue instance_action_logprobs.append( (action_index, score + logprob)) all_action_logprobs.append(instance_action_logprobs) return all_action_logprobs
def _new_entity_loss(self, encoded: torch.Tensor, target_inds: torch.Tensor, shortlist: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor: """ Parameters ========== target_inds : ``torch.Tensor`` Either the shortlist inds if using shortlist, otherwise the target entity ids. """ logits = self._new_entity_logits(encoded, shortlist) if self._use_shortlist: # Take masked softmax to get log probabilties and gather the targets. shortlist_mask = get_text_field_mask(shortlist) log_probs = masked_log_softmax(logits, shortlist_mask) else: logits = logits log_probs = F.log_softmax(logits, dim=-1) num_categories = log_probs.shape[-1] log_probs = log_probs.view(-1, num_categories) target_inds = target_inds.view(-1) target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1) mask = ~target_inds.eq(0) target_log_probs[~mask] = 0 if mask.any(): self._new_entity_accuracy(predictions=log_probs[mask], gold_labels=target_inds[mask]) self._new_entity_accuracy20(predictions=log_probs[mask], gold_labels=target_inds[mask]) return -target_log_probs.sum() / (target_mask.sum() + 1e-13)
def predict_labels_doc(self, output_dict): # Shape: (batch_size, num_spans_to_keep) coref_labels = output_dict["coref_labels"] coreference_scores = output_dict["coreference_scores"] _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict["predicted_antecedents"] = predicted_antecedents top_span_indices = output_dict["top_span_indices"] flat_top_span_indices = output_dict["flat_top_span_indices"] valid_antecedent_indices = output_dict["antecedent_indices"] valid_antecedent_log_mask = output_dict["valid_antecedent_log_mask"] top_spans = output_dict["top_spans"] top_span_mask = output_dict["top_span_mask"] metadata = output_dict["metadata"] sentence_lengths = output_dict["sentence_lengths"] if coref_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( coref_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) # There's an integer wrap-around happening here. It occurs in the original code. antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() # Need to get cluster data in same form as for original AllenNLP coref code so that the # evaluation code works. evaluation_metadata = self._make_evaluation_metadata( metadata, sentence_lengths) self._mention_recall(top_spans, evaluation_metadata) # TODO(dwadden) Shouldnt need to do the unsqueeze here; figure out what's happening. self._conll_coref_scores(top_spans, valid_antecedent_indices.unsqueeze(0), predicted_antecedents, evaluation_metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def forward( self, # type: ignore question_passage: Dict[str, torch.LongTensor] = None, option: Dict[str, torch.LongTensor] = None, answer: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, option_num, option_length, _ = option[ 'token_characters'].size() for k in option.keys(): if k != 'token_characters': option[k] = option[k].view(batch_size * option_num, -1) else: option[k] = option[k].view(batch_size * option_num, option_length, -1) # option['token_characters'] = option['token_characters'].view(batch_size*option_num, option_length, -1) question_bert = self._text_field_embedder(question_passage) option_bert = self._text_field_embedder(option) question_cls_rep = torch.index_select(question_bert, 1, torch.tensor([0])) option_cls_rep = torch.index_select(option_bert, 1, torch.tensor([0])) option_cls_rep = option_cls_rep.view(batch_size, option_num, -1) opt_logits = torch.sum(option_cls_rep * question_cls_rep, dim=-1) opt_predicts = torch.argmax(opt_logits, 1) predict_loss = nll_loss(util.masked_log_softmax(opt_logits, None), answer) loss = predict_loss self._loss_trackers['loss'](loss) output_dict = {"loss": loss} # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_text']) passage_tokens.append(metadata[i]['original_passage']) # ids.append(metadata[i]['_id']) # passage_str = metadata[i]['original_passage'] # offsets = metadata[i]['token_offsets'] answer_text = metadata[i].get("answer_text") ans = answer[i] predict_ans = opt_predicts[i] # print(predict_ans, ans) output_dict['answer_texts'].append(answer_text) self._categorical_acc(opt_logits[i], ans) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['_id'] = ids return output_dict
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None, gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN): # Precomputing of the max_context_length is important # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible. sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) joint_seq_logits = self.qa_outputs(sequence_output) context_logits, context_length = span_util.span_select( joint_seq_logits, context_span, max_context_length) context_mask = allen_util.get_mask_from_sequence_lengths( context_length, max_context_length) # The following line is from AllenNLP bidaf. start_logits = allen_util.replace_masked_values( context_logits[:, :, 0], context_mask, -1e18) # B, T, 2 end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18) if mode == BertSpan.ForwardMode.TRAIN: assert gt_span is not None gt_start = gt_span[:, 0] # gt_span: [B, 2] gt_end = gt_span[:, 1] start_loss = nll_loss( allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1)) end_loss = nll_loss( allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1)) loss = start_loss + end_loss return loss else: return start_logits, end_logits, context_length
def forward(self, x, y, x_mask, y_mask): z_s = y[:, -1, :].unsqueeze(1) # [B, 1, I] z_e = None s = None e = None p_s = None p_e = None for i in range(self.hop): z_s_ = z_s.repeat(1, x.size(1), 1) # [B, S, I] s = self.FFNs_start[i](torch.cat([x, z_s_, x * z_s_], 2)).squeeze(2) # s.data.masked_fill_(x_mask.data, -float('inf')) # p_s = F.softmax(s, dim=1) # [B, S] p_s = util.masked_softmax(s, x_mask, dim=1) u_s = p_s.unsqueeze(1).bmm(x) # [B, 1, I] z_e = self.SFUs_start[i](z_s, u_s) # [B, 1, I] z_e_ = z_e.repeat(1, x.size(1), 1) # [B, S, I] e = self.FFNs_end[i](torch.cat([x, z_e_, x * z_e_], 2)).squeeze(2) # e.data.masked_fill_(x_mask.data, -float('inf')) # p_e = F.softmax(e, dim=1) # [B, S] p_e = util.masked_softmax(e, x_mask, dim=1) u_e = p_e.unsqueeze(1).bmm(x) # [B, 1, I] z_s = self.SFUs_end[i](z_e, u_e) yesno = self.yesno_predictor(torch.cat([x, z_e_, x * z_e_], 2)) if self.normalize: # if self.training: # In training we output log-softmax for NLL # p_s = F.log_softmax(s, dim=1) # [B, S] p_s = util.masked_log_softmax(s, x_mask, dim=1) # p_e = F.log_softmax(e, dim=1) # [B, S] p_e = util.masked_log_softmax(e, x_mask, dim=1) p_yesno = F.log_softmax(yesno, dim=2) # else: # ...Otherwise 0-1 probabilities # p_s = F.softmax(s, dim=1) # [B, S] # p_e = F.softmax(e, dim=1) # [B, S] # p_yesno = F.softmax(yesno, dim=2) else: p_s = s.exp() p_e = e.exp() p_yesno = yesno.exp() return p_s, p_e, p_yesno
def loss(self, edge_scores: torch.Tensor, head_indices: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Computes the edge loss for a sequence given gold head indices and tags. Parameters ---------- edge_scores : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. """ float_mask = mask.float() batch_size, sequence_length, _ = edge_scores.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(edge_scores)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( edge_scores, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(edge_scores)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() if self.normalize_wrt_seq_len: arc_nll /= valid_positions.float() return arc_nll
def _parent_log_probs(self, encoded_head: torch.Tensor, entity_ids: torch.Tensor, parent_ids: torch.Tensor) -> torch.Tensor: # Lookup recent entities (which are candidates for parents) and get their embeddings. candidate_ids, candidate_mask = self._recent_entities(entity_ids) logger.debug('Candidate ids shape: %s', candidate_ids.shape) candidate_embeddings = embedded_dropout(self._entity_embedder, words=candidate_ids, dropout=self._dropoute if self.training else 0) # Logits are computed using a general bilinear form that measures the similarity between # the projected hidden state and the embeddings of candidate entities encoded = self._locked_dropout(encoded_head, self._dropout) selection_logits = torch.bmm( encoded, candidate_embeddings.transpose(1, 2)) # Get log probabilities using masked softmax (need to double check mask works properly). # shape: (batch_size, sequence_length, num_candidates) log_probs = masked_log_softmax(selection_logits, candidate_mask) # Now for the tricky part. We need to convert the parent ids to a mask that selects the # relevant probabilities from log_probs. To do this we need to align the candidates with # the parent ids, which can be achieved by an element-wise equality comparison. We also # need to ensure that null parents are not selected. # shape: (batch_size, sequence_length, num_parents, 1) _parent_ids = parent_ids.unsqueeze(-1) batch_size, num_candidates = candidate_ids.shape # shape: (batch_size, 1, 1, num_candidates) _candidate_ids = candidate_ids.view(batch_size, 1, 1, num_candidates) # shape: (batch_size, sequence_length, num_parents, num_candidates) is_parent = _parent_ids.eq(_candidate_ids) # shape: (batch_size, 1, 1, num_candidates) non_null = ~_candidate_ids.eq(0) # Since multiplication is addition in log-space, we can apply mask by adding its log (+ # some small constant for numerical stability). mask = is_parent & non_null masked_log_probs = log_probs.unsqueeze( 2) + (mask.float() + 1e-45).log() logger.debug('Masked log probs shape: %s', masked_log_probs.shape) # Lastly, we need to get rid of the num_candidates dimension. The easy way to do this would # be to marginalize it out. However, since our data is sparse (the last two dims are # essentially a delta function) this would add a lot of unneccesary terms to the computation graph. # To get around this we are going to try to use a gather. _, index = torch.max(mask, dim=-1, keepdim=True) target_log_probs = torch.gather( masked_log_probs, dim=-1, index=index).squeeze(-1) return target_log_probs
def score_spans_if_labels( self, output_dict, span_labels, metadata, top_span_indices, flat_top_span_indices, top_span_mask, top_spans, valid_antecedent_indices, valid_antecedent_log_mask, coreference_scores, predicted_antecedents, ): if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, input_ids, token_type_ids=None, attention_mask=None, gt_span=None, mode=ForwardMode.TRAIN): sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) joint_length = allen_util.get_lengths_from_binary_sequence_mask( attention_mask) joint_seq_logits = self.qa_outputs(sequence_output) # The following line is from AllenNLP bidaf. start_logits = allen_util.replace_masked_values( joint_seq_logits[:, :, 0], attention_mask, -1e18) # B, T, 2 end_logits = allen_util.replace_masked_values( joint_seq_logits[:, :, 1], attention_mask, -1e18) if mode == BertSpan.ForwardMode.TRAIN: assert gt_span is not None gt_start = gt_span[:, 0] # gt_span: [B, 2] -> [B] gt_end = gt_span[:, 1] start_loss = nll_loss( allen_util.masked_log_softmax(start_logits, attention_mask), gt_start) end_loss = nll_loss( allen_util.masked_log_softmax(end_logits, attention_mask), gt_end) # We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B]. # start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1)) # end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1)) loss = start_loss + end_loss return loss else: return start_logits, end_logits, joint_length
def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes) output_projections, state = self._prepare_output_projections( last_predictions, state) fill_value = -1e32 output_projections = self._adjust_target_outputs( output_projections, state["allowed_token_ids"], state["allowed_mask"], fill_value=fill_value) # shape: (group_size, num_classes) mask = (output_projections != fill_value).float() class_log_probabilities = util.masked_log_softmax(output_projections, mask, dim=-1) # class_log_probabilities = F.log_softmax(output_projections, dim=-1) # class_log_probabilities = class_log_probabilities * mask return class_log_probabilities, state
def _arithmetic_module(self, arithmetic_passage_vector, passage_out, number_indices, number_mask): if self.number_rep in ['average', 'attention']: # Shape: (batch_size, # of numbers, # of pieces) number_indices = util.replace_masked_values(number_indices, number_indices != -1, 0).long() batch_size = number_indices.shape[0] num_numbers = number_indices.shape[1] seqlen = passage_out.shape[1] # Shape : (batch_size, # of numbers, seqlen) mask = torch.zeros((batch_size, num_numbers, seqlen), device=number_indices.device).long().scatter( 2, number_indices, torch.ones(number_indices.shape, device=number_indices.device).long()) mask[:,:,0] = 0 # Shape : (batch_size, # of numbers, seqlen, bert_dim) epassage_out = passage_out.unsqueeze(1).repeat(1,num_numbers,1,1) # Shape : (batch_size, # of numbers, bert_dim) encoded_numbers = self.summary_vector(epassage_out, mask, "numbers") else: number_indices = number_indices[:,:,0].long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) encoded_numbers = torch.gather( passage_out, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, passage_out.size(-1))) if self.num_special_numbers > 0: special_numbers = self.special_embedding(torch.arange(self.num_special_numbers, device=number_indices.device)) special_numbers = special_numbers.expand(number_indices.shape[0],-1,-1) encoded_numbers = torch.cat([special_numbers, encoded_numbers], 1) mask = torch.ones((number_indices.shape[0],self.num_special_numbers), device=number_indices.device).long() number_mask = torch.cat([mask, number_mask], -1) # Shape: (batch_size, # of numbers, 2*bert_dim) encoded_numbers = torch.cat( [encoded_numbers, arithmetic_passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, #templates, #slots, #numbers) arithmetic_template_slot_logits = self._arithmetic_template_slot_predictor(encoded_numbers).transpose(1,2) arithmetic_template_slot_log_probs = util.masked_log_softmax(arithmetic_template_slot_logits, number_mask) arithmetic_template_slot_log_probs = arithmetic_template_slot_log_probs.reshape(number_mask.shape[0], self.num_arithmetic_templates, self.num_template_slots, number_mask.shape[-1]) # Shape: (batch_size, #templates, #slots) arithmetic_best_template_slots = arithmetic_template_slot_log_probs.argmax(-1) return arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask
def loss(self, edge_label_logits: torch.Tensor, mask: torch.Tensor, head_tags: torch.Tensor) -> torch.Tensor: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- edge_label_logits : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, num_head_tags), that contains raw predictions for incoming edge labels head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- tag_nll : ``torch.Tensor``, required. The negative log likelihood from the edge label loss. """ float_mask = mask.float() batch_size, sequence_length, _ = edge_label_logits.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(edge_label_logits)).unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) normalised_edge_label_logits = masked_log_softmax( edge_label_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(edge_label_logits)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) tag_loss = normalised_edge_label_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size if self.normalize_wrt_seq_len: return -tag_loss.sum() / valid_positions.float() else: return -tag_loss.sum()
def _construct_loss( self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: float_mask = mask.float() minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector( batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax( attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = torch.nn.functional.log_softmax( head_tag_logits, dim=-1) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand( batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward(self, frames: torch.FloatTensor, frame_lens: torch.LongTensor): """ frames: (batch_size, seq_len, num_lmks, lmk_dim) frame_lens: (batch_size, ) """ if self.frame_processing == 'flatten': frames = frames.reshape(frames.shape[0], frames.shape[1], -1) # Reverse sorts the batch by unpadded seq_len. (sorted_frames, sorted_frame_lens, restoration_indices, _) = sort_batch_by_length(frames, frame_lens) # Returns a PackedSequence. packed_frames = nn.utils.rnn.pack_padded_sequence( sorted_frames, sorted_frame_lens.data.cpu().numpy() if sorted_frame_lens.is_cuda else sorted_frame_lens.data.numpy(), batch_first=True) # Encoder: feed frames to the model, output hidden states. # final_state: (num_layers * num_dir, batch_size, hidden_size) (*2 if LSTM) packed_hidden_states, final_state = self.rnn(packed_frames) # Unpack encoding, the hidden states, a Tensor. # (batch_size, seq_len, num_dir * hidden_size) hidden_states, _ = nn.utils.rnn.pad_packed_sequence( packed_hidden_states, batch_first=True) # (num_layers, batch_size, hidden_size * num_dir) (*2 if LSTM) if self.bidirectional: final_state = self._cat_directions(final_state) hidden_states = hidden_states.index_select(0, restoration_indices) if isinstance(final_state, tuple): # LSTM final_state = (final_state[0].index_select(1, restoration_indices), final_state[1].index_select(1, restoration_indices)) else: final_state = final_state.index_select(1, restoration_indices) if self.enable_ctc: output_logits = self.output_proj(hidden_states) output_log_probs = masked_log_softmax(output_logits, self.output_mask.expand( output_logits.shape[0], self.adj_vocab_size), dim=-1) return output_log_probs, hidden_states, final_state else: return hidden_states, final_state
def _compute_loss(self, logits: torch.Tensor, mask: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ :param logits: :param mask: :param targets: :return: """ batch_size, seq_len = mask.shape normalised_emission = masked_log_softmax(logits, mask.unsqueeze(-1), dim=-1) loss = normalised_emission.gather(dim=-1, index=targets.unsqueeze(-1)) return -1 * loss.sum() / batch_size
def take_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Take a decoding step. This is called by the beam search class. Parameters ---------- last_predictions : ``torch.Tensor`` A tensor of shape ``(group_size,)``, which gives the indices of the predictions during the last time step. state : ``Dict[str, torch.Tensor]`` A dictionary of tensors that contain the current state information needed to predict the next step, which includes the encoder outputs, the source mask, and the decoder hidden state and context. Each of these tensors has shape ``(group_size, *)``, where ``*`` can be any other number of dimensions. Returns ------- Tuple[torch.Tensor, Dict[str, torch.Tensor]] A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` is a tensor of shape ``(group_size, num_classes)`` containing the predicted log probability of each class for the next step, for each item in the group, while ``updated_state`` is a dictionary of tensors containing the encoder outputs, source mask, and updated decoder hidden state and context. Notes ----- We treat the inputs as a batch, even though ``group_size`` is not necessarily equal to ``batch_size``, since the group may contain multiple states for each source sentence in the batch. """ # shape: (group_size, num_classes + max_input_sequence_length) output_projections, state = self._prepare_output_projections(last_predictions, state) source_mask = state['source_mask'] group_size = source_mask.size(0) # (batch_size, num_classes + max_input_sequence_length) normalization_mask = torch.cat([source_mask.new_ones((group_size, self._num_classes)), source_mask], dim=-1) # shape: (group_size, num_classes + max_input_sequence_length) class_log_probabilities = util.masked_log_softmax(output_projections, normalization_mask, dim=-1) return class_log_probabilities, state
def forward(self, input_ids, token_type_ids=None, attention_mask=None, answer_choice=None, sentence_span_list=None, sentence_ids=None): sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) # mask: 1 for masked value and 0 for true value # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask) doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \ layers.split_doc_sen_que(sequence_output, token_type_ids, attention_mask, sentence_span_list) batch, max_sen, doc_len = doc_sen_mask.size() que_vec = layers.weighted_avg(que, self.que_self_attn(que, que_mask)).view(batch, 1, -1) doc = doc_sen.reshape(batch * max_sen, doc_len, -1) doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len) doc_vecs = layers.weighted_avg(doc, self.doc_sen_self_attn(doc, doc_mask)).view(batch, max_sen, -1) sentence_sim = self.vector_similarity(que_vec, doc_vecs) sentence_hidden = masked_softmax(sentence_sim, 1 - sentence_mask).bmm(doc_vecs).squeeze(1) yesno_logits = self.yesno_predictor(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)) sentence_scores = masked_softmax(sentence_sim, 1 - sentence_mask, dim=-1).squeeze_(1) output_dict = {'yesno_logits': yesno_logits, 'sentence_logits': sentence_scores, 'max_weight_index': sentence_scores.max(dim=1)[1], 'max_weight': sentence_scores.max(dim=1)[0]} loss = 0 if answer_choice is not None: choice_loss = F.cross_entropy(yesno_logits, answer_choice, ignore_index=-1) loss += choice_loss if sentence_ids is not None: log_sentence_sim = masked_log_softmax(sentence_sim.squeeze(1), 1 - sentence_mask, dim=-1) sentence_loss = self.evidence_lam * F.nll_loss(log_sentence_sim, sentence_ids, ignore_index=-1) loss += sentence_loss output_dict['loss'] = loss return output_dict
def _get_next_state_info_without_agenda(state: NlvrDecoderState, considered_actions: List[List[int]], action_logits: torch.Tensor, action_mask: torch.Tensor) -> List[List[Tuple[int, torch.LongTensor]]]: """ We return a list of log probabilities corresponding to actions that are not padding. This method is related to the training scenario where we have target action sequences for training. """ considered_action_logprobs = nn_util.masked_log_softmax(action_logits, action_mask) all_action_logprobs: List[List[Tuple[int, torch.LongTensor]]] = [] for group_index, (score, considered_logprobs) in enumerate(zip(state.score, considered_action_logprobs)): instance_action_logprobs: List[Tuple[int, torch.Tensor]] = [] for action_index, logprob in enumerate(considered_logprobs): # This is the actual index of the action from the original list of actions. action = considered_actions[group_index][action_index] if action == -1: # Ignoring padding. continue instance_action_logprobs.append((action_index, score + logprob)) all_action_logprobs.append(instance_action_logprobs) return all_action_logprobs
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None # encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) # # v5: # # remember to set token embeddings in the CONFIG JSON encoded_question = self._dropout(embedded_question) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX similarity_matrix = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY passage_question_attention = util.last_dim_softmax( similarity_matrix, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Our custom query2context q2c_attention = util.masked_softmax(similarity_matrix, question_mask, dim=1).transpose(-1, -2) q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention) # Now we try the various variants # v1: # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention) # v2: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1])) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2) # v3: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1)) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # v4: # Re-application of query2context attention # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs) # masked_similarity = util.replace_masked_values(new_similarity_matrix, # question_mask.unsqueeze(1), # -1e7) # # Shape: (batch_size, passage_length) # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # # Shape: (batch_size, passage_length) # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # # Shape: (batch_size, encoding_dim) # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # # Shape: (batch_size, passage_length, encoding_dim) # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, # passage_length, # encoding_dim) # ------- Original variant # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( similarity_matrix, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # ------- END # Shape: (batch_size, passage_length, encoding_dim * 4) # original beta combination function final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # # v6: # final_merged_passage = torch.cat([tiled_question_passage_vector], # dim=-1) # # # v7: # final_merged_passage = torch.cat([passage_question_vectors], # dim=-1) # # # v8: # final_merged_passage = torch.cat([passage_question_vectors, # tiled_question_passage_vector], # dim=-1) # # # v9: # final_merged_passage = torch.cat([encoded_passage, # passage_question_vectors, # encoded_passage * passage_question_vectors], # dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
best_span = BidirectionalAttentionFlow_1.get_best_span(span_start_logits, span_end_logits) print ("best_spans", best_span) """ ------------------------------ GET LOSES AND ACCURACIES ----------------------------------- """ span_start_accuracy_function = CategoricalAccuracy() span_end_accuracy_function = CategoricalAccuracy() span_accuracy_function = BooleanAccuracy() squad_metrics_function = SquadEmAndF1() # Compute the loss for training. if span_start is not None: span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) loss = span_start_loss + span_end_loss span_start_accuracy_function(span_start_logits, span_start.squeeze(-1)) span_end_accuracy_function(span_end_logits, span_end.squeeze(-1)) span_accuracy_function(best_span, torch.stack([span_start, span_end], -1)) span_start_accuracy = span_start_accuracy_function.get_metric() span_end_accuracy = span_end_accuracy_function.get_metric() span_accuracy = span_accuracy_function.get_metric() print ("Loss: ", loss) print ("span_start_accuracy: ", span_start_accuracy) print ("span_start_accuracy: ", span_start_accuracy)
def _construct_loss(self, head_tag_representation: torch.Tensor, child_tag_representation: torch.Tensor, attended_arcs: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. Parameters ---------- head_tag_representation : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag_representation : ``torch.Tensor``, required A tensor of shape (batch_size, sequence_length, tag_representation_dim), which will be used to generate predictions for the dependency tags for the given arcs. attended_arcs : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachements of a given word to all other words. head_indices : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : ``torch.Tensor``, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : ``torch.Tensor``, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. Returns ------- arc_nll : ``torch.Tensor``, required. The negative log likelihood from the arc loss. tag_nll : ``torch.Tensor``, required. The negative log likelihood from the arc tag loss. """ float_mask = mask.float() batch_size, sequence_length, _ = attended_arcs.size() # shape (batch_size, 1) range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = masked_log_softmax(attended_arcs, mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) normalised_head_tag_logits = masked_log_softmax(head_tag_logits, mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def _take_first_step(self, state: WikiTablesDecoderState, allowed_actions: List[Set[int]] = None) -> List[WikiTablesDecoderState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = util.masked_log_softmax(start_action_logits, None) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. considered_actions, _, _ = self._get_actions_to_consider(state) if len(considered_actions[0]) != self._num_start_types: raise RuntimeError("Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}.") best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append((group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. batch_index = state.batch_indices[group_index] new_action_history = state.action_history[group_index] + [action] new_score = state.score[group_index] + sorted_log_probs[group_index, action_index] production_rule = state.possible_actions[batch_index][action][0] new_grammar_state = state.grammar_state[group_index].take_action(production_rule) new_checklist_state = [state.checklist_state[group_index]] if state.debug_info is not None: debug_info = { 'considered_actions': considered_actions[group_index], 'probabilities': probs_cpu[group_index], } new_debug_info = [state.debug_info[group_index] + [debug_info]] else: new_debug_info = None # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = WikiTablesDecoderState(batch_indices=[batch_index], action_history=[new_action_history], score=[new_score], rnn_state=[new_rnn_state], grammar_state=[new_grammar_state], action_embeddings=state.action_embeddings, output_action_embeddings=state.output_action_embeddings, action_biases=state.action_biases, action_indices=state.action_indices, possible_actions=state.possible_actions, flattened_linking_scores=state.flattened_linking_scores, actions_to_entities=state.actions_to_entities, entity_types=state.entity_types, world=state.world, example_lisp_string=state.example_lisp_string, checklist_state=new_checklist_state, debug_info=new_debug_info) new_states.append(new_state) return new_states
def take_step(self, state: WikiTablesDecoderState, max_actions: int = None, allowed_actions: List[Set[int]] = None) -> List[WikiTablesDecoderState]: if not state.action_history[0]: # The wikitables parser did something different when predicting the start type, which # is our first action. So in this case we break out into a different function. We'll # ignore max_actions on our first step, assuming there aren't that many start types. return self._take_first_step(state, allowed_actions) # Outline here: first we'll construct the input to the decoder, which is a concatenation of # an embedding of the decoder input (the last action taken) and an attention over the # question. Then we'll update our decoder's hidden state given this input, and recompute an # attention over the question given our new hidden state. We'll use a concatenation of the # new hidden state and the new attention, and optionally the checklist balance, to predict an # output, then yield new states. Each new state corresponds to one valid action that can be # taken from the current state, and they are ordered by their probability of being selected. attended_question = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state]) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state]) previous_action_embedding = torch.stack([rnn_state.previous_action_embedding for rnn_state in state.rnn_state]) # The scores from all prior state transitions until now. Shape: (group_size, 1). scores_so_far = torch.stack(state.score).unsqueeze(-1) # (group_size, decoder_input_dim) projected_input = self._input_projection_layer(torch.cat([attended_question, previous_action_embedding], -1)) decoder_input = torch.nn.functional.relu(projected_input) hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell)) hidden_state = self._dropout(hidden_state) # (group_size, encoder_output_dim) encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices]) encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices]) attended_question, attention_weights = self.attend_on_question(hidden_state, encoder_outputs, encoder_output_mask) action_query = torch.cat([hidden_state, attended_question], dim=-1) considered_actions, actions_to_embed, actions_to_link = self._get_actions_to_consider(state) # action_embeddings: (group_size, num_embedded_actions, action_embedding_dim) # output_action_embeddings: (group_size, num_embedded_actions, action_embedding_dim) # action_mask: (group_size, num_embedded_actions) action_embeddings, output_action_embeddings, action_biases, embedded_action_mask = \ self._get_action_embeddings(state, actions_to_embed) # (group_size, action_embedding_dim) projected_query = torch.nn.functional.relu(self._output_projection_layer(action_query)) predicted_action_embedding = self._dropout(projected_query) linked_balance = None if state.checklist_state[0] is not None: linked_balance, unlinked_balance = self._get_checklist_balance(state, self._unlinked_terminal_indices, actions_to_link) embedding_addition = self._get_predicted_embedding_addition(state, self._unlinked_terminal_indices, unlinked_balance) addition = embedding_addition * self._unlinked_checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # We'll do a batch dot product here with `bmm`. We want `dot(predicted_action_embedding, # action_embedding)` for each `action_embedding`, and we can get that efficiently with # `bmm` and some squeezing. # Shape: (group_size, num_embedded_actions) embedded_action_logits = action_embeddings.bmm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) embedded_action_logits = embedded_action_logits + action_biases.squeeze(-1) if actions_to_link: # entity_action_logits: (group_size, num_entity_actions) # entity_action_mask: (group_size, num_entity_actions) entity_action_logits, entity_action_mask, entity_type_embeddings = \ self._get_entity_action_logits(state, actions_to_link, attention_weights, linked_balance) # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use the # entity type instead. output_action_embeddings = torch.cat([output_action_embeddings, entity_type_embeddings], dim=1) if self._mixture_feedforward is not None: # The entity and action logits are combined with a mixture weight to prevent the # entity_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward(hidden_state) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = util.masked_log_softmax(entity_action_logits, entity_action_mask.float()) + mix1 embedded_action_probs = util.masked_log_softmax(embedded_action_logits, embedded_action_mask.float()) + mix2 current_log_probs = torch.cat([embedded_action_probs, entity_action_probs], dim=1) else: action_logits = torch.cat([embedded_action_logits, entity_action_logits], dim=1) action_mask = torch.cat([embedded_action_mask, entity_action_mask], dim=1).float() current_log_probs = util.masked_log_softmax(action_logits, action_mask) else: action_logits = embedded_action_logits action_mask = embedded_action_mask.float() current_log_probs = util.masked_log_softmax(action_logits, action_mask) # current_log_probs is shape (group_size, num_actions). We're broadcasting an addition # here with scores_so_far, which has shape (group_size, 1). This is now the total score # for each state after taking each action. We're going to sort by this in # `_compute_new_states`, so it's important that this is the total score, not just the score # for the current action. log_probs = scores_so_far + current_log_probs return self._compute_new_states(state, log_probs, hidden_state, memory_cell, output_action_embeddings, attended_question, attention_weights, considered_actions, allowed_actions, max_actions)
def forward( self, # type: ignore text: TextFieldTensors, spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters text : `TextFieldTensors`, required. The output of a `TextField` representing the text of the document. spans : `torch.IntTensor`, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of indices into the text of the document. span_labels : `torch.IntTensor`, optional (default = None). A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : `List[Dict[str, Any]]`, optional (default = None). A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance. # Returns An output dictionary consisting of: top_spans : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : `torch.IntTensor` A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) batch_size = spans.size(0) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings) ).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep ) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices ) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. if self._coarse_to_fine: pruned_antecedents = self._coarse_to_fine_pruning( top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents ) else: pruned_antecedents = self._distance_pruning( top_span_embeddings, top_span_mention_scores, max_antecedents ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, ) = pruned_antecedents flat_top_antecedent_indices = util.flatten_and_batch_shift_indices( top_antecedent_indices, num_spans_to_keep ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) for _ in range(self._inference_order - 1): dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,) top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) attention_weight = util.masked_softmax( coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size) top_antecedent_with_dummy_embeddings = torch.cat( [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2 ) # Shape: (batch_size, num_spans_to_keep, embedding_size) attended_embeddings = util.weighted_sum( top_antecedent_with_dummy_embeddings, attention_weight ) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = self._span_updating_gated_sum( top_span_embeddings, attended_embeddings ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": top_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) antecedent_labels = util.batched_index_select( pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices ).squeeze(-1) antecedent_labels = util.replace_masked_values( antecedent_labels, top_antecedent_mask, -100 ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels ) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask.unsqueeze(-1) ) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores( top_spans, top_antecedent_indices, predicted_antecedents, metadata ) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict