def inference():
    step = sys.argv[1]
    encoder_config = BertConfig.from_pretrained("monologg/kobert")
    decoder_config = BertConfig.from_pretrained("monologg/kobert")
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        encoder_config, decoder_config)

    tokenizer = KoBertTokenizer()
    model = EncoderDecoderModel(config=config)
    ckpt = "model.pt"
    device = "cuda"

    model.load_state_dict(
        torch.load(f"saved/{ckpt}.{step}", map_location="cuda"),
        strict=True,
    )

    model = model.half().eval().to(device)
    test_data = open("dataset/abstractive_test_v2.jsonl",
                     "r").read().splitlines()
    submission = open(f"submission_{step}.csv", "w")

    test_set = []
    for data in test_data:
        data = json.loads(data)
        article_original = data["article_original"]
        article_original = " ".join(article_original)
        news_id = data["id"]
        test_set.append((news_id, article_original))

    for i, (news_id, text) in tqdm(enumerate(test_set)):
        tokens = tokenizer.encode_batch([text], max_length=512)
        generated = model.generate(
            input_ids=tokens["input_ids"].to(device),
            attention_mask=tokens["attention_mask"].to(device),
            use_cache=True,
            bos_token_id=tokenizer.token2idx["[CLS]"],
            eos_token_id=tokenizer.token2idx["[SEP]"],
            pad_token_id=tokenizer.token2idx["[PAD]"],
            num_beams=12,
            do_sample=False,
            temperature=1.0,
            no_repeat_ngram_size=3,
            bad_words_ids=[[tokenizer.token2idx["[UNK]"]]],
            length_penalty=1.0,
            repetition_penalty=1.5,
            max_length=512,
        )

        output = tokenizer.decode_batch(generated.tolist())[0]
        submission.write(f"{news_id},{output}" + "\n")
        print(news_id, output)
def sample_generate(top_k=50,
                    temperature=1.0,
                    model_path='/content/BERT checkpoints/model-9.pth',
                    gpu_id=0):
    # make sure your model is on GPU
    device = torch.device(f"cuda:{gpu_id}")

    # ------------------------LOAD MODEL-----------------
    print('load the model....')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_encoder = BertConfig.from_pretrained('bert-base-uncased')
    bert_decoder = BertConfig.from_pretrained('bert-base-uncased',
                                              is_decoder=True)
    config = EncoderDecoderConfig.from_encoder_decoder_configs(
        bert_encoder, bert_decoder)
    model = EncoderDecoderModel(config)
    model.load_state_dict(torch.load(model_path, map_location='cuda'))
    model = model.to(device)
    encoder = model.get_encoder()
    decoder = model.get_decoder()
    model.eval()

    print('load success')
    # ------------------------END LOAD MODEL--------------

    # ------------------------LOAD VALIDATE DATA------------------
    test_data = torch.load("/content/test_data.pth")
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 shuffle=False,
                                 batch_size=1)
    # ------------------------END LOAD VALIDATE DATA--------------

    # ------------------------START GENERETE-------------------
    update_count = 0

    bleu_2scores = 0
    bleu_4scores = 0
    nist_2scores = 0
    nist_4scores = 0
    sentences = []
    meteor_scores = 0

    print('start generating....')
    for batch in test_dataloader:
        with torch.no_grad():
            batch = [item.to(device) for item in batch]

            encoder_input, decoder_input, mask_encoder_input, _ = batch

            past, _ = encoder(encoder_input, mask_encoder_input)

            prev_pred = decoder_input[:, :1]
            sentence = prev_pred

            # decoding loop
            for i in range(100):
                logits = decoder(sentence, encoder_hidden_states=past)

                logits = logits[0][:, -1]
                logits = logits.squeeze(1) / temperature

                logits = top_k_logits(logits, k=top_k)
                probs = F.softmax(logits, dim=-1)
                prev_pred = torch.multinomial(probs, num_samples=1)
                sentence = torch.cat([sentence, prev_pred], dim=-1)
                if prev_pred[0][0] == 102:
                    break

            predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist())

            encoder_input = encoder_input.squeeze(dim=0)
            encoder_input_num = (encoder_input != 0).sum()
            inputs = tokenizer.convert_ids_to_tokens(
                encoder_input[:encoder_input_num].tolist())

            decoder_input = decoder_input.squeeze(dim=0)
            decoder_input_num = (decoder_input != 0).sum()

            reference = tokenizer.convert_ids_to_tokens(
                decoder_input[:decoder_input_num].tolist())
            print('-' * 20 + f"example {update_count}" + '-' * 20)
            print(f"input: {' '.join(inputs)}")
            print(f"output: {' '.join(reference)}")
            print(f"predict: {' '.join(predict)}")

            temp_bleu_2, \
            temp_bleu_4, \
            temp_nist_2, \
            temp_nist_4, \
            temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1])

            bleu_2scores += temp_bleu_2
            bleu_4scores += temp_bleu_4
            nist_2scores += temp_nist_2
            nist_4scores += temp_nist_4

            meteor_scores += temp_meteor_scores
            sentences.append(" ".join(predict[1:-1]))
            update_count += 1

    entro, dist = cal_entropy(sentences)
    mean_len, var_len = cal_length(sentences)
    print(f'avg: {mean_len}, var: {var_len}')
    print(f'entro: {entro}')
    print(f'dist: {dist}')
    print(f'test bleu_2scores: {bleu_2scores / update_count}')
    print(f'test bleu_4scores: {bleu_4scores / update_count}')
    print(f'test nist_2scores: {nist_2scores / update_count}')
    print(f'test nist_4scores: {nist_4scores / update_count}')
    print(f'test meteor_scores: {meteor_scores / update_count}')
