def get_word_embedding( self, token_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore """get [CLS] embedding and word-level embedding""" # Shape: [batch_size, num_wordpieces, embedding_size]. # embed_model = self.bert if self.arch == "bert" else self.roberta # embeddings = embed_model(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] embeddings = self.bert(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0] # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = allennlp_util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return embeddings[:, 0, :], orig_embeddings
def _embed_spans( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ) -> torch.FloatTensor: # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # shape (batch_size, sequence_length, embedding_dim + 1) concat_tensor = torch.cat([sequence_tensor, global_attention_logits], -1) concat_output, span_mask = util.batched_span_select(concat_tensor, span_indices) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = concat_output[:, :, :, :-1] # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = concat_output[:, :, :, -1] # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) return attended_text_embeddings
def _aggregate_token_embeddings( embeddings_list: List[torch.Tensor], token_offsets: List[torch.Tensor]) -> List[numpy.ndarray]: if len(token_offsets) == 0: return [embeddings.numpy() for embeddings in embeddings_list] aggregated_embeddings = [] # NOTE: This is assuming that embeddings and offsets come in the same order, which may not # be true. But, the intersection of using multiple TextFields with mismatched indexers is # currently zero, so we'll delay handling this corner case until it actually causes a # problem. In practice, both of these lists will always be of size one at the moment. for embeddings, offsets in zip(embeddings_list, token_offsets): span_embeddings, span_mask = util.batched_span_select( embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. embeddings[(span_embeddings_len == 0).expand(embeddings.shape)] = 0 aggregated_embeddings.append(embeddings.numpy()) return aggregated_embeddings
def hook_layers(module, grad_in, grad_out): grads = grad_out[0] if self._token_offsets: # If you have a mismatched indexer with multiple TextFields, it's quite possible # that the order we deal with the gradients is wrong. We'll just take items from # the list one at a time, and try to aggregate the gradients. If we got the order # wrong, we should crash, so you'll know about it. If you get an error because of # that, open an issue on github, and we'll see what we can do. The intersection of # multiple TextFields and mismatched indexers is pretty small (currently empty, that # I know of), so we'll ignore this corner case until it's needed. offsets = self._token_offsets.pop(0) span_grads, span_mask = util.batched_span_select( grads.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_grads *= span_mask # zero out paddings span_grads_sum = span_grads.sum(2) span_grads_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) grads = span_grads_sum / torch.clamp_min(span_grads_len, 1) # All the places where the span length is zero, write in zeros. grads[(span_grads_len == 0).expand(grads.shape)] = 0 embedding_gradients.append(grads)
def forward( self, token_ids: torch.LongTensor, mask: torch.BoolTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: `torch.LongTensor` Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`). mask: `torch.BoolTensor` Shape: [batch_size, num_orig_tokens]. offsets: `torch.LongTensor` Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` corresponds to the original j-th token from the i-th batch. wordpiece_mask: `torch.BoolTensor` Shape: [batch_size, num_wordpieces]. type_ids: `Optional[torch.LongTensor]` Shape: [batch_size, num_wordpieces]. segment_concat_mask: `Optional[torch.BoolTensor]` See `PretrainedTransformerEmbedder`. # Returns `torch.Tensor` Shape: [batch_size, num_orig_tokens, embedding_size]. """ # Shape: [batch_size, num_wordpieces, embedding_size]. if self.iter_norm: return self._matched_embedder.get_embeddings( token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask ) embeddings = self._matched_embedder( token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask ) # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / span_embeddings_len # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 return orig_embeddings
def forward( self, token_ids: torch.LongTensor, mask: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.LongTensor, type_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: torch.LongTensor Shape: [batch_size, num_wordpieces]. mask: torch.LongTensor Shape: [batch_size, num_orig_tokens]. offsets: torch.LongTensor Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` corresponds to the original j-th token from the i-th batch. wordpiece_mask: torch.LongTensor Shape: [batch_size, num_wordpieces]. type_ids: Optional[torch.LongTensor] Shape: [batch_size, num_wordpieces] # Returns: Shape: [batch_size, num_orig_tokens, embedding_size]. """ # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings = self._matched_embedder(token_ids, wordpiece_mask, type_ids=type_ids) # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select( embeddings, offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / span_embeddings_len return orig_embeddings
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, span_indices_mask: torch.BoolTensor = None, ) -> torch.FloatTensor: # shape (batch_size, sequence_length, 1) global_attention_logits = torch.matmul( sequence_tensor, torch.zeros(self.input_dim, 1).to_device(sequence_tensor.device())) # shape (batch_size, sequence_length, embedding_dim + 1) concat_tensor = torch.cat([sequence_tensor, global_attention_logits], -1) concat_output, span_mask = util.batched_span_select( concat_tensor, span_indices) print(span_mask) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = concat_output[:, :, :, :-1] # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = concat_output[:, :, :, -1] # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1) return attended_text_embeddings
def _get_orig_token_embeddings(self, embeddings: torch.Tensor, offsets: torch.LongTensor): # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select( embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / span_embeddings_len # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 return orig_embeddings
def _embed(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]: """ This implementation is borrowed from `PretrainedTransformerMismatchedEmbedder` and uses average pooling to yield a de-wordpieced embedding for each original token. Returns both wordpiece embeddings+mask as well as original token embeddings+mask """ output = self.bert( input_ids=text['tokens']['token_ids'], attention_mask=text["tokens"]["wordpiece_mask"], token_type_ids=text['tokens']['type_ids'], ) wordpiece_embeddings = output.last_hidden_state offsets = text['tokens']['offsets'] # Assemble wordpiece embeddings into embeddings for each word using average pooling span_embeddings, span_mask = util.batched_span_select(wordpiece_embeddings.contiguous(), offsets) # type: ignore span_mask = span_mask.unsqueeze(-1) # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size) span_embeddings *= span_mask # zero out paddings # return the average of embeddings of all sub-tokens of a word # Sum over embeddings of all sub-tokens of a word # Shape: (batch_size, num_orig_tokens, embedding_size) span_embeddings_sum = span_embeddings.sum(2) # Shape (batch_size, num_orig_tokens) span_embeddings_len = span_mask.sum(2) # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len` # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 return { "wordpiece_mask": text['tokens']['wordpiece_mask'], "wordpiece_embeddings": wordpiece_embeddings, "orig_mask": text['tokens']['mask'], "orig_embeddings": orig_embeddings }
def forward( self, token_ids: torch.LongTensor, mask: torch.BoolTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, masked_lm: Optional[List[bool]] = None ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: `torch.LongTensor` Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`). mask: `torch.BoolTensor` Shape: [batch_size, num_orig_tokens]. offsets: `torch.LongTensor` Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` corresponds to the original j-th token from the i-th batch. wordpiece_mask: `torch.BoolTensor` Shape: [batch_size, num_wordpieces]. type_ids: `Optional[torch.LongTensor]` Shape: [batch_size, num_wordpieces]. segment_concat_mask: `Optional[torch.BoolTensor]` See `PretrainedTransformerEmbedder`. # Returns `torch.Tensor` Shape: [batch_size, num_orig_tokens, embedding_size]. """ masked_lm_labels = -100*torch.ones_like(token_ids) masked_token_ids = token_ids activate_masking = masked_lm is not None and any(masked_lm) if activate_masking: batch_size, num_orig_tokens = mask.shape masked_lm = torch.tensor(masked_lm, dtype=torch.bool).to(token_ids.device) mask_probs = torch.rand(mask.shape, device=mask.device) mask_token_choices = (mask_probs < self._mask_probability*self._mask_token_probability) & mask & masked_lm.unsqueeze(-1) mask_random_choices = (mask_probs >= self._mask_probability*self._mask_token_probability) & (mask_probs < self._mask_probability*(self._mask_token_probability+self._mask_random_probability)) & mask & masked_lm.unsqueeze(-1) all_mask_choices = (mask_probs < self._mask_probability) & mask & masked_lm.unsqueeze(-1) mask_token_indices = mask_token_choices.nonzero() mask_random_indices = mask_random_choices.nonzero() mask_random_values = torch.randint(low=0, high=self._matched_embedder.transformer_model.config.vocab_size, size=token_ids.shape, device=mask.device) all_mask_indices = all_mask_choices.nonzero() masked_token_ids = token_ids.clone() for i in range(mask_token_indices.shape[0]): offset_start_end = offsets[mask_token_indices[i][0].item(), mask_token_indices[i][1].item(),:] masked_token_ids[mask_token_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = self._matched_embedder._mask_token_id for i in range(mask_random_indices.shape[0]): offset_start_end = offsets[mask_random_indices[i][0].item(), mask_random_indices[i][1].item(),:] masked_token_ids[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = mask_random_values[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] for i in range(all_mask_indices.shape[0]): offset_start_end = offsets[all_mask_indices[i][0].item(), all_mask_indices[i][1].item(),:] masked_lm_labels[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1] = token_ids[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1] # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings, masked_lm_loss = self._matched_embedder( masked_token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask, masked_lm_labels=masked_lm_labels ) # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) span_embeddings *= span_mask # zero out paddings span_embeddings_sum = span_embeddings.sum(2) span_embeddings_len = span_mask.sum(2) # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 if activate_masking: return orig_embeddings, masked_lm_loss return orig_embeddings
def forward( self, token_ids: torch.LongTensor, mask: torch.BoolTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: `torch.LongTensor` Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`). mask: `torch.BoolTensor` Shape: [batch_size, num_orig_tokens]. offsets: `torch.LongTensor` Shape: [batch_size, num_orig_tokens, 2]. Maps indices for the original tokens, i.e. those given as input to the indexer, to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]` corresponds to the original j-th token from the i-th batch. wordpiece_mask: `torch.BoolTensor` Shape: [batch_size, num_wordpieces]. type_ids: `Optional[torch.LongTensor]` Shape: [batch_size, num_wordpieces]. segment_concat_mask: `Optional[torch.BoolTensor]` See `PretrainedTransformerEmbedder`. # Returns `torch.Tensor` Shape: [batch_size, num_orig_tokens, embedding_size]. """ # Shape: [batch_size, num_wordpieces, embedding_size]. embeddings = self._matched_embedder( token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask) # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) # span_mask: (batch_size, num_orig_tokens, max_span_length) span_embeddings, span_mask = util.batched_span_select( embeddings.contiguous(), offsets) span_mask = span_mask.unsqueeze(-1) # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size) span_embeddings *= span_mask # zero out paddings # If "sub_token_mode" is set to "first", return the first sub-token embedding if self.sub_token_mode == "first": # Select first sub-token embeddings from span embeddings # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings[:, :, 0, :] # If "sub_token_mode" is set to "avg", return the average of embeddings of all sub-tokens of a word elif self.sub_token_mode == "avg": # Sum over embeddings of all sub-tokens of a word # Shape: (batch_size, num_orig_tokens, embedding_size) span_embeddings_sum = span_embeddings.sum(2) # Shape (batch_size, num_orig_tokens) span_embeddings_len = span_mask.sum(2) # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len` # Shape: (batch_size, num_orig_tokens, embedding_size) orig_embeddings = span_embeddings_sum / torch.clamp_min( span_embeddings_len, 1) # All the places where the span length is zero, write in zeros. orig_embeddings[(span_embeddings_len == 0).expand( orig_embeddings.shape)] = 0 # If invalid "sub_token_mode" is provided, throw error else: raise ConfigurationError( f"Do not recognise 'sub_token_mode' {self.sub_token_mode}") return orig_embeddings