コード例 #1
0
    def get_word_embedding(
        self,
        token_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore
        """get [CLS] embedding and word-level embedding"""
        # Shape: [batch_size, num_wordpieces, embedding_size].
        # embed_model = self.bert if self.arch == "bert" else self.roberta
        # embeddings = embed_model(token_ids, token_type_ids=type_ids, attention_mask=wordpiece_mask)[0]
        embeddings = self.bert(token_ids,
                               token_type_ids=type_ids,
                               attention_mask=wordpiece_mask)[0]

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = allennlp_util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(
            span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return embeddings[:, 0, :], orig_embeddings
コード例 #2
0
    def _embed_spans(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.BoolTensor = None,
        span_indices_mask: torch.BoolTensor = None,
    ) -> torch.FloatTensor:
        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # shape (batch_size, sequence_length, embedding_dim + 1)
        concat_tensor = torch.cat([sequence_tensor, global_attention_logits], -1)

        concat_output, span_mask = util.batched_span_select(concat_tensor, span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        return attended_text_embeddings
コード例 #3
0
    def _aggregate_token_embeddings(
            embeddings_list: List[torch.Tensor],
            token_offsets: List[torch.Tensor]) -> List[numpy.ndarray]:
        if len(token_offsets) == 0:
            return [embeddings.numpy() for embeddings in embeddings_list]
        aggregated_embeddings = []
        # NOTE: This is assuming that embeddings and offsets come in the same order, which may not
        # be true.  But, the intersection of using multiple TextFields with mismatched indexers is
        # currently zero, so we'll delay handling this corner case until it actually causes a
        # problem.  In practice, both of these lists will always be of size one at the moment.
        for embeddings, offsets in zip(embeddings_list, token_offsets):
            span_embeddings, span_mask = util.batched_span_select(
                embeddings.contiguous(), offsets)
            span_mask = span_mask.unsqueeze(-1)
            span_embeddings *= span_mask  # zero out paddings

            span_embeddings_sum = span_embeddings.sum(2)
            span_embeddings_len = span_mask.sum(2)
            # Shape: (batch_size, num_orig_tokens, embedding_size)
            embeddings = span_embeddings_sum / torch.clamp_min(
                span_embeddings_len, 1)

            # All the places where the span length is zero, write in zeros.
            embeddings[(span_embeddings_len == 0).expand(embeddings.shape)] = 0
            aggregated_embeddings.append(embeddings.numpy())
        return aggregated_embeddings
コード例 #4
0
        def hook_layers(module, grad_in, grad_out):
            grads = grad_out[0]
            if self._token_offsets:
                # If you have a mismatched indexer with multiple TextFields, it's quite possible
                # that the order we deal with the gradients is wrong.  We'll just take items from
                # the list one at a time, and try to aggregate the gradients.  If we got the order
                # wrong, we should crash, so you'll know about it.  If you get an error because of
                # that, open an issue on github, and we'll see what we can do.  The intersection of
                # multiple TextFields and mismatched indexers is pretty small (currently empty, that
                # I know of), so we'll ignore this corner case until it's needed.
                offsets = self._token_offsets.pop(0)
                span_grads, span_mask = util.batched_span_select(
                    grads.contiguous(), offsets)
                span_mask = span_mask.unsqueeze(-1)
                span_grads *= span_mask  # zero out paddings

                span_grads_sum = span_grads.sum(2)
                span_grads_len = span_mask.sum(2)
                # Shape: (batch_size, num_orig_tokens, embedding_size)
                grads = span_grads_sum / torch.clamp_min(span_grads_len, 1)

                # All the places where the span length is zero, write in zeros.
                grads[(span_grads_len == 0).expand(grads.shape)] = 0

            embedding_gradients.append(grads)
コード例 #5
0
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.BoolTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
        segment_concat_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: `torch.LongTensor`
            Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`).
        mask: `torch.BoolTensor`
            Shape: [batch_size, num_orig_tokens].
        offsets: `torch.LongTensor`
            Shape: [batch_size, num_orig_tokens, 2].
            Maps indices for the original tokens, i.e. those given as input to the indexer,
            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
            corresponds to the original j-th token from the i-th batch.
        wordpiece_mask: `torch.BoolTensor`
            Shape: [batch_size, num_wordpieces].
        type_ids: `Optional[torch.LongTensor]`
            Shape: [batch_size, num_wordpieces].
        segment_concat_mask: `Optional[torch.BoolTensor]`
            See `PretrainedTransformerEmbedder`.

        # Returns

        `torch.Tensor`
            Shape: [batch_size, num_orig_tokens, embedding_size].
        """
        # Shape: [batch_size, num_wordpieces, embedding_size].
        if self.iter_norm:
            return self._matched_embedder.get_embeddings(
                token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
                )

        embeddings = self._matched_embedder(
            token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
        )

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / span_embeddings_len

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0

        return orig_embeddings
コード例 #6
0
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.LongTensor,
        type_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: torch.LongTensor
            Shape: [batch_size, num_wordpieces].
        mask: torch.LongTensor
            Shape: [batch_size, num_orig_tokens].
        offsets: torch.LongTensor
            Shape: [batch_size, num_orig_tokens, 2].
            Maps indices for the original tokens, i.e. those given as input to the indexer,
            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
            corresponds to the original j-th token from the i-th batch.
        wordpiece_mask: torch.LongTensor
            Shape: [batch_size, num_wordpieces].
        type_ids: Optional[torch.LongTensor]
            Shape: [batch_size, num_wordpieces]

        # Returns:

        Shape: [batch_size, num_orig_tokens, embedding_size].
        """
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self._matched_embedder(token_ids,
                                            wordpiece_mask,
                                            type_ids=type_ids)

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(
            embeddings, offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / span_embeddings_len

        return orig_embeddings
コード例 #7
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        span_indices_mask: torch.BoolTensor = None,
    ) -> torch.FloatTensor:
        # shape (batch_size, sequence_length, 1)

        global_attention_logits = torch.matmul(
            sequence_tensor,
            torch.zeros(self.input_dim, 1).to_device(sequence_tensor.device()))

        # shape (batch_size, sequence_length, embedding_dim + 1)
        concat_tensor = torch.cat([sequence_tensor, global_attention_logits],
                                  -1)

        concat_output, span_mask = util.batched_span_select(
            concat_tensor, span_indices)

        print(span_mask)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits,
                                                     span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings,
                                                     span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1)

        return attended_text_embeddings
コード例 #8
0
    def _get_orig_token_embeddings(self, embeddings: torch.Tensor,
                                   offsets: torch.LongTensor):
        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(
            embeddings.contiguous(), offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / span_embeddings_len

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(
            orig_embeddings.shape)] = 0

        return orig_embeddings
コード例 #9
0
ファイル: bert_backbone.py プロジェクト: lgessler/embur
    def _embed(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]:
        """
        This implementation is borrowed from `PretrainedTransformerMismatchedEmbedder` and uses
        average pooling to yield a de-wordpieced embedding for each original token.
        Returns both wordpiece embeddings+mask as well as original token embeddings+mask
        """
        output = self.bert(
            input_ids=text['tokens']['token_ids'],
            attention_mask=text["tokens"]["wordpiece_mask"],
            token_type_ids=text['tokens']['type_ids'],
        )
        wordpiece_embeddings = output.last_hidden_state
        offsets = text['tokens']['offsets']

        # Assemble wordpiece embeddings into embeddings for each word using average pooling
        span_embeddings, span_mask = util.batched_span_select(wordpiece_embeddings.contiguous(), offsets)  # type: ignore
        span_mask = span_mask.unsqueeze(-1)
        # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        span_embeddings *= span_mask  # zero out paddings
        # return the average of embeddings of all sub-tokens of a word
        # Sum over embeddings of all sub-tokens of a word
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        span_embeddings_sum = span_embeddings.sum(2)
        # Shape (batch_size, num_orig_tokens)
        span_embeddings_len = span_mask.sum(2)
        # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len`
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)
        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0

        return {
            "wordpiece_mask": text['tokens']['wordpiece_mask'],
            "wordpiece_embeddings": wordpiece_embeddings,
            "orig_mask": text['tokens']['mask'],
            "orig_embeddings": orig_embeddings
        }
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.BoolTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
        segment_concat_mask: Optional[torch.BoolTensor] = None,
        masked_lm: Optional[List[bool]] = None
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: `torch.LongTensor`
            Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`).
        mask: `torch.BoolTensor`
            Shape: [batch_size, num_orig_tokens].
        offsets: `torch.LongTensor`
            Shape: [batch_size, num_orig_tokens, 2].
            Maps indices for the original tokens, i.e. those given as input to the indexer,
            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
            corresponds to the original j-th token from the i-th batch.
        wordpiece_mask: `torch.BoolTensor`
            Shape: [batch_size, num_wordpieces].
        type_ids: `Optional[torch.LongTensor]`
            Shape: [batch_size, num_wordpieces].
        segment_concat_mask: `Optional[torch.BoolTensor]`
            See `PretrainedTransformerEmbedder`.

        # Returns

        `torch.Tensor`
            Shape: [batch_size, num_orig_tokens, embedding_size].
        """
        masked_lm_labels = -100*torch.ones_like(token_ids)
        masked_token_ids = token_ids
        activate_masking = masked_lm is not None and any(masked_lm)
        if activate_masking:
            batch_size, num_orig_tokens = mask.shape
            masked_lm = torch.tensor(masked_lm, dtype=torch.bool).to(token_ids.device)
            mask_probs = torch.rand(mask.shape, device=mask.device)
            mask_token_choices = (mask_probs < self._mask_probability*self._mask_token_probability) & mask & masked_lm.unsqueeze(-1)
            mask_random_choices = (mask_probs >= self._mask_probability*self._mask_token_probability) & (mask_probs < self._mask_probability*(self._mask_token_probability+self._mask_random_probability)) & mask & masked_lm.unsqueeze(-1)
            all_mask_choices = (mask_probs < self._mask_probability) & mask & masked_lm.unsqueeze(-1)
            mask_token_indices = mask_token_choices.nonzero()
            mask_random_indices = mask_random_choices.nonzero()
            mask_random_values = torch.randint(low=0, high=self._matched_embedder.transformer_model.config.vocab_size, size=token_ids.shape, device=mask.device)
            all_mask_indices = all_mask_choices.nonzero()
            masked_token_ids = token_ids.clone()
            for i in range(mask_token_indices.shape[0]):
                offset_start_end = offsets[mask_token_indices[i][0].item(), mask_token_indices[i][1].item(),:]
                masked_token_ids[mask_token_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = self._matched_embedder._mask_token_id
            for i in range(mask_random_indices.shape[0]):
                offset_start_end = offsets[mask_random_indices[i][0].item(), mask_random_indices[i][1].item(),:]
                masked_token_ids[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1] = mask_random_values[mask_random_indices[i][0].item(), offset_start_end[0].item():offset_start_end[1].item()+1]
            for i in range(all_mask_indices.shape[0]):
                offset_start_end = offsets[all_mask_indices[i][0].item(), all_mask_indices[i][1].item(),:]
                masked_lm_labels[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1] = token_ids[all_mask_indices[i][0], offset_start_end[0].item():offset_start_end[1].item()+1]
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings, masked_lm_loss = self._matched_embedder(
            masked_token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask, masked_lm_labels=masked_lm_labels
        )

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
        span_mask = span_mask.unsqueeze(-1)
        span_embeddings *= span_mask  # zero out paddings

        span_embeddings_sum = span_embeddings.sum(2)
        span_embeddings_len = span_mask.sum(2)
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)

        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0

        if activate_masking:
            return orig_embeddings, masked_lm_loss
        return orig_embeddings
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.BoolTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
        segment_concat_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: `torch.LongTensor`
            Shape: [batch_size, num_wordpieces] (for exception see `PretrainedTransformerEmbedder`).
        mask: `torch.BoolTensor`
            Shape: [batch_size, num_orig_tokens].
        offsets: `torch.LongTensor`
            Shape: [batch_size, num_orig_tokens, 2].
            Maps indices for the original tokens, i.e. those given as input to the indexer,
            to a span in token_ids. `token_ids[i][offsets[i][j][0]:offsets[i][j][1] + 1]`
            corresponds to the original j-th token from the i-th batch.
        wordpiece_mask: `torch.BoolTensor`
            Shape: [batch_size, num_wordpieces].
        type_ids: `Optional[torch.LongTensor]`
            Shape: [batch_size, num_wordpieces].
        segment_concat_mask: `Optional[torch.BoolTensor]`
            See `PretrainedTransformerEmbedder`.

        # Returns

        `torch.Tensor`
            Shape: [batch_size, num_orig_tokens, embedding_size].
        """
        # Shape: [batch_size, num_wordpieces, embedding_size].
        embeddings = self._matched_embedder(
            token_ids,
            wordpiece_mask,
            type_ids=type_ids,
            segment_concat_mask=segment_concat_mask)

        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        # span_mask: (batch_size, num_orig_tokens, max_span_length)
        span_embeddings, span_mask = util.batched_span_select(
            embeddings.contiguous(), offsets)

        span_mask = span_mask.unsqueeze(-1)

        # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        span_embeddings *= span_mask  # zero out paddings

        # If "sub_token_mode" is set to "first", return the first sub-token embedding
        if self.sub_token_mode == "first":
            # Select first sub-token embeddings from span embeddings
            # Shape: (batch_size, num_orig_tokens, embedding_size)
            orig_embeddings = span_embeddings[:, :, 0, :]

        # If "sub_token_mode" is set to "avg", return the average of embeddings of all sub-tokens of a word
        elif self.sub_token_mode == "avg":
            # Sum over embeddings of all sub-tokens of a word
            # Shape: (batch_size, num_orig_tokens, embedding_size)
            span_embeddings_sum = span_embeddings.sum(2)

            # Shape (batch_size, num_orig_tokens)
            span_embeddings_len = span_mask.sum(2)

            # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len`
            # Shape: (batch_size, num_orig_tokens, embedding_size)
            orig_embeddings = span_embeddings_sum / torch.clamp_min(
                span_embeddings_len, 1)

            # All the places where the span length is zero, write in zeros.
            orig_embeddings[(span_embeddings_len == 0).expand(
                orig_embeddings.shape)] = 0

        # If invalid "sub_token_mode" is provided, throw error
        else:
            raise ConfigurationError(
                f"Do not recognise 'sub_token_mode' {self.sub_token_mode}")

        return orig_embeddings