def train(args): logger = get_logger("Parsing") args.cuda = torch.cuda.is_available() device = torch.device('cuda', 0) if args.cuda else torch.device('cpu') train_path = args.train dev_path = args.dev test_path = args.test num_epochs = args.num_epochs batch_size = args.batch_size optim = args.optim learning_rate = args.learning_rate lr_decay = args.lr_decay amsgrad = args.amsgrad eps = args.eps betas = (args.beta1, args.beta2) warmup_steps = args.warmup_steps weight_decay = args.weight_decay grad_clip = args.grad_clip loss_ty_token = args.loss_type == 'token' unk_replace = args.unk_replace freeze = args.freeze model_path = args.model_path model_name = os.path.join(model_path, 'model.pt') punctuation = args.punctuation word_embedding = args.word_embedding word_path = args.word_path char_embedding = args.char_embedding char_path = args.char_path print(args) word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path) char_dict = None if char_embedding != 'random': char_dict, char_dim = utils.load_embedding_dict( char_embedding, char_path) else: char_dict = None char_dim = None logger.info("Creating Alphabets") alphabet_path = os.path.join(model_path, 'alphabets') word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets( alphabet_path, train_path, data_paths=[dev_path, test_path], embedd_dict=word_dict, max_vocabulary_size=200000) num_words = word_alphabet.size() num_chars = char_alphabet.size() num_pos = pos_alphabet.size() num_types = type_alphabet.size() logger.info("Word Alphabet Size: %d" % num_words) logger.info("Character Alphabet Size: %d" % num_chars) logger.info("POS Alphabet Size: %d" % num_pos) logger.info("Type Alphabet Size: %d" % num_types) result_path = os.path.join(model_path, 'tmp') if not os.path.exists(result_path): os.makedirs(result_path) punct_set = None if punctuation is not None: punct_set = set(punctuation) logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set))) def construct_word_embedding_table(): scale = np.sqrt(3.0 / word_dim) table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32) table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype( np.float32) if freeze else np.random.uniform( -scale, scale, [1, word_dim]).astype(np.float32) oov = 0 for word, index in word_alphabet.items(): if word in word_dict: embedding = word_dict[word] elif word.lower() in word_dict: embedding = word_dict[word.lower()] else: embedding = np.zeros([1, word_dim]).astype( np.float32) if freeze else np.random.uniform( -scale, scale, [1, word_dim]).astype(np.float32) oov += 1 table[index, :] = embedding print('word OOV: %d' % oov) return torch.from_numpy(table) def construct_char_embedding_table(): if char_dict is None: return None scale = np.sqrt(3.0 / char_dim) table = np.empty([num_chars, char_dim], dtype=np.float32) table[conllx_data.UNK_ID, :] = np.random.uniform( -scale, scale, [1, char_dim]).astype(np.float32) oov = 0 for char, index, in char_alphabet.items(): if char in char_dict: embedding = char_dict[char] else: embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32) oov += 1 table[index, :] = embedding print('character OOV: %d' % oov) return torch.from_numpy(table) word_table = construct_word_embedding_table() char_table = construct_char_embedding_table() logger.info("constructing network...") hyps = json.load(open(args.config, 'r')) json.dump(hyps, open(os.path.join(model_path, 'config.json'), 'w'), indent=2) model_type = hyps['model'] assert model_type in ['DeepBiAffine', 'NeuroMST', 'StackPtr'] assert word_dim == hyps['word_dim'] if char_dim is not None: assert char_dim == hyps['char_dim'] else: char_dim = hyps['char_dim'] use_pos = hyps['pos'] pos_dim = hyps['pos_dim'] mode = hyps['rnn_mode'] hidden_size = hyps['hidden_size'] arc_space = hyps['arc_space'] type_space = hyps['type_space'] p_in = hyps['p_in'] p_out = hyps['p_out'] p_rnn = hyps['p_rnn'] activation = hyps['activation'] prior_order = None alg = 'transition' if model_type == 'StackPtr' else 'graph' if model_type == 'DeepBiAffine': num_layers = hyps['num_layers'] network = DeepBiAffine(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, mode, hidden_size, num_layers, num_types, arc_space, type_space, embedd_word=word_table, embedd_char=char_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, pos=use_pos, activation=activation) elif model_type == 'NeuroMST': num_layers = hyps['num_layers'] network = NeuroMST(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, mode, hidden_size, num_layers, num_types, arc_space, type_space, embedd_word=word_table, embedd_char=char_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, pos=use_pos, activation=activation) elif model_type == 'StackPtr': encoder_layers = hyps['encoder_layers'] decoder_layers = hyps['decoder_layers'] num_layers = (encoder_layers, decoder_layers) prior_order = hyps['prior_order'] grandPar = hyps['grandPar'] sibling = hyps['sibling'] network = StackPtrNet(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, mode, hidden_size, encoder_layers, decoder_layers, num_types, arc_space, type_space, embedd_word=word_table, embedd_char=char_table, prior_order=prior_order, activation=activation, p_in=p_in, p_out=p_out, p_rnn=p_rnn, pos=use_pos, grandPar=grandPar, sibling=sibling) else: raise RuntimeError('Unknown model type: %s' % model_type) if freeze: freeze_embedding(network.word_embed) network = network.to(device) model = "{}-{}".format(model_type, mode) logger.info("Network: %s, num_layer=%s, hidden=%d, act=%s" % (model, num_layers, hidden_size, activation)) logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" % ('variational', p_in, p_out, p_rnn)) logger.info('# of Parameters: %d' % (sum([param.numel() for param in network.parameters()]))) logger.info("Reading Data") if alg == 'graph': data_train = conllx_data.read_bucketed_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True) data_dev = conllx_data.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True) data_test = conllx_data.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True) else: data_train = conllx_stacked_data.read_bucketed_data( train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, prior_order=prior_order) data_dev = conllx_stacked_data.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, prior_order=prior_order) data_test = conllx_stacked_data.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, prior_order=prior_order) num_data = sum(data_train[1]) logger.info("training: #training data: %d, batch: %d, unk replace: %.2f" % (num_data, batch_size, unk_replace)) pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet) gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet) optimizer, scheduler = get_optimizer(network.parameters(), optim, learning_rate, lr_decay, betas, eps, amsgrad, weight_decay, warmup_steps) best_ucorrect = 0.0 best_lcorrect = 0.0 best_ucomlpete = 0.0 best_lcomplete = 0.0 best_ucorrect_nopunc = 0.0 best_lcorrect_nopunc = 0.0 best_ucomlpete_nopunc = 0.0 best_lcomplete_nopunc = 0.0 best_root_correct = 0.0 best_total = 0 best_total_nopunc = 0 best_total_inst = 0 best_total_root = 0 best_epoch = 0 test_ucorrect = 0.0 test_lcorrect = 0.0 test_ucomlpete = 0.0 test_lcomplete = 0.0 test_ucorrect_nopunc = 0.0 test_lcorrect_nopunc = 0.0 test_ucomlpete_nopunc = 0.0 test_lcomplete_nopunc = 0.0 test_root_correct = 0.0 test_total = 0 test_total_nopunc = 0 test_total_inst = 0 test_total_root = 0 patient = 0 beam = args.beam reset = args.reset num_batches = num_data // batch_size + 1 if optim == 'adam': opt_info = 'adam, betas=(%.1f, %.3f), eps=%.1e, amsgrad=%s' % ( betas[0], betas[1], eps, amsgrad) else: opt_info = 'sgd, momentum=0.9, nesterov=True' for epoch in range(1, num_epochs + 1): start_time = time.time() train_loss = 0. train_arc_loss = 0. train_type_loss = 0. num_insts = 0 num_words = 0 num_back = 0 num_nans = 0 network.train() lr = scheduler.get_lr()[0] print( 'Epoch %d (%s, lr=%.6f, lr decay=%.6f, grad clip=%.1f, l2=%.1e): ' % (epoch, opt_info, lr, lr_decay, grad_clip, weight_decay)) if args.cuda: torch.cuda.empty_cache() gc.collect() with torch.autograd.set_detect_anomaly(True): for step, data in enumerate( iterate_data(data_train, batch_size, bucketed=True, unk_replace=unk_replace, shuffle=True)): optimizer.zero_grad() bert_words = data["BERT_WORD"].to(device) sub_word_idx = data["SUB_IDX"].to(device) words = data['WORD'].to(device) chars = data['CHAR'].to(device) postags = data['POS'].to(device) heads = data['HEAD'].to(device) nbatch = words.size(0) if alg == 'graph': types = data['TYPE'].to(device) masks = data['MASK'].to(device) nwords = masks.sum() - nbatch BERT = True if BERT: loss_arc, loss_type = network.loss(bert_words, sub_word_idx, words, chars, postags, heads, types, mask=masks) else: loss_arc, loss_type = network.loss(words, chars, postags, heads, types, mask=masks) else: masks_enc = data['MASK_ENC'].to(device) masks_dec = data['MASK_DEC'].to(device) stacked_heads = data['STACK_HEAD'].to(device) children = data['CHILD'].to(device) siblings = data['SIBLING'].to(device) stacked_types = data['STACK_TYPE'].to(device) nwords = masks_enc.sum() - nbatch loss_arc, loss_type = network.loss(words, chars, postags, heads, stacked_heads, children, siblings, stacked_types, mask_e=masks_enc, mask_d=masks_dec) loss_arc = loss_arc.sum() loss_type = loss_type.sum() loss_total = loss_arc + loss_type # print("loss", loss_arc, loss_type, loss_total) if loss_ty_token: loss = loss_total.div(nwords) else: loss = loss_total.div(nbatch) loss.backward() if grad_clip > 0: grad_norm = clip_grad_norm_(network.parameters(), grad_clip) else: grad_norm = total_grad_norm(network.parameters()) if math.isnan(grad_norm): num_nans += 1 else: optimizer.step() scheduler.step() with torch.no_grad(): num_insts += nbatch num_words += nwords train_loss += loss_total.item() train_arc_loss += loss_arc.item() train_type_loss += loss_type.item() # update log if step % 100 == 0: torch.cuda.empty_cache() sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) curr_lr = scheduler.get_lr()[0] num_insts = max(num_insts, 1) num_words = max(num_words, 1) log_info = '[%d/%d (%.0f%%) lr=%.6f (%d)] loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f)' % ( step, num_batches, 100. * step / num_batches, curr_lr, num_nans, train_loss / num_insts, train_loss / num_words, train_arc_loss / num_insts, train_arc_loss / num_words, train_type_loss / num_insts, train_type_loss / num_words) sys.stdout.write(log_info) sys.stdout.flush() num_back = len(log_info) sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) print( 'total: %d (%d), loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f), time: %.2fs' % (num_insts, num_words, train_loss / num_insts, train_loss / num_words, train_arc_loss / num_insts, train_arc_loss / num_words, train_type_loss / num_insts, train_type_loss / num_words, time.time() - start_time)) print('-' * 125) # evaluate performance on dev data with torch.no_grad(): pred_filename = os.path.join(result_path, 'pred_dev%d' % epoch) pred_writer.start(pred_filename) gold_filename = os.path.join(result_path, 'gold_dev%d' % epoch) gold_writer.start(gold_filename) print('Evaluating dev:') dev_stats, dev_stats_nopunct, dev_stats_root = eval( alg, data_dev, network, pred_writer, gold_writer, punct_set, word_alphabet, pos_alphabet, device, beam=beam) pred_writer.close() gold_writer.close() dev_ucorr, dev_lcorr, dev_ucomlpete, dev_lcomplete, dev_total = dev_stats dev_ucorr_nopunc, dev_lcorr_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc, dev_total_nopunc = dev_stats_nopunct dev_root_corr, dev_total_root, dev_total_inst = dev_stats_root if best_ucorrect_nopunc + best_lcorrect_nopunc < dev_ucorr_nopunc + dev_lcorr_nopunc: best_ucorrect_nopunc = dev_ucorr_nopunc best_lcorrect_nopunc = dev_lcorr_nopunc best_ucomlpete_nopunc = dev_ucomlpete_nopunc best_lcomplete_nopunc = dev_lcomplete_nopunc best_ucorrect = dev_ucorr best_lcorrect = dev_lcorr best_ucomlpete = dev_ucomlpete best_lcomplete = dev_lcomplete best_root_correct = dev_root_corr best_total = dev_total best_total_nopunc = dev_total_nopunc best_total_root = dev_total_root best_total_inst = dev_total_inst best_epoch = epoch patient = 0 torch.save(network.state_dict(), model_name) pred_filename = os.path.join(result_path, 'pred_test%d' % epoch) pred_writer.start(pred_filename) gold_filename = os.path.join(result_path, 'gold_test%d' % epoch) gold_writer.start(gold_filename) print('Evaluating test:') test_stats, test_stats_nopunct, test_stats_root = eval( alg, data_test, network, pred_writer, gold_writer, punct_set, word_alphabet, pos_alphabet, device, beam=beam) test_ucorrect, test_lcorrect, test_ucomlpete, test_lcomplete, test_total = test_stats test_ucorrect_nopunc, test_lcorrect_nopunc, test_ucomlpete_nopunc, test_lcomplete_nopunc, test_total_nopunc = test_stats_nopunct test_root_correct, test_total_root, test_total_inst = test_stats_root pred_writer.close() gold_writer.close() else: patient += 1 print('-' * 125) print( 'best dev W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (best_ucorrect, best_lcorrect, best_total, best_ucorrect * 100 / best_total, best_lcorrect * 100 / best_total, best_ucomlpete * 100 / dev_total_inst, best_lcomplete * 100 / dev_total_inst, best_epoch)) print( 'best dev Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (best_ucorrect_nopunc, best_lcorrect_nopunc, best_total_nopunc, best_ucorrect_nopunc * 100 / best_total_nopunc, best_lcorrect_nopunc * 100 / best_total_nopunc, best_ucomlpete_nopunc * 100 / best_total_inst, best_lcomplete_nopunc * 100 / best_total_inst, best_epoch)) print( 'best dev Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (best_root_correct, best_total_root, best_root_correct * 100 / best_total_root, best_epoch)) print('-' * 125) print( 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, test_ucomlpete * 100 / test_total_inst, test_lcomplete * 100 / test_total_inst, best_epoch)) print( 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc, test_ucomlpete_nopunc * 100 / test_total_inst, test_lcomplete_nopunc * 100 / test_total_inst, best_epoch)) print( 'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch)) print('=' * 125) if patient >= reset: logger.info('reset optimizer momentums') network.load_state_dict( torch.load(model_name, map_location=device)) scheduler.reset_state() patient = 0
def main(): parser = argparse.ArgumentParser( description='NER with bi-directional RNN-CNN') parser.add_argument('--config', type=str, help='config file', required=True) parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=16, help='Number of sentences in each batch') parser.add_argument('--loss_type', choices=['sentence', 'token'], default='sentence', help='loss type (default: sentence)') parser.add_argument('--optim', choices=['sgd', 'adam'], help='type of optimizer', required=True) parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate') parser.add_argument('--lr_decay', type=float, default=0.999995, help='Decay rate of learning rate') parser.add_argument('--amsgrad', action='store_true', help='AMS Grad') parser.add_argument('--grad_clip', type=float, default=0, help='max norm for gradient clip (default 0: no clip') parser.add_argument('--warmup_steps', type=int, default=0, metavar='N', help='number of steps to warm up (default: 0)') parser.add_argument('--weight_decay', type=float, default=0.0, help='weight for l2 norm decay') parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK') parser.add_argument('--embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True) parser.add_argument('--embedding_dict', help='path for embedding dict') parser.add_argument('--train', help='path for training file.', required=True) parser.add_argument('--dev', help='path for dev file.', required=True) parser.add_argument('--test', help='path for test file.', required=True) parser.add_argument('--model_path', help='path for saving model file.', required=True) args = parser.parse_args() logger = get_logger("POS") args.cuda = torch.cuda.is_available() device = torch.device('cuda', 0) if args.cuda else torch.device('cpu') train_path = args.train dev_path = args.dev test_path = args.test num_epochs = args.num_epochs batch_size = args.batch_size optim = args.optim learning_rate = args.learning_rate lr_decay = args.lr_decay amsgrad = args.amsgrad warmup_steps = args.warmup_steps weight_decay = args.weight_decay grad_clip = args.grad_clip loss_ty_token = args.loss_type == 'token' unk_replace = args.unk_replace model_path = args.model_path model_name = os.path.join(model_path, 'model.pt') embedding = args.embedding embedding_path = args.embedding_dict print(args) embedd_dict, embedd_dim = utils.load_embedding_dict( embedding, embedding_path) logger.info("Creating Alphabets") alphabet_path = os.path.join(model_path, 'alphabets') word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets( alphabet_path, train_path, data_paths=[dev_path, test_path], embedd_dict=embedd_dict, max_vocabulary_size=50000) logger.info("Word Alphabet Size: %d" % word_alphabet.size()) logger.info("Character Alphabet Size: %d" % char_alphabet.size()) logger.info("POS Alphabet Size: %d" % pos_alphabet.size()) logger.info("Reading Data") data_train = conllx_data.read_bucketed_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) num_data = sum(data_train[1]) num_labels = pos_alphabet.size() data_dev = conllx_data.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) data_test = conllx_data.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) def construct_word_embedding_table(): scale = np.sqrt(3.0 / embedd_dim) table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32) table[conllx_data.UNK_ID, :] = np.random.uniform( -scale, scale, [1, embedd_dim]).astype(np.float32) oov = 0 for word, index in word_alphabet.items(): if word in embedd_dict: embedding = embedd_dict[word] elif word.lower() in embedd_dict: embedding = embedd_dict[word.lower()] else: embedding = np.random.uniform( -scale, scale, [1, embedd_dim]).astype(np.float32) oov += 1 table[index, :] = embedding print('oov: %d' % oov) return torch.from_numpy(table) word_table = construct_word_embedding_table() logger.info("constructing network...") hyps = json.load(open(args.config, 'r')) json.dump(hyps, open(os.path.join(model_path, 'config.json'), 'w'), indent=2) dropout = hyps['dropout'] crf = hyps['crf'] bigram = hyps['bigram'] assert embedd_dim == hyps['embedd_dim'] char_dim = hyps['char_dim'] mode = hyps['rnn_mode'] hidden_size = hyps['hidden_size'] out_features = hyps['out_features'] num_layers = hyps['num_layers'] p_in = hyps['p_in'] p_out = hyps['p_out'] p_rnn = hyps['p_rnn'] activation = hyps['activation'] if dropout == 'std': if crf: network = BiRecurrentConvCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), mode, hidden_size, out_features, num_layers, num_labels, embedd_word=word_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, bigram=bigram, activation=activation) else: network = BiRecurrentConv(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), mode, hidden_size, out_features, num_layers, num_labels, embedd_word=word_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, activation=activation) elif dropout == 'variational': if crf: network = BiVarRecurrentConvCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), mode, hidden_size, out_features, num_layers, num_labels, embedd_word=word_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, bigram=bigram, activation=activation) else: network = BiVarRecurrentConv(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), mode, hidden_size, out_features, num_layers, num_labels, embedd_word=word_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, activation=activation) else: raise ValueError('Unkown dropout type: {}'.format(dropout)) network = network.to(device) optimizer, scheduler = get_optimizer(network.parameters(), optim, learning_rate, lr_decay, amsgrad, weight_decay, warmup_steps) model = "{}-CNN{}".format(mode, "-CRF" if crf else "") logger.info("Network: %s, num_layer=%d, hidden=%d, act=%s" % (model, num_layers, hidden_size, activation)) logger.info( "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)" % (weight_decay, num_data, batch_size, unk_replace)) logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" % (dropout, p_in, p_out, p_rnn)) print('# of Parameters: %d' % (sum([param.numel() for param in network.parameters()]))) best_corr = 0.0 best_total = 0.0 test_corr = 0.0 test_total = 0.0 best_epoch = 0 patient = 0 num_batches = num_data // batch_size + 1 result_path = os.path.join(model_path, 'tmp') if not os.path.exists(result_path): os.makedirs(result_path) for epoch in range(1, num_epochs + 1): start_time = time.time() train_loss = 0. num_insts = 0 num_words = 0 num_back = 0 network.train() lr = scheduler.get_lr()[0] print('Epoch %d (%s, lr=%.6f, lr decay=%.6f, amsgrad=%s, l2=%.1e): ' % (epoch, optim, lr, lr_decay, amsgrad, weight_decay)) if args.cuda: torch.cuda.empty_cache() gc.collect() for step, data in enumerate( iterate_data(data_train, batch_size, bucketed=True, unk_replace=unk_replace, shuffle=True)): optimizer.zero_grad() words = data['WORD'].to(device) chars = data['CHAR'].to(device) labels = data['POS'].to(device) masks = data['MASK'].to(device) nbatch = words.size(0) nwords = masks.sum().item() loss_total = network.loss(words, chars, labels, mask=masks).sum() if loss_ty_token: loss = loss_total.div(nwords) else: loss = loss_total.div(nbatch) loss.backward() if grad_clip > 0: clip_grad_norm_(network.parameters(), grad_clip) optimizer.step() scheduler.step() with torch.no_grad(): num_insts += nbatch num_words += nwords train_loss += loss_total.item() # update log if step % 100 == 0: torch.cuda.empty_cache() sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) curr_lr = scheduler.get_lr()[0] log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % ( step, num_batches, 100. * step / num_batches, curr_lr, train_loss / num_insts, train_loss / num_words) sys.stdout.write(log_info) sys.stdout.flush() num_back = len(log_info) sys.stdout.write("\b" * num_back) sys.stdout.write(" " * num_back) sys.stdout.write("\b" * num_back) print('total: %d (%d), loss: %.4f (%.4f), time: %.2fs' % (num_insts, num_words, train_loss / num_insts, train_loss / num_words, time.time() - start_time)) print('-' * 100) # evaluate performance on dev data with torch.no_grad(): dev_corr, dev_total = eval(data_dev, network, device) print('Dev corr: %d, total: %d, acc: %.2f%%' % (dev_corr, dev_total, dev_corr * 100 / dev_total)) if best_corr < dev_corr: torch.save(network.state_dict(), model_name) best_corr = dev_corr best_total = dev_total best_epoch = epoch # evaluate on test data when better performance detected test_corr, test_total = eval(data_test, network, device) print('test corr: %d, total: %d, acc: %.2f%%' % (test_corr, test_total, test_corr * 100 / test_total)) patient = 0 else: patient += 1 print('-' * 100) print( "Best dev corr: %d, total: %d, acc: %.2f%% (epoch: %d (%d))" % (best_corr, best_total, best_corr * 100 / best_total, best_epoch, patient)) print( "Best test corr: %d, total: %d, acc: %.2f%% (epoch: %d (%d))" % (test_corr, test_total, test_corr * 100 / test_total, best_epoch, patient)) print('=' * 100) if patient > 4: logger.info('reset optimizer momentums') scheduler.reset_state() patient = 0