class EmbExtractor(): def __init__(self, model_name: str, sentence_transformer: bool, gpu: bool, fp16: bool, pooling: str, without_encoding: bool, use_mlm_head: bool, use_mlm_head_without_layernorm: bool): self._sentence_transformer = sentence_transformer self._gpu = gpu self._fp16 = fp16 self._pooling = pooling self._without_encoding = without_encoding self._use_mlm_head = use_mlm_head self._use_mlm_head_without_layernorm = use_mlm_head_without_layernorm self._tokenizer = AutoTokenizer.from_pretrained(model_name) if self._sentence_transformer: self._model = SentenceTransformer(model_name) else: if self._pooling == "mask" or self._use_mlm_head: self._model = AutoModelForMaskedLM.from_pretrained(model_name) self._model.config.output_hidden_states = True else: self._model = AutoModel.from_pretrained(model_name) if self._gpu: self._model.cuda() if self._fp16: self._model.half() def extract_emb(self, lines: Union[str, List[str]]): if not isinstance(lines, list): lines = [lines] if self._sentence_transformer: # Shape: (batch_size, num_embs) sentence_embedding = self._model.encode(lines) return sentence_embedding else: encoded_input = self._tokenizer.batch_encode_plus( lines, truncation=True, padding=True, pad_to_multiple_of=8, return_tensors='pt', return_special_tokens_mask=True) if self._gpu: encoded_input = {k: v.cuda() for k, v in encoded_input.items()} # Shape: (batch_size, num_tokens, 1) special_tokens_mask = ( 1 - encoded_input.pop("special_tokens_mask").unsqueeze(axis=-1)) if self._use_mlm_head: self._model.lm_head.decoder = Identity() if self._use_mlm_head_without_layernorm: self._model.lm_head.lm_head_norm = Identity() with torch.no_grad(): outputs = self._model(**encoded_input) if self._use_mlm_head: self._pooling = "mask" if self._pooling == "mask": assert not self._without_encoding # Shape: (batch_size, num_tokens, num_embs) output = outputs["hidden_states"][-1] if self._use_mlm_head: with torch.no_grad(): # Shape: (batch_size, num_tokens, num_embs) output = self._model.lm_head(output) # Shape: (batch_size, num_embs) - <mask> is the 2nd token sentence_embedding = output[:, 1, :] # ... elif self._pooling == "cls": # Shape: (batch_size, num_tokens, num_embs) output = outputs["last_hidden_state"] # Shape: (batch_size, num_embs) sentence_embedding = output[:, 0, :] else: if self._without_encoding: # Shape: (batch_size, num_embs) output = outputs["last_hidden_state"][ 0] * special_tokens_mask else: # Shape: (batch_size, num_tokens, num_embs) output = outputs["last_hidden_state"] * special_tokens_mask if self._pooling == 'avg': # Shape: (batch_size, num_embs) output_masked = torch.sum(output, dim=1) # Shape: (batch_size, 1) non_zeros_n = torch.sum(special_tokens_mask, dim=1) # Shape: (batch_size, num_embs) sentence_embedding = output_masked / non_zeros_n elif self._pooling == 'max': # Shape: (batch_size, num_embs) output_masked = (output).max(dim=1) # Shape: (batch_size, num_embs) sentence_embedding = output_masked.values else: logging.critical(" - pooling method doesnt exists") exit() return sentence_embedding.float().cpu().numpy()