Exemplo n.º 1
0
def main(data_path: str, saved_model_path: str) -> None:
    """The main training function"""
    if saved_model_path:
        global embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, \
            use_bert_uncased, use_bert_large
        embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, use_bert_uncased, \
        use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_uncased or use_bert_cased:
        use_bert = True
    else:
        use_bert = False
    if use_bert:
        train_iter, \
        val_iter, \
        word_to_ix, \
        ix_to_word, \
        tag_vocab, \
        char_to_ix = create_bert_datasets(
            data_path=data_path,
            mode=TRAIN,
            use_bert_cased=use_bert_cased,
            use_bert_uncased=use_bert_uncased,
            use_bert_large=use_bert_large
        )
        vocab_size = None
        word_vocab = None
    else:
        train_iter, \
        val_iter, \
        word_vocab, \
        tag_vocab, \
        char_to_ix = create_datasets(data_path=data_path, mode=TRAIN)
        #char_to_ix gets added to automatically with any characters (e.g. < >) encountered during evaluation, but we want to
        #save the original copy so that the char embeddings para can be computed, hence we create a copy here.
        word_to_ix, ix_to_word = word_vocab.stoi, word_vocab.itos
        vocab_size = len(word_to_ix)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    char_to_ix_original = copy.deepcopy(char_to_ix)
    word_vocab_original = copy.deepcopy(word_vocab)
    word_to_ix_original = copy.deepcopy(word_to_ix)
    ix_to_word_original = copy.deepcopy(ix_to_word)
    tag_vocab_original = copy.deepcopy(tag_vocab)
    model = LSTMTagger(embedding_dim=embedding_dim,
                       hidden_dim=hidden_dim,
                       vocab_size=vocab_size,
                       tagset_size=len(tag_to_ix),
                       char_embedding_dim=char_embedding_dim,
                       char_hidden_dim=char_hidden_dim,
                       char_vocab_size=len(char_to_ix),
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    loss_function = CrossEntropyLoss(ignore_index=tag_to_ix['<pad>'])
    model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=weight_decay)
    if models_folder not in os.listdir(".."):
        os.mkdir(os.path.join("..", models_folder))
    if saved_model_path:
        av_train_losses, \
        av_eval_losses, \
        checkpoint_epoch, \
        best_accuracy, \
        lowest_av_eval_loss, \
        best_micro_precision, \
        best_micro_recall, \
        best_micro_f1, \
        best_weighted_macro_precision, \
        best_weighted_macro_recall, \
        best_weighted_macro_f1 = load_model(model=model,
                                            saved_model_path=saved_model_path,
                                            optimizer=optimizer)
        model_file_name = os.path.split(saved_model_path)[1]
    else:
        checkpoint_epoch = 0
        av_train_losses = []
        av_eval_losses = []
        lowest_av_eval_loss = 999999
        model_file_name = strftime("%Y_%m_%d_%H_%M_%S.pt")
    #torch.autograd.set_detect_anomaly(True)
    print("training..\n")
    model.train()
    start_epoch = checkpoint_epoch + 1
    end_epoch = checkpoint_epoch + num_epochs
    for epoch in range(
            start_epoch, end_epoch +
            1):  # again, normally you would NOT do 300 epochs, it is toy data
        model.train()
        print('===============================')
        print('\n======== Epoch {} / {} ========'.format(epoch, end_epoch))
        batch_num = 0
        train_losses = []
        for batch in train_iter:
            batch_num += 1
            if batch_num % 20 == 0 or batch_num == 1:
                if batch_num != 1:
                    print(
                        "\nAverage Training loss for epoch {} at end of batch {}: {}"
                        .format(epoch, str(batch_num - 1),
                                sum(train_losses) / len(train_losses)))
                print('\n======== at batch {} / {} ========'.format(
                    batch_num, len(train_iter)))
            model.zero_grad()
            if use_bert:
                sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch
                sentences_in = sentences_in.to(device)
                attention_masks = attention_masks.to(device)
                targets = targets.to(device)
                max_length = (attention_masks !=
                              0).max(0)[0].nonzero()[-1].item() + 1
                if max_length < sentences_in.shape[1]:
                    sentences_in = sentences_in[:, :max_length]
                    attention_masks = attention_masks[:, :max_length]
                sent_batch_size = sentences_in.shape[0]
                original_sentences_split = [
                    sent.split() for sent in original_sentences
                ]
                word_batch_size = max(
                    [len(sent) for sent in original_sentences_split])
                sent_lengths = [item for item in map(len, token_start_idx)]
            else:
                word_batch_size = batch.sentence.shape[0]
                sent_batch_size = batch.sentence.shape[1]
                sentences_in = batch.sentence.permute(1, 0).to(device)
                targets = batch.tags.permute(1, 0).reshape(
                    sent_batch_size * word_batch_size).to(device)
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = train_iter.sent_lengths[batch_num - 1]
            words_in = get_words_in(
                sentences_in=sentences_in,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size=sent_batch_size, device=device)
            tag_logits = model(sentences=sentences_in,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            mask = targets != 1
            loss = loss_function(tag_logits, targets)
            loss /= mask.float().sum()
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
        av_train_losses.append(sum(train_losses) / len(train_losses))
        accuracy, av_eval_loss, micro_precision, micro_recall, micro_f1, weighted_macro_precision, \
        weighted_macro_recall, weighted_macro_f1 = eval_model(
            model=model,
            loss_function=loss_function,
            val_iter=val_iter,
            char_to_ix=char_to_ix,
            ix_to_word=ix_to_word,
            ix_to_tag=ix_to_tag,
            av_eval_losses=av_eval_losses,
            use_bert=use_bert
        )
        print_results(epoch, accuracy, av_eval_loss, micro_precision,
                      micro_recall, micro_f1, weighted_macro_precision,
                      weighted_macro_recall, weighted_macro_f1)
        if av_eval_losses[-1] < lowest_av_eval_loss:
            lowest_av_eval_loss = av_eval_losses[-1]
            best_accuracy, \
            best_micro_precision, \
            best_micro_recall, \
            best_micro_f1, \
            best_weighted_macro_precision, \
            best_weighted_macro_recall, \
            best_weighted_macro_f1 = accuracy, \
                                     micro_precision, \
                                     micro_recall, \
                                     micro_f1, \
                                     weighted_macro_precision, \
                                     weighted_macro_recall, \
                                     weighted_macro_f1
            checkpoint_epoch = epoch
            save_model(epoch=checkpoint_epoch,
                       model=model,
                       optimizer=optimizer,
                       av_train_losses=av_train_losses,
                       av_eval_losses=av_eval_losses,
                       model_file_name=model_file_name,
                       word_to_ix=word_to_ix_original,
                       ix_to_word=ix_to_word_original,
                       word_vocab=word_vocab_original,
                       tag_vocab=tag_vocab_original,
                       char_to_ix=char_to_ix_original,
                       models_folder=models_folder,
                       embedding_dim=embedding_dim,
                       char_embedding_dim=char_embedding_dim,
                       hidden_dim=hidden_dim,
                       char_hidden_dim=char_hidden_dim,
                       accuracy=best_accuracy,
                       av_eval_loss=lowest_av_eval_loss,
                       micro_precision=best_micro_precision,
                       micro_recall=best_micro_recall,
                       micro_f1=best_micro_f1,
                       weighted_macro_precision=best_weighted_macro_precision,
                       weighted_macro_recall=best_weighted_macro_recall,
                       weighted_macro_f1=best_weighted_macro_f1,
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    print_results(epoch=checkpoint_epoch,
                  accuracy=best_accuracy,
                  av_eval_loss=lowest_av_eval_loss,
                  micro_precision=best_micro_precision,
                  micro_recall=best_micro_recall,
                  micro_f1=best_micro_f1,
                  weighted_macro_precision=best_weighted_macro_precision,
                  weighted_macro_recall=best_weighted_macro_recall,
                  weighted_macro_f1=best_weighted_macro_f1,
                  final=True)
    plot_train_eval_loss(av_train_losses, av_eval_losses)
Exemplo n.º 2
0
def main(data_path: str, saved_model_path: str) -> None:
    embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \
    use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_cased or use_bert_uncased:
        use_bert = True
    else:
        use_bert = False
    word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix(
        saved_model_path)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    if use_bert:
        test_iter = create_bert_datasets(data_path=data_path,
                                         mode=TEST,
                                         use_bert_cased=use_bert_cased,
                                         use_bert_uncased=use_bert_uncased,
                                         use_bert_large=use_bert_large,
                                         tag_to_ix=tag_to_ix)
        vocab_size = None
    else:
        test_iter = create_datasets(data_path=data_path,
                                    mode=TEST,
                                    word_to_ix=word_to_ix,
                                    word_vocab=word_vocab,
                                    tag_vocab=tag_vocab)
        vocab_size = len(word_to_ix)
    char_to_ix_original = copy.deepcopy(char_to_ix)
    model = LSTMTagger(embedding_dim=embedding_dim,
                       hidden_dim=hidden_dim,
                       vocab_size=vocab_size,
                       tagset_size=len(tag_to_ix),
                       char_embedding_dim=char_embedding_dim,
                       char_hidden_dim=char_hidden_dim,
                       char_vocab_size=len(char_to_ix_original),
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    loss_function = torch.nn.CrossEntropyLoss(ignore_index=tag_to_ix['<pad>'])
    load_model(model=model, saved_model_path=saved_model_path)
    model.to(device)
    #torch.autograd.set_detect_anomaly(True)
    print("testing model: " + saved_model_path + '\n')
    model.eval()
    y_pred = []
    y_true = []
    av_test_losses = []
    with torch.no_grad():
        batch_num = 0
        test_losses = []
        for batch in test_iter:
            batch_num += 1
            if use_bert:
                sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch
                sentences_in = sentences_in.to(device)
                attention_masks = attention_masks.to(device)
                targets = targets.to(device)
                max_length = (attention_masks !=
                              0).max(0)[0].nonzero()[-1].item() + 1
                if max_length < sentences_in.shape[1]:
                    sentences_in = sentences_in[:, :max_length]
                    attention_masks = attention_masks[:, :max_length]
                sent_batch_size = sentences_in.shape[0]
                original_sentences_split = [
                    sent.split() for sent in original_sentences
                ]
                word_batch_size = max(
                    [len(sent) for sent in original_sentences_split])
                sent_lengths = [item for item in map(len, token_start_idx)]
            else:
                word_batch_size = batch.sentence.shape[0]
                sent_batch_size = batch.sentence.shape[1]
                sentences_in = batch.sentence.permute(1, 0).to(device)
                targets = batch.tags.permute(1, 0).reshape(
                    sent_batch_size * word_batch_size).to(device)
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = test_iter.sent_lengths[batch_num - 1]
            y_true += [ix_to_tag[ix.item()] for ix in targets]
            words_in = get_words_in(
                sentences_in=sentences_in,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size, device=device)
            tag_logits = model(sentences=sentences_in,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            mask = targets != 1
            test_loss = loss_function(tag_logits, targets)
            test_loss /= mask.float().sum()
            test_losses.append(test_loss.item())
            pred = categoriesFromOutput(tag_logits, ix_to_tag)
            y_pred += pred
        av_test_losses.append(sum(test_losses) / len(test_losses))
        y_true, y_pred = remove_pads(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        micro_precision, micro_recall, micro_f1, support = precision_recall_fscore_support(
            y_true, y_pred, average='micro')
        weighted_macro_precision, weighted_macro_recall, weighted_macro_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted')
        av_test_loss = sum(test_losses) / len(test_losses)
        print("Test accuracy: {:.2f}%".format(accuracy * 100))
        print("Average Test loss: {}".format(str(av_test_loss)))
        print("Micro Precision: {}".format(micro_precision))
        print("Micro Recall: {}".format(micro_recall))
        print("Micro F1: {}".format(micro_f1))
        print("Weighted Macro Precision: {}".format(weighted_macro_precision))
        print("Weighted Macro Recall: {}".format(weighted_macro_recall))
        print("Weighted Macro F1: {}".format(weighted_macro_f1))
Exemplo n.º 3
0
def main(data_path: str, dest_path: str, saved_model_path: str) -> None:
    """
    The function for tagging the unseen data.
    """
    embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \
    use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_uncased:
        if use_bert_large:
            tokenizer = BertTokenizer.from_pretrained('bert-large-uncased',
                                                      do_lower_case=True)
        else:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                                      do_lower_case=True)
    elif use_bert_cased:
        if use_bert_large:
            tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
        else:
            tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    else:
        tokenizer = None
    if use_bert_uncased or use_bert_cased:
        use_bert = True
    else:
        use_bert = False
    word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix(
        saved_model_path)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    if use_bert:
        vocab_size = None
    else:
        vocab_size = len(word_to_ix)
    sentences, sent_tensors = prepare_untagged_data(
        data_path=data_path,
        word_to_ix=copy.deepcopy(word_to_ix),
        device=device,
        tokenizer=tokenizer,
        use_bert=use_bert)
    dest_file = open(dest_path, 'w')
    model = LSTMTagger(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        vocab_size=vocab_size,
        tagset_size=len(tag_to_ix),
        char_embedding_dim=char_embedding_dim,
        char_hidden_dim=char_hidden_dim,\
        char_vocab_size=len(char_to_ix),
        use_bert_cased=use_bert_cased,
        use_bert_uncased=use_bert_uncased,
        use_bert_large=use_bert_large
    )
    load_model(model=model, saved_model_path=saved_model_path)
    #torch.autograd.set_detect_anomaly(True)
    model.to(device)
    print("\ntagging sentences using model: " + saved_model_path + '\n')
    model.eval()
    with torch.no_grad():
        sent_num = 0
        for sent_tensor in sent_tensors:
            sentence = sentences[sent_num]
            sent_num += 1
            word_batch_size = len(sentence)
            sent_batch_size = 1
            if use_bert:
                subwords = list(map(tokenizer.tokenize, sentence))
                subword_lengths = list(map(len, subwords))
                token_start_idx = [list(np.cumsum([0] + subword_lengths))[1:]]
                sent_lengths = [len(token_start_idx[0])]
                attention_masks = torch.tensor(
                    [1, 1] +
                    [1 for x in token_start_idx[0]]).unsqueeze(0).to(device)
                original_sentences_split = [sentence]
            else:
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = [len(sentence)]
            words_in = get_words_in(
                sentences_in=sent_tensor,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size=sent_batch_size, device=device)
            tag_logits = model(sentences=sent_tensor,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            pred = categoriesFromOutput(tag_logits, ix_to_tag)
            dest_file.write(" ".join(sentence) + '\t' + " ".join(pred) + '\n')
    dest_file.close()
    print("tagging complete")