Exemplo n.º 1
0
class EmbExtractor():
    def __init__(self, model_name: str, sentence_transformer: bool, gpu: bool,
                 fp16: bool, pooling: str, without_encoding: bool,
                 use_mlm_head: bool, use_mlm_head_without_layernorm: bool):

        self._sentence_transformer = sentence_transformer
        self._gpu = gpu
        self._fp16 = fp16
        self._pooling = pooling
        self._without_encoding = without_encoding
        self._use_mlm_head = use_mlm_head
        self._use_mlm_head_without_layernorm = use_mlm_head_without_layernorm

        self._tokenizer = AutoTokenizer.from_pretrained(model_name)

        if self._sentence_transformer:
            self._model = SentenceTransformer(model_name)
        else:
            if self._pooling == "mask" or self._use_mlm_head:
                self._model = AutoModelForMaskedLM.from_pretrained(model_name)
                self._model.config.output_hidden_states = True
            else:
                self._model = AutoModel.from_pretrained(model_name)

        if self._gpu:
            self._model.cuda()
        if self._fp16:
            self._model.half()

    def extract_emb(self, lines: Union[str, List[str]]):

        if not isinstance(lines, list):
            lines = [lines]

        if self._sentence_transformer:
            # Shape: (batch_size, num_embs)
            sentence_embedding = self._model.encode(lines)

            return sentence_embedding
        else:
            encoded_input = self._tokenizer.batch_encode_plus(
                lines,
                truncation=True,
                padding=True,
                pad_to_multiple_of=8,
                return_tensors='pt',
                return_special_tokens_mask=True)
            if self._gpu:
                encoded_input = {k: v.cuda() for k, v in encoded_input.items()}

            # Shape: (batch_size, num_tokens, 1)
            special_tokens_mask = (
                1 -
                encoded_input.pop("special_tokens_mask").unsqueeze(axis=-1))

            if self._use_mlm_head:
                self._model.lm_head.decoder = Identity()
                if self._use_mlm_head_without_layernorm:
                    self._model.lm_head.lm_head_norm = Identity()

            with torch.no_grad():
                outputs = self._model(**encoded_input)

            if self._use_mlm_head:
                self._pooling = "mask"

            if self._pooling == "mask":
                assert not self._without_encoding
                # Shape: (batch_size, num_tokens, num_embs)
                output = outputs["hidden_states"][-1]

                if self._use_mlm_head:
                    with torch.no_grad():
                        # Shape: (batch_size, num_tokens, num_embs)
                        output = self._model.lm_head(output)
                # Shape: (batch_size, num_embs) - <mask> is the 2nd token
                sentence_embedding = output[:, 1, :]
                # ...
            elif self._pooling == "cls":
                # Shape: (batch_size, num_tokens, num_embs)
                output = outputs["last_hidden_state"]
                # Shape: (batch_size, num_embs)
                sentence_embedding = output[:, 0, :]
            else:

                if self._without_encoding:
                    # Shape: (batch_size, num_embs)
                    output = outputs["last_hidden_state"][
                        0] * special_tokens_mask
                else:
                    # Shape: (batch_size, num_tokens, num_embs)
                    output = outputs["last_hidden_state"] * special_tokens_mask

                if self._pooling == 'avg':
                    # Shape: (batch_size, num_embs)
                    output_masked = torch.sum(output, dim=1)
                    # Shape: (batch_size, 1)
                    non_zeros_n = torch.sum(special_tokens_mask, dim=1)

                    # Shape: (batch_size, num_embs)
                    sentence_embedding = output_masked / non_zeros_n
                elif self._pooling == 'max':
                    # Shape: (batch_size, num_embs)
                    output_masked = (output).max(dim=1)

                    # Shape: (batch_size, num_embs)
                    sentence_embedding = output_masked.values
                else:
                    logging.critical(" - pooling method doesnt exists")
                    exit()

            return sentence_embedding.float().cpu().numpy()