def reorder_encoder_states( self, encoder_states: Tuple[torch.Tensor, ...], indices: torch.LongTensor ) -> Tuple[torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], List[List[Document]], torch.Tensor, ]: """ Reorder the encoder states. For RAG Turn Doc-Then-Turn, we need to repeat the indices n_turns times. """ enc, mask, input_turns_cnt, docs, doc_probs = encoder_states if self.turn_marginalize == 'doc_then_turn': n_inputs = input_turns_cnt.size(0) old_inds = indices.clone() indices = (indices.view(n_inputs, -1).repeat_interleave(input_turns_cnt, dim=0).view(-1) ) # type: ignore input_turns_cnt = input_turns_cnt.index_select(0, old_inds) n_docs = doc_probs.shape[1] enc = _stack_ctxt(_unstack_ctxt(enc, n_docs).index_select(0, indices)) mask = _stack_ctxt( _unstack_ctxt(mask, n_docs).index_select(0, indices)) doc_probs = doc_probs.index_select(0, indices) return enc, mask, input_turns_cnt, docs, doc_probs # type: ignore
def _pick_embed(embedding: nn.Embedding, indices: LongTensor, mask: torch.ByteTensor): indices = indices.clone() indices[~mask] = 0 embed = embedding.forward(indices) embed *= mask.to(dtype=torch.float).unsqueeze(-1).expand_as(embed) return embed
def predict_scores_all_heads( self, rt_batch: torch.LongTensor, slice_size: Optional[int] = None, ) -> torch.FloatTensor: """Forward pass using left side (head) prediction for obtaining scores of all possible heads. This method calculates the score for all possible heads for each (relation, tail) pair. Additionally, the model is set to evaluation mode. :param rt_batch: torch.Tensor, shape: (batch_size, 2), dtype: long The indices of (relation, tail) pairs. :param slice_size: >0 The divisor for the scoring function when using slicing. :return: torch.Tensor, shape: (batch_size, num_entities), dtype: float For each r-t pair, the scores for all possible heads. """ # Enforce evaluation mode self.eval() ''' In case the model was trained using inverse triples, the scoring of all heads is not handled by calculating the scores for all heads based on a (relation, tail) pair, but instead all possible tails are calculated for a (tail, inverse_relation) pair. ''' if not self.triples_factory.create_inverse_triples: if slice_size is None: scores = self.score_h(rt_batch) else: scores = self.score_h(rt_batch, slice_size=slice_size) if self.predict_with_sigmoid: scores = torch.sigmoid(scores) return scores ''' The PyKEEN package handles _inverse relations_ by adding the number of relations to the index of the _native relation_. Example: The triples/knowledge graph used to train the model contained 100 relations. Due to using inverse relations, the model now has an additional 100 inverse relations. If the _native relation_ has the index 3, the index of the _inverse relation_ is 4 (id of relation + 1). ''' rt_batch_cloned = rt_batch.clone() rt_batch_cloned.to(device=rt_batch.device) # The number of relations stored in the triples factory includes the number of inverse relations # Id of inverse relation: relation + 1 rt_batch_cloned[:, 0] = rt_batch_cloned[:, 0] + 1 # The score_t function requires (entity, relation) pairs instead of (relation, entity) pairs rt_batch_cloned = rt_batch_cloned.flip(1) if slice_size is None: scores = self.score_t(rt_batch_cloned) else: scores = self.score_t(rt_batch_cloned, slice_size=slice_size) if self.predict_with_sigmoid: scores = torch.sigmoid(scores) return scores
def construct_trees( self, all_spans: torch.LongTensor, sentences: List[List[str]], logits, target_tokens, ) -> List[Tree]: """ Construct `treelib.Tree` as the span tree for each batch element by running CKY for search. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 trees: List[Tree] = [] for batch_index, (spans, sentence, logit, target_token) in enumerate( zip(exclusive_end_spans, sentences, logits, target_tokens) ): try: if target_token: program = ' '.join(target_token) else: program = None tree = self._cky.find_best_span_tree(sentence, logit, spans, gold_program=program) except Exception as e: print('cannot parse to tree {}'.format(e)) tree = (None, None) trees.append(tree) return trees
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > min(self.model.max_positions()): raise ValueError("tokens exceeds maximum length: {} > {}".format( tokens.size(-1), self.model.max_positions())) tokens.to(device=self.device), prev_output_tokens = tokens.clone() prev_output_tokens[:, 0] = tokens.gather( 1, (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1), ).squeeze() prev_output_tokens[:, 1:] = tokens[:, :-1] features, extra = self.model( src_tokens=tokens, src_lengths=None, prev_output_tokens=prev_output_tokens, features_only=True, return_all_hiddens=return_all_hiddens, ) if return_all_hiddens: # convert from T x B x C -> B x T x C inner_states = extra["inner_states"] return [ inner_state.transpose(0, 1) for inner_state in inner_states ] else: return features # just the last layer's features
def _map(self, batch: torch.LongTensor, index: int = 1, invert: bool = False) -> torch.LongTensor: # noqa: D102 batch = batch.clone() batch[:, index] *= 2 return batch
def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """Generate negative samples from the positive batch.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Equally corrupt head and tail split_idx = num_negs // 2 # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() # Sample random entities as replacement negative_entities = torch.randint(high=self.num_entities - 1, size=(num_negs,), device=positive_batch.device) # Replace heads – To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] filter_same_head = (negative_entities[:split_idx] >= positive_batch[:split_idx, 0]) negative_batch[:split_idx, 0] = negative_entities[:split_idx] + filter_same_head.long() # Corrupt tails filter_same_tail = (negative_entities[split_idx:] >= positive_batch[split_idx:, 2]) negative_batch[split_idx:, 2] = negative_entities[split_idx:] + filter_same_tail.long() return negative_batch
def sample( self, positive_batch: torch.LongTensor ) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]: """Sample a negative batched based on the bern approach.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() device = positive_batch.device # Decide whether to corrupt head or tail head_corruption_probability = self.corrupt_head_probability[ positive_batch[:, 1]] head_mask = torch.rand( num_negs, device=device) < head_corruption_probability.to(device=device) # Tails are corrupted if heads are not corrupted tail_mask = ~head_mask # Randomly sample corruption. See below for explanation of # why this is on a range of [0, num_entities - 1] negative_entities = torch.randint( self.triples_factory.num_entities - 1, size=(num_negs, ), device=positive_batch.device, ) # Replace heads negative_batch[head_mask, 0] = negative_entities[head_mask] # Replace tails negative_batch[tail_mask, 2] = negative_entities[tail_mask] # If filtering is activated, all negative triples that are positive in the training dataset will be removed if self.filtered: negative_batch, batch_filter = self.filter_negative_triples( negative_batch=negative_batch) else: # To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] negative_batch[head_mask, 0] += (negative_batch[head_mask, 0] >= positive_batch[head_mask, 0]).long() negative_batch[tail_mask, 2] += (negative_batch[tail_mask, 2] >= positive_batch[tail_mask, 2]).long() batch_filter = None return negative_batch, batch_filter
def _merge_overlapping( boxes: Boxes, classes: torch.LongTensor, relation_indexes: torch.LongTensor, nms_threshold: float, ): # Boxes are candidate for merging if their IoU is above a threshold iou_above_thres = pairwise_iou(boxes, boxes) > nms_threshold # Also, they have to belong to the same class to be candidates. # Here we treat "person subj" and "person obj" as two # separate classes, to avoid merging cases of "person hugs person" # where the two people have high overlap but must remain separate obj_idx = relation_indexes[1] obj_is_person = classes[obj_idx] == 0 classes_tmp = classes.clone() classes_tmp[obj_idx[obj_is_person]] = -1 same_class = classes_tmp[:, None] == classes_tmp[None, :] candidates = iou_above_thres & same_class keep = [] visited = torch.full((len(boxes), ), False, dtype=torch.bool) relation_indexes = relation_indexes.clone() for old_box_idx, skip in enumerate(visited): if skip: continue new_box_idx = len(keep) keep.append(old_box_idx) matches = torch.nonzero(candidates[old_box_idx, :] & ~visited, as_tuple=True)[0] visited[matches] = True rel_idx_to_fix = torch.any( relation_indexes[:, :, None] == matches[None, None, :], dim=2) relation_indexes[rel_idx_to_fix] = new_box_idx return boxes[keep], classes[keep], relation_indexes
def prepare_batch( self, sequences: torch.LongTensor, lengths: torch.LongTensor ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor]: # compute the attention mask batch_size, max_seq_len = sequences.size() attention_mask = (torch.arange( max_seq_len, dtype=torch.long, device=lengths.device) < lengths[:, None]) # prepare the target target = sequences.clone().detach() # get the token probabilities of the sequences in the batch _token_probabilities = self.token_probabilities[sequences.view(-1)] # compute the number of targets (tokens for which a prediction needs to be made) num_targets = math.ceil(self.pred_proportion * lengths.sum().item()) # compute the prediction mask target_idxs = torch.multinomial(_token_probabilities / _token_probabilities.sum(), num_targets, replacement=False) pred_mask = torch.zeros(batch_size * max_seq_len, dtype=torch.bool, device=sequences.device) pred_mask[target_idxs] = 1 pred_mask = pred_mask.view(batch_size, max_seq_len) pred_mask[sequences == self.dataloader.dataset.special_tokens_map['pad_token']] = 0 # compute the prediction tokens sequences_keep = sequences[pred_mask] sequences_rand = sequences_keep.clone().random_( self.dataloader.dataset._tokenizer.get_vocab_size()) sequences_mask = sequences_keep.clone().fill_( self.dataloader.dataset.special_tokens_map['mask_token']) pred_idxs = torch.multinomial(self.pred_probabilities, len(sequences_keep), replacement=True) pred_tokens = sequences_mask * (pred_idxs == 0).long( ) + sequences_keep * (pred_idxs == 1).long() + sequences_rand * ( pred_idxs == 2).long() # copy the prediction tokens into the sequences, given the prediction mask sequences = sequences.masked_scatter(pred_mask, pred_tokens) # ignore tokens that are not in the prediction mask target[~pred_mask] = -100 return sequences, attention_mask, target
def sample( self, positive_batch: torch.LongTensor ) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]: """Generate negative samples from the positive batch.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Equally corrupt all sides split_idx = num_negs // len(self._corruption_indices) # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() for index, start in zip(self._corruption_indices, range(0, num_negs, split_idx)): stop = min(start + split_idx, num_negs) # Relations have a different index maximum than entities index_max = self.num_relations if index == 1 else self.num_entities # If we do not use a filterer, we at least make sure to not replace the triples by the original value if self.filterer is None: index_max -= 1 negative_batch[start:stop, index] = torch.randint( high=index_max, size=(stop - start, ), device=positive_batch.device, ) # To make sure we don't replace the {head, relation, tail} by the # original value we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_{heads, relations, tails} -1] if self.filterer is None: negative_batch[start:stop, index] += (negative_batch[start:stop, index] >= positive_batch[start:stop, index]).long() # If filtering is activated, all negative triples that are positive in the training dataset will be removed if self.filterer is not None: negative_batch, batch_filter = self.filterer( negative_batch=negative_batch) else: batch_filter = None return negative_batch, batch_filter
def corrupt_batch( self, positive_batch: torch.LongTensor ) -> torch.LongTensor: # noqa: D102 if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat_interleave( repeats=self.num_negs_per_pos, dim=0) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() device = positive_batch.device # Decide whether to corrupt head or tail head_corruption_probability = self.corrupt_head_probability[ positive_batch[:, 1]] head_mask = torch.rand( num_negs, device=device) < head_corruption_probability.to(device=device) # Tails are corrupted if heads are not corrupted tail_mask = ~head_mask # We at least make sure to not replace the triples by the original value # See below for explanation of why this is on a range of [0, num_entities - 1] index_max = self.num_entities - 1 # Randomly sample corruption. negative_entities = torch.randint( index_max, size=(num_negs, ), device=positive_batch.device, ) # Replace heads negative_batch[head_mask, 0] = negative_entities[head_mask] # Replace tails negative_batch[tail_mask, 2] = negative_entities[tail_mask] # To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] negative_batch[head_mask, 0] += (negative_batch[head_mask, 0] >= positive_batch[head_mask, 0]).long() negative_batch[tail_mask, 2] += (negative_batch[tail_mask, 2] >= positive_batch[tail_mask, 2]).long() return negative_batch.view(-1, self.num_negs_per_pos, 3)
def _prepare_inverse_batch(self, batch: torch.LongTensor, index_relation: int) -> torch.LongTensor: if not self.triples_factory.create_inverse_triples: raise ValueError( "Your model is not configured to predict with inverse relations." " Set ``create_inverse_triples=True`` when creating the dataset/triples factory" " or using the pipeline().", ) batch_cloned = batch.clone() # The number of relations stored in the triples factory includes the number of inverse relations # Id of inverse relation: relation + 1 batch_cloned[:, index_relation] = batch_cloned[:, index_relation] + 1 return batch_cloned.flip(1)
def run_rnn(rnn, inputs, lens: torch.LongTensor): if not isinstance(lens, torch.Tensor): lens = torch.LongTensor(lens).to(inputs.device) lens = lens.clone() # pad sequence to avoid rnn error when seq_len = 0 null_mask = lens == 0 lens[null_mask] += 1 packed = nn.utils.rnn.pack_padded_sequence(inputs, lens, batch_first=True, enforce_sorted=False) outputs, _ = rnn(packed) padded, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=0., total_length=inputs.size(1)) return padded
def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """Sample a negative batched based on the bern approach.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() device = positive_batch.device # Decide whether to corrupt head or tail head_corruption_probability = self.corrupt_head_probability[ positive_batch[:, 1]] head_mask = torch.rand( num_negs, device=device) < head_corruption_probability.to(device=device) # Tails are corrupted if heads are not corrupted tail_mask = ~head_mask # Randomly sample corruption negative_entities = torch.randint( self.triples_factory.num_entities - 1, size=(num_negs, ), device=positive_batch.device, ) # Replace heads – To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] filter_same_head = (negative_entities[head_mask] >= positive_batch[:, 0][head_mask]) negative_batch[:, 0][head_mask] = negative_entities[ head_mask] + filter_same_head.long() # Replace tails filter_same_tail = (negative_entities[tail_mask] >= positive_batch[:, 2][tail_mask]) negative_batch[:, 2][tail_mask] = negative_entities[ tail_mask] + filter_same_tail.long() return negative_batch
def __init__(self, X: torch.FloatTensor, y: torch.LongTensor, transform=None, all_tasks: list = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)], truncate_size: int = None): super().__init__() self.all_tasks = all_tasks self.current_task = None self.truncate_size = truncate_size assert isinstance(X, torch.FloatTensor) assert isinstance(y, torch.LongTensor) self.all_y = y.clone() self.all_x = X.clone() self.y = self.all_y self.X = self.all_x self.transform = transform
def accuracy_func( y_predict: Variable, y_truth: LongTensor, ) -> float: '''accuracy ''' total_num = len(torch.nonzero(y_truth)) y_truth_modified = y_truth.clone() y_truth_modified[y_truth == 0] = -1 hit_tags = (torch.max(y_predict, 2)[1].view( y_truth.size()).data.cpu() == y_truth_modified).sum() # a = targets_in.data # a = a.numpy() # size = len(a) return hit_tags / total_num
def mlm_mask_tokens( inputs: torch.LongTensor, tokenizer, mlm_probability) -> Tuple[torch.LongTensor, torch.LongTensor]: """From HuggingFace""" device = inputs.device inputs = inputs.cpu().clone() labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training # (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, mlm_probability) special_tokens_mask = [ tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) # noinspection PyProtectedMember if tokenizer._pad_token is not None: padding_mask = labels.eq(tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[ ~masked_indices] = NON_MASKED_TOKEN_LABEL_ID # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices inputs[indices_replaced] = tokenizer.convert_tokens_to_ids( tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced) random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged # noinspection PyTypeChecker return inputs.to(device), labels.to(device)
def corrupt_batch( self, positive_batch: torch.LongTensor ) -> torch.LongTensor: # noqa: D102 if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat_interleave( repeats=self.num_negs_per_pos, dim=0) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Equally corrupt all sides split_idx = int(math.ceil(num_negs / len(self._corruption_indices))) # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() for index, start in zip(self._corruption_indices, range(0, num_negs, split_idx)): stop = min(start + split_idx, num_negs) # Relations have a different index maximum than entities # At least make sure to not replace the triples by the original value index_max = (self.num_relations if index == 1 else self.num_entities) - 1 negative_batch[start:stop, index] = torch.randint( high=index_max, size=(stop - start, ), device=positive_batch.device, ) # To make sure we don't replace the {head, relation, tail} by the # original value we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_{heads, relations, tails} -1] negative_batch[start:stop, index] += (negative_batch[start:stop, index] >= positive_batch[start:stop, index]).long() return negative_batch.view(-1, self.num_negs_per_pos, 3)
def corrupt_batch( self, positive_batch: torch.LongTensor ) -> torch.LongTensor: # noqa: D102 batch_shape = positive_batch.shape[:-1] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() negative_batch = negative_batch.unsqueeze(dim=-2).repeat( *(1 for _ in batch_shape), self.num_negs_per_pos, 1) corruption_index = torch.randint(self._n_corruptions, size=(*batch_shape, self.num_negs_per_pos)) # split_idx = int(math.ceil(num_negs / len(self._corruption_indices))) for index in self._corruption_indices: # Relations have a different index maximum than entities # At least make sure to not replace the triples by the original value index_max = (self.num_relations if index == 1 else self.num_entities) - 1 mask = corruption_index == index # To make sure we don't replace the {head, relation, tail} by the # original value we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_{heads, relations, tails} -1] negative_indices = torch.randint( high=index_max, size=(mask.sum().item(), ), device=positive_batch.device, ) # determine shift *before* writing the negative indices shift = (negative_indices >= negative_batch[mask][:, index]).long() negative_indices += shift # write the negative indices negative_batch[mask.unsqueeze(dim=-1) & ( torch.arange(3) == index).view(*( 1 for _ in batch_shape), 1, 3)] = negative_indices return negative_batch
def accuracy_func_test( y_predict: Variable, y_truth: LongTensor, ) -> float: ''' accuracy test ''' total_num = len(torch.nonzero(y_truth)) ground_truth_modified = y_truth.clone() ground_truth_modified[y_truth == 0] = -1 hit_tags = 0 for i in range(len(y_truth)): # batch loop for j in range(len(y_truth[i])): # words loop if y_truth[i][j] != 0: if (y_truth[i][j] in np.array( torch.sort(y_predict[i][j], -1, True)[1][:10].data)): # if top 10 hits hit_tags += 1 return hit_tags / total_num
def _window_lengths(self, lengths: torch.LongTensor) -> torch.LongTensor: lengths = lengths.clone() if lengths.ndim == 0: lengths.unsqueeze(0) lengths_win = [] while True: sub_lengths = torch.clamp(lengths, max=self.window_size) lengths = torch.clamp(lengths - self.step, min=0) if sub_lengths.max() <= self.window_size - self.step: # Missing frames were only in overlap so part of last chunk. break lengths_win.append(sub_lengths) if not self.zero_padding and lengths.max() < self.window_size: break if len(lengths_win) == 0: # Edge case where all lengths are in zeroes overlap. return sub_lengths else: return torch.stack(lengths_win, dim=1).view(-1)
def construct_trees(self, predictions: torch.FloatTensor, all_spans: torch.LongTensor, num_spans: torch.LongTensor, sentences: List[List[str]], pos_tags: List[List[str]] = None) -> List[Tree]: """ Construct ``nltk.Tree``'s for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model. Parameters ---------- predictions : ``torch.FloatTensor``, required. A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. all_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. sentences : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, optional (default = None). A list of POS tags for each word in the sentence for each element in the batch. Returns ------- A ``List[Tree]`` containing the decoded trees for each element in the batch. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 no_label_id = self.vocab.get_token_index("NO-LABEL", "labels") trees: List[Tree] = [] for batch_index, (scored_spans, spans, sentence) in enumerate( zip(predictions, exclusive_end_spans, sentences)): selected_spans = [] for prediction, span in zip(scored_spans[:num_spans[batch_index]], spans[:num_spans[batch_index]]): start, end = span no_label_prob = prediction[no_label_id] label_prob, label_index = torch.max(prediction, -1) # Does the span have a label != NO-LABEL or is it the root node? # If so, include it in the spans that we consider. if int(label_index) != no_label_id or (start == 0 and end == len(sentence)): # TODO(Mark): Remove this once pylint sorts out named tuples. # https://github.com/PyCQA/pylint/issues/1418 selected_spans.append( SpanInformation( start=int(start), # pylint: disable=no-value-for-parameter end=int(end), label_prob=float(label_prob), no_label_prob=float(no_label_prob), label_index=int(label_index))) # The spans we've selected might overlap, which causes problems when we try # to construct the tree as they won't nest properly. consistent_spans = self.resolve_overlap_conflicts_greedily( selected_spans) spans_to_labels = { (span.start, span.end): self.vocab.get_token_from_index(span.label_index, "labels") for span in consistent_spans } sentence_pos = pos_tags[ batch_index] if pos_tags is not None else None trees.append( self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos)) return trees
def construct_trees(self, predictions: torch.FloatTensor, all_spans: torch.LongTensor, num_spans: torch.LongTensor, sentences: List[List[str]], pos_tags: List[List[str]] = None) -> List[Tree]: """ Construct ``nltk.Tree``'s for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model. Parameters ---------- predictions : ``torch.FloatTensor``, required. A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. all_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. sentences : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, optional (default = None). A list of POS tags for each word in the sentence for each element in the batch. Returns ------- A ``List[Tree]`` containing the decoded trees for each element in the batch. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 no_label_id = self.vocab.get_token_index("NO-LABEL", "labels") trees: List[Tree] = [] for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions, exclusive_end_spans, sentences)): selected_spans = [] for prediction, span in zip(scored_spans[:num_spans[batch_index]], spans[:num_spans[batch_index]]): start, end = span no_label_prob = prediction[no_label_id] label_prob, label_index = torch.max(prediction, -1) # Does the span have a label != NO-LABEL or is it the root node? # If so, include it in the spans that we consider. if int(label_index) != no_label_id or (start == 0 and end == len(sentence)): # TODO(Mark): Remove this once pylint sorts out named tuples. # https://github.com/PyCQA/pylint/issues/1418 selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter end=int(end), label_prob=float(label_prob), no_label_prob=float(no_label_prob), label_index=int(label_index))) # The spans we've selected might overlap, which causes problems when we try # to construct the tree as they won't nest properly. consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans) spans_to_labels = {(span.start, span.end): self.vocab.get_token_from_index(span.label_index, "labels") for span in consistent_spans} sentence_pos = pos_tags[batch_index] if pos_tags is not None else None trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos)) return trees
def forward( self, token_ids: torch.LongTensor, mask: torch.BoolTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, masked_lm: Optional[List[bool]] = None ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: `torch.LongTensor` Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`). mask: `torch.BoolTensor` Shape: [batch_size, num_orig_tokens]. offsets: `torch.LongTensor` Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` corresponds to the original j-th token from the i-th batch. wordpiece_mask: `torch.BoolTensor` Shape: [batch_size, num_wordpieces]. type_ids: `Optional[torch.LongTensor]` Shape: [batch_size, num_wordpieces]. segment_concat_mask: `Optional[torch.BoolTensor]` See `PretrainedTransformerEmbedder`. # Returns `torch.Tensor` Shape: [batch_size, num_orig_tokens, embedding_size]. """ masked_lm_labels = -100*torch.ones_like(token_ids) masked_token_ids = token_ids activate_masking = masked_lm is not None and any(masked_lm) if activate_masking: batch_size, num_orig_tokens = mask.shape masked_lm = torch.tensor(masked_lm, dtype=torch.bool).to(token_ids.device) mask_probs = torch.rand(mask.shape, device=mask.device) mask_token_choices = (mask_probs < self._mask_probability*self._mask_token_probability) & mask & masked_lm.unsqueeze(-1) mask_random_choices = (mask_probs >= self._mask_probability*self._mask_token_probability) & (mask_probs < self._mask_probability*(self._mask_token_probability+self._mask_random_probability)) & mask & masked_lm.unsqueeze(-1) all_mask_choices = (mask_probs < self._mask_probability) & mask & masked_lm.unsqueeze(-1) mask_token_indices = mask_token_choices.nonzero() mask_random_indices = mask_random_choices.nonzero() mask_random_values = torch.randint(low=0, high=self._matched_embedder.transformer_model.config.vocab_size, size=token_ids.shape, device=mask.device) all_mask_indices = all_mask_choices.nonzero() masked_token_ids = token_ids.clone() for i in range(mask_token_indices.shape[0]): offset_start_end = offsets[mask_token_indices[i][0].item(), mask_token_indices[i][1].item(),:] masked_token_ids[mask_token_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = self._matched_embedder._mask_token_id for i in range(mask_random_indices.shape[0]): offset_start_end = offsets[mask_random_indices[i][0].item(), mask_random_indices[i][1].item(),:] masked_token_ids[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = mask_random_values[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] for i in range(all_mask_indices.shape[0]): offset_start_end = offsets[all_mask_indices[i][0].item(), all_mask_indices[i][1].item(),:] masked_lm_labels[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1] = token_ids[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1] # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings, masked_lm_loss = self._matched_embedder( masked_token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask, masked_lm_labels=masked_lm_labels ) # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 if activate_masking: return orig_embeddings, masked_lm_loss return orig_embeddings