def forward(self, input_sequences: torch.IntTensor, sequence_lengths: torch.IntTensor) -> torch.FloatTensor: """ Feed the network. :param input_sequences: The input sequences. :type input_sequences: torch.IntTensor[seq_len, batch_size]. :param sequence_lengths: The lengths of each input sequence. :type sequence_lengths: torch.IntTensor[batch_size,]. :raise ValueError: if the batch_size of input_sequences and sequence_lengths doesn't match. :return: torch.FloatTensor[n_layers * 2, batch_size, hidden_size] """ if input_sequences.size(-1) != sequence_lengths.size(-1): raise ValueError( "The batch_size of input_sequences and sequence_lengths doesn't match." ) embedded = self.embeddings( input_sequences) # [seq_len, batch_size, embedding_size] embedded = self.embedding_dropout( embedded) # [seq_len, batch_size, embedding_size] packed = pack_padded_sequence(embedded, sequence_lengths, enforce_sorted=False) _, h_state = self.gru(packed) # [seq_len, batch_size, hidden_size * 2] return h_state
def __call__(self, inputs: torch.IntTensor) -> torch.IntTensor: patch = torch.IntTensor([ inputs[i] for i in range(inputs.size()[0] - 1) if inputs[i] != inputs[i + 1] and inputs[i] != 0 ] + [inputs[inputs.size()[0] - 1]]) reduced = torch.IntTensor([ patch[i] if i < patch.size()[0] else 0 for i in range(self.target_size) ]) return reduced
def mask_loc_logits(self, loc_logits, num_cands: torch.IntTensor): """ Mask the padded candidates with an -inf score, so they will have a likelihood = 0 after softmax Args: loc_logits - output scores for each candidate in each sentence, size (batch, max_sents, max_cands) num_cands - total number of candidates in each instance of the given batch, size (batch,) """ assert torch.max(num_cands) == loc_logits.size(-1) assert loc_logits.size(0) == num_cands.size(0) batch_size = loc_logits.size(0) max_cands = loc_logits.size(-1) # first, we create a mask tensor that masked all positions above the num_cands limit range_tensor = torch.arange(start=1, end=max_cands + 1) if self.use_cuda: range_tensor = range_tensor.cuda() range_tensor = range_tensor.unsqueeze(dim=0).expand( batch_size, max_cands) bool_range = torch.gt( range_tensor, num_cands.unsqueeze(dim=-1)) # find the off-limit positions assert bool_range.size() == (batch_size, max_cands) bool_range = bool_range.unsqueeze(dim=-2).expand_as( loc_logits) # use this bool tensor to mask loc_logits masked_loc_logits = loc_logits.masked_fill( bool_range, value=float('-inf')) # mask padded positions to -inf assert masked_loc_logits.size() == loc_logits.size() return masked_loc_logits
def get_pred_loc(loc_logits: torch.Tensor, gold_loc_seq: torch.IntTensor) -> List[List[int]]: """ Get the predicted location sequence from raw logits. Note that loc_logits should be MASKED while gold_loc_seq should NOT. Args: loc_logits - raw logits, with padding elements set to -inf (masked). (batch, max_sents, max_cands) gold_loc_seq - gold location sequence without masking. (batch, max_sents) """ assert gold_loc_seq.size() == (loc_logits.size(0), loc_logits.size(1)) argmax_loc = torch.argmax(loc_logits, dim = -1) assert argmax_loc.size() == gold_loc_seq.size() argmax_loc = argmax_loc.masked_fill(mask = (gold_loc_seq == PAD_LOC), value = PAD_LOC).tolist() pred_loc = [] for inst in argmax_loc: pred_loc.append([x for x in inst if x != PAD_LOC]) return pred_loc
def decompress( self, strings: str, indexes: torch.IntTensor, dtype: torch.dtype = torch.float, means: torch.Tensor = None, ): """ Decompress char strings to tensors. Args: strings (str): compressed tensors indexes (torch.IntTensor): tensors CDF indexes dtype (torch.dtype): type of dequantized output means (torch.Tensor, optional): optional tensor means """ if not isinstance(strings, (tuple, list)): raise ValueError("Invalid `strings` parameter type.") if not len(strings) == indexes.size(0): raise ValueError("Invalid strings or indexes parameters") if len(indexes.size()) < 2: raise ValueError( "Invalid `indexes` size. Expected a tensor with at least 2 dimensions." ) self._check_cdf_size() self._check_cdf_length() self._check_offsets_size() if means is not None: if means.size()[:2] != indexes.size()[:2]: raise ValueError("Invalid means or indexes parameters") if means.size() != indexes.size(): for i in range(2, len(indexes.size())): if means.size(i) != 1: raise ValueError("Invalid means parameters") cdf = self._quantized_cdf outputs = cdf.new_empty(indexes.size()) for i, s in enumerate(strings): values = self.entropy_coder.decode_with_indexes( s, indexes[i].reshape(-1).int().tolist(), cdf.tolist(), self._cdf_length.reshape(-1).int().tolist(), self._offset.reshape(-1).int().tolist(), ) outputs[i] = torch.tensor(values, device=outputs.device, dtype=outputs.dtype).reshape( outputs[i].size()) outputs = self.dequantize(outputs, means, dtype) return outputs
def compute_loc_accuracy(logits: torch.FloatTensor, gold: torch.IntTensor, pad_value: int) -> (int, int): """ Given the generated location logits and the gold location sequence, compute the location prediction accuracy. Args: logits - size (batch, max_sents, max_cands) gold - size (batch, max_sents) pad_value - elements with this value will not count in accuracy """ pred = torch.argmax(logits, dim = -1) assert pred.size() == gold.size() total_pred = torch.sum(gold != pad_value) # total number of valid elements correct_pred = torch.sum(pred == gold) # the model cannot predict PAD, NIL or UNK, so all padded positions should be false return correct_pred.item(), total_pred.item()
def nd_batched_index_select(target: torch.Tensor, indices: torch.IntTensor) -> torch.Tensor: """ Multidimensional version of `util.batched_index_select`. """ batch_axes = target.size()[:-2] num_batch_axes = len(batch_axes) target_shape = target.size() indices_shape = indices.size() target_reshaped = target.view(-1, *target_shape[num_batch_axes:]) indices_reshaped = indices.view(-1, *indices_shape[num_batch_axes:]) output_reshaped = util.batched_index_select(target_reshaped, indices_reshaped) return output_reshaped.view(*indices_shape, -1)
def memorize(self, states: torch.tensor, actions: torch.IntTensor, next_states: torch.tensor, rewards: torch.tensor): """ Memorizes a batch of exploration transitions (quadruples s, a, ns, r). :param states: Successive states encountered. Should have shape (number_of_states, state_dim + 1) where the last column values are either 1 if the correspond state is final or 0 otherwise. :param actions: Successive actions decided by the agent. Should be a tensor of shape (number_of_states) :param next_states: (number_of_states, state_dim) shaped tensor indicating the next states. :param rewards: (number_of_states, )-sized 1D tensor containing the rewards for the episode. """ if len(states.size()) + len(actions.size()) + len(next_states.size()) != 5: raise ValueError("Wrong dimensions") return None # Make sure the tensors are on the right device states.to(self.device) next_states.to(self.device) actions.to(self.device) rewards.to(self.device) if self.need_init: self.state_mem = states self.action_mem = actions.type(torch.int64) self.nstate_mem = next_states self.reward_mem = rewards self.need_init = False else: self.state_mem = torch.cat((self.state_mem, states), dim=0) self.action_mem = torch.cat((self.action_mem, actions.type(torch.int64))) self.nstate_mem = torch.cat((self.nstate_mem, next_states), dim=0) nb_states_added = states.size()[0] self.reward_mem = torch.cat((self.reward_mem, rewards))
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * ( 1 - sent_labels_mask) # word + char embedding embedded_question = self._text_field_embedder(question) embedded_passage = self._text_field_embedder(passage) # mask ques_mask = util.get_text_field_mask(question).float() context_mask = util.get_text_field_mask(passage).float() # BiDAF for answer predicion ques_output = self._dropout( self._phrase_layer(embedded_question, ques_mask)) context_output = self._dropout( self._phrase_layer(embedded_passage, context_mask)) modeled_passage, _, qc_score = self.qc_att(context_output, ques_output, ques_mask) modeled_passage = self._modeling_layer(modeled_passage, context_mask) # BiDAF for gate prediction ques_output_sp = self._dropout( self._phrase_layer_sp(embedded_question, ques_mask)) context_output_sp = self._dropout( self._phrase_layer_sp(embedded_passage, context_mask)) modeled_passage_sp, _, qc_score_sp = self.qc_att_sp( context_output_sp, ques_output_sp, ques_mask) modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp, context_mask) # gate prediction # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim) # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width) spans_rep_sp, spans_mask = convert_sequence_to_spans( modeled_passage_sp, sentence_spans) spans_rep, _ = convert_sequence_to_spans(modeled_passage, sentence_spans) # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate( spans_rep_sp, spans_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = spans_mask.size() strong_sup_loss = F.nll_loss( F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() spans_rep = spans_rep * gate.unsqueeze(-1).float() attended_sent_embeddings = convert_span_to_sequence( modeled_passage_sp, spans_rep, spans_mask) modeled_passage = attended_sent_embeddings + modeled_passage self_att_passage = self._self_attention_layer(modeled_passage, mask=context_mask) modeled_passage = modeled_passage + self_att_passage[0] self_att_score = self_att_passage[2] output_start = self._span_start_encoder(modeled_passage, context_mask) span_start_logits = self.linear_start(output_start).squeeze( 2) - 1e30 * (1 - context_mask) output_end = torch.cat([modeled_passage, output_start], dim=2) output_end = self._span_end_encoder(output_end, context_mask) span_end_logits = self.linear_end(output_end).squeeze( 2) - 1e30 * (1 - context_mask) output_type = torch.cat([modeled_passage, output_end, output_start], dim=2) output_type = torch.max(output_type, 1)[0] # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0] predict_type = self.linear_type(output_type) type_predicts = torch.argmax(predict_type, 1) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span": best_span, "pred_sent_labels": gate.view(batch_size, num_spans), #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not qc_score is None: output_dict['qc_score'] = qc_score if not qc_score_sp is None: output_dict['qc_score_sp'] = qc_score_sp if not self_att_score is None: output_dict['self_attention_score'] = self_att_score if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the loss for training. if span_start is not None: try: start_loss = nll_loss( util.masked_log_softmax(span_start_logits, None), span_start.squeeze(-1)) end_loss = nll_loss( util.masked_log_softmax(span_end_logits, None), span_end.squeeze(-1)) type_loss = nll_loss( util.masked_log_softmax(predict_type, None), q_type) loss = start_loss + end_loss + type_loss + strong_sup_loss self._loss_trackers['loss'](loss) self._loss_trackers['start_loss'](start_loss) self._loss_trackers['end_loss'](end_loss) self._loss_trackers['type_loss'](type_loss) self._loss_trackers['strong_sup_loss'](strong_sup_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] token_spans_sp = [] token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] count_yes = 0 count_no = 0 for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) token_spans_sp.append(metadata[i]['token_spans_sp']) token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] if type_predicts[i] == 1: best_span_string = 'yes' count_yes += 1 elif type_predicts[i] == 2: best_span_string = 'no' count_no += 1 else: predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) if answer_texts: self._squad_metrics(best_span_string.lower(), answer_texts) # shift sentence indice back evd_possible_chains.append([ s_idx - 1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0 ]) ans_sent_idxs.append( [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']]) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens output_dict['token_spans_sp'] = token_spans_sp output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict
def compute_simple_span_representations( max_span_width: int, encoded_text: torch.FloatTensor, span_starts: torch.IntTensor, span_ends: torch.IntTensor, span_width_embedding: Embedding, head_scorer: TimeDistributed) -> torch.FloatTensor: """ Computes an embedded representation of every candidate span. This is a concatenation of the contextualized endpoints of the span, an embedded representation of the width of the span and a representation of the span's predicted head. Parameters ---------- encoded_text : ``torch.FloatTensor``, required. The deeply embedded sentence of shape (batch_size, sequence_length, embedding_dim) over which we are computing a weighted sum. span_starts : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the start of each span candidate. span_ends : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the end of each span candidate. Returns ------- span_embeddings : ``torch.FloatTensor`` An embedded representation of every candidate span with shape: (batch_size, sentence_length, span_width, context_layer.get_output_dim() * 2 + embedding_size + feature_size) """ # Shape: (batch_size, sequence_length, encoding_dim) # TODO(Swabha): necessary to have this? is it going to mess with attention computation? # contextualized_embeddings = self._context_layer(text_embeddings, text_mask) _, sequence_length, _ = encoded_text.size() contextualized_embeddings = encoded_text # Shape: (batch_size, num_spans, encoding_dim) batch_size, num_spans = span_starts.size() assert num_spans == sequence_length * max_span_width start_embeddings = util.batched_index_select(contextualized_embeddings, span_starts.squeeze(-1)) end_embeddings = util.batched_index_select(contextualized_embeddings, span_ends.squeeze(-1)) # Compute and embed the span_widths (strictly speaking the span_widths - 1) # Shape: (batch_size, num_spans, 1) span_widths = span_ends - span_starts # Shape: (batch_size, num_spans, encoding_dim) span_width_embeddings = span_width_embedding(span_widths.squeeze(-1)) # Shape: (batch_size, sequence_length, 1) head_scores = head_scorer(contextualized_embeddings) # Shape: (batch_size, num_spans, embedding_dim) # Note that we used the original text embeddings, not the contextual ones here. attended_text_embeddings = create_attended_span_representations( max_span_width, head_scores, encoded_text, span_ends, span_widths) # (batch_size, num_spans, context_layer.get_output_dim() * 3 + 2 * feature_dim) span_embeddings = torch.cat([ start_embeddings, end_embeddings, span_width_embeddings, attended_text_embeddings ], -1) span_embeddings = span_embeddings.view(batch_size, sequence_length, max_span_width, -1) return span_embeddings
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, metadata: List[Dict[str, Any]], doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, doc_relations=None, doc_ner_labels: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence) for document in metadata for sentence in document['doc_tokens']) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relations is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) self._compute_relex_metrics(output_dict, doc_relations) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) if self._use_gold_mentions: if text_embeddings.is_cuda: device = torch.device("cuda") else: device = torch.device("cpu") s = [ torch.as_tensor(pair, dtype=torch.long, device=device) for cluster in metadata[0]["clusters"] for pair in cluster ] gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1) span_mask = spans.unsqueeze(2) - gm span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1] == 0) span_mask, _ = (span_mask == 2).max(-1) num_spans = span_mask.sum().item() span_mask = span_mask.float() else: span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() num_spans = spans.size(1) # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents( num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore source_spans: torch.IntTensor, source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the entire target sequence. Parameters ---------- source_spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for source sentence representation. Comes from a ``ListField[SpanField]`` of indices into the source sentence. source_tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : Dict[str, torch.LongTensor], optional (default = None) Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, encoder_output_dim) embedded_input = self._source_embedder(source_tokens) num_spans = source_spans.size(1) source_length = embedded_input.size(1) batch_size, _, _ = embedded_input.size() # (batch_size, source_length) source_mask = get_text_field_mask(source_tokens) # Shape: (batch_size, num_spans) span_mask = (source_spans[:, :, 0] >= 0).squeeze(-1).float() # Shape: (batch_size, num_spans, 2) spans = F.relu(source_spans.float()).long() # Contextualized word embeddings; Shape: (batch_size, source_length, embedding_dim) contextualized_word_embeddings = self._encoder(embedded_input, source_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) span_embeddings = self._span_extractor(contextualized_word_embeddings, spans) # Prune based on feedforward scorer num_spans_to_keep = int( math.floor(self._spans_per_word * source_length)) # Shape: see return section of SpanPruner docs (top_span_embeddings, top_span_mask, top_span_indices, top_span_scores) = self._span_pruner(span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = batched_index_select(spans, top_span_indices, flat_top_span_indices) # Here we define what we will init first hidden state of decoder with summary_of_encoded_source = contextualized_word_embeddings[:, -1] # (batch_size, encoder_output_dim) if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end symbol. Either way, we # don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Condition decoder on encoder # Here we just derive and append one more dummy embedding feature (sum) to match dimensions later # Shape: (batch_size, encoder_output_dim + 1) decoder_hidden = torch.cat( (summary_of_encoded_source, summary_of_encoded_source.sum(1).unsqueeze(1)), 1) decoder_context = Variable(top_span_embeddings.data.new().resize_( batch_size, self._decoder_output_dim).fill_(0)) last_predictions = None step_logits = [] step_probabilities = [] step_predictions = [] step_attention_weights = [] for timestep in range(num_decoding_steps): if self.training and all( torch.rand(1) >= self._scheduled_sampling_ratio): input_choices = targets[:, timestep] else: if timestep == 0: # For the first timestep, when we do not have targets, we input start symbols. # (batch_size,) input_choices = Variable( source_mask.data.new().resize_(batch_size).fill_( self._start_index)) else: input_choices = last_predictions # We append span scores to the span embedding features to make SpanPrune trainable # Shape: (batch_size, num_spans_to_keep, span_embedding_dim + 1) top_span_embeddings_scores = torch.cat( (top_span_embeddings, top_span_scores), 2) # Shape: (batch_size, decoder_input_dim) decoder_input, attention_weights = self._prepare_decode_step_input( input_choices, decoder_hidden, top_span_embeddings_scores, top_span_mask) if attention_weights is not None: step_attention_weights.append(attention_weights) # Shape: both (batch_size, decoder_output_dim), decoder_hidden, decoder_context = self._decoder_cell( decoder_input, (decoder_hidden, decoder_context)) # (batch_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) step_probabilities.append(class_probabilities.unsqueeze(1)) last_predictions = predicted_classes # (batch_size, 1) step_predictions.append(last_predictions.unsqueeze(1)) # step_logits is a list containing tensors of shape (batch_size, 1, num_classes) # This is (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) class_probabilities = torch.cat(step_probabilities, 1) all_predictions = torch.cat(step_predictions, 1) # step_attention_weights is a list containing tensors of shape (batch_size, num_encoder_outputs) # This is (batch_size, num_decoding_steps, num_encoder_outputs) if len(step_attention_weights) > 0: attention_matrix = torch.cat(step_attention_weights, 0) attention_matrix.unsqueeze_(0) output_dict = { "logits": logits, "class_probabilities": class_probabilities, "predictions": all_predictions, "top_spans": top_spans, "attention_matrix": attention_matrix, "top_spans_scores": top_span_scores } if target_tokens: target_mask = get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict[ "loss"] = loss #+ top_span_scores.squeeze().view(-1).index_select(0, top_span_mask.view(-1).long()).sum() return output_dict
def forward(self, token_ids: torch.Tensor, entity_mask: torch.IntTensor, verb_mask: torch.IntTensor, loc_mask: torch.IntTensor, gold_loc_seq: torch.IntTensor, gold_state_seq: torch.IntTensor, num_cands: torch.IntTensor, sentence_mask: torch.IntTensor, cpnet_triples: List, state_rel_labels: torch.IntTensor, loc_rel_labels: torch.IntTensor): """ Args: token_ids: size (batch * max_wiki, max_ctx_tokens) *_mask: size (batch, max_sents, max_tokens) loc_mask: size (batch, max_cands, max_sents + 1, max_tokens), +1 for location 0 gold_loc_seq: size (batch, max_sents) gold_state_seq: size (batch, max_sents) state_rel_labels: size (batch, max_sents, max_cpnet) loc_rel_labels: size (batch, max_sents, max_cpnet) num_cands: size (batch,) """ assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size(-2) - 1\ == gold_state_seq.size(-1) == gold_loc_seq.size(-1) - 1 assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size(-1) batch_size = entity_mask.size(0) max_tokens = entity_mask.size(-1) max_sents = gold_state_seq.size(-1) max_cands = loc_mask.size(-3) attention_mask = (token_ids != self.plm_tokenizer.pad_token_id).to( torch.int) plm_outputs = self.embed_encoder(token_ids, attention_mask=attention_mask) embeddings = plm_outputs[ 0] # hidden states at the last layer, (batch, max_tokens, plm_hidden_size) token_rep, _ = self.TokenEncoder( embeddings) # (batch, max_tokens, 2*hidden_size) token_rep = self.Dropout(token_rep) assert token_rep.size() == (batch_size, max_tokens, 2 * self.hidden_size) cpnet_rep = self.CpnetEncoder(cpnet_triples, tokenizer=self.plm_tokenizer, encoder=self.cpnet_encoder) # state change prediction # size (batch, max_sents, NUM_STATES) tag_logits, state_attn_probs = self.StateTracker( encoder_out=token_rep, entity_mask=entity_mask, verb_mask=verb_mask, sentence_mask=sentence_mask, cpnet_triples=cpnet_triples, cpnet_rep=cpnet_rep) tag_mask = (gold_state_seq != PAD_STATE ) # mask the padded part so they won't count in loss log_likelihood = self.CRFLayer(emissions=tag_logits, tags=gold_state_seq.long(), mask=tag_mask, reduction='token_mean') state_loss = -log_likelihood # State classification loss is negative log likelihood pred_state_seq = self.CRFLayer.decode(emissions=tag_logits, mask=tag_mask) assert len(pred_state_seq) == batch_size correct_state_pred, total_state_pred = compute_state_accuracy( pred=pred_state_seq, gold=gold_state_seq.tolist(), pad_value=PAD_STATE) # location prediction # size (batch, max_cands, max_sents + 1) empty_mask = torch.zeros((batch_size, 1, max_tokens), dtype=torch.int) if self.use_cuda: empty_mask = empty_mask.cuda() entity_mask = torch.cat([empty_mask, entity_mask], dim=1) loc_logits, loc_attn_probs = self.LocationPredictor( encoder_out=token_rep, entity_mask=entity_mask, loc_mask=loc_mask, sentence_mask=sentence_mask, cpnet_triples=cpnet_triples, cpnet_rep=cpnet_rep) loc_logits = loc_logits.transpose( -1, -2) # size (batch, max_sents + 1, max_cands) masked_loc_logits = self.mask_loc_logits( loc_logits=loc_logits, num_cands=num_cands) # (batch, max_sents + 1, max_cands) masked_gold_loc_seq = self.mask_undefined_loc( gold_loc_seq=gold_loc_seq, mask_value=PAD_LOC) # (batch, max_sents + 1) loc_loss = self.CrossEntropy(input=masked_loc_logits.view( batch_size * (max_sents + 1), max_cands + 1), target=masked_gold_loc_seq.view( batch_size * (max_sents + 1)).long()) correct_loc_pred, total_loc_pred = compute_loc_accuracy( logits=masked_loc_logits, gold=masked_gold_loc_seq, pad_value=PAD_LOC) if loc_attn_probs is not None: loc_attn_probs = self.get_gold_attn_probs(loc_attn_probs, gold_loc_seq) attn_loss, total_attn_pred = self.get_attn_loss( state_attn_probs, loc_attn_probs, state_rel_labels, loc_rel_labels) if self.is_test: # inference pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits, gold_loc_seq=gold_loc_seq) return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred return state_loss, loc_loss, attn_loss, correct_state_pred, total_state_pred, \ correct_loc_pred, total_loc_pred, total_attn_pred
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1]) # Shape: (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, class_num + 1) ne_scores = self._compute_named_entity_scores(top_span_embeddings) # Shape: (batch_size, num_spans_to_keep) _, predicted_named_entities = ne_scores.max(2) output_dict = { "top_spans": top_spans, "predicted_named_entities": predicted_named_entities } if labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) negative_log_likelihood = F.cross_entropy( ne_scores.reshape(-1, self.class_num), pruned_gold_labels.reshape(-1)) pruner_loss = F.binary_cross_entropy_with_logits( top_span_mention_scores.reshape(-1), (pruned_gold_labels.reshape(-1) != 0).float()) loss = negative_log_likelihood + pruner_loss output_dict["loss"] = loss output_dict["pruner_loss"] = pruner_loss batch_size, _ = labels.shape all_scores = ne_scores.new_zeros( [batch_size * num_spans, self.class_num]) all_scores[:, 0] = 1 all_scores[flat_top_span_indices] = ne_scores.reshape( -1, self.class_num) all_scores = all_scores.reshape( [batch_size, num_spans, self.class_num]) self._metric_all(all_scores, labels) self._metric_avg(all_scores, labels) return output_dict
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, evidence: torch.IntTensor = None, pad_idx=-1, max_select=5, gamma=0.95, teacher_forcing_ratio=1, features=None, metadata=None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` evidence : torch.IntTensor, optional (default = None) From a ``ListField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ #print([int(i.data[0]) for i in premise['tokens'][0,0]]) premise_mask = get_text_field_mask(premise, num_wrapping_dims=1).float() hypothesis_mask = get_text_field_mask(hypothesis).float() aggregated_input = self._sentence_selection_esim(premise, hypothesis, premise_mask, hypothesis_mask, wrap_output=True, features=features) batch_size, num_evidence, max_premise_length = premise_mask.shape #print(premise_mask.shape) aggregated_input = aggregated_input.view(batch_size, num_evidence, -1) evidence_mask = premise_mask.sum(dim=-1).gt(0) evidence_len = evidence_mask.view(batch_size, -1).sum(dim=-1) #print(aggregated_input.shape) #print(evidence_len) #for each element in the batch valid_indices = [] indices = [] probs = [] baselines = [] states = [] selected_evidence_lengths = [] for i in range(evidence.size(0)): #print(label[i].data[0], evidence[i]) gold_evidence = None #teacher forcing, give a list of indices and get the probabilities #print(label[i]) try: curr_label = label[i].data[0] except IndexError: curr_label = label[i].item() if random.random( ) > teacher_forcing_ratio and curr_label != self._nei_label and float( evidence[i].ne(pad_idx).sum()) > 0: gold_evidence = evidence[i] #print(gold_evidence) output = self._ptr_extract_summ(aggregated_input[i], max_select, evidence_mask[i], gold_evidence, beam_size=self._beam_size) #print(output['states'].shape) #print(idxs) states.append(output.get('states', [])) valid_idx = [] try: curr_evidence_len = evidence_len[i].data[0] except IndexError: curr_evidence_len = evidence_len[i].item() for idx in output['idxs'][:min(max_select, curr_evidence_len)]: try: curr_idx = idx.view(-1).data[0] except IndexError: curr_idx = idx.view(-1).item() if curr_idx == num_evidence: break valid_idx.append(curr_idx) if valid_idx[-1] >= curr_evidence_len: valid_idx[-1] = 0 #TODO: if it selects none, use the first one? selected_evidence_lengths.append(len(valid_idx)) #print(selected_evidence_lengths[-1]) indices.append(valid_idx) if 'scores' in output: baselines.append(output['scores'][:len(valid_idx)]) if 'probs' in output: probs.append(output['probs'][:len(valid_idx)]) valid_indices.append(torch.LongTensor(valid_idx + \ [-1]*(max_select-len(valid_idx)))) ''' for q in range(label.size(0)): if selected_evidence_lengths[q] >= 5: continue print(label[q]) print(evidence[q]) print(valid_indices[q]) if len(baselines): print(probs[q][0].probs) print(baselines[q]) ''' output_dict = {'predicted_sentences': torch.stack(valid_indices)} predictions = torch.autograd.Variable(torch.stack(valid_indices)) selected_premise = {} index = predictions.unsqueeze(2).expand(batch_size, max_select, max_premise_length) #B x num_selected l = torch.autograd.Variable( len_mask(selected_evidence_lengths, max_len=max_select, dtype=torch.FloatTensor)) index = index * l.long().unsqueeze(-1) if torch.cuda.is_available() and premise_mask.is_cuda: idx = premise_mask.get_device() index = index.cuda(idx) l = l.cuda(idx) predictions = predictions.cuda(idx) if self._use_decoder_states: states = torch.cat(states, dim=0) label_sequence = make_label_sequence(predictions, evidence, label, pad_idx=pad_idx, nei_label=self._nei_label) #print(states.shape) batch_size, max_length, _ = states.shape label_logits = self._entailment_esim( features=states.view(batch_size * max_length, 1, -1)) if 'loss' not in output_dict: output_dict['loss'] = 0 output_dict['loss'] += sequence_loss(label_logits.view( batch_size, max_length, -1), label_sequence, self._evidence_loss, pad_idx=pad_idx) output_dict['label_sequence_logits'] = label_logits.view( batch_size, max_length, -1) label_logits = output_dict['label_sequence_logits'][:, -1, :] else: for key in premise: selected_premise[key] = torch.gather(premise[key], dim=1, index=index) selected_mask = torch.gather(premise_mask, dim=1, index=index) selected_mask = selected_mask * l.unsqueeze(-1) selected_features = None if features is not None: index = predictions.unsqueeze(2).expand( batch_size, max_select, features.size(-1)) index = index * l.long().unsqueeze(-1) selected_features = torch.gather(features, dim=1, index=index) #UNDO!!!!! selected_features = selected_features[:, :, :200] label_logits = self._entailment_esim(selected_premise, hypothesis, premise_mask=selected_mask, features=selected_features) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) #print(label_probs[0]) ''' key = 'tokens' for q in range(premise[key].size(0)): print(index[q,:,0]) print([int(i.data[0]) for i in hypothesis[key][q]]) print([self.vocab._index_to_token[key][i.data[0]] for i in hypothesis[key][q]]) print([int(i.data[0]) for i in premise[key][q,0]]) print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,0]]) print([self.vocab._index_to_token[key][i.data[0]] for i in premise[key][q,index[q,0,0].data[0]]]) print([self.vocab._index_to_token[key][i.data[0]] for i in selected_premise[key][q,0]]) print([int(i.data[0]) for i in premise_mask[q,0]]) print(l[q]) print([int(i.data[0]) for i in premise_mask[q,index[q,0,0].data[0]]]) for z in range(5): print([int(i.data[0]) for i in selected_mask[q,z]]) print(label[q], label_probs[q]) ''' output_dict.update({ "label_logits": label_logits, "label_probs": label_probs }) #get fever score, recall, and accuracy if len(label.shape) > 1: self._accuracy(label_logits, label.squeeze(-1)) else: self._accuracy(label_logits, label) fever_reward = self._fever(label_logits, label.squeeze(-1), predictions, evidence, indices=True, pad_idx=pad_idx, metadata=metadata) if not self._fix_sentence_extraction_params: #multiply the reward for the support/refute labels by a constant so that the model selects the correct evidence instead of just trying to predict the not enough info labels fever_reward = fever_reward * label.squeeze(-1).ne( self._nei_label ) * self._ei_reward_weight + fever_reward * label.squeeze(-1).eq( self._nei_label) #compute discounted reward rewards = [] #print(fever_reward[0]) avg_reward = 0 for i in range(evidence.size(0)): avg_reward += float(fever_reward[i]) #rewards.append(gamma ** torch.range(selected_evidence_lengths[i]-1,0,-1) * float(fever_reward[i])) rewards.append( gamma**torch.arange(selected_evidence_lengths[i]).float() * fever_reward[i].float()) #print(fever_reward[0]) #print(rewards[0]) reward = torch.autograd.Variable(torch.cat(rewards), requires_grad=False) if torch.cuda.is_available() and fever_reward.is_cuda: idx = fever_reward.get_device() reward = reward.cuda(idx) #print(reward) if len(baselines): indices = list(itertools.chain(*indices)) probs = list(itertools.chain(*probs)) baselines = list(itertools.chain(*baselines)) #print(baselines) # standardize rewards reward = (reward - reward.mean()) / ( reward.std() + float(np.finfo(np.float32).eps)) #print(reward) baseline = torch.cat(baselines).squeeze() avg_advantage = 0 losses = [] for action, p, r, b in zip(indices, probs, reward, baseline): #print(action, p, r, b) action = torch.autograd.Variable(torch.LongTensor([action ])) if torch.cuda.is_available() and r.is_cuda: idx = r.get_device() action = action.cuda(idx) advantage = r - b #print(r, b, advantage) avg_advantage += advantage losses.append(-p.log_prob(action) * (advantage / len(indices))) # divide by T*B #print(losses[-1]) critic_loss = F.mse_loss(baseline, reward) output_dict['loss'] = critic_loss + sum(losses) #output_dict['loss'].backward(retain_graph=True) #grad_log = self.grad_fn() #print(grad_log) try: output_dict['advantage'] = avg_advantage.data[0] / len( indices) output_dict['mse'] = critic_loss.data[0] except IndexError: output_dict['advantage'] = avg_advantage.item() / len( indices) output_dict['mse'] = critic_loss.item() #output_dict['reward'] = avg_reward / evidence.size(0) if self.training and self._train_gold_evidence: if 'loss' not in output_dict: output_dict['loss'] = 0 if evidence.sum() != -1 * torch.numel(evidence): if len(evidence.shape) > 2: evidence = evidence.squeeze(-1) #print(evidence_len.long().data.cpu().numpy().tolist()) #print(evidence.shape, evidence_len.shape) #print(evidence, evidence_len) output = self._ptr_extract_summ( aggregated_input, None, None, evidence, evidence_len.long().data.cpu().numpy().tolist()) #print(output['states'].shape) loss = sequence_loss(output['scores'][:, :-1, :], evidence, self._evidence_loss, pad_idx=pad_idx) output_dict['loss'] += self.lambda_weight * loss if not self._fix_entailment_params: if self._use_decoder_states: if self.training: label_sequence = make_label_sequence( evidence, evidence, label, pad_idx=pad_idx, nei_label=self._nei_label) batch_size, max_length, _ = output['states'].shape label_logits = self._entailment_esim( features=output['states'][:, 1:, :].contiguous().view( batch_size * (max_length - 1), 1, -1)) if 'loss' not in output_dict: output_dict['loss'] = 0 #print(label_logits.shape, label_sequence.shape) output_dict['loss'] += sequence_loss(label_logits.view( batch_size, max_length - 1, -1), label_sequence, self._evidence_loss, pad_idx=pad_idx) else: #TODO: only update classifier if we have correct evidence #evidence_reward = self._fever_evidence_only(label_logits, label.squeeze(-1), # predictions, evidence, # indices=True, pad_idx=pad_idx) ###print(evidence_reward) ###print(label) #mask = evidence_reward > 0 #target = mask * label.byte() + mask.eq(0) * self._nei_label mask = fever_reward != 2**7 target = label.view(-1).masked_select(mask) ###print(target) mask = fever_reward != 2**7 logit = label_logits.masked_select( mask.unsqueeze(1).expand_as( label_logits)).contiguous().view( -1, label_logits.size(-1)) loss = self._loss( logit, target.long()) #label_logits, label.long().view(-1)) if 'loss' in output_dict: output_dict["loss"] += self.lambda_weight * loss else: output_dict["loss"] = self.lambda_weight * loss return output_dict
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.last_dim_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def forward( self, # type: ignore characters: Dict[str, torch.LongTensor], character_spans: torch.IntTensor, ## Shape: batch x num_spans x 2 pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: embedded_text_input = self.text_field_embedder(characters) character_mask = util.get_text_field_mask(characters).float() embedded_text_input = embedded_text_input * character_mask.unsqueeze( -1).float() num_spans = character_spans.size(1) span_mask = (character_spans[:, :, 0] >= 0).squeeze(-1).long() if num_spans == 1: span_mask = span_mask.unsqueeze(-1) spans = F.relu(character_spans.float()).long() # Shape: (batch_size, num_spans, 2 * encoding_dim) endpoint_span_embeddings = self._endpoint_span_extractor( embedded_text_input, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( embedded_text_input, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) embedded_text_input = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse( embedded_text_input, span_mask, head_tags, head_indices) loss = arc_nll + tag_nll if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices, head_tags, evaluation_mask) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) # answer_len for masking answer_len = [len(elem['answer_texts']) for elem in metadata] if metadata is not None else [] if answer_len: mask = torch.zeros((batch_size, max(answer_len), 2)).long() for index, length in enumerate(answer_len): mask[index, :length] = 1 else: mask = None best_span = self.get_best_span(span_start_logits, span_end_logits, answer_len) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: span_start = span_start.squeeze(-1) #batch X max_answer_L span_end = span_end.squeeze(-1) #batch X max_answer_L # TODO answer padding needs to be ignored step = 0 span_start_1D = span_start[ : , step:step + 1] #batch X 1 span_end_1D = span_end[ : , step:step + 1] #batch X 1 loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO for step in range(1, span_start.size(1)): span_start_1D = span_start[ : , step:step + 1] #batch X 1 span_end_1D = span_end[ : , step:step + 1] #batch X 1 loss += nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO self._span_accuracy(best_span, torch.stack([span_start, span_end], -1), mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): best_span_strings = [] question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_spans = tuple(best_span[i].data.cpu().numpy()) for predicted_span in predicted_spans: if predicted_span[0] == -1: break start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] best_span_strings.append(best_span_string) output_dict['best_span_str'].append(best_span_strings) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_strings, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # In this model, we only take the first chain in ``evd_chain_labels`` for supervision evd_chain_labels = evd_chain_labels[:, 0] if not evd_chain_labels is None else None # there may be some instances that we can't find any evd chain for training # In that case, use the mask to ignore those instances evd_instance_mask = (evd_chain_labels[:, 0] != 0).float( ) if not evd_chain_labels is None else None # bert embedding for answer prediction # shape: [batch_size, max_q_len, emb_size] embedded_question = self._text_field_embedder(question) # shape: [batch_size, max_passage_len, embedding_dim] embedded_passage = self._text_field_embedder(passage) # embedded_question = self._bert_projection(embedded_question) # embedded_passage = self._bert_projection(embedded_passage) print('size embedded_passage:', embedded_passage.shape) # mask ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float() context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float() # extract word embeddings for each sentence batch_size, max_passage_len, emb_size = embedded_passage.size() batch_size, max_num_sent, _ = sentence_spans.size() # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim) # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width) spans_rep_sp, spans_mask = convert_sequence_to_spans( embedded_passage, sentence_spans) # chain prediction # Shape(all_predictions): (batch_size, num_decoding_steps) # Shape(all_logprobs): (batch_size, num_decoding_steps) # Shape(seq_logprobs): (batch_size,) # Shape(gate): (batch_size * num_spans, 1) # Shape(gate_probs): (batch_size * num_spans, 1) # Shape(gate_mask): (batch_size, num_spans) # Shape(g_att_score): (batch_size, num_heads, num_spans, num_spans) # Shape(orders): (batch_size, K, num_spans) all_predictions, \ all_logprobs, \ seq_logprobs, \ gate, \ gate_probs, \ gate_mask, \ g_att_score, \ orders = self._span_gate(spans_rep_sp, spans_mask, embedded_question, ques_mask, evd_chain_labels, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = spans_mask.size() output_dict = { "pred_sent_labels": gate.squeeze(1).view(batch_size, num_spans), #[B, num_span] "gate_probs": gate_probs.squeeze(1).view(batch_size, num_spans), #[B, num_span] "pred_sent_orders": orders, #[B, K, num_span] } if self._output_att_scores: if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score # compute evd rl training metric, rewards, and loss print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) evd_TP, evd_NP, evd_NT = self._f1_metrics( gate.squeeze(1).view(batch_size, num_spans), sent_labels, mask=gate_mask, instance_mask=evd_instance_mask if self.training else None, sum=False) print("TP:", evd_TP) print("NP:", evd_NP) print("NT:", evd_NT) evd_ps = np.array(evd_TP) / (np.array(evd_NP) + 1e-13) evd_rs = np.array(evd_TP) / (np.array(evd_NT) + 1e-13) evd_f1s = 2. * ((evd_ps * evd_rs) / (evd_ps + evd_rs + 1e-13)) predict_mask = get_evd_prediction_mask(all_predictions.unsqueeze(1), eos_idx=0)[0] gold_mask = get_evd_prediction_mask(evd_chain_labels, eos_idx=0)[0] # default to take multiple predicted chains, so unsqueeze dim 1 self.evd_sup_acc_metric(predictions=all_predictions.unsqueeze(1), gold_labels=evd_chain_labels, predict_mask=predict_mask, gold_mask=gold_mask, instance_mask=evd_instance_mask) print("gold chain:", evd_chain_labels) predict_mask = predict_mask.float().squeeze(1) rl_loss = -torch.mean( torch.sum(all_logprobs * predict_mask * evd_instance_mask[:, None], dim=1)) # torch.cuda.empty_cache() # Compute the EM and F1 on SQuAD and add the tokenized input to the output. # Compute before loss for rl if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] #token_spans_sp = [] token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] pred_chains_include_ans = [] beam_pred_chains_include_ans = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) # shift sentence indice back evd_possible_chains.append([ s_idx - 1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0 ]) ans_sent_idxs.append( [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']]) print("ans_sent_idxs:", metadata[i]['ans_sent_idxs']) if len(metadata[i]['ans_sent_idxs']) > 0: pred_sent_orders = orders[i].detach().cpu().numpy() if any([ pred_sent_orders[0][s_idx - 1] >= 0 for s_idx in metadata[i]['ans_sent_idxs'] ]): self.evd_ans_metric(1) pred_chains_include_ans.append(1) else: self.evd_ans_metric(0) pred_chains_include_ans.append(0) if any([ any([ pred_sent_orders[beam][s_idx - 1] >= 0 for s_idx in metadata[i]['ans_sent_idxs'] ]) for beam in range(len(pred_sent_orders)) ]): self.evd_beam_ans_metric(1) beam_pred_chains_include_ans.append(1) else: self.evd_beam_ans_metric(0) beam_pred_chains_include_ans.append(0) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens #output_dict['token_spans_sp'] = token_spans_sp output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['pred_chains_include_ans'] = pred_chains_include_ans output_dict[ 'beam_pred_chains_include_ans'] = beam_pred_chains_include_ans output_dict['_id'] = ids # Compute the loss for training. if evd_chain_labels is not None: try: loss = rl_loss self._loss_trackers['loss'](loss) self._loss_trackers['rl_loss'](rl_loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(output_dict['_id']) return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, sentence_spans: torch.IntTensor = None, sent_labels: torch.IntTensor = None, evd_chain_labels: torch.IntTensor = None, q_type: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: if self._sent_labels_src == 'chain': batch_size, num_spans = sent_labels.size() sent_labels_mask = (sent_labels >= 0).float() print("chain:", evd_chain_labels) # we use the chain as the label to supervise the gate # In this model, we only take the first chain in ``evd_chain_labels`` for supervision, # right now the number of chains should only be one too. evd_chain_labels = evd_chain_labels[:, 0].long() # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding # shape: (batch_size, 1+num_spans) sent_labels = sent_labels.new_zeros((batch_size, 1+num_spans)) sent_labels.scatter_(1, evd_chain_labels, 1.) # remove the column for end embedding # shape: (batch_size, num_spans) sent_labels = sent_labels[:, 1:].float() # make the padding be -1 sent_labels = sent_labels * sent_labels_mask + -1. * (1 - sent_labels_mask) print('\nBert wordpiece size:', passage['bert'].shape) # bert embedding for answer prediction # shape: [batch_size, max_q_len, emb_size] embedded_question = self._text_field_embedder(question, num_wrapping_dims=0) # shape: [batch_size, num_sent, max_sent_len+q_len, embedding_dim] embedded_passage = self._text_field_embedder(passage, num_wrapping_dims=1) # print('\npassage size:', embedded_passage.shape) #embedded_question = self._bert_projection(embedded_question) #embedded_passage = self._bert_projection(embedded_passage) #print('size embedded_passage:', embedded_passage.shape) # mask ques_mask = util.get_text_field_mask(question, num_wrapping_dims=0).float() context_mask = util.get_text_field_mask(passage, num_wrapping_dims=1).float() # gate prediction # Shape(gate_logit): (batch_size * num_spans, 2) # Shape(gate): (batch_size * num_spans, 1) # Shape(pred_sent_probs): (batch_size * num_spans, 2) # Shape(gate_mask): (batch_size, num_spans) #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask) gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(embedded_passage, context_mask, self._gate_self_attention_layer, self._gate_sent_encoder) batch_size, num_spans, max_batch_span_width = context_mask.size() loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1), sent_labels.long().view(batch_size * num_spans), ignore_index=-1) gate = (gate >= 0.3).long() gate = gate.view(batch_size, num_spans) output_dict = { "pred_sent_labels": gate, #[B, num_span] "gate_probs": pred_sent_probs[:, 1].view(batch_size, num_spans), #[B, num_span] } if self._output_att_scores: if not g_att_score is None: output_dict['evd_self_attention_score'] = g_att_score # Compute the loss for training. try: #loss = strong_sup_loss self._loss_trackers['loss'](loss) output_dict["loss"] = loss except RuntimeError: print('\n meta_data:', metadata) print(span_start_logits.shape) print("sent label:") for b_label in np.array(sent_labels.cpu()): b_label = b_label == 1 indices = np.arange(len(b_label)) print(indices[b_label] + 1) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['answer_texts'] = [] question_tokens = [] passage_tokens = [] #token_spans_sp = [] #token_spans_sent = [] sent_labels_list = [] evd_possible_chains = [] ans_sent_idxs = [] ids = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_sent_tokens']) #token_spans_sp.append(metadata[i]['token_spans_sp']) #token_spans_sent.append(metadata[i]['token_spans_sent']) sent_labels_list.append(metadata[i]['sent_labels']) ids.append(metadata[i]['_id']) passage_str = metadata[i]['original_passage'] #offsets = metadata[i]['token_offsets'] answer_texts = metadata[i].get('answer_texts', []) output_dict['answer_texts'].append(answer_texts) # shift sentence indice back evd_possible_chains.append([s_idx-1 for s_idx in metadata[i]['evd_possible_chains'][0] if s_idx > 0]) ans_sent_idxs.append([s_idx-1 for s_idx in metadata[i]['ans_sent_idxs']]) if len(metadata[i]['ans_sent_idxs']) > 0: pred_sent_gate = gate[i].detach().cpu().numpy() if any([pred_sent_gate[s_idx-1] > 0 for s_idx in metadata[i]['ans_sent_idxs']]): self.evd_ans_metric(1) else: self.evd_ans_metric(0) self._f1_metrics(pred_sent_probs, sent_labels.view(-1), gate_mask.view(-1)) output_dict['question_tokens'] = question_tokens output_dict['passage_sent_tokens'] = passage_tokens #output_dict['token_spans_sp'] = token_spans_sp #output_dict['token_spans_sent'] = token_spans_sent output_dict['sent_labels'] = sent_labels_list output_dict['evd_possible_chains'] = evd_possible_chains output_dict['ans_sent_idxs'] = ans_sent_idxs output_dict['_id'] = ids return output_dict
def forward( self, # type: ignore para_id: int, participant_strings: List[str], paragraph: Dict[str, torch.LongTensor], sentences: Dict[str, torch.LongTensor], paragraph_sentence_indicators: torch.IntTensor, participants: Dict[str, torch.LongTensor], participant_indicators: torch.IntTensor, paragraph_participant_indicators: torch.IntTensor, verbs: torch.IntTensor, paragraph_verbs: torch.IntTensor, actions: torch.IntTensor = None, before_locations: torch.IntTensor = None, after_locations: torch.IntTensor = None, filename: List[str] = [], score: List[float] = 1.0 # instance_score ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- para_id: int The id of the paragraph participant_strings: List[str] The participants in the paragraph paragraph: Dict[str, torch.LongTensor] The token indices for the paragraph sentences: Dict[str, torch.LongTensor] The token indices batched by sentence. paragraph_sentence_indicators: torch.LongTensor Indicates before / inside / after for each sentence participants: Dict[str, torch.LongTensor] The token indices for the participant names participant_indicators: torch.IntTensor Indicates each participant in each sentence paragraph_participant_indicators: torch.IntTensor Indicates each participant in the paragraph verbs: torch.IntTensor Indicates the positions of verbs in the sentences paragraph_verbs: torch.IntTensor Indicates the positions of verbs in the paragraph actions: torch.IntTensor, optional (default = None) Indicates the actions taken per participant per sentence. before_locations: torch.IntTensor, optional (default = None) Indicates the span for the before location per participant per sentence after_locations: torch.IntTensor, optional (default = None) Indicates the span for the after location per participant per sentence filename: List[str], optional (default = '') The files from which the instances were read score: List[float], optional (default = 1.0) The score for each instance Returns ------- An output dictionary consisting of: action_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_sentences, num_participants, num_action_types)`` representing a distribution of state change types per sentence, participant in each datapoint (paragraph). loss : torch.FloatTensor, optional A scalar loss to be optimised. """ self.filename = filename self.instance_score = score # original shape (batch_size, num_participants, num_sentences, sentence_length) participant_indicators = participant_indicators.transpose(1, 2) # new shape (batch_size, num_sentences, num_participants, sentence_length) batch_size, num_sentences, num_participants, sentence_length = participant_indicators.size( ) # (batch_size, num_sentences, sentence_length, embedding_size) embedded_sentences = self.text_field_embedder(sentences) # (batch_size, num_participants, description_length, embedding_size) embedded_participants = self.text_field_embedder(participants) batch_size, num_sentences, sentence_length, embedding_size = embedded_sentences.size( ) self.num_sentences = num_sentences # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # (batch_size, num_sentences, num_participants, sentence_length, embedding_size) embedded_sentence_participant_pairs = embedded_sentences.unsqueeze(2).expand(batch_size, num_sentences, \ num_participants, sentence_length, embedding_size) # (batch_size, num_sentences, sentence_length) -> (batch_size, num_sentences, num_participants, sentence_length) mask = get_text_field_mask(sentences, num_wrapping_dims=1). \ unsqueeze(2).expand(batch_size, num_sentences, num_participants, sentence_length).float() # (batch_size, num_participants, num_sentences * sentence_length) participant_view = participant_indicators.transpose(1, 2). \ view(batch_size, num_participants, num_sentences * sentence_length) # participant_mask is used to mask out invalid sentence, participant pairs # (batch_size, num_sentences, num_participants, sentence_length) sent_participant_pair_mask = (participant_view.sum(dim=2) > 0). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences, sentence_length). \ transpose(1, 2).float() # whether the sentence is masked or not (sent does not exist in paragraph). # this is either (batch_size, num_sentences, num_participants) # or if only one participant (batch_size, num_sentences) # TODO(joelgrus) why is there a squeeze here sentence_mask = (mask.sum(3) > 0).squeeze(-1).float() # (batch_size, num_sentences, num_participants, sentence_length) mask = mask * sent_participant_pair_mask # (batch_size, num_participants, num_sentences * sentence_length) # -> (batch_size, num_participants) # -> (batch_size, num_participants, num_sentences) # -> (batch_size, num_sentences, num_participants) participant_mask = (participant_view.sum(dim=2) > 0). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \ transpose(1, 2).float() # Example: 0.0 where action is -1 (padded) # action: [[[1, 0, 1], [3, 2, 3]], [[0, -1, -1], [-1, -1, -1]]] # action_mask: [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] # (batch_size, num_sentences, num_participants) action_mask = participant_mask * sentence_mask # (batch_size, num_sentences, num_participants, sentence_length) verb_indicators = verbs.unsqueeze(2).expand(batch_size, num_sentences, num_participants, sentence_length).float() # =========================================================================================================== # Layer 2: Concatenate sentence embedding with verb and participant indicator bits # espp: (batch_size, num_sentences, num_participants, sentence_length, embedding_size) # vi: (batch_size, num_sentences, num_participants, sentence_length) # pi: (batch_size, num_sentences, num_participants, sentence_length) # # result: (batch_size, num_sentences, num_participants, sentence_length, embedding_size + 2) embedded_sentence_verb_entity = \ torch.cat([embedded_sentence_participant_pairs, verb_indicators.unsqueeze(-1).float(), participant_indicators.unsqueeze(-1).float()], dim=-1) # =========================================================================================================== # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence if self.use_attention: # (batch_size, num_sentences, num_participants, sentence_length, ) # contextual_seq_embedding: batch_size * num_sentences * # num_participants * sentence_length * (2*seq2seq_output_size) contextual_seq_embedding = self.time_distributed_seq2seq_encoder( embedded_sentence_verb_entity, mask) # Layer 3.5: Attention (Contextual embedding, BOW(verb span)) verb_weight_matrix = verb_indicators.float() / ( verb_indicators.float().sum(-1).unsqueeze(-1) + 1e-13) # (batch_size, num_sentences, num_participants, embedding_size) verb_vector = weighted_sum( contextual_seq_embedding * verb_indicators.float().unsqueeze(-1), verb_weight_matrix) # (batch_size, num_sentences, num_participants, sentence_length) participant_weight_matrix = participant_indicators.float() / ( participant_indicators.float().sum(-1).unsqueeze(-1) + 1e-13) # (batch_size, num_sentences, num_participants, embedding_size) participant_vector = weighted_sum( contextual_seq_embedding * participant_indicators.float().unsqueeze(-1), participant_weight_matrix) # (batch_size, num_sentences, num_participants, 2 * embedding_size) verb_participant_vector = torch.cat( [verb_vector, participant_vector], -1) batch_size, num_sentences, num_participants, sentence_length, verb_ind_size = verb_indicators.float( ).unsqueeze(-1).size() # attention weights for type prediction # (batch_size, num_sentences, num_participants) attention_weights_actions = self.time_distributed_attention_layer( verb_participant_vector, contextual_seq_embedding, mask) contextual_vec_embedding = weighted_sum(contextual_seq_embedding, attention_weights_actions) else: # batch_size * num_sentences * num_participants * sentence_length * embedding_size contextual_vec_embedding = self.time_distributed_seq2vec_encoder( embedded_sentence_verb_entity, mask) # (batch_size, num_participants, num_sentences, 1) -> (batch_size, nnum_sentences, num_participants, 1) if actions is not None: actions = actions.transpose(1, 2) # # =========================================================================================================== # # Layer 4 = Aggregate FeedForward to choose an action label per sentence, participant pair # (batch_size, num_sentences, num_participants, num_actions) action_logits = self.aggregate_feedforward(contextual_vec_embedding) action_probs = torch.nn.functional.softmax(action_logits, dim=-1) # (batch_size * num_sentences * num_participants, num_actions) action_probs_decode = action_probs.view( (batch_size * num_sentences * num_participants), self.num_actions) output_dict = {} if self.use_decoder_trainer: # (batch_size, num_participants, description_length, embedding_size) participants_list = embedded_participants.data.cpu().numpy() output_dict.update( DecoderTrainerHelper.pass_on_info_to_decoder_trainer( selfie=self, para_id_list=para_id, actions=actions, target_mask=action_mask, participants_list=participants_list, participant_strings=participant_strings, participant_indicators=participant_indicators.transpose( 1, 2), logit_tensor=action_logits)) # Compute type_accuracy based on best_final_states and actions best_decoded_state = output_dict['best_final_states'][0][0][0] best_decoded_action_seq = [] if best_decoded_state.action_history: for cur_step_action in best_decoded_state.action_history[0]: step_predictions = [] for step_action in list(cur_step_action): step_predictions.append(step_action) best_decoded_action_seq.append(step_predictions) best_decoded_tensor = torch.LongTensor( best_decoded_action_seq).unsqueeze(0) if actions is not None: flattened_gold = actions.long().contiguous().view(-1) self._type_accuracy( best_decoded_tensor.long().contiguous().view(-1), flattened_gold) output_dict['best_decoded_action_seq'] = [best_decoded_action_seq] else: # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict["action_probs"] = action_probs output_dict["action_probs_decode"] = action_probs_decode action_loss = 0.0 location_loss = 0.0 if actions is not None: # (batch_size * num_sentences * num_participants, num_actions) flattened_predictions = action_logits.view( (batch_size * num_sentences * num_participants), self.num_actions) # Flattened_gold: contains the gold action index (Action enum in propara_dataset_reader) # Note: tensor is not a single block of memory, but a block with holes. # view can be only used with contiguous tensors, so if you need to use it here, just call .contiguous() before. # (batch_size * num_sentences * num_participants) flattened_gold = actions.long().contiguous().view(-1) action_loss = self._loss(flattened_predictions, flattened_gold) flattened_probs = action_probs.view( (batch_size * num_sentences * num_participants), self.num_actions) evaluation_mask = (flattened_gold != -1) self._type_accuracy(flattened_probs, flattened_gold, mask=evaluation_mask) output_dict["loss"] = action_loss best_span_after, span_start_logits_after, span_end_logits_after = \ self.compute_location_spans(contextual_seq_embedding=contextual_seq_embedding, embedded_sentence_verb_entity=embedded_sentence_verb_entity, mask=mask) output_dict["location_span_after"] = [best_span_after] not_in_test = (self.training or 'test' not in self.filename) if not_in_test and (before_locations is not None and after_locations is not None): after_locations = after_locations.transpose(1, 2) (bs, ns, np, sl) = span_start_logits_after.size() #print("after_locations[:,:,:,[0]]:", after_locations[:,:,:,[0]]) location_mask = (after_locations[:, :, :, 0] >= 0).float().unsqueeze(-1).expand(bs, ns, np, sl) #print("location_mask:", location_mask) start_after_log_predicted = util.masked_log_softmax( span_start_logits_after, location_mask) start_after_log_predicted_transpose = start_after_log_predicted.transpose( 2, 3).transpose(1, 2) start_after_gold = torch.clamp(after_locations[:, :, :, [0]].squeeze(-1), min=-1) #print("start_after_log_predicted_transpose: ", start_after_log_predicted_transpose) #print("start_after_gold: ", start_after_gold) location_loss = nll_loss(input=start_after_log_predicted_transpose, target=start_after_gold, ignore_index=-1) end_after_log_predicted = util.masked_log_softmax( span_end_logits_after, location_mask) end_after_log_predicted_transpose = end_after_log_predicted.transpose( 2, 3).transpose(1, 2) end_after_gold = torch.clamp(after_locations[:, :, :, [1]].squeeze(-1), min=-1) #print("end_after_log_predicted_transpose: ", end_after_log_predicted_transpose) #print("end_after_gold: ", end_after_gold) location_loss += nll_loss(input=end_after_log_predicted_transpose, target=end_after_gold, ignore_index=-1) output_dict["loss"] += location_loss # output_dict = {"loss" : 0.0} output_dict['action_probs_decode'] = action_probs_decode output_dict['action_logits'] = action_logits return output_dict
def forward( self, # type: ignore text: TextFieldTensors, spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters text : `TextFieldTensors`, required. The output of a `TextField` representing the text of the document. spans : `torch.IntTensor`, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of indices into the text of the document. span_labels : `torch.IntTensor`, optional (default = None). A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : `List[Dict[str, Any]]`, optional (default = None). A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance. # Returns An output dictionary consisting of: top_spans : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : `torch.IntTensor` A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) batch_size = spans.size(0) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings) ).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep ) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices ) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. if self._coarse_to_fine: pruned_antecedents = self._coarse_to_fine_pruning( top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents ) else: pruned_antecedents = self._distance_pruning( top_span_embeddings, top_span_mention_scores, max_antecedents ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, ) = pruned_antecedents flat_top_antecedent_indices = util.flatten_and_batch_shift_indices( top_antecedent_indices, num_spans_to_keep ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) for _ in range(self._inference_order - 1): dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,) top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) attention_weight = util.masked_softmax( coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size) top_antecedent_with_dummy_embeddings = torch.cat( [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2 ) # Shape: (batch_size, num_spans_to_keep, embedding_size) attended_embeddings = util.weighted_sum( top_antecedent_with_dummy_embeddings, attention_weight ) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = self._span_updating_gated_sum( top_span_embeddings, attended_embeddings ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": top_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) antecedent_labels = util.batched_index_select( pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices ).squeeze(-1) antecedent_labels = util.replace_masked_values( antecedent_labels, top_antecedent_mask, -100 ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels ) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask.unsqueeze(-1) ) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores( top_spans, top_antecedent_indices, predicted_antecedents, metadata ) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, char_paragraph: torch.Tensor, entity_mask: torch.IntTensor, verb_mask: torch.IntTensor, loc_mask: torch.IntTensor, gold_loc_seq: torch.IntTensor, gold_state_seq: torch.IntTensor, num_cands: torch.IntTensor): """ Args: gold_loc_seq: size (batch, max_sents) gold_state_seq: size (batch, max_sents) num_cands: size(batch,) """ assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size( -2) == gold_state_seq.size(-1) == gold_loc_seq.size(-1) assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size( -1) == char_paragraph.size(-2) batch_size = char_paragraph.size(0) max_tokens = char_paragraph.size(1) max_sents = gold_state_seq.size(-1) max_cands = loc_mask.size(-3) embeddings = self.EmbeddingLayer( char_paragraph, verb_mask) # (batch, max_tokens, embed_size) token_rep, _ = self.TokenEncoder( embeddings) # (batch, max_tokens, 2*hidden_size) token_rep = self.Dropout(token_rep) assert token_rep.size() == (batch_size, max_tokens, 2 * self.hidden_size) # state cheng prediction # size (batch, max_sents, NUM_STATES) tag_logits = self.StateTracker(encoder_out=token_rep, entity_mask=entity_mask, verb_mask=verb_mask) tag_mask = (gold_state_seq != PAD_STATE ) # mask the padded part so they won't count in loss log_likelihood = self.CRFLayer(emissions=tag_logits, tags=gold_state_seq.long(), mask=tag_mask, reduction='token_mean') state_loss = -log_likelihood # State classification loss is negative log likelihood pred_state_seq = self.CRFLayer.decode(emissions=tag_logits, mask=tag_mask) assert len(pred_state_seq) == batch_size correct_state_pred, total_state_pred = compute_state_accuracy( pred=pred_state_seq, gold=gold_state_seq.tolist(), pad_value=PAD_STATE) # location prediction # size (batch, max_cands, max_sents) loc_logits = self.LocationPredictor(encoder_out=token_rep, entity_mask=entity_mask, loc_mask=loc_mask) loc_logits = loc_logits.transpose( -1, -2) # size (batch, max_sents, max_cands) masked_loc_logits = self.mask_loc_logits( loc_logits=loc_logits, num_cands=num_cands) # (batch, max_sents, max_cands) masked_gold_loc_seq = self.mask_undefined_loc( gold_loc_seq=gold_loc_seq, mask_value=PAD_LOC) # (batch, max_sents) loc_loss = self.CrossEntropy( input=masked_loc_logits.view(batch_size * max_sents, max_cands), target=masked_gold_loc_seq.view(batch_size * max_sents).long()) correct_loc_pred, total_loc_pred = compute_loc_accuracy( logits=masked_loc_logits, gold=masked_gold_loc_seq, pad_value=PAD_LOC) # assert total_loc_pred > 0 if self.is_test: # inference pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits, gold_loc_seq=gold_loc_seq) return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred return state_loss, loc_loss, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred