Exemplo n.º 1
0
class Classifier(torch.nn.Module):
    def __init__(self, hidden_size=768, linear_out=2, batch_first=True):

        super(Classifier, self).__init__()

        self.output_model_file = "lm/pytorch_model.bin"
        self.output_config_file = "lm/config.json"
        self.tokenizer = BertTokenizer.from_pretrained("lm",
                                                       do_lower_case=False)
        self.config = BertConfig.from_json_file(self.output_config_file)
        self.model = BertForMaskedLM(self.config)
        device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.state_dict = torch.load(self.output_model_file,
                                     map_location=device)
        self.model.load_state_dict(self.state_dict)
        self.lstm = torch.nn.LSTM(hidden_size, 300)
        self.linear = torch.nn.Linear(300, linear_out)

    def get_embeddings(self, x_instance):
        indexed_tokens = x_instance.tolist()
        break_sentence = indexed_tokens.index(102)
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_ids = [0] * (break_sentence + 1)
        segments_ids += [1] * (len(indexed_tokens) - break_sentence - 1)
        segments_tensors = torch.tensor([segments_ids])
        self.model.eval()
        with torch.no_grad():
            encoded_layers, _ = self.model.bert(tokens_tensor.to(device),
                                                segments_tensors.to(device))
        token_embeddings = torch.stack(encoded_layers, dim=0)
        token_embeddings = torch.squeeze(token_embeddings, dim=1)
        token_embeddings = token_embeddings.permute(1, 0, 2)
        token_vecs_cat = []
        for token in token_embeddings:
            cat_vec = torch.stack((token[-1], token[-2], token[-3], token[-4]))
            mean_vec = torch.mean(cat_vec, 0)
            token_vecs_cat.append(mean_vec)
        token_vecs_cat = torch.stack(token_vecs_cat, dim=0)
        return token_vecs_cat

    def embed_data(self, x):
        entries = []
        for entry in x:
            emb = self.get_embeddings(entry.to(device)).to(device)
            entries.append(emb)
        return torch.stack(entries)

    def forward(self, x):

        h = self.embed_data(x)
        h = h.permute(1, 0, 2)
        output, _ = self.lstm(h)
        pred = self.linear(output)
        pred = pred.permute(1, 0, 2)
        return pred
Exemplo n.º 2
0
 def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
     model = BertForMaskedLM(config=config)
     model.eval()
     loss = model(input_ids, token_type_ids, input_mask, token_labels)
     prediction_scores = model(input_ids, token_type_ids, input_mask)
     outputs = {
         "loss": loss,
         "prediction_scores": prediction_scores,
     }
     return outputs
Exemplo n.º 3
0
def get_words_for_blank_slow_decode(text: str, model: BertForMaskedLM, tokenizer: BertTokenizer):
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)


    mask_positions = []
    tokenized_text = tokenizer.tokenize(text)
    top_words_all = []
    for i in range(len(tokenized_text)):
        if tokenized_text[i] == '_':
            tokenized_text[i] = '[MASK]'
            mask_positions.append(i)

    while mask_positions:
        top_words = []
        # Convert tokens to vocab indices
        token_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([token_ids])

        # Call BERT to calculate unnormalized probabilities for all pos
        model.eval()
        predictions = model(tokens_tensor)

        # get predictions
        mask_preds = predictions[0, mask_positions, :]

        candidates = [] #(word, prob)
        for mask_pos in mask_positions:
            mask_preds = predictions[0, mask_pos, :]

            top_idxs = mask_preds.detach().numpy().argsort()[::-1]
            top_idx = top_idxs[0]
            top_prob = mask_preds[top_idx]
            top_word = tokenizer.ids_to_tokens[top_idx]
            candidates.append((top_word, top_prob.detach().item()))
            top_words_pos = []
            for i in top_idxs[:20]:
                top_words_pos.append((tokenizer.ids_to_tokens[i], mask_preds[i].detach().item()))
            top_words.append(top_words_pos)
        best_candidate = max(candidates, key = lambda x: x[1])
        best_pos = mask_positions[candidates.index(best_candidate)]

        tokenized_text[best_pos] = best_candidate[0]
        mask_positions = [i for i in mask_positions if i != best_pos]

        top_words_all.append(top_words[candidates.index(best_candidate)])

    pred_sent = ' '.join(tokenized_text).replace(' ##', '')
    return (pred_sent, top_words_all)
def predict_word(text: str, model: BertForMaskedLM, tokenizer: BertTokenizer, tgt_word: str, tgt_pos: int):
    # print('Template sentence: ', text)
    mask_positions = []

    # insert mask tokens
    tokenized_text = tokenizer.tokenize(text)

    for i in range(len(tokenized_text)):
        if tokenized_text[i] == '_':
            tokenized_text[i] = '[MASK]'
            mask_positions.append(i)

    # Convert tokens to vocab indices
    token_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([token_ids])

    # Call BERT to calculate unnormalized probabilities for all pos
    model.eval()
    predictions = model(tokens_tensor)

    # normalize by softmax
    predictions = F.softmax(predictions, dim=2)

    # For the target word position, get probabilities for each word of interest
    normalized = predictions[0, tgt_pos, :]
    out_prob = normalized[tokenizer.vocab[tgt_word]].item()

    # Also, fill in all blanks by max prob, and print for inspection
    for mask_pos in mask_positions:
        predicted_index = torch.argmax(predictions[0, mask_pos, :]).item()
        predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
        tokenized_text[mask_pos] = predicted_token

    for mask_pos in mask_positions:
        tokenized_text[mask_pos] = "_" + tokenized_text[mask_pos] + "_"
    pred_sent = ' '.join(tokenized_text).replace(' ##', '')
    # print(pred_sent)
    return out_prob, pred_sent