Exemplo n.º 1
0
def save_model(epoch: int,
               model: LSTMTagger,
               optimizer: Adam,
               av_train_losses: List[float],
               av_eval_losses: List[float],
               model_file_name: str,
               word_to_ix: Union[BertTokenToIx, defaultdict],
               ix_to_word: Union[BertIxToToken, defaultdict],
               word_vocab: Optional[Vocab],
               tag_vocab: Vocab,
               char_to_ix: DefaultDict[str, int],
               models_folder: str,
               embedding_dim: int,
               char_embedding_dim: int,
               hidden_dim: int,
               char_hidden_dim: int,
               accuracy: float64,
               av_eval_loss: float,
               micro_precision: float64,
               micro_recall: float64,
               micro_f1: float64,
               weighted_macro_precision: float64,
               weighted_macro_recall: float64,
               weighted_macro_f1: float64,
               use_bert_cased: bool,
               use_bert_uncased: bool,
               use_bert_large: bool
               ) -> None:
    try:
        os.remove(os.path.join("..", models_folder, model_file_name))
    except FileNotFoundError:
        pass
    torch.save({
            'checkpoint_epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'av_train_losses': av_train_losses,
            'av_eval_losses': av_eval_losses,
            'word_to_ix': word_to_ix,
            'ix_to_word': ix_to_word,
            'word_vocab': word_vocab,
            'tag_vocab': tag_vocab,
            'char_to_ix': char_to_ix,
            'embedding_dim': embedding_dim,
            'char_embedding_dim': char_embedding_dim,
            'hidden_dim': hidden_dim,
            'char_hidden_dim': char_hidden_dim,
            'accuracy': accuracy,
            'av_eval_loss': av_eval_loss,
            'micro_precision': micro_precision,
            'micro_recall': micro_recall,
            'micro_f1': micro_f1,
            'weighted_macro_precision': weighted_macro_precision,
            'weighted_macro_recall': weighted_macro_recall,
            'weighted_macro_f1': weighted_macro_f1,
            'use_bert_cased': use_bert_cased,
            'use_bert_uncased': use_bert_uncased,
            'use_bert_large': use_bert_large
    }, os.path.join("..", models_folder, model_file_name))
    print("Model with lowest average eval loss successfully saved as: "+os.path.join("..", models_folder, model_file_name))
Exemplo n.º 2
0
def load_model():
    word_to_idx = get_word_index()
    tag_to_idx = get_tag_index()
    idx_to_tag = [
        tag for tag, _ in sorted(tag_to_idx.items(), key=lambda x: x[1])
    ]
    model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_idx),
                       len(tag_to_idx))
    print(model)
    load_checkpoint("x.epoch", model)
    return model, word_to_idx, tag_to_idx, idx_to_tag
