def main(data_path: str, saved_model_path: str) -> None: """The main training function""" if saved_model_path: global embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, \ use_bert_uncased, use_bert_large embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, use_bert_cased, use_bert_uncased, \ use_bert_large = load_hyper_params(saved_model_path) if use_bert_uncased or use_bert_cased: use_bert = True else: use_bert = False if use_bert: train_iter, \ val_iter, \ word_to_ix, \ ix_to_word, \ tag_vocab, \ char_to_ix = create_bert_datasets( data_path=data_path, mode=TRAIN, use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large ) vocab_size = None word_vocab = None else: train_iter, \ val_iter, \ word_vocab, \ tag_vocab, \ char_to_ix = create_datasets(data_path=data_path, mode=TRAIN) #char_to_ix gets added to automatically with any characters (e.g. < >) encountered during evaluation, but we want to #save the original copy so that the char embeddings para can be computed, hence we create a copy here. word_to_ix, ix_to_word = word_vocab.stoi, word_vocab.itos vocab_size = len(word_to_ix) tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos char_to_ix_original = copy.deepcopy(char_to_ix) word_vocab_original = copy.deepcopy(word_vocab) word_to_ix_original = copy.deepcopy(word_to_ix) ix_to_word_original = copy.deepcopy(ix_to_word) tag_vocab_original = copy.deepcopy(tag_vocab) model = LSTMTagger(embedding_dim=embedding_dim, hidden_dim=hidden_dim, vocab_size=vocab_size, tagset_size=len(tag_to_ix), char_embedding_dim=char_embedding_dim, char_hidden_dim=char_hidden_dim, char_vocab_size=len(char_to_ix), use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large) loss_function = CrossEntropyLoss(ignore_index=tag_to_ix['<pad>']) model.to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) if models_folder not in os.listdir(".."): os.mkdir(os.path.join("..", models_folder)) if saved_model_path: av_train_losses, \ av_eval_losses, \ checkpoint_epoch, \ best_accuracy, \ lowest_av_eval_loss, \ best_micro_precision, \ best_micro_recall, \ best_micro_f1, \ best_weighted_macro_precision, \ best_weighted_macro_recall, \ best_weighted_macro_f1 = load_model(model=model, saved_model_path=saved_model_path, optimizer=optimizer) model_file_name = os.path.split(saved_model_path)[1] else: checkpoint_epoch = 0 av_train_losses = [] av_eval_losses = [] lowest_av_eval_loss = 999999 model_file_name = strftime("%Y_%m_%d_%H_%M_%S.pt") #torch.autograd.set_detect_anomaly(True) print("training..\n") model.train() start_epoch = checkpoint_epoch + 1 end_epoch = checkpoint_epoch + num_epochs for epoch in range( start_epoch, end_epoch + 1): # again, normally you would NOT do 300 epochs, it is toy data model.train() print('===============================') print('\n======== Epoch {} / {} ========'.format(epoch, end_epoch)) batch_num = 0 train_losses = [] for batch in train_iter: batch_num += 1 if batch_num % 20 == 0 or batch_num == 1: if batch_num != 1: print( "\nAverage Training loss for epoch {} at end of batch {}: {}" .format(epoch, str(batch_num - 1), sum(train_losses) / len(train_losses))) print('\n======== at batch {} / {} ========'.format( batch_num, len(train_iter))) model.zero_grad() if use_bert: sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch sentences_in = sentences_in.to(device) attention_masks = attention_masks.to(device) targets = targets.to(device) max_length = (attention_masks != 0).max(0)[0].nonzero()[-1].item() + 1 if max_length < sentences_in.shape[1]: sentences_in = sentences_in[:, :max_length] attention_masks = attention_masks[:, :max_length] sent_batch_size = sentences_in.shape[0] original_sentences_split = [ sent.split() for sent in original_sentences ] word_batch_size = max( [len(sent) for sent in original_sentences_split]) sent_lengths = [item for item in map(len, token_start_idx)] else: word_batch_size = batch.sentence.shape[0] sent_batch_size = batch.sentence.shape[1] sentences_in = batch.sentence.permute(1, 0).to(device) targets = batch.tags.permute(1, 0).reshape( sent_batch_size * word_batch_size).to(device) attention_masks = None token_start_idx = None original_sentences_split = None sent_lengths = train_iter.sent_lengths[batch_num - 1] words_in = get_words_in( sentences_in=sentences_in, char_to_ix=char_to_ix, ix_to_word=ix_to_word, device=device, original_sentences_split=original_sentences_split) model.init_hidden(sent_batch_size=sent_batch_size, device=device) tag_logits = model(sentences=sentences_in, words=words_in, char_hidden_dim=char_hidden_dim, sent_lengths=sent_lengths, word_batch_size=word_batch_size, device=device, attention_masks=attention_masks, token_start_idx=token_start_idx) mask = targets != 1 loss = loss_function(tag_logits, targets) loss /= mask.float().sum() train_losses.append(loss.item()) loss.backward() optimizer.step() av_train_losses.append(sum(train_losses) / len(train_losses)) accuracy, av_eval_loss, micro_precision, micro_recall, micro_f1, weighted_macro_precision, \ weighted_macro_recall, weighted_macro_f1 = eval_model( model=model, loss_function=loss_function, val_iter=val_iter, char_to_ix=char_to_ix, ix_to_word=ix_to_word, ix_to_tag=ix_to_tag, av_eval_losses=av_eval_losses, use_bert=use_bert ) print_results(epoch, accuracy, av_eval_loss, micro_precision, micro_recall, micro_f1, weighted_macro_precision, weighted_macro_recall, weighted_macro_f1) if av_eval_losses[-1] < lowest_av_eval_loss: lowest_av_eval_loss = av_eval_losses[-1] best_accuracy, \ best_micro_precision, \ best_micro_recall, \ best_micro_f1, \ best_weighted_macro_precision, \ best_weighted_macro_recall, \ best_weighted_macro_f1 = accuracy, \ micro_precision, \ micro_recall, \ micro_f1, \ weighted_macro_precision, \ weighted_macro_recall, \ weighted_macro_f1 checkpoint_epoch = epoch save_model(epoch=checkpoint_epoch, model=model, optimizer=optimizer, av_train_losses=av_train_losses, av_eval_losses=av_eval_losses, model_file_name=model_file_name, word_to_ix=word_to_ix_original, ix_to_word=ix_to_word_original, word_vocab=word_vocab_original, tag_vocab=tag_vocab_original, char_to_ix=char_to_ix_original, models_folder=models_folder, embedding_dim=embedding_dim, char_embedding_dim=char_embedding_dim, hidden_dim=hidden_dim, char_hidden_dim=char_hidden_dim, accuracy=best_accuracy, av_eval_loss=lowest_av_eval_loss, micro_precision=best_micro_precision, micro_recall=best_micro_recall, micro_f1=best_micro_f1, weighted_macro_precision=best_weighted_macro_precision, weighted_macro_recall=best_weighted_macro_recall, weighted_macro_f1=best_weighted_macro_f1, use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large) print_results(epoch=checkpoint_epoch, accuracy=best_accuracy, av_eval_loss=lowest_av_eval_loss, micro_precision=best_micro_precision, micro_recall=best_micro_recall, micro_f1=best_micro_f1, weighted_macro_precision=best_weighted_macro_precision, weighted_macro_recall=best_weighted_macro_recall, weighted_macro_f1=best_weighted_macro_f1, final=True) plot_train_eval_loss(av_train_losses, av_eval_losses)
def main(data_path: str, saved_model_path: str) -> None: embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \ use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path) if use_bert_cased or use_bert_uncased: use_bert = True else: use_bert = False word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix( saved_model_path) tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos if use_bert: test_iter = create_bert_datasets(data_path=data_path, mode=TEST, use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large, tag_to_ix=tag_to_ix) vocab_size = None else: test_iter = create_datasets(data_path=data_path, mode=TEST, word_to_ix=word_to_ix, word_vocab=word_vocab, tag_vocab=tag_vocab) vocab_size = len(word_to_ix) char_to_ix_original = copy.deepcopy(char_to_ix) model = LSTMTagger(embedding_dim=embedding_dim, hidden_dim=hidden_dim, vocab_size=vocab_size, tagset_size=len(tag_to_ix), char_embedding_dim=char_embedding_dim, char_hidden_dim=char_hidden_dim, char_vocab_size=len(char_to_ix_original), use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large) loss_function = torch.nn.CrossEntropyLoss(ignore_index=tag_to_ix['<pad>']) load_model(model=model, saved_model_path=saved_model_path) model.to(device) #torch.autograd.set_detect_anomaly(True) print("testing model: " + saved_model_path + '\n') model.eval() y_pred = [] y_true = [] av_test_losses = [] with torch.no_grad(): batch_num = 0 test_losses = [] for batch in test_iter: batch_num += 1 if use_bert: sentences_in, attention_masks, token_start_idx, targets, original_sentences = batch sentences_in = sentences_in.to(device) attention_masks = attention_masks.to(device) targets = targets.to(device) max_length = (attention_masks != 0).max(0)[0].nonzero()[-1].item() + 1 if max_length < sentences_in.shape[1]: sentences_in = sentences_in[:, :max_length] attention_masks = attention_masks[:, :max_length] sent_batch_size = sentences_in.shape[0] original_sentences_split = [ sent.split() for sent in original_sentences ] word_batch_size = max( [len(sent) for sent in original_sentences_split]) sent_lengths = [item for item in map(len, token_start_idx)] else: word_batch_size = batch.sentence.shape[0] sent_batch_size = batch.sentence.shape[1] sentences_in = batch.sentence.permute(1, 0).to(device) targets = batch.tags.permute(1, 0).reshape( sent_batch_size * word_batch_size).to(device) attention_masks = None token_start_idx = None original_sentences_split = None sent_lengths = test_iter.sent_lengths[batch_num - 1] y_true += [ix_to_tag[ix.item()] for ix in targets] words_in = get_words_in( sentences_in=sentences_in, char_to_ix=char_to_ix, ix_to_word=ix_to_word, device=device, original_sentences_split=original_sentences_split) model.init_hidden(sent_batch_size, device=device) tag_logits = model(sentences=sentences_in, words=words_in, char_hidden_dim=char_hidden_dim, sent_lengths=sent_lengths, word_batch_size=word_batch_size, device=device, attention_masks=attention_masks, token_start_idx=token_start_idx) mask = targets != 1 test_loss = loss_function(tag_logits, targets) test_loss /= mask.float().sum() test_losses.append(test_loss.item()) pred = categoriesFromOutput(tag_logits, ix_to_tag) y_pred += pred av_test_losses.append(sum(test_losses) / len(test_losses)) y_true, y_pred = remove_pads(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred) micro_precision, micro_recall, micro_f1, support = precision_recall_fscore_support( y_true, y_pred, average='micro') weighted_macro_precision, weighted_macro_recall, weighted_macro_f1, _ = precision_recall_fscore_support( y_true, y_pred, average='weighted') av_test_loss = sum(test_losses) / len(test_losses) print("Test accuracy: {:.2f}%".format(accuracy * 100)) print("Average Test loss: {}".format(str(av_test_loss))) print("Micro Precision: {}".format(micro_precision)) print("Micro Recall: {}".format(micro_recall)) print("Micro F1: {}".format(micro_f1)) print("Weighted Macro Precision: {}".format(weighted_macro_precision)) print("Weighted Macro Recall: {}".format(weighted_macro_recall)) print("Weighted Macro F1: {}".format(weighted_macro_f1))
def main(data_path: str, dest_path: str, saved_model_path: str) -> None: """ The function for tagging the unseen data. """ embedding_dim, char_embedding_dim, hidden_dim, char_hidden_dim, \ use_bert_cased, use_bert_uncased, use_bert_large = load_hyper_params(saved_model_path) if use_bert_uncased: if use_bert_large: tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True) else: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) elif use_bert_cased: if use_bert_large: tokenizer = BertTokenizer.from_pretrained('bert-large-cased') else: tokenizer = BertTokenizer.from_pretrained('bert-base-cased') else: tokenizer = None if use_bert_uncased or use_bert_cased: use_bert = True else: use_bert = False word_to_ix, ix_to_word, word_vocab, tag_vocab, char_to_ix = load_vocab_and_char_to_ix( saved_model_path) tag_to_ix, ix_to_tag = tag_vocab.stoi, tag_vocab.itos if use_bert: vocab_size = None else: vocab_size = len(word_to_ix) sentences, sent_tensors = prepare_untagged_data( data_path=data_path, word_to_ix=copy.deepcopy(word_to_ix), device=device, tokenizer=tokenizer, use_bert=use_bert) dest_file = open(dest_path, 'w') model = LSTMTagger( embedding_dim=embedding_dim, hidden_dim=hidden_dim, vocab_size=vocab_size, tagset_size=len(tag_to_ix), char_embedding_dim=char_embedding_dim, char_hidden_dim=char_hidden_dim,\ char_vocab_size=len(char_to_ix), use_bert_cased=use_bert_cased, use_bert_uncased=use_bert_uncased, use_bert_large=use_bert_large ) load_model(model=model, saved_model_path=saved_model_path) #torch.autograd.set_detect_anomaly(True) model.to(device) print("\ntagging sentences using model: " + saved_model_path + '\n') model.eval() with torch.no_grad(): sent_num = 0 for sent_tensor in sent_tensors: sentence = sentences[sent_num] sent_num += 1 word_batch_size = len(sentence) sent_batch_size = 1 if use_bert: subwords = list(map(tokenizer.tokenize, sentence)) subword_lengths = list(map(len, subwords)) token_start_idx = [list(np.cumsum([0] + subword_lengths))[1:]] sent_lengths = [len(token_start_idx[0])] attention_masks = torch.tensor( [1, 1] + [1 for x in token_start_idx[0]]).unsqueeze(0).to(device) original_sentences_split = [sentence] else: attention_masks = None token_start_idx = None original_sentences_split = None sent_lengths = [len(sentence)] words_in = get_words_in( sentences_in=sent_tensor, char_to_ix=char_to_ix, ix_to_word=ix_to_word, device=device, original_sentences_split=original_sentences_split) model.init_hidden(sent_batch_size=sent_batch_size, device=device) tag_logits = model(sentences=sent_tensor, words=words_in, char_hidden_dim=char_hidden_dim, sent_lengths=sent_lengths, word_batch_size=word_batch_size, device=device, attention_masks=attention_masks, token_start_idx=token_start_idx) pred = categoriesFromOutput(tag_logits, ix_to_tag) dest_file.write(" ".join(sentence) + '\t' + " ".join(pred) + '\n') dest_file.close() print("tagging complete")