Exemplo n.º 1
0
collator = DataCollatorForWholeWordMask(tokenizer=tokenizer)
ids = torch.cat([tensors[0]['tokens']['tokens']['token_ids'].unsqueeze(0),
                 tensors[1]['tokens']['tokens']['token_ids'].unsqueeze(0)], dim=0)
ids.shape
wwm = collator._whole_word_mask([[vocab.get_token_from_index(i.item()) for i in wp_ids] for wp_ids in ids])

wwms = []
for i in range(ids.shape[0]):
    tokens = [vocab.get_token_from_index(i.item()) for i in ids[i]]
    wwm = torch.tensor(collator._whole_word_mask(tokens)).unsqueeze(0)
    wwms.append(wwm)
wwms = torch.cat(wwms, dim=0)

wwm = torch.tensor(wwm).unsqueeze(0)
wwm
masked_ids, labels = collator.mask_tokens(ids, wwm)
masked_ids
labels
print([vocab.get_token_from_index(i.item()) for i in out[0][0]])

tensors[0]


import torch
labels = torch.tensor([[-100, 1, -100, -100], [-100, -100, 2, 0]])
not_modified_mask = (labels == -100)
padding_mask = (labels == 0)
padding_mask
not_modified_mask
mask = (~(padding_mask | not_modified_mask))
mask
Exemplo n.º 2
0
class BertBackbone(Backbone):
    def __init__(
        self,
        vocab: Vocabulary,
        embedding_dim: int,
        feedforward_dim: int,
        num_layers: int,
        num_attention_heads: int,
        position_embedding_dim: int,
        tokenizer_path: str,
        position_embedding_type: str = "absolute",
        activation: str = "gelu",
        hidden_dropout: float = 0.1,
    ) -> None:
        super().__init__()
        # TODO:
        # - Need to apply corrections in pretrained_transformer_mismatched_embedder

        tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        vocab.add_transformer_vocab(tokenizer, "tokens")
        # "tokens" is padded by default--undo that
        del vocab._token_to_index["tokens"]["@@PADDING@@"]
        del vocab._token_to_index["tokens"]["@@UNKNOWN@@"]
        assert len(vocab._token_to_index["tokens"]) == len(vocab._index_to_token["tokens"])

        cfg = BertConfig(
            vocab_size=vocab.get_vocab_size("tokens"),
            hidden_size=embedding_dim,
            num_hidden_layers=num_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=feedforward_dim,
            hidden_act=activation,
            hidden_dropout_prob=hidden_dropout,
            max_position_embeddings=position_embedding_dim,
            position_embedding_type=position_embedding_type,
            use_cache=True,
        )
        self.cfg = cfg
        self._vocab = vocab
        self._namespace = "tokens"
        self.bert = BertModel(cfg)
        self.masking_collator = DataCollatorForWholeWordMask(
            tokenizer=tokenizer, mlm=True, mlm_probability=0.15
        )

    def _embed(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]:
        """
        This implementation is borrowed from `PretrainedTransformerMismatchedEmbedder` and uses
        average pooling to yield a de-wordpieced embedding for each original token.
        Returns both wordpiece embeddings+mask as well as original token embeddings+mask
        """
        output = self.bert(
            input_ids=text['tokens']['token_ids'],
            attention_mask=text["tokens"]["wordpiece_mask"],
            token_type_ids=text['tokens']['type_ids'],
        )
        wordpiece_embeddings = output.last_hidden_state
        offsets = text['tokens']['offsets']

        # Assemble wordpiece embeddings into embeddings for each word using average pooling
        span_embeddings, span_mask = util.batched_span_select(wordpiece_embeddings.contiguous(), offsets)  # type: ignore
        span_mask = span_mask.unsqueeze(-1)
        # Shape: (batch_size, num_orig_tokens, max_span_length, embedding_size)
        span_embeddings *= span_mask  # zero out paddings
        # return the average of embeddings of all sub-tokens of a word
        # Sum over embeddings of all sub-tokens of a word
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        span_embeddings_sum = span_embeddings.sum(2)
        # Shape (batch_size, num_orig_tokens)
        span_embeddings_len = span_mask.sum(2)
        # Find the average of sub-tokens embeddings by dividing `span_embedding_sum` by `span_embedding_len`
        # Shape: (batch_size, num_orig_tokens, embedding_size)
        orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)
        # All the places where the span length is zero, write in zeros.
        orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0

        return {
            "wordpiece_mask": text['tokens']['wordpiece_mask'],
            "wordpiece_embeddings": wordpiece_embeddings,
            "orig_mask": text['tokens']['mask'],
            "orig_embeddings": orig_embeddings
        }

    def forward(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]:  # type: ignore
        bert_output = self._embed(text)

        outputs = {
            "encoded_text": bert_output['orig_embeddings'],
            "encoded_text_mask": bert_output['orig_mask'],
            "wordpiece_encoded_text": bert_output['wordpiece_embeddings'],
            "wordpiece_encoded_text_mask": bert_output['wordpiece_mask'],
            "token_ids": util.get_token_ids_from_text_field_tensors(text),
        }

        self._extend_with_masked_text(outputs, text)
        return outputs

    def _extend_with_masked_text(self, outputs: Dict[str, Any], text: TextFieldTensors) -> None:
        input_ids = text['tokens']['token_ids']

        # get the binary mask that'll tell us which parts to mask--this is random and dynamically done
        wwms = []
        for i in range(input_ids.shape[0]):
            tokens = [self._vocab.get_token_from_index(i.item()) for i in input_ids[i]]
            wwm = torch.tensor(self.masking_collator._whole_word_mask(tokens)).unsqueeze(0)
            wwms.append(wwm)
        wwms = torch.cat(wwms, dim=0)

        masked_ids, labels = self.masking_collator.mask_tokens(input_ids.to('cpu'), wwms.to('cpu'))
        masked_ids = masked_ids.to(input_ids.device)
        labels = labels.to(input_ids.device)
        bert_output = self.bert(
            input_ids=masked_ids,
            attention_mask=text["tokens"]["wordpiece_mask"],
            token_type_ids=text['tokens']['type_ids'],
        )
        outputs["encoded_masked_text"] = bert_output.last_hidden_state
        outputs["masked_text_labels"] = labels

    @overrides
    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        tokens = []
        for instance_tokens in output_dict["token_ids"]:
            tokens.append(
                [
                    self._vocab.get_token_from_index(token_id.item(), namespace=self._namespace)
                    for token_id in instance_tokens
                ]
            )
        output_dict["tokens"] = tokens
        del output_dict["token_ids"]
        del output_dict["encoded_text"]
        del output_dict["encoded_text_mask"]
        del output_dict["encoded_masked_text"]
        return output_dict