def _compute_antecedent_gold_labels(relation_labels: torch.IntTensor, coref_labels: torch.IntTensor): # Shape: (batch_size, num_spans_to_keep, max_antecedents) source_labels = relation_labels.unsqueeze(1) target_labels = relation_labels.unsqueeze(2) relation_indicator = (target_labels * source_labels).sum(-1).clamp(0, 1).float() source_labels = coref_labels.unsqueeze(1) target_labels = coref_labels.unsqueeze(2) coref_indicator = (target_labels * source_labels).sum(-1).clamp(0, 1).float() label = relation_indicator * (relation_indicator - coref_indicator) assert (label < 0).sum() == 0, breakpoint() return label
def mask_loc_logits(self, loc_logits, num_cands: torch.IntTensor): """ Mask the padded candidates with an -inf score, so they will have a likelihood = 0 after softmax Args: loc_logits - output scores for each candidate in each sentence, size (batch, max_sents, max_cands) num_cands - total number of candidates in each instance of the given batch, size (batch,) """ assert torch.max(num_cands) == loc_logits.size(-1) assert loc_logits.size(0) == num_cands.size(0) batch_size = loc_logits.size(0) max_cands = loc_logits.size(-1) # first, we create a mask tensor that masked all positions above the num_cands limit range_tensor = torch.arange(start=1, end=max_cands + 1) if self.use_cuda: range_tensor = range_tensor.cuda() range_tensor = range_tensor.unsqueeze(dim=0).expand( batch_size, max_cands) bool_range = torch.gt( range_tensor, num_cands.unsqueeze(dim=-1)) # find the off-limit positions assert bool_range.size() == (batch_size, max_cands) bool_range = bool_range.unsqueeze(dim=-2).expand_as( loc_logits) # use this bool tensor to mask loc_logits masked_loc_logits = loc_logits.masked_fill( bool_range, value=float('-inf')) # mask padded positions to -inf assert masked_loc_logits.size() == loc_logits.size() return masked_loc_logits
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor, fps_idx: torch.IntTensor) -> torch.Tensor: r""" Parameters ---------- radius : float radius of the balls nsample : int maximum number of features in the balls xyz : torch.Tensor (B, N, 3) xyz coordinates of the features new_xyz : torch.Tensor (B, npoint, 3) centers of the ball query Returns ------- torch.Tensor (B, npoint, nsample) tensor with the indicies of the features that form the query balls """ assert new_xyz.is_contiguous() assert xyz.is_contiguous() B, N, _ = xyz.size() npoint = new_xyz.size(1) idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, fps_idx, idx) return torch.cat([fps_idx.unsqueeze(2), idx], dim=2)
def compute_antecedent_gold_labels(top_span_labels: torch.IntTensor, antecedent_labels: torch.IntTensor): """ Generates a binary indicator for every pair of spans. This label is one if and only if the pair of spans belong to the same cluster. The labels are augmented with a dummy antecedent at the zeroth position, which represents the prediction that a span does not have any antecedent. Parameters ---------- top_span_labels : ``torch.IntTensor``, required. The cluster id label for every span. The id is arbitrary, as we just care about the clustering. Has shape (batch_size, num_spans_to_keep). antecedent_labels : ``torch.IntTensor``, required. The cluster id label for every antecedent span. The id is arbitrary, as we just care about the clustering. Has shape (batch_size, num_spans_to_keep, max_antecedents). Returns ------- pairwise_labels_with_dummy_label : ``torch.FloatTensor`` A binary tensor representing whether a given pair of spans belong to the same cluster in the gold clustering. Has shape (batch_size, num_spans_to_keep, max_antecedents + 1). """ # Shape: (batch_size, num_spans_to_keep, max_antecedents) top_span_labels = top_span_labels.unsqueeze(0) antecedent_labels = antecedent_labels.unsqueeze(0) target_labels = top_span_labels.expand_as(antecedent_labels) same_cluster_indicator = (target_labels == antecedent_labels).float() non_dummy_indicator = (target_labels >= 0).float() pairwise_labels = same_cluster_indicator * non_dummy_indicator # Shape: (batch_size, num_spans_to_keep, 1) dummy_labels = (1 - pairwise_labels).prod(-1, keepdim=True) # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) pairwise_labels_with_dummy_label = torch.cat( [dummy_labels, pairwise_labels], -1) return pairwise_labels_with_dummy_label.squeeze(0)
def forward( self, # type: ignore spans: torch.IntTensor, span_mask: torch.IntTensor, span_embeddings: torch.IntTensor, sentence_lengths: torch.Tensor, ner_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. """ # Shape: (Batch size, Number of Spans, Span Embedding Size) # span_embeddings self._active_namespace = f"{metadata.dataset}__ner_labels" if self._active_namespace not in self._ner_scorers: return {"loss": 0} scorer = self._ner_scorers[self._active_namespace] ner_scores = scorer(span_embeddings) # Give large negative scores to masked-out elements. mask = span_mask.unsqueeze(-1) ner_scores = util.replace_masked_values(ner_scores, mask.bool(), -1e20) # The dummy_scores are the score for the null label. dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1] dummy_scores = ner_scores.new_zeros(*dummy_dims) ner_scores = torch.cat((dummy_scores, ner_scores), -1) _, predicted_ner = ner_scores.max(2) predictions = self.predict(ner_scores.detach().cpu(), spans.detach().cpu(), span_mask.detach().cpu(), metadata) output_dict = {"predictions": predictions} if ner_labels is not None: metrics = self._ner_metrics[self._active_namespace] metrics(predicted_ner, ner_labels, span_mask) ner_scores_flat = ner_scores.view( -1, self._n_labels[self._active_namespace]) ner_labels_flat = ner_labels.view(-1) mask_flat = span_mask.view(-1).bool() loss = self._loss(ner_scores_flat[mask_flat], ner_labels_flat[mask_flat]) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore text: Dict[str, Any], text_mask: torch.IntTensor, token_embeddings: torch.IntTensor, sentence_lengths: torch.Tensor, token_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. """ seq_scores = self._seq_scorer(token_embeddings) # Give large negative scores to masked-out elements. mask = text_mask.unsqueeze(-1) seq_scores = util.replace_masked_values(seq_scores, mask, -1e20) seq_scores[:, :, 0] *= text_mask _, predicted_seq = seq_scores.max(2) if self._label_scheme == 'flat': pred_spans = self._seq_metrics._decode_flat( predicted_seq, text_mask) elif self._label_scheme == 'stacked': pred_spans = self._seq_metrics._decode_stacked( predicted_seq, text_mask) else: raise RuntimeError("invalid label_scheme {}".format( self.label_scheme)) output_dict = { "predicted_seq": predicted_seq, "predicted_seq_span": pred_spans } if token_labels is not None: self._seq_metrics(predicted_seq, token_labels, text_mask, self.training) seq_scores_flat = seq_scores.view(-1, self._n_labels) seq_labels_flat = token_labels.view(-1) mask_flat = text_mask.view(-1).bool() loss = self._loss(seq_scores_flat[mask_flat], seq_labels_flat[mask_flat]) output_dict["loss"] = loss return output_dict
def forward(self, input_word_index: torch.IntTensor, h_state: torch.FloatTensor, c_state: torch.FloatTensor, enc_outputs: torch.FloatTensor, mask: torch.BoolTensor): """ Pass inputs through the model. Args: input_word_index: torch.IntTensor[batch_size,] h_state: torch.FloatTensor[n_layers, batch_size, hidden_size] c_state: torch.FloatTensor[n_layers, batch_size, hidden_size] enc_outputs: torch.FloatTensor[seq_len, batch_size, hidden_size] mask: torch.BoolTensor[seq_len, batch_size, 1] Returns: logit: torch.FloatTensor[batch_size, vocab_size] h_state: torch.FloatTensor[n_layers, batch_size, hidden_size] c_state: torch.FloatTensor[n_layers, batch_size, hidden_size] attention_weights: torch.FloatTensor[seq_len, batch_size, 1] """ embedded = self.embedding(input_word_index.unsqueeze(0)) embedded = F.dropout(embedded, p=self.embedding_dropout) output, (h_state, c_state) = self.lstm(embedded, (h_state, c_state)) # output: [seq_len=1, batch_size, hidden_size] # h_state: [n_layers, batch_size, hidden_size] # c_state: [n_layers, batch_size, hidden_size] # Compute attention weights attention_weights = self.attention_layer( h_state=h_state, enc_outputs=enc_outputs, mask=mask) # attention_weights: [seq_len, batch_size, 1] # Compute the context vector context_vector = torch.bmm( enc_outputs.permute(1, 2, 0), # [batch_size, hidden_size, seq_len] attention_weights.permute(1, 0, 2), # [batch_size, seq_len, 1] ).permute(2, 0, 1) # [1, batch_size, hidden_size] # New input: concatenate context_vector with hidden_states new_input = torch.cat((context_vector, output), dim=2) # [1, batch_size, hidden_size * 2] # Get logit x = self.fc1(new_input.squeeze(0)) # [batch_size, hidden_size] x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout) logit = self.fc2(x) # [batch_size, vocab_size] return logit, (h_state, c_state, attention_weights.squeeze(2))
def forward( self, # type: ignore spans: torch.IntTensor, span_mask: torch.IntTensor, span_embeddings: torch.IntTensor, sentence_lengths: torch.Tensor, ner_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. """ # Shape: (Batch size, Number of Spans, Span Embedding Size) # span_embeddings ner_scores = self._ner_scorer(span_embeddings) # Give large negative scores to masked-out elements. mask = span_mask.unsqueeze(-1) ner_scores = util.replace_masked_values(ner_scores, mask, -1e20) # The dummy_scores are the score for the null label. dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1] dummy_scores = ner_scores.new_zeros(*dummy_dims) ner_scores = torch.cat((dummy_scores, ner_scores), -1) _, predicted_ner = ner_scores.max(2) output_dict = { "spans": spans, "span_mask": span_mask, "ner_scores": ner_scores, "predicted_ner": predicted_ner } if ner_labels is not None: self._ner_metrics(predicted_ner, ner_labels, span_mask) ner_scores_flat = ner_scores.view(-1, self._n_labels) ner_labels_flat = ner_labels.view(-1) mask_flat = span_mask.view(-1).bool() loss = self._loss(ner_scores_flat[mask_flat], ner_labels_flat[mask_flat]) output_dict["loss"] = loss if metadata is not None: output_dict["document"] = [x["sentence"] for x in metadata] return output_dict
def forward( self, # type: ignore spans: torch.IntTensor, span_mask: torch.IntTensor, span_embeddings: torch.IntTensor, sentence_lengths: torch.Tensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. """ # Shape: (Batch size, Number of Spans, Span Embedding Size) # span_embeddings span_scores = self._span_scorer(span_embeddings) # Give large negative scores to masked-out elements. mask = span_mask.unsqueeze(-1) span_scores = util.replace_masked_values(span_scores, mask, -1e20) span_scores[:, :, 0] *= span_mask _, predicted_span = span_scores.max(2) output_dict = { "spans": spans, "span_mask": span_mask, "span_scores": span_scores, "predicted_span": predicted_span } if span_labels is not None: self._span_metrics(predicted_span, span_labels, span_mask) span_scores_flat = span_scores.view(-1, self._n_labels) span_labels_flat = span_labels.view(-1) mask_flat = span_mask.view(-1).bool() loss = self._loss(span_scores_flat[mask_flat], span_labels_flat[mask_flat]) output_dict["loss"] = loss if metadata is not None: output_dict["document"] = [x["sentence"] for x in metadata] return output_dict
def forward(self, input_word_index: torch.IntTensor, h_state: torch.FloatTensor, c_state: torch.FloatTensor): """ Pass inputs through the model. Args: input_word_index: torch.IntTensor[batch_size,] h_state: torch.FloatTensor[n_layer, batch_size, hidden_size] c_state: torch.FloatTensor[n_layer, batch_size, hidden_size] Returns: logit: torch.FloatTensor[batch_size, vocab_size] h_state: torch.FloatTensor[n_layer, batch_size, hidden_size] c_state: torch.FloatTensor[n_layer, batch_size, hidden_size] """ embedded = self.embedding(input_word_index.unsqueeze(0)) embedded = F.dropout(embedded, p=self.embedding_dropout) output, (h_state, c_state) = self.lstm(embedded, (h_state, c_state)) logit = self.fc(output.squeeze(0)) return logit, (h_state, c_state)
def forward( ctx, e1, e2, e3: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor, fps_idx: torch.IntTensor ) -> torch.Tensor: r""" Parameters , ingroup_pts_cnt: torch.IntTensor, ingroup_out: torch.Tensor, ingroup_cva: torch.Tensor, v: torch.Tensor, d: torch.Tensor ---------- e1, e2, e3 : float e1, e2, e3 of the ellipsoid nsample : int maximum number of features in the balls xyz : torch.Tensor (B, N, 3) xyz coordinates of the features new_xyz : torch.Tensor (B, npoint, 3) centers of the ball query Returns ------- torch.Tensor (B, npoint, nsample) tensor with the indicies of the features that form the query balls """ assert new_xyz.is_contiguous() assert xyz.is_contiguous() B, N, _ = xyz.size() npoint = new_xyz.size(1) idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() ingroup_pts_cnt = torch.cuda.IntTensor(B, npoint).zero_() ingroup_out = torch.cuda.FloatTensor(B, npoint, nsample, 3).zero_() ingroup_cva = torch.cuda.FloatTensor(B, npoint, 3*3).zero_() v = torch.cuda.FloatTensor(B, npoint, 3*3).zero_() d = torch.cuda.FloatTensor(B, npoint, 3).zero_() pointnet2.ellipsoid_query_wrapper( B, N, npoint, e1, e2, e3, nsample, new_xyz, xyz, fps_idx, idx, ingroup_pts_cnt, ingroup_out, ingroup_cva, v, d ) return torch.cat([fps_idx.unsqueeze(2), idx], dim = 2),d
def batched_gather(x: torch.Tensor, indices: torch.IntTensor, dim: int): """ Similar to the gather method of :class:`torch.Tensor`. Args: x: the tensor to select. indices: the indices to choose. dim: the dimension to choose. Returns: A selected tensor """ if indices.dim() == 1: return x[indices] elif indices.dim() == 2: if x.dim() > indices.dim(): indices = indices.unsqueeze(-1).repeat_interleave(x.shape[-1], dim=-1) return x.gather(dim, indices) raise NotImplementedError( "Currently do not support more batch dimensions than 1!")
def rnnt_loss(log_probs: torch.FloatTensor, labels: torch.IntTensor, frames_lengths: torch.IntTensor, labels_lengths: torch.IntTensor, average_frames: bool = False, reduction: Optional[AnyStr] = None, blank: int = 0, gather: bool = False) -> torch.Tensor: """The CUDA-Warp RNN-Transducer loss. Args: log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V) where N is the minibatch size, T is the maximum number of input frames, U is the maximum number of output labels and V is the vocabulary of labels (including the blank). labels (torch.IntTensor): Tensor with shape (N, U-1) representing the reference labels for all samples in the minibatch. frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. average_frames (bool, optional): Specifies whether the loss of each sample should be divided by its number of frames. Default: False. reduction (string, optional): Specifies the type of reduction. Default: None. blank (int, optional): label used to represent the blank symbol. Default: 0. gather (bool, optional): Reduce memory consumption. Default: False. """ assert average_frames is None or isinstance(average_frames, bool) assert reduction is None or reduction in ("none", "mean", "sum") assert isinstance(blank, int) assert isinstance(gather, bool) assert not labels.requires_grad, "labels does not require gradients" assert not frames_lengths.requires_grad, "frames_lengths does not require gradients" assert not labels_lengths.requires_grad, "labels_lengths does not require gradients" if gather: N, T, U, V = log_probs.size() index = torch.full([N, T, U, 2], blank, device=labels.device, dtype=torch.long) index[:, :, :U-1, 1] = labels.unsqueeze(dim=1) log_probs = log_probs.gather(dim=3, index=index) blank = -1 costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank) if average_frames: costs = costs / frames_lengths.to(log_probs) if reduction == "sum": return costs.sum() elif reduction == "mean": return costs.mean() return costs
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, metadata: List[Dict[str, Any]], doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, doc_relations=None, doc_ner_labels: torch.IntTensor = None, ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence) for document in metadata for sentence in document['doc_tokens']) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relations is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) self._compute_relex_metrics(output_dict, doc_relations) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def compute_representations( self, # type: ignore span_embeddings, # (1, Ns, E) coref_labels: torch.IntTensor, # (1, Ns, C) type_to_cluster_ids: Dict[str, List[int]], relation_to_cluster_ids: Dict[int, List[int]] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ if coref_labels.sum() == 0: return {"loss": 0.0, "metadata": metadata} cluster_type_embeddings = self.map_cluster_to_type_embeddings( type_to_cluster_ids) # (1, C, E) sum_embeddings = (span_embeddings.unsqueeze(2) * coref_labels.float().unsqueeze(-1)).sum(1) length_embeddings = (coref_labels.unsqueeze(-1).sum(1) + 1e-5) cluster_span_embeddings = sum_embeddings / length_embeddings paragraph_cluster_mask = (coref_labels.sum(1) > 0).float().unsqueeze( -1) # (P, C, 1) paragraph_cluster_embeddings = cluster_span_embeddings * paragraph_cluster_mask + cluster_type_embeddings * ( 1 - paragraph_cluster_mask) # (P, C, E) assert (paragraph_cluster_embeddings.shape[1] == coref_labels.shape[2] and paragraph_cluster_embeddings.shape[2] == span_embeddings.shape[-1]) paragraph_cluster_embeddings = torch.cat( [ paragraph_cluster_embeddings, self._bias_vectors.expand( paragraph_cluster_embeddings.shape[0], -1, -1) ], dim=1, ) # (P, C+4, E) n_true_clusters = coref_labels.shape[-1] candidate_relations, candidate_relations_labels, candidate_relations_types = self.generate_product( type_to_clusters_map=type_to_cluster_ids, relation_to_clusters_map=relation_to_cluster_ids, n_true_clusters=n_true_clusters, ) candidate_relations_tensor = torch.LongTensor(candidate_relations).to( span_embeddings.device) # (R, 4) candidate_relations_labels_tensor = torch.LongTensor( candidate_relations_labels).to(span_embeddings.device) # (R, ) if len(candidate_relations) == 0: return {"loss": 0.0, "metadata": metadata} all_relation_embeddings = util.batched_index_select( paragraph_cluster_embeddings, candidate_relations_tensor.unsqueeze(0).expand( paragraph_cluster_embeddings.shape[0], -1, -1), ) # (P, R', n, E) relation_scores, relation_logits = self.get_relation_scores( all_relation_embeddings) # (1, R') output_dict = {} output_dict["relations_candidates_list"] = candidate_relations output_dict["relation_labels"] = candidate_relations_labels output_dict["relation_types"] = candidate_relations_types output_dict["doc_id"] = metadata[0]["doc_id"] output_dict["metadata"] = metadata output_dict["relation_scores"] = relation_scores output_dict["relation_logits"] = relation_logits if relation_to_cluster_ids is not None: output_dict = self.predict_labels( relation_scores, relation_logits, candidate_relations_labels_tensor, output_dict) return output_dict
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], target_word: torch.IntTensor, gold_label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens: target_word: (batch_size, 2) gold_label: (batch_size) metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containing the original words in the sentence to be tagged under a 'words' key. Returns ------- An output dictionary consisting of: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # Shape: (batch_size, sentence_length, embedding_size) tokens_embeddings = self._lexical_dropout( self._text_field_embedder(tokens)) # Shape: (batch_size, sentence_length) tokens_mask = util.get_text_field_mask(tokens).float() # Shape: (batch_size, sentence_length, encoding_dim) contextualized_embeddings = self._context_layer( tokens_embeddings, tokens_mask) # Shape: (batch_size, 2 * encoding_dim) target_word_embeddings = self._target_word_extractor( contextualized_embeddings, target_word) # Shape: (batch_size, 1) complex_word_logits = self._complex_word_scorer(target_word_embeddings) complex_word_predictions = complex_word_logits > 0.5 output_dict = { "logits": complex_word_logits, "predictions": complex_word_predictions } if gold_label is not None: output_dict["loss"] = self._loss(complex_word_logits, gold_label.unsqueeze(-1).float()) macro_F1 = metrics.f1_score(gold_label, complex_word_predictions, average='macro') self._metric(complex_word_predictions, gold_label) return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1]) # Shape: (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, class_num + 1) ne_scores = self._compute_named_entity_scores(top_span_embeddings) # Shape: (batch_size, num_spans_to_keep) _, predicted_named_entities = ne_scores.max(2) output_dict = { "top_spans": top_spans, "predicted_named_entities": predicted_named_entities } if labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( labels.unsqueeze(-1), top_span_indices, flat_top_span_indices).squeeze(-1) negative_log_likelihood = F.cross_entropy( ne_scores.reshape(-1, self.class_num), pruned_gold_labels.reshape(-1)) pruner_loss = F.binary_cross_entropy_with_logits( top_span_mention_scores.reshape(-1), (pruned_gold_labels.reshape(-1) != 0).float()) loss = negative_log_likelihood + pruner_loss output_dict["loss"] = loss output_dict["pruner_loss"] = pruner_loss batch_size, _ = labels.shape all_scores = ne_scores.new_zeros( [batch_size * num_spans, self.class_num]) all_scores[:, 0] = 1 all_scores[flat_top_span_indices] = ne_scores.reshape( -1, self.class_num) all_scores = all_scores.reshape( [batch_size, num_spans, self.class_num]) self._metric_all(all_scores, labels) self._metric_avg(all_scores, labels) return output_dict
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.last_dim_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def forward( self, # type: ignore para_id: int, participant_strings: List[str], paragraph: Dict[str, torch.LongTensor], sentences: Dict[str, torch.LongTensor], paragraph_sentence_indicators: torch.IntTensor, participants: Dict[str, torch.LongTensor], participant_indicators: torch.IntTensor, paragraph_participant_indicators: torch.IntTensor, verbs: torch.IntTensor, paragraph_verbs: torch.IntTensor, actions: torch.IntTensor = None, before_locations: torch.IntTensor = None, after_locations: torch.IntTensor = None, filename: List[str] = [], score: List[float] = 1.0 # instance_score ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- para_id: int The id of the paragraph participant_strings: List[str] The participants in the paragraph paragraph: Dict[str, torch.LongTensor] The token indices for the paragraph sentences: Dict[str, torch.LongTensor] The token indices batched by sentence. paragraph_sentence_indicators: torch.LongTensor Indicates before / inside / after for each sentence participants: Dict[str, torch.LongTensor] The token indices for the participant names participant_indicators: torch.IntTensor Indicates each participant in each sentence paragraph_participant_indicators: torch.IntTensor Indicates each participant in the paragraph verbs: torch.IntTensor Indicates the positions of verbs in the sentences paragraph_verbs: torch.IntTensor Indicates the positions of verbs in the paragraph actions: torch.IntTensor, optional (default = None) Indicates the actions taken per participant per sentence. before_locations: torch.IntTensor, optional (default = None) Indicates the span for the before location per participant per sentence after_locations: torch.IntTensor, optional (default = None) Indicates the span for the after location per participant per sentence filename: List[str], optional (default = '') The files from which the instances were read score: List[float], optional (default = 1.0) The score for each instance Returns ------- An output dictionary consisting of: action_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_sentences, num_participants, num_action_types)`` representing a distribution of state change types per sentence, participant in each datapoint (paragraph). loss : torch.FloatTensor, optional A scalar loss to be optimised. """ self.filename = filename self.instance_score = score # original shape (batch_size, num_participants, num_sentences, sentence_length) participant_indicators = participant_indicators.transpose(1, 2) # new shape (batch_size, num_sentences, num_participants, sentence_length) batch_size, num_sentences, num_participants, sentence_length = participant_indicators.size( ) # (batch_size, num_sentences, sentence_length, embedding_size) embedded_sentences = self.text_field_embedder(sentences) # (batch_size, num_participants, description_length, embedding_size) embedded_participants = self.text_field_embedder(participants) batch_size, num_sentences, sentence_length, embedding_size = embedded_sentences.size( ) self.num_sentences = num_sentences # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # (batch_size, num_sentences, num_participants, sentence_length, embedding_size) embedded_sentence_participant_pairs = embedded_sentences.unsqueeze(2).expand(batch_size, num_sentences, \ num_participants, sentence_length, embedding_size) # (batch_size, num_sentences, sentence_length) -> (batch_size, num_sentences, num_participants, sentence_length) mask = get_text_field_mask(sentences, num_wrapping_dims=1). \ unsqueeze(2).expand(batch_size, num_sentences, num_participants, sentence_length).float() # (batch_size, num_participants, num_sentences * sentence_length) participant_view = participant_indicators.transpose(1, 2). \ view(batch_size, num_participants, num_sentences * sentence_length) # participant_mask is used to mask out invalid sentence, participant pairs # (batch_size, num_sentences, num_participants, sentence_length) sent_participant_pair_mask = (participant_view.sum(dim=2) > 0). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences, sentence_length). \ transpose(1, 2).float() # whether the sentence is masked or not (sent does not exist in paragraph). # this is either (batch_size, num_sentences, num_participants) # or if only one participant (batch_size, num_sentences) # TODO(joelgrus) why is there a squeeze here sentence_mask = (mask.sum(3) > 0).squeeze(-1).float() # (batch_size, num_sentences, num_participants, sentence_length) mask = mask * sent_participant_pair_mask # (batch_size, num_participants, num_sentences * sentence_length) # -> (batch_size, num_participants) # -> (batch_size, num_participants, num_sentences) # -> (batch_size, num_sentences, num_participants) participant_mask = (participant_view.sum(dim=2) > 0). \ unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \ transpose(1, 2).float() # Example: 0.0 where action is -1 (padded) # action: [[[1, 0, 1], [3, 2, 3]], [[0, -1, -1], [-1, -1, -1]]] # action_mask: [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] # (batch_size, num_sentences, num_participants) action_mask = participant_mask * sentence_mask # (batch_size, num_sentences, num_participants, sentence_length) verb_indicators = verbs.unsqueeze(2).expand(batch_size, num_sentences, num_participants, sentence_length).float() # =========================================================================================================== # Layer 2: Concatenate sentence embedding with verb and participant indicator bits # espp: (batch_size, num_sentences, num_participants, sentence_length, embedding_size) # vi: (batch_size, num_sentences, num_participants, sentence_length) # pi: (batch_size, num_sentences, num_participants, sentence_length) # # result: (batch_size, num_sentences, num_participants, sentence_length, embedding_size + 2) embedded_sentence_verb_entity = \ torch.cat([embedded_sentence_participant_pairs, verb_indicators.unsqueeze(-1).float(), participant_indicators.unsqueeze(-1).float()], dim=-1) # =========================================================================================================== # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence if self.use_attention: # (batch_size, num_sentences, num_participants, sentence_length, ) # contextual_seq_embedding: batch_size * num_sentences * # num_participants * sentence_length * (2*seq2seq_output_size) contextual_seq_embedding = self.time_distributed_seq2seq_encoder( embedded_sentence_verb_entity, mask) # Layer 3.5: Attention (Contextual embedding, BOW(verb span)) verb_weight_matrix = verb_indicators.float() / ( verb_indicators.float().sum(-1).unsqueeze(-1) + 1e-13) # (batch_size, num_sentences, num_participants, embedding_size) verb_vector = weighted_sum( contextual_seq_embedding * verb_indicators.float().unsqueeze(-1), verb_weight_matrix) # (batch_size, num_sentences, num_participants, sentence_length) participant_weight_matrix = participant_indicators.float() / ( participant_indicators.float().sum(-1).unsqueeze(-1) + 1e-13) # (batch_size, num_sentences, num_participants, embedding_size) participant_vector = weighted_sum( contextual_seq_embedding * participant_indicators.float().unsqueeze(-1), participant_weight_matrix) # (batch_size, num_sentences, num_participants, 2 * embedding_size) verb_participant_vector = torch.cat( [verb_vector, participant_vector], -1) batch_size, num_sentences, num_participants, sentence_length, verb_ind_size = verb_indicators.float( ).unsqueeze(-1).size() # attention weights for type prediction # (batch_size, num_sentences, num_participants) attention_weights_actions = self.time_distributed_attention_layer( verb_participant_vector, contextual_seq_embedding, mask) contextual_vec_embedding = weighted_sum(contextual_seq_embedding, attention_weights_actions) else: # batch_size * num_sentences * num_participants * sentence_length * embedding_size contextual_vec_embedding = self.time_distributed_seq2vec_encoder( embedded_sentence_verb_entity, mask) # (batch_size, num_participants, num_sentences, 1) -> (batch_size, nnum_sentences, num_participants, 1) if actions is not None: actions = actions.transpose(1, 2) # # =========================================================================================================== # # Layer 4 = Aggregate FeedForward to choose an action label per sentence, participant pair # (batch_size, num_sentences, num_participants, num_actions) action_logits = self.aggregate_feedforward(contextual_vec_embedding) action_probs = torch.nn.functional.softmax(action_logits, dim=-1) # (batch_size * num_sentences * num_participants, num_actions) action_probs_decode = action_probs.view( (batch_size * num_sentences * num_participants), self.num_actions) output_dict = {} if self.use_decoder_trainer: # (batch_size, num_participants, description_length, embedding_size) participants_list = embedded_participants.data.cpu().numpy() output_dict.update( DecoderTrainerHelper.pass_on_info_to_decoder_trainer( selfie=self, para_id_list=para_id, actions=actions, target_mask=action_mask, participants_list=participants_list, participant_strings=participant_strings, participant_indicators=participant_indicators.transpose( 1, 2), logit_tensor=action_logits)) # Compute type_accuracy based on best_final_states and actions best_decoded_state = output_dict['best_final_states'][0][0][0] best_decoded_action_seq = [] if best_decoded_state.action_history: for cur_step_action in best_decoded_state.action_history[0]: step_predictions = [] for step_action in list(cur_step_action): step_predictions.append(step_action) best_decoded_action_seq.append(step_predictions) best_decoded_tensor = torch.LongTensor( best_decoded_action_seq).unsqueeze(0) if actions is not None: flattened_gold = actions.long().contiguous().view(-1) self._type_accuracy( best_decoded_tensor.long().contiguous().view(-1), flattened_gold) output_dict['best_decoded_action_seq'] = [best_decoded_action_seq] else: # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict["action_probs"] = action_probs output_dict["action_probs_decode"] = action_probs_decode action_loss = 0.0 location_loss = 0.0 if actions is not None: # (batch_size * num_sentences * num_participants, num_actions) flattened_predictions = action_logits.view( (batch_size * num_sentences * num_participants), self.num_actions) # Flattened_gold: contains the gold action index (Action enum in propara_dataset_reader) # Note: tensor is not a single block of memory, but a block with holes. # view can be only used with contiguous tensors, so if you need to use it here, just call .contiguous() before. # (batch_size * num_sentences * num_participants) flattened_gold = actions.long().contiguous().view(-1) action_loss = self._loss(flattened_predictions, flattened_gold) flattened_probs = action_probs.view( (batch_size * num_sentences * num_participants), self.num_actions) evaluation_mask = (flattened_gold != -1) self._type_accuracy(flattened_probs, flattened_gold, mask=evaluation_mask) output_dict["loss"] = action_loss best_span_after, span_start_logits_after, span_end_logits_after = \ self.compute_location_spans(contextual_seq_embedding=contextual_seq_embedding, embedded_sentence_verb_entity=embedded_sentence_verb_entity, mask=mask) output_dict["location_span_after"] = [best_span_after] not_in_test = (self.training or 'test' not in self.filename) if not_in_test and (before_locations is not None and after_locations is not None): after_locations = after_locations.transpose(1, 2) (bs, ns, np, sl) = span_start_logits_after.size() #print("after_locations[:,:,:,[0]]:", after_locations[:,:,:,[0]]) location_mask = (after_locations[:, :, :, 0] >= 0).float().unsqueeze(-1).expand(bs, ns, np, sl) #print("location_mask:", location_mask) start_after_log_predicted = util.masked_log_softmax( span_start_logits_after, location_mask) start_after_log_predicted_transpose = start_after_log_predicted.transpose( 2, 3).transpose(1, 2) start_after_gold = torch.clamp(after_locations[:, :, :, [0]].squeeze(-1), min=-1) #print("start_after_log_predicted_transpose: ", start_after_log_predicted_transpose) #print("start_after_gold: ", start_after_gold) location_loss = nll_loss(input=start_after_log_predicted_transpose, target=start_after_gold, ignore_index=-1) end_after_log_predicted = util.masked_log_softmax( span_end_logits_after, location_mask) end_after_log_predicted_transpose = end_after_log_predicted.transpose( 2, 3).transpose(1, 2) end_after_gold = torch.clamp(after_locations[:, :, :, [1]].squeeze(-1), min=-1) #print("end_after_log_predicted_transpose: ", end_after_log_predicted_transpose) #print("end_after_gold: ", end_after_gold) location_loss += nll_loss(input=end_after_log_predicted_transpose, target=end_after_gold, ignore_index=-1) output_dict["loss"] += location_loss # output_dict = {"loss" : 0.0} output_dict['action_probs_decode'] = action_probs_decode output_dict['action_logits'] = action_logits return output_dict
def create_attended_span_representations( max_span_width: int, head_scores: torch.FloatTensor, encoded_text: torch.FloatTensor, span_ends: torch.IntTensor, span_widths: torch.IntTensor) -> torch.FloatTensor: """ Given a tensor of unnormalized attention scores for each word in the document, compute distributions over every span with respect to these scores by normalising the headedness scores for words inside the span. Given these headedness distributions over every span, weight the corresponding vector representations of the words in the span by this distribution, returning a weighted representation of each span. Parameters ---------- head_scores : ``torch.FloatTensor``, required. Unnormalized headedness scores for every word. This score is shared for every candidate. The only way in which the headedness scores differ over different spans is in the set of words over which they are normalized. text_embeddings: ``torch.FloatTensor``, required. The embeddings with shape (batch_size, document_length, embedding_size) over which we are computing a weighted sum. span_ends: ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans), representing the end indices of each span. span_widths : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans) representing the width of each span candidates. Returns ------- attended_text_embeddings : ``torch.FloatTensor`` A tensor of shape (batch_size, num_spans, embedding_dim) - the result of applying attention over all words within each candidate span. """ # Shape: (1, 1, max_span_width) max_span_range_indices = util.get_range_vector(max_span_width, encoded_text.is_cuda).view( 1, 1, -1) # Shape: (batch_size, num_spans, max_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. span_ends = span_ends.unsqueeze(-1) span_widths = span_widths.unsqueeze(-1) span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the document # are of a smaller width than max_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() # Spans span_indices = F.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_span_width) flat_span_indices = util.flatten_and_batch_shift_indices( span_indices, encoded_text.size(1)) # Shape: (batch_size, num_spans, max_span_width, embedding_dim) span_text_embeddings = util.batched_index_select(encoded_text, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_span_width) span_head_scores = util.batched_index_select(head_scores, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_span_width) span_head_weights = util.last_dim_softmax(span_head_scores, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised head score distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_text_embeddings, span_head_weights) return attended_text_embeddings
def forward( self, # type: ignore spans: torch.IntTensor, span_mask: torch.IntTensor, span_embeddings: torch.IntTensor, sentence_lengths: torch.Tensor, ner_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, previous_step_output: Dict[str, Any] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. """ # Shape: (Batch size, Number of Spans, Span Embedding Size) # span_embeddings ner_scores = self._ner_scorer(span_embeddings) # Give large negative scores to masked-out elements. mask = span_mask.unsqueeze(-1) ner_scores = util.replace_masked_values(ner_scores, mask, -1e20) dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1] dummy_scores = ner_scores.new_zeros(*dummy_dims) if previous_step_output is not None and "predicted_span" in previous_step_output and not self.training: dummy_scores.masked_fill_( previous_step_output["predicted_span"].bool().unsqueeze(-1), -1e20) dummy_scores.masked_fill_( (1 - previous_step_output["predicted_span"]).bool().unsqueeze(-1), 1e20) ner_scores = torch.cat((dummy_scores, ner_scores), -1) if previous_step_output is not None and "predicted_seq_span" in previous_step_output and not self.training: for row_idx, all_spans in enumerate(spans): pred_spans = previous_step_output["predicted_seq_span"][ row_idx] pred_spans = all_spans.new_tensor(pred_spans) for col_idx, span in enumerate(all_spans): if span_mask[row_idx][col_idx] == 0: continue bFind = False for pred_span in pred_spans: if span[0] == pred_span[0] and span[1] == pred_span[1]: bFind = True break if bFind: # if find, use the ner scores, set dummy to a big negative ner_scores[row_idx, col_idx, 0] = -1e20 else: # if not find, use the previous step, set dummy to a big positive ner_scores[row_idx, col_idx, 0] = 1e20 _, predicted_ner = ner_scores.max(2) output_dict = { "spans": spans, "span_mask": span_mask, "ner_scores": ner_scores, "predicted_ner": predicted_ner } if ner_labels is not None: self._ner_metrics(predicted_ner, ner_labels, span_mask) ner_scores_flat = ner_scores.view(-1, self._n_labels) ner_labels_flat = ner_labels.view(-1) mask_flat = span_mask.view(-1).bool() loss = self._loss(ner_scores_flat[mask_flat], ner_labels_flat[mask_flat]) output_dict["loss"] = loss if metadata is not None: output_dict["document"] = [x["sentence"] for x in metadata] return output_dict
def forward( self, # type: ignore sentences: torch.LongTensor, labels: torch.IntTensor = None, confidences: torch.Tensor = None, additional_features: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- TODO: add description Returns ------- An output dictionary consisting of: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # Input: sentences # Output: embedded_sentences # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size embedded_sentences = self.text_field_embedder(sentences) mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() batch_size, num_sentences, _, _ = embedded_sentences.size() if self.use_sep: # The following code collects vectors of the SEP tokens from all the examples in the batch, # and arrange them in one list. It does the same for the labels and confidences. # TODO: replace 103 with '[SEP]' sentences_mask = sentences[ 'bert'] == 103 # mask for all the SEP tokens in the batch embedded_sentences = embedded_sentences[ sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len # returns num_sentences_per_batch x vector_len assert embedded_sentences.dim() == 2 num_sentences = embedded_sentences.shape[0] # for the rest of the code in this model to work, think of the data we have as one example # with so many sentences and a batch of size 1 batch_size = 1 embedded_sentences = embedded_sentences.unsqueeze(dim=0) embedded_sentences = self.dropout(embedded_sentences) if labels is not None: if self.labels_are_scores: labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding) else: labels_mask = labels != -1 # mask for all the labels in the batch (no padding) labels = labels[ labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch assert labels.dim() == 1 if confidences is not None: confidences = confidences[labels_mask] assert confidences.dim() == 1 if additional_features is not None: additional_features = additional_features[labels_mask] assert additional_features.dim() == 2 num_labels = labels.shape[0] if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences` logger.warning( f'Found {num_labels} labels but {num_sentences} sentences' ) labels = labels[: num_sentences] # Ignore some labels. This is ok for training but bad for testing. # We are ignoring this problem for now. # TODO: fix, at least for testing # do the same for `confidences` if confidences is not None: num_confidences = confidences.shape[0] if num_confidences != num_sentences: assert num_confidences > num_sentences confidences = confidences[:num_sentences] # and for `additional_features` if additional_features is not None: num_additional_features = additional_features.shape[0] if num_additional_features != num_sentences: assert num_additional_features > num_sentences additional_features = additional_features[: num_sentences] # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1 labels = labels.unsqueeze(dim=0) if confidences is not None: confidences = confidences.unsqueeze(dim=0) if additional_features is not None: additional_features = additional_features.unsqueeze(dim=0) else: # ['CLS'] token embedded_sentences = embedded_sentences[:, :, 0, :] embedded_sentences = self.dropout(embedded_sentences) batch_size, num_sentences, _ = embedded_sentences.size() sent_mask = (mask.sum(dim=2) != 0) embedded_sentences = self.self_attn(embedded_sentences, sent_mask) if additional_features is not None: embedded_sentences = torch.cat( (embedded_sentences, additional_features), dim=-1) label_logits = self.time_distributed_aggregate_feedforward( embedded_sentences) # label_logits: batch_size, num_sentences, num_labels if self.labels_are_scores: label_probs = label_logits else: label_probs = torch.nn.functional.softmax(label_logits, dim=-1) # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict = {"action_probs": label_probs} # ===================================================================== if self.with_crf: # Layer 4 = CRF layer across labels of sentences in an abstract mask_sentences = (labels != -1) best_paths = self.crf.viterbi_tags(label_logits, mask_sentences) # # # Just get the tags and ignore the score. predicted_labels = [x for x, y in best_paths] # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}") label_loss = 0.0 if labels is not None: # Compute cross entropy loss flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels) flattened_gold = labels.contiguous().view(-1) if not self.with_crf: label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) if confidences is not None: label_loss = label_loss * confidences.type_as( label_loss).view(-1) label_loss = label_loss.mean() flattened_probs = torch.softmax(flattened_logits, dim=-1) else: clamped_labels = torch.clamp(labels, min=0) log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences) label_loss = -log_likelihood # compute categorical accuracy crf_label_probs = label_logits * 0. for i, instance_labels in enumerate(predicted_labels): for j, label_id in enumerate(instance_labels): crf_label_probs[i, j, label_id] = 1 flattened_probs = crf_label_probs.view( (batch_size * num_sentences), self.num_labels) if not self.labels_are_scores: evaluation_mask = (flattened_gold != -1) self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) self.all_f1_metrics(flattened_probs, flattened_gold, mask=evaluation_mask) # compute F1 per label for label_index in range(self.num_labels): label_name = self.vocab.get_token_from_index( namespace='labels', index=label_index) metric = self.label_f1_metrics[label_name] metric(flattened_probs, flattened_gold, mask=evaluation_mask) if labels is not None: output_dict["loss"] = label_loss output_dict['action_logits'] = label_logits return output_dict
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) if self._use_gold_mentions: if text_embeddings.is_cuda: device = torch.device("cuda") else: device = torch.device("cpu") s = [ torch.as_tensor(pair, dtype=torch.long, device=device) for cluster in metadata[0]["clusters"] for pair in cluster ] gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1) span_mask = spans.unsqueeze(2) - gm span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1] == 0) span_mask, _ = (span_mask == 2).max(-1) num_spans = span_mask.sum().item() span_mask = span_mask.float() else: span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() num_spans = spans.size(1) # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents( num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: TextFieldTensors, spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters text : `TextFieldTensors`, required. The output of a `TextField` representing the text of the document. spans : `torch.IntTensor`, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of indices into the text of the document. span_labels : `torch.IntTensor`, optional (default = None). A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : `List[Dict[str, Any]]`, optional (default = None). A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance. # Returns An output dictionary consisting of: top_spans : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : `torch.IntTensor` A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) batch_size = spans.size(0) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) num_spans_to_keep = min(num_spans_to_keep, num_spans) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings) ).squeeze(-1) # Shape: (batch_size, num_spans) for all 3 tensors top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep ) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices ) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. if self._coarse_to_fine: pruned_antecedents = self._coarse_to_fine_pruning( top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents ) else: pruned_antecedents = self._distance_pruning( top_span_embeddings, top_span_mention_scores, max_antecedents ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors ( top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, top_antecedent_indices, ) = pruned_antecedents flat_top_antecedent_indices = util.flatten_and_batch_shift_indices( top_antecedent_indices, num_spans_to_keep ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) for _ in range(self._inference_order - 1): dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,) top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) attention_weight = util.masked_softmax( coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size) top_antecedent_with_dummy_embeddings = torch.cat( [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2 ) # Shape: (batch_size, num_spans_to_keep, embedding_size) attended_embeddings = util.weighted_sum( top_antecedent_with_dummy_embeddings, attention_weight ) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = self._span_updating_gated_sum( top_span_embeddings, attended_embeddings ) # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) top_antecedent_embeddings = util.batched_index_select( top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices ) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( top_span_embeddings, top_antecedent_embeddings, top_partial_coreference_scores, top_antecedent_mask, top_antecedent_offsets, ) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": top_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. # Shape: (batch_size, num_spans_to_keep, 1) pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices ) # Shape: (batch_size, num_spans_to_keep, max_antecedents) antecedent_labels = util.batched_index_select( pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices ).squeeze(-1) antecedent_labels = util.replace_masked_values( antecedent_labels, top_antecedent_mask, -100 ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels ) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask.unsqueeze(-1) ) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores( top_spans, top_antecedent_indices, predicted_antecedents, metadata ) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore sentences: torch.LongTensor, labels: torch.IntTensor = None, confidences: torch.Tensor = None, additional_features: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- TODO: add description Returns ------- An output dictionary consisting of: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # Input: sentences # Output: embedded_sentences print(sentences) sentences_conv = {} for key, val in sentences_conv.items(): sentences_conv[key] = val.cpu().data.numpy().tolist() self.track_embedding["Transformation_0"] = { "sentences": sentences_conv } # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size embedded_sentences = self.text_field_embedder(sentences) self.track_embedding["Transformation_1"] = { "size": list(embedded_sentences.size()), "dim": embedded_sentences.dim() } # Kacper: Basically a padding mask for bert mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() batch_size, num_sentences, _, _ = list(embedded_sentences.size()) if self.use_sep: # The following code collects vectors of the SEP tokens from all the examples in the batch, # and arrange them in one list. It does the same for the labels and confidences. # TODO: replace 103 with '[SEP]' # Kacper: This is an important step where we get SEP tokens to later do sentence classification # Kacper: We take a location of SEP tokens from the sentences to get a mask sentences_mask = sentences[ 'bert'] == 103 # mask for all the SEP tokens in the batch # Kacper: We use this mask to get the respective embeddings from the output layer of bert embedded_sentences = embedded_sentences[ sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len # returns num_sentences_per_batch x vector_len self.track_embedding["Transformation_2"] = { "size": list(embedded_sentences.size()), "dim": embedded_sentences.dim() } # Kacper: I dont get it why it became 2 instead of 4? What is the difference between size() and dim()??? assert embedded_sentences.dim() == 2 num_sentences = embedded_sentences.shape[0] # Kacper: comment below is vague # Kacper: I think we batch in one array because we just need to compute a mean loss from all of them # for the rest of the code in this model to work, think of the data we have as one example # with so many sentences and a batch of size 1 batch_size = 1 embedded_sentences = embedded_sentences.unsqueeze( dim=0) # Kacper: We batch all sentences in one array self.track_embedding["Transformation_3"] = { "size": list(embedded_sentences.size()), "dim": embedded_sentences.dim() } # Kacper: Dropout layer is between filtered embeddings and linear layer embedded_sentences = self.dropout(embedded_sentences) self.track_embedding["Transformation_4"] = { "size": list(embedded_sentences.size()), "dim": embedded_sentences.dim() } # Kacper: we provide the labels for training (for each sentence) if labels is not None: if self.labels_are_scores: labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding) else: labels_mask = labels != -1 # mask for all the labels in the batch (no padding) labels = labels[ labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch assert labels.dim() == 1 if confidences is not None: confidences = confidences[labels_mask] assert confidences.dim() == 1 if additional_features is not None: additional_features = additional_features[labels_mask] assert additional_features.dim() == 2 num_labels = labels.shape[0] # Kacper: this might be useful to consider in my code as well if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences` logger.warning( f'Found {num_labels} labels but {num_sentences} sentences' ) labels = labels[: num_sentences] # Ignore some labels. This is ok for training but bad for testing. # We are ignoring this problem for now. # TODO: fix, at least for testing # do the same for `confidences` if confidences is not None: num_confidences = confidences.shape[0] if num_confidences != num_sentences: assert num_confidences > num_sentences confidences = confidences[:num_sentences] # and for `additional_features` if additional_features is not None: num_additional_features = additional_features.shape[0] if num_additional_features != num_sentences: assert num_additional_features > num_sentences additional_features = additional_features[: num_sentences] # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1 labels = labels.unsqueeze(dim=0) if confidences is not None: confidences = confidences.unsqueeze(dim=0) if additional_features is not None: additional_features = additional_features.unsqueeze(dim=0) else: # ['CLS'] token # Kacper: this shouldnt be the case for our project embedded_sentences = embedded_sentences[:, :, 0, :] embedded_sentences = self.dropout(embedded_sentences) batch_size, num_sentences, _ = list(embedded_sentences.size()) sent_mask = (mask.sum(dim=2) != 0) embedded_sentences = self.self_attn(embedded_sentences, sent_mask) if additional_features is not None: embedded_sentences = torch.cat( (embedded_sentences, additional_features), dim=-1) # Kacper: we unwrap the time dimension of a tensor into the 1st dimension (batch), # Kacper: apply a linear layer and wrap the the time dimension back # Kacper: I would suspect it is happening only for embeddings related to the [SEP] tokens label_logits = self.time_distributed_aggregate_feedforward( embedded_sentences) # label_logits: batch_size, num_sentences, num_labels self.track_embedding["logits"] = { "size": list(label_logits.size()), "dim": label_logits.dim() } #print(self.track_embedding) self.track_embedding_list.append(deepcopy(self.track_embedding)) with open(path_json, 'w') as json_out: json.dump(self.track_embedding_list, json_out) if self.labels_are_scores: label_probs = label_logits else: label_probs = torch.nn.functional.softmax(label_logits, dim=-1) # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict = {"action_probs": label_probs} # ===================================================================== if self.with_crf: # Layer 4 = CRF layer across labels of sentences in an abstract mask_sentences = (labels != -1) best_paths = self.crf.viterbi_tags(label_logits, mask_sentences) # # # Just get the tags and ignore the score. predicted_labels = [x for x, y in best_paths] # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}") label_loss = 0.0 if labels is not None: # Compute cross entropy loss # Kacper: reshape logits to be of the following shape in view() flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels) # Make labels to be contiguous in memory, reshape it so it is in a one dimension flattened_gold = labels.contiguous().view( -1) # Kacper: True labels if not self.with_crf: # Kacper: We are only interested in this part of the code since we don't use crf # Kacper: Get a loss (MSE if sci_sum is True or Crossentropy) label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) if confidences is not None: label_loss = label_loss * confidences.type_as( label_loss).view(-1) label_loss = label_loss.mean() # Kacper: Get a mean loss # Kacper: Get a probabilities from the logits flattened_probs = torch.softmax(flattened_logits, dim=-1) else: # Kacper: We are not interested in this if statement branch (for our project) clamped_labels = torch.clamp(labels, min=0) log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences) label_loss = -log_likelihood # compute categorical accuracy crf_label_probs = label_logits * 0. for i, instance_labels in enumerate(predicted_labels): for j, label_id in enumerate(instance_labels): crf_label_probs[i, j, label_id] = 1 flattened_probs = crf_label_probs.view( (batch_size * num_sentences), self.num_labels) if not self.labels_are_scores: # Kacper: this will be a case for us as well because labels are numerical for Pubmed data evaluation_mask = (flattened_gold != -1) # Kacper: CategoricalAccuracy is computed in this case self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) # compute F1 per label for label_index in range(self.num_labels): label_name = self.vocab.get_token_from_index( namespace='labels', index=label_index) metric = self.label_f1_metrics[label_name] metric(flattened_probs, flattened_gold, mask=evaluation_mask) if labels is not None: output_dict["loss"] = label_loss output_dict['action_logits'] = label_logits return output_dict