Exemplo n.º 3
0
def load_model(model: LSTMTagger,
               saved_model_path: str,
               optimizer: Optional[Adam] = None
               ) -> Optional[loadModelReturn]:
    print("Attempting to load saved model checkpoint from: " + saved_model_path)
    checkpoint = torch.load(saved_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Successfully loaded model..")
    if not optimizer:
        return None
    else:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['av_train_losses'], \
               checkpoint['av_eval_losses'], \
               checkpoint['checkpoint_epoch'], \
               checkpoint['accuracy'], \
               checkpoint['av_eval_loss'], \
               checkpoint['micro_precision'], \
               checkpoint['micro_recall'], \
               checkpoint['micro_f1'], \
               checkpoint['weighted_macro_precision'], \
               checkpoint['weighted_macro_recall'], \
               checkpoint['weighted_macro_f1']
    def makeMultilabelModel(self, modelName, num_labels=10, root='', **kwargs):
        if modelName in [
                'distilbert-base-uncased', 'distilbert2/', 'distilbert3/'
        ]:
            print(root)
            tokenizer = DistilBertTokenizerFast.from_pretrained(
                'distilbert-base-uncased')
            model = DistilBertForTokenClassification.from_pretrained(
                root + modelName, num_labels=num_labels, **kwargs)
        if modelName == 'bertweet':
            tokenizer = AutoTokenizer.from_pretrained('vinai/bertweet-base')
            model = AutoModelForTokenClassification.from_pretrained(
                root + "vinai/bertweet-base", num_labels=num_labels, **kwargs)
        if modelName == 'distilroberta-base':
            tokenizer = AutoTokenizer.from_pretrained('distilroberta-base',
                                                      add_prefix_space=True)
            model = AutoModelForTokenClassification.from_pretrained(
                root + "distilroberta-base", num_labels=num_labels, **kwargs)
        if modelName == 'lstm':
            tokenizer = AutoTokenizer.from_pretrained(
                'distilbert-base-uncased')
            model = LSTMTagger(128, 64, 2, tokenizer.vocab_size, num_labels)
        if modelName == 'albert-base-v2':
            tokenizer = AutoTokenizer.from_pretrained('albert-base-v2',
                                                      add_prefix_space=True)
            model = AutoModelForTokenClassification.from_pretrained(
                root + "albert-base-v2", num_labels=num_labels, **kwargs)
        if modelName in 'squeezebert/squeezebert-uncased':
            tokenizer = AutoTokenizer.from_pretrained(
                'squeezebert/squeezebert-uncased', add_prefix_space=True)
            model = AutoModelForTokenClassification.from_pretrained(
                root + "squeezebert/squeezebert-uncased",
                num_labels=num_labels,
                **kwargs)
        if modelName == 'xlnet-base-cased':
            tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased',
                                                      add_prefix_space=True)
            model = AutoModelForTokenClassification.from_pretrained(
                root + "xlnet-base-cased", num_labels=num_labels, **kwargs)

        return tokenizer, model