Beispiel #3
0
class PhonetizerModel:

    phon_tokenizer = {
        'e': 7,
        'i': 8,
        'R': 9,
        'a': 10,
        'o': 11,
        't': 12,
        's': 13,
        'l': 14,
        'k': 15,
        'p': 16,
        'm': 17,
        'n': 18,
        'd': 19,
        'y': 20,
        '@': 21,
        'f': 22,
        'z': 23,
        'b': 24,
        '§': 25,
        'v': 26,
        '2': 27,
        '1': 28,
        'Z': 29,
        'g': 30,
        'u': 31,
        'S': 32
    }
    phon_untokenizer = {v: k for k, v in phon_tokenizer.items()}
    char_tokenizer = {
        'e': 7,
        'i': 8,
        'a': 9,
        'r': 10,
        'o': 11,
        's': 12,
        't': 13,
        'n': 14,
        'l': 15,
        'é': 16,
        'c': 17,
        'p': 18,
        'u': 19,
        'm': 20,
        'd': 21,
        '-': 22,
        'h': 23,
        'g': 24,
        'b': 25,
        'v': 26,
        'f': 27,
        'k': 28,
        'y': 29,
        'x': 30,
        'è': 31,
        'ï': 32,
        'j': 33,
        'z': 34,
        'w': 35,
        'q': 36
    }

    def __init__(self, device='cpu', model=None):
        vocabsize = 37
        max_length = 50
        encoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1)
        encoder = BertModel(config=encoder_config)

        vocabsize = 33
        max_length = 50
        decoder_config = BertConfig(vocab_size=vocabsize,
                                    max_position_embeddings=max_length + 64,
                                    num_attention_heads=4,
                                    num_hidden_layers=4,
                                    hidden_size=128,
                                    type_vocab_size=1,
                                    add_cross_attentions=True,
                                    is_decoder=True)
        decoder_config.add_cross_attention = True
        decoder = BertLMHeadModel(config=decoder_config)

        # Define encoder decoder model
        self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
        self.model.to(device)
        self.device = device
        if model is not None:
            self.model.load_state_dict(torch.load(model))

    def phonetize(self, word):
        word = word.replace('à', 'a')
        word = word.replace('û', 'u')
        word = word.replace('ù', 'u')
        word = word.replace('î', 'i')
        word = word.replace('ç', 'ss')
        word = word.replace('ô', 'o')
        word = word.replace('â', 'a')
        word = word.replace('qu', 'k')
        word = word.replace('ê', 'e')
        assert set(word).issubset(set(PhonetizerModel.char_tokenizer.keys()))
        encoded = torch.tensor(
            [0] + [PhonetizerModel.char_tokenizer[p] for p in word] + [2])
        output = self.model.generate(
            encoded.unsqueeze(0).to(self.device),
            max_length=50,
            decoder_start_token_id=0,
            eos_token_id=2,
            pad_token_id=1,
        ).detach().cpu().numpy()[0]
        bound = np.where(output == 2)[0][0] if 2 in output else 1000
        phon_pred = ''.join([
            PhonetizerModel.phon_untokenizer[c] for c in output[:bound]
            if c > 6
        ])
        return phon_pred

    def check_phonetization_error(self, word, phon):
        prediction = self.phonetize(word)[:5]
        score = pairwise2.align.globalms(list(phon[:5]),
                                         list(prediction),
                                         2,
                                         -1,
                                         -1,
                                         -.5,
                                         score_only=True,
                                         gap_char=['-']) / len(phon[:5])
        return score