def main(): parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CNN-CRF') parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=10, help='Number of sentences in each batch') parser.add_argument('--num_units', type=int, default=100, help='Number of hidden units in LSTM') parser.add_argument('--num_filters', type=int, default=20, help='Number of filters in CNN') parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate') parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate') parser.add_argument('--grad_clipping', type=float, default=0, help='Gradient clipping') parser.add_argument('--gamma', type=float, default=1e-6, help='weight for regularization') parser.add_argument('--delta', type=float, default=0.0, help='weight for expectation-linear regularization') parser.add_argument('--regular', choices=['none', 'l2'], help='regularization for training', required=True) parser.add_argument('--dropout', choices=['std', 'recurrent'], help='dropout patten') parser.add_argument('--schedule', nargs='+', type=int, help='schedule for learning rate decay') parser.add_argument('--output_prediction', action='store_true', help='Output predictions to temp files') parser.add_argument('--train') # "data/POS-penn/wsj/split1/wsj1.train.original" parser.add_argument('--dev') # "data/POS-penn/wsj/split1/wsj1.dev.original" parser.add_argument('--test') # "data/POS-penn/wsj/split1/wsj1.test.original" args = parser.parse_args() logger = get_logger("Sequence Labeling") train_path = args.train dev_path = args.dev test_path = args.test num_epochs = args.num_epochs batch_size = args.batch_size num_units = args.num_units num_filters = args.num_filters regular = args.regular grad_clipping = args.grad_clipping gamma = args.gamma delta = args.delta learning_rate = args.learning_rate momentum = 0.9 decay_rate = args.decay_rate schedule = args.schedule output_predict = args.output_prediction dropout = args.dropout p = 0.5 logger.info("Creating Alphabets") word_alphabet, char_alphabet, pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/", [train_path, dev_path, test_path], 40000) logger.info("Word Alphabet Size: %d" % word_alphabet.size()) logger.info("Character Alphabet Size: %d" % char_alphabet.size()) logger.info("POS Alphabet Size: %d" % pos_alphabet.size()) num_labels = pos_alphabet.size() - 1 logger.info("Reading Data") data_train = data_utils.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) data_dev = data_utils.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) data_test = data_utils.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet) num_data = sum([len(bucket) for bucket in data_train]) logger.info("constructing network...") # create variables target_var = T.imatrix(name='targets') mask_var = T.matrix(name='masks', dtype=theano.config.floatX) mask_nr_var = T.matrix(name='masks_nr', dtype=theano.config.floatX) word_var = T.imatrix(name='inputs') char_var = T.itensor3(name='char-inputs') network = build_network(word_var, char_var, mask_var, word_alphabet, char_alphabet, dropout, num_units, num_labels, grad_clipping, num_filters, p) logger.info("Network structure: hidden=%d, filter=%d, dropout=%s" % (num_units, num_filters, dropout)) # compute loss num_tokens = mask_var.sum(dtype=theano.config.floatX) num_tokens_nr = mask_nr_var.sum(dtype=theano.config.floatX) # get outpout of bi-lstm-cnn-crf shape [batch, length, num_labels, num_labels] energies_train = lasagne.layers.get_output(network) energies_train_det = lasagne.layers.get_output(network, deterministic=True) energies_eval = lasagne.layers.get_output(network, deterministic=True) loss_train_org = chain_crf_loss(energies_train, target_var, mask_var).mean() energy_shape = energies_train.shape # [batch, length, num_labels, num_labels] --> [batch*length, num_labels*num_labels] energies = T.reshape(energies_train, (energy_shape[0] * energy_shape[1], energy_shape[2] * energy_shape[3])) energies = nonlinearities.softmax(energies) energies_det = T.reshape(energies_train_det, (energy_shape[0] * energy_shape[1], energy_shape[2] * energy_shape[3])) energies_det = nonlinearities.softmax(energies_det) # [batch*length, num_labels*num_labels] --> [batch, length*num_labels*num_labels] energies = T.reshape(energies, (energy_shape[0], energy_shape[1] * energy_shape[2] * energy_shape[3])) energies_det = T.reshape(energies_det, (energy_shape[0], energy_shape[1] * energy_shape[2] * energy_shape[3])) loss_train_expect_linear = lasagne.objectives.squared_error(energies, energies_det) loss_train_expect_linear = loss_train_expect_linear.sum(axis=1) loss_train_expect_linear = loss_train_expect_linear.mean() loss_train = loss_train_org + delta * loss_train_expect_linear # l2 regularization? if regular == 'l2': l2_penalty = lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2) loss_train = loss_train + gamma * l2_penalty _, corr_train = chain_crf_accuracy(energies_train, target_var) corr_nr_train = (corr_train * mask_nr_var).sum(dtype=theano.config.floatX) corr_train = (corr_train * mask_var).sum(dtype=theano.config.floatX) prediction_eval, corr_eval = chain_crf_accuracy(energies_eval, target_var) corr_nr_eval = (corr_eval * mask_nr_var).sum(dtype=theano.config.floatX) corr_eval = (corr_eval * mask_var).sum(dtype=theano.config.floatX) params = lasagne.layers.get_all_params(network, trainable=True) updates = nesterov_momentum(loss_train, params=params, learning_rate=learning_rate, momentum=momentum) # Compile a function performing a training step on a mini-batch train_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var], [loss_train, loss_train_org, loss_train_expect_linear, corr_train, corr_nr_train, num_tokens, num_tokens_nr], updates=updates) # Compile a second function evaluating the loss and accuracy of network eval_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var], [corr_eval, corr_nr_eval, num_tokens, num_tokens_nr, prediction_eval]) # Finally, launch the training loop. logger.info( "Start training: regularization: %s(%f), dropout: %s, delta: %.2f (#training data: %d, batch size: %d, clip: %.1f)..." \ % (regular, (0.0 if regular == 'none' else gamma), dropout, delta, num_data, batch_size, grad_clipping)) num_batches = num_data / batch_size + 1 dev_correct = 0.0 dev_correct_nr = 0.0 best_epoch = 0 test_correct = 0.0 test_correct_nr = 0.0 test_total = 0 test_total_nr = 0 test_inst = 0 lr = learning_rate for epoch in range(1, num_epochs + 1): print 'Epoch %d (learning rate=%.4f, decay rate=%.4f): ' % (epoch, lr, decay_rate) train_err = 0.0 train_err_org = 0.0 train_err_linear = 0.0 train_corr = 0.0 train_corr_nr = 0.0 train_total = 0 train_total_nr = 0 train_inst = 0 start_time = time.time() num_back = 0 for batch in xrange(1, num_batches + 1): wids, cids, pids, _, _, masks = data_utils.get_batch(data_train, batch_size) masks_nr = np.copy(masks) masks_nr[:, 0] = 0 err, err_org, err_linear, corr, corr_nr, num, num_nr = train_fn(wids, cids, pids, masks, masks_nr) train_err += err * wids.shape[0] train_err_org += err_org * wids.shape[0] train_err_linear += err_linear * wids.shape[0] train_corr += corr train_corr_nr += corr_nr train_total += num train_total_nr += num_nr train_inst += wids.shape[0] time_ave = (time.time() - start_time) / batch time_left = (num_batches - batch) * time_ave # update log sys.stdout.write("\b" * num_back) log_info = 'train: %d/%d loss: %.4f, loss_org: %.4f, loss_linear: %.4f, acc: %.2f%%, acc(no root): %.2f%%, time left (estimated): %.2fs' % ( batch, num_batches, train_err / train_inst, train_err_org / train_inst, train_err_linear / train_inst, train_corr * 100 / train_total, train_corr_nr * 100 / train_total_nr, time_left) sys.stdout.write(log_info) num_back = len(log_info) # update training log after each epoch assert train_inst == num_batches * batch_size assert train_total == train_total_nr + train_inst sys.stdout.write("\b" * num_back) print 'train: %d/%d loss: %.4f, loss_org: %.4f, loss_linear: %.4f, acc: %.2f%%, acc(no root): %.2f%%, time: %.2fs' % ( train_inst, train_inst, train_err / train_inst, train_err_org / train_inst, train_err_linear / train_inst, train_corr * 100 / train_total, train_corr_nr * 100 / train_total_nr, time.time() - start_time) # evaluate performance on dev data dev_corr = 0.0 dev_corr_nr = 0.0 dev_total = 0 dev_total_nr = 0 dev_inst = 0 for batch in data_utils.iterate_batch(data_dev, batch_size): wids, cids, pids, _, _, masks = batch masks_nr = np.copy(masks) masks_nr[:, 0] = 0 corr, corr_nr, num, num_nr, predictions = eval_fn(wids, cids, pids, masks, masks_nr) dev_corr += corr dev_corr_nr += corr_nr dev_total += num dev_total_nr += num_nr dev_inst += wids.shape[0] assert dev_total == dev_total_nr + dev_inst print 'dev corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%%' % ( dev_corr, dev_total, dev_corr * 100 / dev_total, dev_corr_nr, dev_total_nr, dev_corr_nr * 100 / dev_total_nr) if dev_correct_nr < dev_corr_nr: dev_correct = dev_corr dev_correct_nr = dev_corr_nr best_epoch = epoch # evaluate on test data when better performance detected test_corr = 0.0 test_corr_nr = 0.0 test_total = 0 test_total_nr = 0 test_inst = 0 for batch in data_utils.iterate_batch(data_test, batch_size): wids, cids, pids, _, _, masks = batch masks_nr = np.copy(masks) masks_nr[:, 0] = 0 corr, corr_nr, num, num_nr, predictions = eval_fn(wids, cids, pids, masks, masks_nr) test_corr += corr test_corr_nr += corr_nr test_total += num test_total_nr += num_nr test_inst += wids.shape[0] assert test_total + test_total_nr + test_inst test_correct = test_corr test_correct_nr = test_corr_nr print "best dev corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % ( dev_correct, dev_total, dev_correct * 100 / dev_total, dev_correct_nr, dev_total_nr, dev_correct_nr * 100 / dev_total_nr, best_epoch) print "best test corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % ( test_correct, test_total, test_correct * 100 / test_total, test_correct_nr, test_total_nr, test_correct_nr * 100 / test_total_nr, best_epoch) if epoch in schedule: lr = lr * decay_rate updates = nesterov_momentum(loss_train, params=params, learning_rate=lr, momentum=momentum) train_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var], [loss_train, loss_train_org, loss_train_expect_linear, corr_train, corr_nr_train, num_tokens, num_tokens_nr], updates=updates)
def main(): args_parser = argparse.ArgumentParser(description='Neural MST-Parser') args_parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs') args_parser.add_argument('--batch_size', type=int, default=10, help='Number of sentences in each batch') args_parser.add_argument('--num_units', type=int, default=100, help='Number of hidden units in LSTM') args_parser.add_argument('--depth', type=int, default=2, help='Depth of LSTM layer') args_parser.add_argument('--mlp', type=int, default=1, help='Depth of MLP layer') args_parser.add_argument('--num_filters', type=int, default=20, help='Number of filters in CNN') args_parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate') args_parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate') args_parser.add_argument('--grad_clipping', type=float, default=0, help='Gradient clipping') args_parser.add_argument('--peepholes', action='store_true', help='Peepholes for LSTM') args_parser.add_argument('--max_norm', type=float, default=0, help='weight for max-norm regularization') args_parser.add_argument('--gamma', type=float, default=1e-6, help='weight for regularization') args_parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam') args_parser.add_argument('--delta', type=float, default=0.0, help='weight for expectation-linear regularization') args_parser.add_argument('--regular', choices=['none', 'l2'], help='regularization for training', required=True) args_parser.add_argument('--opt', choices=['adam', 'momentum'], help='optimization algorithm', required=True) args_parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') args_parser.add_argument('--schedule', nargs='+', type=int, help='schedule for learning rate decay', required=True) # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay', required=True) args_parser.add_argument('--pos', action='store_true', help='using pos embedding') args_parser.add_argument('--char', action='store_true', help='using cnn for character embedding') args_parser.add_argument('--normalize_digits', action='store_true', help='normalize digits') args_parser.add_argument('--output_prediction', action='store_true', help='Output predictions to temp files') # args_parser.add_argument('--punctuation', default=None, help='List of punctuations separated by whitespace') args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations') args_parser.add_argument('--train', help='path of training data') args_parser.add_argument('--dev', help='path of validation data') args_parser.add_argument('--test', help='path of test data') args_parser.add_argument('--embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True) args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters', required=True) args_parser.add_argument('--embedding_dict', default='data/word2vec/GoogleNews-vectors-negative300.bin', help='path for embedding dict') args_parser.add_argument('--char_dict', default='data/polyglot/polyglot-zh_char.pkl', help='path for character embedding dict') args_parser.add_argument('--tmp', default='tmp', help='Directory for temp files.') args = args_parser.parse_args() logger = get_logger("Parsing") train_path = args.train dev_path = args.dev test_path = args.test num_epochs = args.num_epochs batch_size = args.batch_size num_units = args.num_units depth = args.depth mlp = args.mlp num_filters = args.num_filters regular = args.regular opt = args.opt grad_clipping = args.grad_clipping peepholes = args.peepholes gamma = args.gamma delta = args.delta max_norm = args.max_norm learning_rate = args.learning_rate momentum = 0.9 beta1 = 0.9 beta2 = args.beta2 decay_rate = args.decay_rate schedule = args.schedule use_pos = args.pos use_char = args.char normalize_digits = args.normalize_digits output_predict = args.output_prediction dropout = args.dropout punctuation = args.punctuation tmp_dir = args.tmp embedding = args.embedding char_embedding = args.char_embedding embedding_path = args.embedding_dict char_path = args.char_dict punct_set = None if punctuation is not None: punct_set = set(punctuation) logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set))) logger.info("Creating Alphabets: normalize_digits=%s" % normalize_digits) word_alphabet, char_alphabet, \ pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/", [train_path,], 60000, min_occurence=1, normalize_digits=normalize_digits) logger.info("Word Alphabet Size: %d" % word_alphabet.size()) logger.info("Character Alphabet Size: %d" % char_alphabet.size()) logger.info("POS Alphabet Size: %d" % pos_alphabet.size()) logger.info("Type Alphabet Size: %d" % type_alphabet.size()) num_pos = pos_alphabet.size() num_types = type_alphabet.size() logger.info("Reading Data") data_train = data_utils.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) data_dev = data_utils.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) data_test = data_utils.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) num_data = sum([len(bucket) for bucket in data_train]) logger.info("constructing network...(pos embedding=%s, character embedding=%s)" % (use_pos, use_char)) # create variables head_var = T.imatrix(name='heads') type_var = T.imatrix(name='types') mask_var = T.matrix(name='masks', dtype=theano.config.floatX) word_var = T.imatrix(name='inputs') pos_var = T.imatrix(name='pos-inputs') char_var = T.itensor3(name='char-inputs') network = build_network(word_var, char_var, pos_var, mask_var, word_alphabet, char_alphabet, pos_alphabet, depth, num_units, num_types, grad_clipping, num_filters, p=dropout, mlp=mlp, peepholes=peepholes, use_char=use_char, use_pos=use_pos, normalize_digits=normalize_digits, embedding=embedding, embedding_path=embedding_path, char_embedding=char_embedding, char_path=char_path) logger.info("Network: depth=%d, hidden=%d, peepholes=%s, filter=%d, dropout=%s, #mlp=%d" % ( depth, num_units, peepholes, num_filters, dropout, mlp)) # compute loss energies_train = lasagne.layers.get_output(network) energies_eval = lasagne.layers.get_output(network, deterministic=True) loss_train = tree_crf_loss(energies_train, head_var, type_var, mask_var).mean() loss_eval = tree_crf_loss(energies_eval, head_var, type_var, mask_var).mean() # loss_train, E, D, L, lengths = tree_crf_loss(energies_train, head_var, type_var, mask_var) # loss_train = loss_train.mean() # loss_eval, _, _, _, _ = tree_crf_loss(energies_eval, head_var, type_var, mask_var) # loss_eval = loss_eval.mean() # l2 regularization? if regular == 'l2': l2_penalty = lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2) loss_train = loss_train + gamma * l2_penalty updates = create_updates(loss_train, network, opt, learning_rate, momentum, beta1, beta2) # Compile a function performing a training step on a mini-batch train_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train, updates=updates, on_unused_input='warn') # Compile a second function evaluating the loss and accuracy of network eval_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], [loss_eval, energies_eval], on_unused_input='warn') # Finally, launch the training loop. logger.info("Start training: (#training data: %d, batch size: %d, clip: %.1f)..." % ( num_data, batch_size, grad_clipping)) num_batches = num_data / batch_size + 1 dev_ucorrect = 0.0 dev_lcorrect = 0.0 dev_ucorrect_nopunct = 0.0 dev_lcorrect_nopunct = 0.0 best_epoch = 0 test_ucorrect = 0.0 test_lcorrect = 0.0 test_ucorrect_nopunct = 0.0 test_lcorrect_nopunct = 0.0 test_total = 0 test_total_nopunc = 0 test_inst = 0 lr = learning_rate for epoch in range(1, num_epochs + 1): print 'Epoch %d (learning rate=%.5f, decay rate=%.4f, beta1=%.3f, beta2=%.3f): ' % ( epoch, lr, decay_rate, beta1, beta2) train_err = 0.0 train_inst = 0 start_time = time.time() num_back = 0 for batch in xrange(1, num_batches + 1): wids, cids, pids, hids, tids, masks = data_utils.get_batch(data_train, batch_size) err = train_fn(wids, cids, pids, hids, tids, masks) train_err += err * wids.shape[0] train_inst += wids.shape[0] time_ave = (time.time() - start_time) / batch time_left = (num_batches - batch) * time_ave # update log sys.stdout.write("\b" * num_back) log_info = 'train: %d/%d loss: %.4f, time left: %.2fs' % ( batch, num_batches, train_err / train_inst, time_left) sys.stdout.write(log_info) num_back = len(log_info) # update training log after each epoch assert train_inst == num_batches * batch_size sys.stdout.write("\b" * num_back) print 'train: %d/%d loss: %.4f, time: %.2fs' % ( train_inst, train_inst, train_err / train_inst, time.time() - start_time) # evaluate performance on dev data dev_err = 0.0 dev_ucorr = 0.0 dev_lcorr = 0.0 dev_ucorr_nopunc = 0.0 dev_lcorr_nopunc = 0.0 dev_total = 0 dev_total_nopunc = 0 dev_inst = 0 for batch in data_utils.iterate_batch(data_dev, batch_size): wids, cids, pids, hids, tids, masks = batch err, energies = eval_fn(wids, cids, pids, hids, tids, masks) dev_err += err * wids.shape[0] pars_pred, types_pred = parser.decode_MST(energies, masks) ucorr, lcorr, total, ucorr_nopunc, \ lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks, tmp_dir + '/dev_parse%d' % epoch, word_alphabet, pos_alphabet, type_alphabet, punct_set=punct_set) dev_inst += wids.shape[0] dev_ucorr += ucorr dev_lcorr += lcorr dev_total += total dev_ucorr_nopunc += ucorr_nopunc dev_lcorr_nopunc += lcorr_nopunc dev_total_nopunc += total_nopunc print 'dev loss: %.4f' % (dev_err / dev_inst) print 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % ( dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total) print 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % ( dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc * 100 / dev_total_nopunc) if dev_ucorrect_nopunct <= dev_ucorr_nopunc: dev_ucorrect_nopunct = dev_ucorr_nopunc dev_lcorrect_nopunct = dev_lcorr_nopunc dev_ucorrect = dev_ucorr dev_lcorrect = dev_lcorr best_epoch = epoch test_err = 0.0 test_ucorr = 0.0 test_lcorr = 0.0 test_ucorr_nopunc = 0.0 test_lcorr_nopunc = 0.0 test_total = 0 test_total_nopunc = 0 test_inst = 0 for batch in data_utils.iterate_batch(data_test, batch_size): wids, cids, pids, hids, tids, masks = batch err, energies = eval_fn(wids, cids, pids, hids, tids, masks) test_err += err * wids.shape[0] pars_pred, types_pred = parser.decode_MST(energies, masks) ucorr, lcorr, total, ucorr_nopunc, \ lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks, tmp_dir + '/test_parse%d' % epoch, word_alphabet, pos_alphabet, type_alphabet, punct_set=punct_set) test_inst += wids.shape[0] test_ucorr += ucorr test_lcorr += lcorr test_total += total test_ucorr_nopunc += ucorr_nopunc test_lcorr_nopunc += lcorr_nopunc test_total_nopunc += total_nopunc test_ucorrect = test_ucorr test_lcorrect = test_lcorr test_ucorrect_nopunct = test_ucorr_nopunc test_lcorrect_nopunct = test_lcorr_nopunc print 'best dev W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total, best_epoch) print 'best dev Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( dev_ucorrect_nopunct, dev_lcorrect_nopunct, dev_total_nopunc, dev_ucorrect_nopunct * 100 / dev_total_nopunc, dev_lcorrect_nopunct * 100 / dev_total_nopunc, best_epoch) print 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, best_epoch) print 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( test_ucorrect_nopunct, test_lcorrect_nopunct, test_total_nopunc, test_ucorrect_nopunct * 100 / test_total_nopunc, test_lcorrect_nopunct * 100 / test_total_nopunc, best_epoch) if epoch in schedule: # if epoch % schedule == 0: lr = lr * decay_rate # lr = learning_rate / (1.0 + epoch * decay_rate) updates = create_updates(loss_train, network, opt, lr, momentum, beta1, beta2) train_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train, updates=updates, on_unused_input='warn')
def main(): args_parser = argparse.ArgumentParser(description='Neural MST-Parser') args_parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs') args_parser.add_argument('--batch_size', type=int, default=10, help='Number of sentences in each batch') args_parser.add_argument('--num_units', type=int, default=100, help='Number of hidden units in LSTM') args_parser.add_argument('--depth', type=int, default=2, help='Depth of LSTM layer') args_parser.add_argument('--mlp', type=int, default=1, help='Depth of MLP layer') args_parser.add_argument('--num_filters', type=int, default=20, help='Number of filters in CNN') args_parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate') args_parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate') args_parser.add_argument('--grad_clipping', type=float, default=0, help='Gradient clipping') args_parser.add_argument('--peepholes', action='store_true', help='Peepholes for LSTM') args_parser.add_argument('--max_norm', type=float, default=0, help='weight for max-norm regularization') args_parser.add_argument('--gamma', type=float, default=1e-6, help='weight for regularization') args_parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam') args_parser.add_argument( '--delta', type=float, default=0.0, help='weight for expectation-linear regularization') args_parser.add_argument('--regular', choices=['none', 'l2'], help='regularization for training', required=True) args_parser.add_argument('--opt', choices=['adam', 'momentum'], help='optimization algorithm', required=True) args_parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') args_parser.add_argument('--schedule', nargs='+', type=int, help='schedule for learning rate decay', required=True) # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay', required=True) args_parser.add_argument('--pos', action='store_true', help='using pos embedding') args_parser.add_argument('--char', action='store_true', help='using cnn for character embedding') args_parser.add_argument('--normalize_digits', action='store_true', help='normalize digits') args_parser.add_argument('--output_prediction', action='store_true', help='Output predictions to temp files') # args_parser.add_argument('--punctuation', default=None, help='List of punctuations separated by whitespace') args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations') args_parser.add_argument('--train', help='path of training data') args_parser.add_argument('--dev', help='path of validation data') args_parser.add_argument('--test', help='path of test data') args_parser.add_argument('--embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True) args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters', required=True) args_parser.add_argument( '--embedding_dict', default='data/word2vec/GoogleNews-vectors-negative300.bin', help='path for embedding dict') args_parser.add_argument('--char_dict', default='data/polyglot/polyglot-zh_char.pkl', help='path for character embedding dict') args_parser.add_argument('--tmp', default='tmp', help='Directory for temp files.') args = args_parser.parse_args() logger = get_logger("Parsing") train_path = args.train dev_path = args.dev test_path = args.test num_epochs = args.num_epochs batch_size = args.batch_size num_units = args.num_units depth = args.depth mlp = args.mlp num_filters = args.num_filters regular = args.regular opt = args.opt grad_clipping = args.grad_clipping peepholes = args.peepholes gamma = args.gamma delta = args.delta max_norm = args.max_norm learning_rate = args.learning_rate momentum = 0.9 beta1 = 0.9 beta2 = args.beta2 decay_rate = args.decay_rate schedule = args.schedule use_pos = args.pos use_char = args.char normalize_digits = args.normalize_digits output_predict = args.output_prediction dropout = args.dropout punctuation = args.punctuation tmp_dir = args.tmp embedding = args.embedding char_embedding = args.char_embedding embedding_path = args.embedding_dict char_path = args.char_dict punct_set = None if punctuation is not None: punct_set = set(punctuation) logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set))) logger.info("Creating Alphabets: normalize_digits=%s" % normalize_digits) word_alphabet, char_alphabet, \ pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/", [train_path,], 60000, min_occurence=1, normalize_digits=normalize_digits) logger.info("Word Alphabet Size: %d" % word_alphabet.size()) logger.info("Character Alphabet Size: %d" % char_alphabet.size()) logger.info("POS Alphabet Size: %d" % pos_alphabet.size()) logger.info("Type Alphabet Size: %d" % type_alphabet.size()) num_pos = pos_alphabet.size() num_types = type_alphabet.size() logger.info("Reading Data") data_train = data_utils.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) data_dev = data_utils.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) data_test = data_utils.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, normalize_digits=normalize_digits) num_data = sum([len(bucket) for bucket in data_train]) logger.info( "constructing network...(pos embedding=%s, character embedding=%s)" % (use_pos, use_char)) # create variables head_var = T.imatrix(name='heads') type_var = T.imatrix(name='types') mask_var = T.matrix(name='masks', dtype=theano.config.floatX) word_var = T.imatrix(name='inputs') pos_var = T.imatrix(name='pos-inputs') char_var = T.itensor3(name='char-inputs') network = build_network(word_var, char_var, pos_var, mask_var, word_alphabet, char_alphabet, pos_alphabet, depth, num_units, num_types, grad_clipping, num_filters, p=dropout, mlp=mlp, peepholes=peepholes, use_char=use_char, use_pos=use_pos, normalize_digits=normalize_digits, embedding=embedding, embedding_path=embedding_path, char_embedding=char_embedding, char_path=char_path) logger.info( "Network: depth=%d, hidden=%d, peepholes=%s, filter=%d, dropout=%s, #mlp=%d" % (depth, num_units, peepholes, num_filters, dropout, mlp)) # compute loss energies_train = lasagne.layers.get_output(network) energies_eval = lasagne.layers.get_output(network, deterministic=True) loss_train = tree_crf_loss(energies_train, head_var, type_var, mask_var).mean() loss_eval = tree_crf_loss(energies_eval, head_var, type_var, mask_var).mean() # loss_train, E, D, L, lengths = tree_crf_loss(energies_train, head_var, type_var, mask_var) # loss_train = loss_train.mean() # loss_eval, _, _, _, _ = tree_crf_loss(energies_eval, head_var, type_var, mask_var) # loss_eval = loss_eval.mean() # l2 regularization? if regular == 'l2': l2_penalty = lasagne.regularization.regularize_network_params( network, lasagne.regularization.l2) loss_train = loss_train + gamma * l2_penalty updates = create_updates(loss_train, network, opt, learning_rate, momentum, beta1, beta2) # Compile a function performing a training step on a mini-batch train_fn = theano.function( [word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train, updates=updates, on_unused_input='warn') # Compile a second function evaluating the loss and accuracy of network eval_fn = theano.function( [word_var, char_var, pos_var, head_var, type_var, mask_var], [loss_eval, energies_eval], on_unused_input='warn') # Finally, launch the training loop. logger.info( "Start training: (#training data: %d, batch size: %d, clip: %.1f)..." % (num_data, batch_size, grad_clipping)) num_batches = num_data / batch_size + 1 dev_ucorrect = 0.0 dev_lcorrect = 0.0 dev_ucorrect_nopunct = 0.0 dev_lcorrect_nopunct = 0.0 best_epoch = 0 test_ucorrect = 0.0 test_lcorrect = 0.0 test_ucorrect_nopunct = 0.0 test_lcorrect_nopunct = 0.0 test_total = 0 test_total_nopunc = 0 test_inst = 0 lr = learning_rate for epoch in range(1, num_epochs + 1): print 'Epoch %d (learning rate=%.5f, decay rate=%.4f, beta1=%.3f, beta2=%.3f): ' % ( epoch, lr, decay_rate, beta1, beta2) train_err = 0.0 train_inst = 0 start_time = time.time() num_back = 0 for batch in xrange(1, num_batches + 1): wids, cids, pids, hids, tids, masks = data_utils.get_batch( data_train, batch_size) err = train_fn(wids, cids, pids, hids, tids, masks) train_err += err * wids.shape[0] train_inst += wids.shape[0] time_ave = (time.time() - start_time) / batch time_left = (num_batches - batch) * time_ave # update log sys.stdout.write("\b" * num_back) log_info = 'train: %d/%d loss: %.4f, time left: %.2fs' % ( batch, num_batches, train_err / train_inst, time_left) sys.stdout.write(log_info) num_back = len(log_info) # update training log after each epoch assert train_inst == num_batches * batch_size sys.stdout.write("\b" * num_back) print 'train: %d/%d loss: %.4f, time: %.2fs' % ( train_inst, train_inst, train_err / train_inst, time.time() - start_time) # evaluate performance on dev data dev_err = 0.0 dev_ucorr = 0.0 dev_lcorr = 0.0 dev_ucorr_nopunc = 0.0 dev_lcorr_nopunc = 0.0 dev_total = 0 dev_total_nopunc = 0 dev_inst = 0 for batch in data_utils.iterate_batch(data_dev, batch_size): wids, cids, pids, hids, tids, masks = batch err, energies = eval_fn(wids, cids, pids, hids, tids, masks) dev_err += err * wids.shape[0] pars_pred, types_pred = parser.decode_MST(energies, masks) ucorr, lcorr, total, ucorr_nopunc, \ lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks, tmp_dir + '/dev_parse%d' % epoch, word_alphabet, pos_alphabet, type_alphabet, punct_set=punct_set) dev_inst += wids.shape[0] dev_ucorr += ucorr dev_lcorr += lcorr dev_total += total dev_ucorr_nopunc += ucorr_nopunc dev_lcorr_nopunc += lcorr_nopunc dev_total_nopunc += total_nopunc print 'dev loss: %.4f' % (dev_err / dev_inst) print 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % ( dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total) print 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % ( dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc * 100 / dev_total_nopunc) if dev_ucorrect_nopunct <= dev_ucorr_nopunc: dev_ucorrect_nopunct = dev_ucorr_nopunc dev_lcorrect_nopunct = dev_lcorr_nopunc dev_ucorrect = dev_ucorr dev_lcorrect = dev_lcorr best_epoch = epoch test_err = 0.0 test_ucorr = 0.0 test_lcorr = 0.0 test_ucorr_nopunc = 0.0 test_lcorr_nopunc = 0.0 test_total = 0 test_total_nopunc = 0 test_inst = 0 for batch in data_utils.iterate_batch(data_test, batch_size): wids, cids, pids, hids, tids, masks = batch err, energies = eval_fn(wids, cids, pids, hids, tids, masks) test_err += err * wids.shape[0] pars_pred, types_pred = parser.decode_MST(energies, masks) ucorr, lcorr, total, ucorr_nopunc, \ lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks, tmp_dir + '/test_parse%d' % epoch, word_alphabet, pos_alphabet, type_alphabet, punct_set=punct_set) test_inst += wids.shape[0] test_ucorr += ucorr test_lcorr += lcorr test_total += total test_ucorr_nopunc += ucorr_nopunc test_lcorr_nopunc += lcorr_nopunc test_total_nopunc += total_nopunc test_ucorrect = test_ucorr test_lcorrect = test_lcorr test_ucorrect_nopunct = test_ucorr_nopunc test_lcorrect_nopunct = test_lcorr_nopunc print 'best dev W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total, best_epoch) print 'best dev Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( dev_ucorrect_nopunct, dev_lcorrect_nopunct, dev_total_nopunc, dev_ucorrect_nopunct * 100 / dev_total_nopunc, dev_lcorrect_nopunct * 100 / dev_total_nopunc, best_epoch) print 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, best_epoch) print 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % ( test_ucorrect_nopunct, test_lcorrect_nopunct, test_total_nopunc, test_ucorrect_nopunct * 100 / test_total_nopunc, test_lcorrect_nopunct * 100 / test_total_nopunc, best_epoch) if epoch in schedule: # if epoch % schedule == 0: lr = lr * decay_rate # lr = learning_rate / (1.0 + epoch * decay_rate) updates = create_updates(loss_train, network, opt, lr, momentum, beta1, beta2) train_fn = theano.function( [word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train, updates=updates, on_unused_input='warn')