Exemplo n.º 5
0
def main(data_path: str, saved_model_path: str) -> None:
    embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \
    use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_cased or use_bert_uncased:
        use_bert = True
    else:
        use_bert = False
    word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix(
        saved_model_path)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    if use_bert:
        test_iter = create_bert_datasets(data_path=data_path,
                                         mode=TEST,
                                         use_bert_cased=use_bert_cased,
                                         use_bert_uncased=use_bert_uncased,
                                         use_bert_large=use_bert_large,
                                         tag_to_ix=tag_to_ix)
        vocab_size = None
    else:
        test_iter = create_datasets(data_path=data_path,
                                    mode=TEST,
                                    word_to_ix=word_to_ix,
                                    word_vocab=word_vocab,
                                    tag_vocab=tag_vocab)
        vocab_size = len(word_to_ix)
    char_to_ix_original = copy.deepcopy(char_to_ix)
    model = LSTMTagger(embedding_dim=embedding_dim,
                       hidden_dim=hidden_dim,
                       vocab_size=vocab_size,
                       tagset_size=len(tag_to_ix),
                       char_embedding_dim=char_embedding_dim,
                       char_hidden_dim=char_hidden_dim,
                       char_vocab_size=len(char_to_ix_original),
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    loss_function = torch.nn.CrossEntropyLoss(ignore_index=tag_to_ix['<pad>'])
    load_model(model=model, saved_model_path=saved_model_path)
    model.to(device)
    #torch.autograd.set_detect_anomaly(True)
    print("testing model: " + saved_model_path + '\n')
    model.eval()
    y_pred = []
    y_true = []
    av_test_losses = []
    with torch.no_grad():
        batch_num = 0
        test_losses = []
        for batch in test_iter:
            batch_num += 1
            if use_bert:
                sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch
                sentences_in = sentences_in.to(device)
                attention_masks = attention_masks.to(device)
                targets = targets.to(device)
                max_length = (attention_masks !=
                              0).max(0)[0].nonzero()[-1].item() + 1
                if max_length < sentences_in.shape[1]:
                    sentences_in = sentences_in[:, :max_length]
                    attention_masks = attention_masks[:, :max_length]
                sent_batch_size = sentences_in.shape[0]
                original_sentences_split = [
                    sent.split() for sent in original_sentences
                ]
                word_batch_size = max(
                    [len(sent) for sent in original_sentences_split])
                sent_lengths = [item for item in map(len, token_start_idx)]
            else:
                word_batch_size = batch.sentence.shape[0]
                sent_batch_size = batch.sentence.shape[1]
                sentences_in = batch.sentence.permute(1, 0).to(device)
                targets = batch.tags.permute(1, 0).reshape(
                    sent_batch_size * word_batch_size).to(device)
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = test_iter.sent_lengths[batch_num - 1]
            y_true += [ix_to_tag[ix.item()] for ix in targets]
            words_in = get_words_in(
                sentences_in=sentences_in,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size, device=device)
            tag_logits = model(sentences=sentences_in,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            mask = targets != 1
            test_loss = loss_function(tag_logits, targets)
            test_loss /= mask.float().sum()
            test_losses.append(test_loss.item())
            pred = categoriesFromOutput(tag_logits, ix_to_tag)
            y_pred += pred
        av_test_losses.append(sum(test_losses) / len(test_losses))
        y_true, y_pred = remove_pads(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        micro_precision, micro_recall, micro_f1, support = precision_recall_fscore_support(
            y_true, y_pred, average='micro')
        weighted_macro_precision, weighted_macro_recall, weighted_macro_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted')
        av_test_loss = sum(test_losses) / len(test_losses)
        print("Test accuracy: {:.2f}%".format(accuracy * 100))
        print("Average Test loss: {}".format(str(av_test_loss)))
        print("Micro Precision: {}".format(micro_precision))
        print("Micro Recall: {}".format(micro_recall))
        print("Micro F1: {}".format(micro_f1))
        print("Weighted Macro Precision: {}".format(weighted_macro_precision))
        print("Weighted Macro Recall: {}".format(weighted_macro_recall))
        print("Weighted Macro F1: {}".format(weighted_macro_f1))
Exemplo n.º 6
0
def main(data_path: str, saved_model_path: str) -> None:
    """The main training function"""
    if saved_model_path:
        global embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, \
            use_bert_uncased, use_bert_large
        embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, use_bert_uncased, \
        use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_uncased or use_bert_cased:
        use_bert = True
    else:
        use_bert = False
    if use_bert:
        train_iter, \
        val_iter, \
        word_to_ix, \
        ix_to_word, \
        tag_vocab, \
        char_to_ix = create_bert_datasets(
            data_path=data_path,
            mode=TRAIN,
            use_bert_cased=use_bert_cased,
            use_bert_uncased=use_bert_uncased,
            use_bert_large=use_bert_large
        )
        vocab_size = None
        word_vocab = None
    else:
        train_iter, \
        val_iter, \
        word_vocab, \
        tag_vocab, \
        char_to_ix = create_datasets(data_path=data_path, mode=TRAIN)
        #char_to_ix gets added to automatically with any characters (e.g. < >) encountered during evaluation, but we want to
        #save the original copy so that the char embeddings para can be computed, hence we create a copy here.
        word_to_ix, ix_to_word = word_vocab.stoi, word_vocab.itos
        vocab_size = len(word_to_ix)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    char_to_ix_original = copy.deepcopy(char_to_ix)
    word_vocab_original = copy.deepcopy(word_vocab)
    word_to_ix_original = copy.deepcopy(word_to_ix)
    ix_to_word_original = copy.deepcopy(ix_to_word)
    tag_vocab_original = copy.deepcopy(tag_vocab)
    model = LSTMTagger(embedding_dim=embedding_dim,
                       hidden_dim=hidden_dim,
                       vocab_size=vocab_size,
                       tagset_size=len(tag_to_ix),
                       char_embedding_dim=char_embedding_dim,
                       char_hidden_dim=char_hidden_dim,
                       char_vocab_size=len(char_to_ix),
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    loss_function = CrossEntropyLoss(ignore_index=tag_to_ix['<pad>'])
    model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=weight_decay)
    if models_folder not in os.listdir(".."):
        os.mkdir(os.path.join("..", models_folder))
    if saved_model_path:
        av_train_losses, \
        av_eval_losses, \
        checkpoint_epoch, \
        best_accuracy, \
        lowest_av_eval_loss, \
        best_micro_precision, \
        best_micro_recall, \
        best_micro_f1, \
        best_weighted_macro_precision, \
        best_weighted_macro_recall, \
        best_weighted_macro_f1 = load_model(model=model,
                                            saved_model_path=saved_model_path,
                                            optimizer=optimizer)
        model_file_name = os.path.split(saved_model_path)[1]
    else:
        checkpoint_epoch = 0
        av_train_losses = []
        av_eval_losses = []
        lowest_av_eval_loss = 999999
        model_file_name = strftime("%Y_%m_%d_%H_%M_%S.pt")
    #torch.autograd.set_detect_anomaly(True)
    print("training..\n")
    model.train()
    start_epoch = checkpoint_epoch + 1
    end_epoch = checkpoint_epoch + num_epochs
    for epoch in range(
            start_epoch, end_epoch +
            1):  # again, normally you would NOT do 300 epochs, it is toy data
        model.train()
        print('===============================')
        print('\n======== Epoch {} / {} ========'.format(epoch, end_epoch))
        batch_num = 0
        train_losses = []
        for batch in train_iter:
            batch_num += 1
            if batch_num % 20 == 0 or batch_num == 1:
                if batch_num != 1:
                    print(
                        "\nAverage Training loss for epoch {} at end of batch {}: {}"
                        .format(epoch, str(batch_num - 1),
                                sum(train_losses) / len(train_losses)))
                print('\n======== at batch {} / {} ========'.format(
                    batch_num, len(train_iter)))
            model.zero_grad()
            if use_bert:
                sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch
                sentences_in = sentences_in.to(device)
                attention_masks = attention_masks.to(device)
                targets = targets.to(device)
                max_length = (attention_masks !=
                              0).max(0)[0].nonzero()[-1].item() + 1
                if max_length < sentences_in.shape[1]:
                    sentences_in = sentences_in[:, :max_length]
                    attention_masks = attention_masks[:, :max_length]
                sent_batch_size = sentences_in.shape[0]
                original_sentences_split = [
                    sent.split() for sent in original_sentences
                ]
                word_batch_size = max(
                    [len(sent) for sent in original_sentences_split])
                sent_lengths = [item for item in map(len, token_start_idx)]
            else:
                word_batch_size = batch.sentence.shape[0]
                sent_batch_size = batch.sentence.shape[1]
                sentences_in = batch.sentence.permute(1, 0).to(device)
                targets = batch.tags.permute(1, 0).reshape(
                    sent_batch_size * word_batch_size).to(device)
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = train_iter.sent_lengths[batch_num - 1]
            words_in = get_words_in(
                sentences_in=sentences_in,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size=sent_batch_size, device=device)
            tag_logits = model(sentences=sentences_in,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            mask = targets != 1
            loss = loss_function(tag_logits, targets)
            loss /= mask.float().sum()
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
        av_train_losses.append(sum(train_losses) / len(train_losses))
        accuracy, av_eval_loss, micro_precision, micro_recall, micro_f1, weighted_macro_precision, \
        weighted_macro_recall, weighted_macro_f1 = eval_model(
            model=model,
            loss_function=loss_function,
            val_iter=val_iter,
            char_to_ix=char_to_ix,
            ix_to_word=ix_to_word,
            ix_to_tag=ix_to_tag,
            av_eval_losses=av_eval_losses,
            use_bert=use_bert
        )
        print_results(epoch, accuracy, av_eval_loss, micro_precision,
                      micro_recall, micro_f1, weighted_macro_precision,
                      weighted_macro_recall, weighted_macro_f1)
        if av_eval_losses[-1] < lowest_av_eval_loss:
            lowest_av_eval_loss = av_eval_losses[-1]
            best_accuracy, \
            best_micro_precision, \
            best_micro_recall, \
            best_micro_f1, \
            best_weighted_macro_precision, \
            best_weighted_macro_recall, \
            best_weighted_macro_f1 = accuracy, \
                                     micro_precision, \
                                     micro_recall, \
                                     micro_f1, \
                                     weighted_macro_precision, \
                                     weighted_macro_recall, \
                                     weighted_macro_f1
            checkpoint_epoch = epoch
            save_model(epoch=checkpoint_epoch,
                       model=model,
                       optimizer=optimizer,
                       av_train_losses=av_train_losses,
                       av_eval_losses=av_eval_losses,
                       model_file_name=model_file_name,
                       word_to_ix=word_to_ix_original,
                       ix_to_word=ix_to_word_original,
                       word_vocab=word_vocab_original,
                       tag_vocab=tag_vocab_original,
                       char_to_ix=char_to_ix_original,
                       models_folder=models_folder,
                       embedding_dim=embedding_dim,
                       char_embedding_dim=char_embedding_dim,
                       hidden_dim=hidden_dim,
                       char_hidden_dim=char_hidden_dim,
                       accuracy=best_accuracy,
                       av_eval_loss=lowest_av_eval_loss,
                       micro_precision=best_micro_precision,
                       micro_recall=best_micro_recall,
                       micro_f1=best_micro_f1,
                       weighted_macro_precision=best_weighted_macro_precision,
                       weighted_macro_recall=best_weighted_macro_recall,
                       weighted_macro_f1=best_weighted_macro_f1,
                       use_bert_cased=use_bert_cased,
                       use_bert_uncased=use_bert_uncased,
                       use_bert_large=use_bert_large)
    print_results(epoch=checkpoint_epoch,
                  accuracy=best_accuracy,
                  av_eval_loss=lowest_av_eval_loss,
                  micro_precision=best_micro_precision,
                  micro_recall=best_micro_recall,
                  micro_f1=best_micro_f1,
                  weighted_macro_precision=best_weighted_macro_precision,
                  weighted_macro_recall=best_weighted_macro_recall,
                  weighted_macro_f1=best_weighted_macro_f1,
                  final=True)
    plot_train_eval_loss(av_train_losses, av_eval_losses)
Exemplo n.º 7
0
def eval_model(model: LSTMTagger, loss_function: CrossEntropyLoss,
               val_iter: BucketIterator, char_to_ix: DefaultDict[str, int],
               ix_to_word: List[str], ix_to_tag: List[str],
               av_eval_losses: List[str], use_bert: bool) -> evalModelReturn:
    """
    Function for evaluating the model being trained.
    """
    model.eval()
    y_pred = []
    y_true = []
    print("\nEvaluating model...")
    with torch.no_grad():
        batch_num = 0
        eval_losses = []
        for batch in val_iter:
            batch_num += 1
            if use_bert:
                sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch
                sentences_in = sentences_in.to(device)
                attention_masks = attention_masks.to(device)
                targets = targets.to(device)
                max_length = (attention_masks !=
                              0).max(0)[0].nonzero()[-1].item() + 1
                if max_length < sentences_in.shape[1]:
                    sentences_in = sentences_in[:, :max_length]
                    attention_masks = attention_masks[:, :max_length]
                sent_batch_size = sentences_in.shape[0]
                original_sentences_split = [
                    sent.split() for sent in original_sentences
                ]
                word_batch_size = max(
                    [len(sent) for sent in original_sentences_split])
                sent_lengths = [item for item in map(len, token_start_idx)]
            else:
                word_batch_size = batch.sentence.shape[0]
                sent_batch_size = batch.sentence.shape[1]
                sentences_in = batch.sentence.permute(1, 0).to(device)
                targets = batch.tags.permute(1, 0).reshape(
                    sent_batch_size * word_batch_size).to(device)
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = val_iter.sent_lengths[batch_num - 1]
            y_true += [ix_to_tag[ix.item()] for ix in targets]
            words_in = get_words_in(
                sentences_in=sentences_in,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size=sent_batch_size, device=device)
            tag_logits = model(sentences=sentences_in,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            eval_loss = loss_function(tag_logits, targets)
            mask = targets != 1
            eval_loss /= mask.float().sum()
            eval_losses.append(eval_loss.item())
            pred = categoriesFromOutput(tag_logits, ix_to_tag)
            y_pred += pred
        av_eval_losses.append(sum(eval_losses) / len(eval_losses))
        y_true, y_pred = remove_pads(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        micro_precision, micro_recall, micro_f1, support = precision_recall_fscore_support(
            y_true, y_pred, average='micro')
        weighted_macro_precision, weighted_macro_recall, weighted_macro_f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted')
        av_eval_loss = sum(eval_losses) / len(eval_losses)

    return accuracy, av_eval_loss, micro_precision, micro_recall, micro_f1, weighted_macro_precision, \
           weighted_macro_recall, weighted_macro_f1
Exemplo n.º 8
0
def main(data_path: str, dest_path: str, saved_model_path: str) -> None:
    """
    The function for tagging the unseen data.
    """
    embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \
    use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path)
    if use_bert_uncased:
        if use_bert_large:
            tokenizer = BertTokenizer.from_pretrained('bert-large-uncased',
                                                      do_lower_case=True)
        else:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                                      do_lower_case=True)
    elif use_bert_cased:
        if use_bert_large:
            tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
        else:
            tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    else:
        tokenizer = None
    if use_bert_uncased or use_bert_cased:
        use_bert = True
    else:
        use_bert = False
    word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix(
        saved_model_path)
    tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos
    if use_bert:
        vocab_size = None
    else:
        vocab_size = len(word_to_ix)
    sentences, sent_tensors = prepare_untagged_data(
        data_path=data_path,
        word_to_ix=copy.deepcopy(word_to_ix),
        device=device,
        tokenizer=tokenizer,
        use_bert=use_bert)
    dest_file = open(dest_path, 'w')
    model = LSTMTagger(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        vocab_size=vocab_size,
        tagset_size=len(tag_to_ix),
        char_embedding_dim=char_embedding_dim,
        char_hidden_dim=char_hidden_dim,\
        char_vocab_size=len(char_to_ix),
        use_bert_cased=use_bert_cased,
        use_bert_uncased=use_bert_uncased,
        use_bert_large=use_bert_large
    )
    load_model(model=model, saved_model_path=saved_model_path)
    #torch.autograd.set_detect_anomaly(True)
    model.to(device)
    print("\ntagging sentences using model: " + saved_model_path + '\n')
    model.eval()
    with torch.no_grad():
        sent_num = 0
        for sent_tensor in sent_tensors:
            sentence = sentences[sent_num]
            sent_num += 1
            word_batch_size = len(sentence)
            sent_batch_size = 1
            if use_bert:
                subwords = list(map(tokenizer.tokenize, sentence))
                subword_lengths = list(map(len, subwords))
                token_start_idx = [list(np.cumsum([0] + subword_lengths))[1:]]
                sent_lengths = [len(token_start_idx[0])]
                attention_masks = torch.tensor(
                    [1, 1] +
                    [1 for x in token_start_idx[0]]).unsqueeze(0).to(device)
                original_sentences_split = [sentence]
            else:
                attention_masks = None
                token_start_idx = None
                original_sentences_split = None
                sent_lengths = [len(sentence)]
            words_in = get_words_in(
                sentences_in=sent_tensor,
                char_to_ix=char_to_ix,
                ix_to_word=ix_to_word,
                device=device,
                original_sentences_split=original_sentences_split)
            model.init_hidden(sent_batch_size=sent_batch_size, device=device)
            tag_logits = model(sentences=sent_tensor,
                               words=words_in,
                               char_hidden_dim=char_hidden_dim,
                               sent_lengths=sent_lengths,
                               word_batch_size=word_batch_size,
                               device=device,
                               attention_masks=attention_masks,
                               token_start_idx=token_start_idx)
            pred = categoriesFromOutput(tag_logits, ix_to_tag)
            dest_file.write(" ".join(sentence) + '\t' + " ".join(pred) + '\n')
    dest_file.close()
    print("tagging complete")