예제 #1
0
def main():
    args_parser = argparse.ArgumentParser(description='Neural MST-Parser')
    args_parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=10, help='Number of sentences in each batch')
    args_parser.add_argument('--num_units', type=int, default=100, help='Number of hidden units in LSTM')
    args_parser.add_argument('--depth', type=int, default=2, help='Depth of LSTM layer')
    args_parser.add_argument('--mlp', type=int, default=1, help='Depth of MLP layer')
    args_parser.add_argument('--num_filters', type=int, default=20, help='Number of filters in CNN')
    args_parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate')
    args_parser.add_argument('--grad_clipping', type=float, default=0, help='Gradient clipping')
    args_parser.add_argument('--peepholes', action='store_true', help='Peepholes for LSTM')
    args_parser.add_argument('--max_norm', type=float, default=0, help='weight for max-norm regularization')
    args_parser.add_argument('--gamma', type=float, default=1e-6, help='weight for regularization')
    args_parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    args_parser.add_argument('--delta', type=float, default=0.0, help='weight for expectation-linear regularization')
    args_parser.add_argument('--regular', choices=['none', 'l2'], help='regularization for training', required=True)
    args_parser.add_argument('--opt', choices=['adam', 'momentum'], help='optimization algorithm', required=True)
    args_parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
    args_parser.add_argument('--schedule', nargs='+', type=int, help='schedule for learning rate decay', required=True)
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay', required=True)
    args_parser.add_argument('--pos', action='store_true', help='using pos embedding')
    args_parser.add_argument('--char', action='store_true', help='using cnn for character embedding')
    args_parser.add_argument('--normalize_digits', action='store_true', help='normalize digits')
    args_parser.add_argument('--output_prediction', action='store_true', help='Output predictions to temp files')
    # args_parser.add_argument('--punctuation', default=None, help='List of punctuations separated by whitespace')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--train', help='path of training data')
    args_parser.add_argument('--dev', help='path of validation data')
    args_parser.add_argument('--test', help='path of test data')
    args_parser.add_argument('--embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words',
                             required=True)
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--embedding_dict', default='data/word2vec/GoogleNews-vectors-negative300.bin',
                             help='path for embedding dict')
    args_parser.add_argument('--char_dict', default='data/polyglot/polyglot-zh_char.pkl',
                             help='path for character embedding dict')
    args_parser.add_argument('--tmp', default='tmp', help='Directory for temp files.')

    args = args_parser.parse_args()

    logger = get_logger("Parsing")
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    num_units = args.num_units
    depth = args.depth
    mlp = args.mlp
    num_filters = args.num_filters
    regular = args.regular
    opt = args.opt
    grad_clipping = args.grad_clipping
    peepholes = args.peepholes
    gamma = args.gamma
    delta = args.delta
    max_norm = args.max_norm
    learning_rate = args.learning_rate
    momentum = 0.9
    beta1 = 0.9
    beta2 = args.beta2
    decay_rate = args.decay_rate
    schedule = args.schedule
    use_pos = args.pos
    use_char = args.char
    normalize_digits = args.normalize_digits
    output_predict = args.output_prediction
    dropout = args.dropout
    punctuation = args.punctuation
    tmp_dir = args.tmp
    embedding = args.embedding
    char_embedding = args.char_embedding
    embedding_path = args.embedding_dict
    char_path = args.char_dict

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    logger.info("Creating Alphabets: normalize_digits=%s" % normalize_digits)
    word_alphabet, char_alphabet, \
    pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/", [train_path,],
                                                              60000, min_occurence=1, normalize_digits=normalize_digits)
    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Type Alphabet Size: %d" % type_alphabet.size())

    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Reading Data")
    data_train = data_utils.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet,
                                      normalize_digits=normalize_digits)
    data_dev = data_utils.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet,
                                    normalize_digits=normalize_digits)
    data_test = data_utils.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet,
                                     normalize_digits=normalize_digits)

    num_data = sum([len(bucket) for bucket in data_train])

    logger.info("constructing network...(pos embedding=%s, character embedding=%s)" % (use_pos, use_char))
    # create variables
    head_var = T.imatrix(name='heads')
    type_var = T.imatrix(name='types')
    mask_var = T.matrix(name='masks', dtype=theano.config.floatX)
    word_var = T.imatrix(name='inputs')
    pos_var = T.imatrix(name='pos-inputs')
    char_var = T.itensor3(name='char-inputs')

    network = build_network(word_var, char_var, pos_var, mask_var, word_alphabet, char_alphabet, pos_alphabet,
                            depth, num_units, num_types, grad_clipping, num_filters,
                            p=dropout, mlp=mlp, peepholes=peepholes,
                            use_char=use_char, use_pos=use_pos, normalize_digits=normalize_digits,
                            embedding=embedding, embedding_path=embedding_path,
                            char_embedding=char_embedding, char_path=char_path)

    logger.info("Network: depth=%d, hidden=%d, peepholes=%s, filter=%d, dropout=%s, #mlp=%d" % (
        depth, num_units, peepholes, num_filters, dropout, mlp))
    # compute loss
    energies_train = lasagne.layers.get_output(network)
    energies_eval = lasagne.layers.get_output(network, deterministic=True)

    loss_train = tree_crf_loss(energies_train, head_var, type_var, mask_var).mean()
    loss_eval = tree_crf_loss(energies_eval, head_var, type_var, mask_var).mean()
    # loss_train, E, D, L, lengths = tree_crf_loss(energies_train, head_var, type_var, mask_var)
    # loss_train = loss_train.mean()
    # loss_eval, _, _, _, _ = tree_crf_loss(energies_eval, head_var, type_var, mask_var)
    # loss_eval = loss_eval.mean()

    # l2 regularization?
    if regular == 'l2':
        l2_penalty = lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
        loss_train = loss_train + gamma * l2_penalty

    updates = create_updates(loss_train, network, opt, learning_rate, momentum, beta1, beta2)

    # Compile a function performing a training step on a mini-batch
    train_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train, updates=updates,
                               on_unused_input='warn')
    # Compile a second function evaluating the loss and accuracy of network
    eval_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], [loss_eval, energies_eval],
                              on_unused_input='warn')

    # Finally, launch the training loop.
    logger.info("Start training: (#training data: %d, batch size: %d, clip: %.1f)..." % (
        num_data, batch_size, grad_clipping))

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucorrect_nopunct = 0.0
    dev_lcorrect_nopunct = 0.0
    best_epoch = 0
    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucorrect_nopunct = 0.0
    test_lcorrect_nopunct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_inst = 0
    lr = learning_rate
    for epoch in range(1, num_epochs + 1):
        print 'Epoch %d (learning rate=%.5f, decay rate=%.4f, beta1=%.3f, beta2=%.3f): ' % (
            epoch, lr, decay_rate, beta1, beta2)
        train_err = 0.0
        train_inst = 0
        start_time = time.time()
        num_back = 0
        for batch in xrange(1, num_batches + 1):
            wids, cids, pids, hids, tids, masks = data_utils.get_batch(data_train, batch_size)
            err = train_fn(wids, cids, pids, hids, tids, masks)
            train_err += err * wids.shape[0]
            train_inst += wids.shape[0]
            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            sys.stdout.write("\b" * num_back)
            log_info = 'train: %d/%d loss: %.4f, time left: %.2fs' % (
                batch, num_batches, train_err / train_inst, time_left)
            sys.stdout.write(log_info)
            num_back = len(log_info)
        # update training log after each epoch
        assert train_inst == num_batches * batch_size
        sys.stdout.write("\b" * num_back)
        print 'train: %d/%d loss: %.4f, time: %.2fs' % (
            train_inst, train_inst, train_err / train_inst, time.time() - start_time)

        # evaluate performance on dev data
        dev_err = 0.0
        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total = 0
        dev_total_nopunc = 0
        dev_inst = 0
        for batch in data_utils.iterate_batch(data_dev, batch_size):
            wids, cids, pids, hids, tids, masks = batch
            err, energies = eval_fn(wids, cids, pids, hids, tids, masks)
            dev_err += err * wids.shape[0]
            pars_pred, types_pred = parser.decode_MST(energies, masks)
            ucorr, lcorr, total, ucorr_nopunc, \
            lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks,
                                                     tmp_dir + '/dev_parse%d' % epoch, word_alphabet, pos_alphabet,
                                                     type_alphabet, punct_set=punct_set)
            dev_inst += wids.shape[0]

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
        print 'dev loss: %.4f' % (dev_err / dev_inst)
        print 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total)
        print 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc)

        if dev_ucorrect_nopunct <= dev_ucorr_nopunc:
            dev_ucorrect_nopunct = dev_ucorr_nopunc
            dev_lcorrect_nopunct = dev_lcorr_nopunc
            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            best_epoch = epoch

            test_err = 0.0
            test_ucorr = 0.0
            test_lcorr = 0.0
            test_ucorr_nopunc = 0.0
            test_lcorr_nopunc = 0.0
            test_total = 0
            test_total_nopunc = 0
            test_inst = 0
            for batch in data_utils.iterate_batch(data_test, batch_size):
                wids, cids, pids, hids, tids, masks = batch
                err, energies = eval_fn(wids, cids, pids, hids, tids, masks)
                test_err += err * wids.shape[0]
                pars_pred, types_pred = parser.decode_MST(energies, masks)
                ucorr, lcorr, total, ucorr_nopunc, \
                lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks,
                                                         tmp_dir + '/test_parse%d' % epoch, word_alphabet, pos_alphabet,
                                                         type_alphabet, punct_set=punct_set)
                test_inst += wids.shape[0]

                test_ucorr += ucorr
                test_lcorr += lcorr
                test_total += total

                test_ucorr_nopunc += ucorr_nopunc
                test_lcorr_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
            test_ucorrect = test_ucorr
            test_lcorrect = test_lcorr
            test_ucorrect_nopunct = test_ucorr_nopunc
            test_lcorrect_nopunct = test_lcorr_nopunc

        print 'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            best_epoch)
        print 'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunct, dev_lcorrect_nopunct, dev_total_nopunc, dev_ucorrect_nopunct * 100 / dev_total_nopunc,
            dev_lcorrect_nopunct * 100 / dev_total_nopunc, best_epoch)
        print 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total,
            test_lcorrect * 100 / test_total, best_epoch)
        print 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunct, test_lcorrect_nopunct, test_total_nopunc,
            test_ucorrect_nopunct * 100 / test_total_nopunc, test_lcorrect_nopunct * 100 / test_total_nopunc,
            best_epoch)

        if epoch in schedule:
        # if epoch % schedule == 0:
            lr = lr * decay_rate
            # lr = learning_rate / (1.0 + epoch * decay_rate)
            updates = create_updates(loss_train, network, opt, lr, momentum, beta1, beta2)
            train_fn = theano.function([word_var, char_var, pos_var, head_var, type_var, mask_var], loss_train,
                                       updates=updates, on_unused_input='warn')
예제 #2
0
def main():
    args_parser = argparse.ArgumentParser(description='Neural MST-Parser')
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=1000,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=10,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--num_units',
                             type=int,
                             default=100,
                             help='Number of hidden units in LSTM')
    args_parser.add_argument('--depth',
                             type=int,
                             default=2,
                             help='Depth of LSTM layer')
    args_parser.add_argument('--mlp',
                             type=int,
                             default=1,
                             help='Depth of MLP layer')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=20,
                             help='Number of filters in CNN')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.1,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.1,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--grad_clipping',
                             type=float,
                             default=0,
                             help='Gradient clipping')
    args_parser.add_argument('--peepholes',
                             action='store_true',
                             help='Peepholes for LSTM')
    args_parser.add_argument('--max_norm',
                             type=float,
                             default=0,
                             help='weight for max-norm regularization')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=1e-6,
                             help='weight for regularization')
    args_parser.add_argument('--beta2',
                             type=float,
                             default=0.9,
                             help='beta2 for adam')
    args_parser.add_argument(
        '--delta',
        type=float,
        default=0.0,
        help='weight for expectation-linear regularization')
    args_parser.add_argument('--regular',
                             choices=['none', 'l2'],
                             help='regularization for training',
                             required=True)
    args_parser.add_argument('--opt',
                             choices=['adam', 'momentum'],
                             help='optimization algorithm',
                             required=True)
    args_parser.add_argument('--dropout',
                             type=float,
                             default=0.5,
                             help='dropout rate')
    args_parser.add_argument('--schedule',
                             nargs='+',
                             type=int,
                             help='schedule for learning rate decay',
                             required=True)
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay', required=True)
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='using pos embedding')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='using cnn for character embedding')
    args_parser.add_argument('--normalize_digits',
                             action='store_true',
                             help='normalize digits')
    args_parser.add_argument('--output_prediction',
                             action='store_true',
                             help='Output predictions to temp files')
    # args_parser.add_argument('--punctuation', default=None, help='List of punctuations separated by whitespace')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--train', help='path of training data')
    args_parser.add_argument('--dev', help='path of validation data')
    args_parser.add_argument('--test', help='path of test data')
    args_parser.add_argument('--embedding',
                             choices=['glove', 'senna', 'sskip', 'polyglot'],
                             help='Embedding for words',
                             required=True)
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument(
        '--embedding_dict',
        default='data/word2vec/GoogleNews-vectors-negative300.bin',
        help='path for embedding dict')
    args_parser.add_argument('--char_dict',
                             default='data/polyglot/polyglot-zh_char.pkl',
                             help='path for character embedding dict')
    args_parser.add_argument('--tmp',
                             default='tmp',
                             help='Directory for temp files.')

    args = args_parser.parse_args()

    logger = get_logger("Parsing")
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    num_units = args.num_units
    depth = args.depth
    mlp = args.mlp
    num_filters = args.num_filters
    regular = args.regular
    opt = args.opt
    grad_clipping = args.grad_clipping
    peepholes = args.peepholes
    gamma = args.gamma
    delta = args.delta
    max_norm = args.max_norm
    learning_rate = args.learning_rate
    momentum = 0.9
    beta1 = 0.9
    beta2 = args.beta2
    decay_rate = args.decay_rate
    schedule = args.schedule
    use_pos = args.pos
    use_char = args.char
    normalize_digits = args.normalize_digits
    output_predict = args.output_prediction
    dropout = args.dropout
    punctuation = args.punctuation
    tmp_dir = args.tmp
    embedding = args.embedding
    char_embedding = args.char_embedding
    embedding_path = args.embedding_dict
    char_path = args.char_dict

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    logger.info("Creating Alphabets: normalize_digits=%s" % normalize_digits)
    word_alphabet, char_alphabet, \
    pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/", [train_path,],
                                                              60000, min_occurence=1, normalize_digits=normalize_digits)
    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Type Alphabet Size: %d" % type_alphabet.size())

    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Reading Data")
    data_train = data_utils.read_data(train_path,
                                      word_alphabet,
                                      char_alphabet,
                                      pos_alphabet,
                                      type_alphabet,
                                      normalize_digits=normalize_digits)
    data_dev = data_utils.read_data(dev_path,
                                    word_alphabet,
                                    char_alphabet,
                                    pos_alphabet,
                                    type_alphabet,
                                    normalize_digits=normalize_digits)
    data_test = data_utils.read_data(test_path,
                                     word_alphabet,
                                     char_alphabet,
                                     pos_alphabet,
                                     type_alphabet,
                                     normalize_digits=normalize_digits)

    num_data = sum([len(bucket) for bucket in data_train])

    logger.info(
        "constructing network...(pos embedding=%s, character embedding=%s)" %
        (use_pos, use_char))
    # create variables
    head_var = T.imatrix(name='heads')
    type_var = T.imatrix(name='types')
    mask_var = T.matrix(name='masks', dtype=theano.config.floatX)
    word_var = T.imatrix(name='inputs')
    pos_var = T.imatrix(name='pos-inputs')
    char_var = T.itensor3(name='char-inputs')

    network = build_network(word_var,
                            char_var,
                            pos_var,
                            mask_var,
                            word_alphabet,
                            char_alphabet,
                            pos_alphabet,
                            depth,
                            num_units,
                            num_types,
                            grad_clipping,
                            num_filters,
                            p=dropout,
                            mlp=mlp,
                            peepholes=peepholes,
                            use_char=use_char,
                            use_pos=use_pos,
                            normalize_digits=normalize_digits,
                            embedding=embedding,
                            embedding_path=embedding_path,
                            char_embedding=char_embedding,
                            char_path=char_path)

    logger.info(
        "Network: depth=%d, hidden=%d, peepholes=%s, filter=%d, dropout=%s, #mlp=%d"
        % (depth, num_units, peepholes, num_filters, dropout, mlp))
    # compute loss
    energies_train = lasagne.layers.get_output(network)
    energies_eval = lasagne.layers.get_output(network, deterministic=True)

    loss_train = tree_crf_loss(energies_train, head_var, type_var,
                               mask_var).mean()
    loss_eval = tree_crf_loss(energies_eval, head_var, type_var,
                              mask_var).mean()
    # loss_train, E, D, L, lengths = tree_crf_loss(energies_train, head_var, type_var, mask_var)
    # loss_train = loss_train.mean()
    # loss_eval, _, _, _, _ = tree_crf_loss(energies_eval, head_var, type_var, mask_var)
    # loss_eval = loss_eval.mean()

    # l2 regularization?
    if regular == 'l2':
        l2_penalty = lasagne.regularization.regularize_network_params(
            network, lasagne.regularization.l2)
        loss_train = loss_train + gamma * l2_penalty

    updates = create_updates(loss_train, network, opt, learning_rate, momentum,
                             beta1, beta2)

    # Compile a function performing a training step on a mini-batch
    train_fn = theano.function(
        [word_var, char_var, pos_var, head_var, type_var, mask_var],
        loss_train,
        updates=updates,
        on_unused_input='warn')
    # Compile a second function evaluating the loss and accuracy of network
    eval_fn = theano.function(
        [word_var, char_var, pos_var, head_var, type_var, mask_var],
        [loss_eval, energies_eval],
        on_unused_input='warn')

    # Finally, launch the training loop.
    logger.info(
        "Start training: (#training data: %d, batch size: %d, clip: %.1f)..." %
        (num_data, batch_size, grad_clipping))

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucorrect_nopunct = 0.0
    dev_lcorrect_nopunct = 0.0
    best_epoch = 0
    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucorrect_nopunct = 0.0
    test_lcorrect_nopunct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_inst = 0
    lr = learning_rate
    for epoch in range(1, num_epochs + 1):
        print 'Epoch %d (learning rate=%.5f, decay rate=%.4f, beta1=%.3f, beta2=%.3f): ' % (
            epoch, lr, decay_rate, beta1, beta2)
        train_err = 0.0
        train_inst = 0
        start_time = time.time()
        num_back = 0
        for batch in xrange(1, num_batches + 1):
            wids, cids, pids, hids, tids, masks = data_utils.get_batch(
                data_train, batch_size)
            err = train_fn(wids, cids, pids, hids, tids, masks)
            train_err += err * wids.shape[0]
            train_inst += wids.shape[0]
            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            sys.stdout.write("\b" * num_back)
            log_info = 'train: %d/%d loss: %.4f, time left: %.2fs' % (
                batch, num_batches, train_err / train_inst, time_left)
            sys.stdout.write(log_info)
            num_back = len(log_info)
        # update training log after each epoch
        assert train_inst == num_batches * batch_size
        sys.stdout.write("\b" * num_back)
        print 'train: %d/%d loss: %.4f, time: %.2fs' % (
            train_inst, train_inst, train_err / train_inst,
            time.time() - start_time)

        # evaluate performance on dev data
        dev_err = 0.0
        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total = 0
        dev_total_nopunc = 0
        dev_inst = 0
        for batch in data_utils.iterate_batch(data_dev, batch_size):
            wids, cids, pids, hids, tids, masks = batch
            err, energies = eval_fn(wids, cids, pids, hids, tids, masks)
            dev_err += err * wids.shape[0]
            pars_pred, types_pred = parser.decode_MST(energies, masks)
            ucorr, lcorr, total, ucorr_nopunc, \
            lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks,
                                                     tmp_dir + '/dev_parse%d' % epoch, word_alphabet, pos_alphabet,
                                                     type_alphabet, punct_set=punct_set)
            dev_inst += wids.shape[0]

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
        print 'dev loss: %.4f' % (dev_err / dev_inst)
        print 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
            dev_lcorr * 100 / dev_total)
        print 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
            dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc)

        if dev_ucorrect_nopunct <= dev_ucorr_nopunc:
            dev_ucorrect_nopunct = dev_ucorr_nopunc
            dev_lcorrect_nopunct = dev_lcorr_nopunc
            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            best_epoch = epoch

            test_err = 0.0
            test_ucorr = 0.0
            test_lcorr = 0.0
            test_ucorr_nopunc = 0.0
            test_lcorr_nopunc = 0.0
            test_total = 0
            test_total_nopunc = 0
            test_inst = 0
            for batch in data_utils.iterate_batch(data_test, batch_size):
                wids, cids, pids, hids, tids, masks = batch
                err, energies = eval_fn(wids, cids, pids, hids, tids, masks)
                test_err += err * wids.shape[0]
                pars_pred, types_pred = parser.decode_MST(energies, masks)
                ucorr, lcorr, total, ucorr_nopunc, \
                lcorr_nopunc, total_nopunc = parser.eval(wids, pids, pars_pred, types_pred, hids, tids, masks,
                                                         tmp_dir + '/test_parse%d' % epoch, word_alphabet, pos_alphabet,
                                                         type_alphabet, punct_set=punct_set)
                test_inst += wids.shape[0]

                test_ucorr += ucorr
                test_lcorr += lcorr
                test_total += total

                test_ucorr_nopunc += ucorr_nopunc
                test_lcorr_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
            test_ucorrect = test_ucorr
            test_lcorrect = test_lcorr
            test_ucorrect_nopunct = test_ucorr_nopunc
            test_lcorrect_nopunct = test_lcorr_nopunc

        print 'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 /
            dev_total, dev_lcorrect * 100 / dev_total, best_epoch)
        print 'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunct, dev_lcorrect_nopunct, dev_total_nopunc,
            dev_ucorrect_nopunct * 100 / dev_total_nopunc,
            dev_lcorrect_nopunct * 100 / dev_total_nopunc, best_epoch)
        print 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
            test_total, test_lcorrect * 100 / test_total, best_epoch)
        print 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunct, test_lcorrect_nopunct, test_total_nopunc,
            test_ucorrect_nopunct * 100 / test_total_nopunc,
            test_lcorrect_nopunct * 100 / test_total_nopunc, best_epoch)

        if epoch in schedule:
            # if epoch % schedule == 0:
            lr = lr * decay_rate
            # lr = learning_rate / (1.0 + epoch * decay_rate)
            updates = create_updates(loss_train, network, opt, lr, momentum,
                                     beta1, beta2)
            train_fn = theano.function(
                [word_var, char_var, pos_var, head_var, type_var, mask_var],
                loss_train,
                updates=updates,
                on_unused_input='warn')
예제 #3
0
def main():
    parser = argparse.ArgumentParser(description='Tuning with bi-directional LSTM-CNN-CRF')
    parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=10, help='Number of sentences in each batch')
    parser.add_argument('--num_units', type=int, default=100, help='Number of hidden units in LSTM')
    parser.add_argument('--num_filters', type=int, default=20, help='Number of filters in CNN')
    parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate')
    parser.add_argument('--grad_clipping', type=float, default=0, help='Gradient clipping')
    parser.add_argument('--gamma', type=float, default=1e-6, help='weight for regularization')
    parser.add_argument('--delta', type=float, default=0.0, help='weight for expectation-linear regularization')
    parser.add_argument('--regular', choices=['none', 'l2'], help='regularization for training', required=True)
    parser.add_argument('--dropout', choices=['std', 'recurrent'], help='dropout patten')
    parser.add_argument('--schedule', nargs='+', type=int, help='schedule for learning rate decay')
    parser.add_argument('--output_prediction', action='store_true', help='Output predictions to temp files')
    parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("Sequence Labeling")
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    num_units = args.num_units
    num_filters = args.num_filters
    regular = args.regular
    grad_clipping = args.grad_clipping
    gamma = args.gamma
    delta = args.delta
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    schedule = args.schedule
    output_predict = args.output_prediction
    dropout = args.dropout
    p = 0.5

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = data_utils.create_alphabets("data/alphabets/",
                                                                                            [train_path, dev_path,
                                                                                             test_path],
                                                                                            40000)
    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())

    num_labels = pos_alphabet.size() - 1

    logger.info("Reading Data")
    data_train = data_utils.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    data_dev = data_utils.read_data(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    data_test = data_utils.read_data(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    num_data = sum([len(bucket) for bucket in data_train])

    logger.info("constructing network...")
    # create variables
    target_var = T.imatrix(name='targets')
    mask_var = T.matrix(name='masks', dtype=theano.config.floatX)
    mask_nr_var = T.matrix(name='masks_nr', dtype=theano.config.floatX)
    word_var = T.imatrix(name='inputs')
    char_var = T.itensor3(name='char-inputs')

    network = build_network(word_var, char_var, mask_var, word_alphabet, char_alphabet, dropout, num_units, num_labels,
                            grad_clipping, num_filters, p)

    logger.info("Network structure: hidden=%d, filter=%d, dropout=%s" % (num_units, num_filters, dropout))
    # compute loss
    num_tokens = mask_var.sum(dtype=theano.config.floatX)
    num_tokens_nr = mask_nr_var.sum(dtype=theano.config.floatX)

    # get outpout of bi-lstm-cnn-crf shape [batch, length, num_labels, num_labels]
    energies_train = lasagne.layers.get_output(network)
    energies_train_det = lasagne.layers.get_output(network, deterministic=True)
    energies_eval = lasagne.layers.get_output(network, deterministic=True)

    loss_train_org = chain_crf_loss(energies_train, target_var, mask_var).mean()

    energy_shape = energies_train.shape
    # [batch, length, num_labels, num_labels] --> [batch*length, num_labels*num_labels]
    energies = T.reshape(energies_train, (energy_shape[0] * energy_shape[1], energy_shape[2] * energy_shape[3]))
    energies = nonlinearities.softmax(energies)
    energies_det = T.reshape(energies_train_det, (energy_shape[0] * energy_shape[1], energy_shape[2] * energy_shape[3]))
    energies_det = nonlinearities.softmax(energies_det)
    # [batch*length, num_labels*num_labels] --> [batch, length*num_labels*num_labels]
    energies = T.reshape(energies, (energy_shape[0], energy_shape[1] * energy_shape[2] * energy_shape[3]))
    energies_det = T.reshape(energies_det, (energy_shape[0], energy_shape[1] * energy_shape[2] * energy_shape[3]))

    loss_train_expect_linear = lasagne.objectives.squared_error(energies, energies_det)
    loss_train_expect_linear = loss_train_expect_linear.sum(axis=1)
    loss_train_expect_linear = loss_train_expect_linear.mean()

    loss_train = loss_train_org + delta * loss_train_expect_linear
    # l2 regularization?
    if regular == 'l2':
        l2_penalty = lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
        loss_train = loss_train + gamma * l2_penalty

    _, corr_train = chain_crf_accuracy(energies_train, target_var)
    corr_nr_train = (corr_train * mask_nr_var).sum(dtype=theano.config.floatX)
    corr_train = (corr_train * mask_var).sum(dtype=theano.config.floatX)
    prediction_eval, corr_eval = chain_crf_accuracy(energies_eval, target_var)
    corr_nr_eval = (corr_eval * mask_nr_var).sum(dtype=theano.config.floatX)
    corr_eval = (corr_eval * mask_var).sum(dtype=theano.config.floatX)

    params = lasagne.layers.get_all_params(network, trainable=True)
    updates = nesterov_momentum(loss_train, params=params, learning_rate=learning_rate, momentum=momentum)

    # Compile a function performing a training step on a mini-batch
    train_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var],
                               [loss_train, loss_train_org, loss_train_expect_linear,
                                corr_train, corr_nr_train, num_tokens, num_tokens_nr], updates=updates)
    # Compile a second function evaluating the loss and accuracy of network
    eval_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var],
                              [corr_eval, corr_nr_eval, num_tokens, num_tokens_nr, prediction_eval])

    # Finally, launch the training loop.
    logger.info(
        "Start training: regularization: %s(%f), dropout: %s, delta: %.2f (#training data: %d, batch size: %d, clip: %.1f)..." \
        % (regular, (0.0 if regular == 'none' else gamma), dropout, delta, num_data, batch_size, grad_clipping))

    num_batches = num_data / batch_size + 1
    dev_correct = 0.0
    dev_correct_nr = 0.0
    best_epoch = 0
    test_correct = 0.0
    test_correct_nr = 0.0
    test_total = 0
    test_total_nr = 0
    test_inst = 0
    lr = learning_rate
    for epoch in range(1, num_epochs + 1):
        print 'Epoch %d (learning rate=%.4f, decay rate=%.4f): ' % (epoch, lr, decay_rate)
        train_err = 0.0
        train_err_org = 0.0
        train_err_linear = 0.0
        train_corr = 0.0
        train_corr_nr = 0.0
        train_total = 0
        train_total_nr = 0
        train_inst = 0
        start_time = time.time()
        num_back = 0
        for batch in xrange(1, num_batches + 1):
            wids, cids, pids, _, _, masks = data_utils.get_batch(data_train, batch_size)
            masks_nr = np.copy(masks)
            masks_nr[:, 0] = 0
            err, err_org, err_linear, corr, corr_nr, num, num_nr = train_fn(wids, cids, pids, masks, masks_nr)
            train_err += err * wids.shape[0]
            train_err_org += err_org * wids.shape[0]
            train_err_linear += err_linear * wids.shape[0]
            train_corr += corr
            train_corr_nr += corr_nr
            train_total += num
            train_total_nr += num_nr
            train_inst += wids.shape[0]
            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            sys.stdout.write("\b" * num_back)
            log_info = 'train: %d/%d loss: %.4f, loss_org: %.4f, loss_linear: %.4f, acc: %.2f%%, acc(no root): %.2f%%, time left (estimated): %.2fs' % (
                batch, num_batches, train_err / train_inst, train_err_org / train_inst, train_err_linear / train_inst,
                train_corr * 100 / train_total, train_corr_nr * 100 / train_total_nr, time_left)
            sys.stdout.write(log_info)
            num_back = len(log_info)
        # update training log after each epoch
        assert train_inst == num_batches * batch_size
        assert train_total == train_total_nr + train_inst
        sys.stdout.write("\b" * num_back)
        print 'train: %d/%d loss: %.4f,  loss_org: %.4f, loss_linear: %.4f, acc: %.2f%%, acc(no root): %.2f%%, time: %.2fs' % (
            train_inst, train_inst, train_err / train_inst, train_err_org / train_inst, train_err_linear / train_inst,
            train_corr * 100 / train_total, train_corr_nr * 100 / train_total_nr, time.time() - start_time)

        # evaluate performance on dev data
        dev_corr = 0.0
        dev_corr_nr = 0.0
        dev_total = 0
        dev_total_nr = 0
        dev_inst = 0
        for batch in data_utils.iterate_batch(data_dev, batch_size):
            wids, cids, pids, _, _, masks = batch
            masks_nr = np.copy(masks)
            masks_nr[:, 0] = 0
            corr, corr_nr, num, num_nr, predictions = eval_fn(wids, cids, pids, masks, masks_nr)
            dev_corr += corr
            dev_corr_nr += corr_nr
            dev_total += num
            dev_total_nr += num_nr
            dev_inst += wids.shape[0]
        assert dev_total == dev_total_nr + dev_inst
        print 'dev corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%%' % (
            dev_corr, dev_total, dev_corr * 100 / dev_total, dev_corr_nr, dev_total_nr, dev_corr_nr * 100 / dev_total_nr)

        if dev_correct_nr < dev_corr_nr:
            dev_correct = dev_corr
            dev_correct_nr = dev_corr_nr
            best_epoch = epoch

            # evaluate on test data when better performance detected
            test_corr = 0.0
            test_corr_nr = 0.0
            test_total = 0
            test_total_nr = 0
            test_inst = 0
            for batch in data_utils.iterate_batch(data_test, batch_size):
                wids, cids, pids, _, _, masks = batch
                masks_nr = np.copy(masks)
                masks_nr[:, 0] = 0
                corr, corr_nr, num, num_nr, predictions = eval_fn(wids, cids, pids, masks, masks_nr)
                test_corr += corr
                test_corr_nr += corr_nr
                test_total += num
                test_total_nr += num_nr
                test_inst += wids.shape[0]
            assert test_total + test_total_nr + test_inst
            test_correct = test_corr
            test_correct_nr = test_corr_nr
        print "best dev  corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % (
            dev_correct, dev_total, dev_correct * 100 / dev_total,
            dev_correct_nr, dev_total_nr, dev_correct_nr * 100 / dev_total_nr, best_epoch)
        print "best test corr: %d, total: %d, acc: %.2f%%, no root corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % (
            test_correct, test_total, test_correct * 100 / test_total,
            test_correct_nr, test_total_nr, test_correct_nr * 100 / test_total_nr, best_epoch)

        if epoch in schedule:
            lr = lr * decay_rate
            updates = nesterov_momentum(loss_train, params=params, learning_rate=lr, momentum=momentum)
            train_fn = theano.function([word_var, char_var, target_var, mask_var, mask_nr_var],
                                       [loss_train, loss_train_org, loss_train_expect_linear,
                                        corr_train, corr_nr_train, num_tokens, num_tokens_nr], updates=updates)
예제 #4
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional MAXRU-CNN')
    parser.add_argument('--architec',
                        choices=['sgru', 'lstm', 'gru0', 'gru1'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=1000,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--num_units',
                        type=int,
                        default=100,
                        help='Number of hidden units in TARU')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.1,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--grad_clipping',
                        type=float,
                        default=0,
                        help='Gradient clipping')
    parser.add_argument('--schedule',
                        nargs='+',
                        type=int,
                        help='schedule for learning rate decay')
    args = parser.parse_args()

    architec = args.architec
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    num_units = args.num_units
    learning_rate = args.learning_rate
    decay_rate = args.decay_rate
    schedule = args.schedule
    grad_clipping = args.grad_clipping
    logger = get_logger("Sentiment Classification (%s)" % (architec))

    def read_dataset(filename):
        data = [[] for _ in _buckets]
        print 'Reading data from %s' % filename
        counter = 0
        with open(filename, "r") as f:
            for line in f:
                counter += 1
                tag, words = line.lower().strip().split(" ||| ")
                words = words.split(" ")
                wids = [w2i[x] for x in words]
                tag = t2i[tag]
                length = len(words)
                for bucket_id, bucket_size in enumerate(_buckets):
                    if length < bucket_size:
                        data[bucket_id].append([words, wids, tag])
                        break

        print "Total number of data: %d" % counter
        return data

    def generate_random_embedding(scale, shape):
        return np.random.uniform(-scale, scale,
                                 shape).astype(theano.config.floatX)

    def construct_word_input_layer():
        # shape = [batch, n-step]
        layer_word_input = lasagne.layers.InputLayer(shape=(None, None),
                                                     input_var=word_var,
                                                     name='word_input')
        # shape = [batch, n-step, w_dim]
        layer_word_embedding = lasagne.layers.EmbeddingLayer(
            layer_word_input,
            input_size=vocab_size,
            output_size=WORD_DIM,
            W=word_table,
            name='word_embedd')
        return layer_word_embedding

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / WORD_DIM)
        table = np.empty([vocab_size, WORD_DIM], dtype=theano.config.floatX)
        table[UNK, :] = generate_random_embedding(scale, [1, WORD_DIM])
        for word, index in w2i.iteritems():
            if index == 0:
                continue
            ww = word.lower() if caseless else word
            embedding = embedd_dict[
                ww] if ww in embedd_dict else generate_random_embedding(
                    scale, [1, WORD_DIM])
            table[index, :] = embedding
        return table

    # Functions to read in the corpus
    w2i = defaultdict(lambda: len(w2i))
    t2i = defaultdict(lambda: len(t2i))
    UNK = w2i["<unk>"]

    data_train = read_dataset('data/sst1/train.txt')
    w2i = defaultdict(lambda: UNK, w2i)
    data_dev = read_dataset('data/sst1/dev.txt')
    data_test = read_dataset('data/sst1/test.txt')
    vocab_size = len(w2i)
    num_labels = len(t2i)

    embedd_dict, embedd_dim, caseless = utils.load_word_embedding_dict(
        'glove', "data/glove/glove.6B/glove.6B.100d.gz")
    assert embedd_dim == WORD_DIM

    num_data_train = sum([len(bucket) for bucket in data_train])
    num_data_dev = sum([len(bucket) for bucket in data_dev])
    num_data_test = sum([len(bucket) for bucket in data_test])

    logger.info("constructing network...")
    # create variables
    target_var = T.ivector(name='targets')
    mask_var = T.matrix(name='masks', dtype=theano.config.floatX)
    word_var = T.imatrix(name='inputs')

    word_table = construct_word_embedding_table()
    layer_word_input = construct_word_input_layer()
    layer_mask = lasagne.layers.InputLayer(shape=(None, None),
                                           input_var=mask_var,
                                           name='mask')

    layer_input = layer_word_input

    layer_input = lasagne.layers.DropoutLayer(layer_input, p=0.2)

    layer_rnn = build_RNN(architec, layer_input, layer_mask, num_units,
                          grad_clipping)
    layer_rnn = lasagne.layers.DropoutLayer(layer_rnn, p=0.5)

    network = lasagne.layers.DenseLayer(layer_rnn,
                                        num_units=num_labels,
                                        nonlinearity=nonlinearities.softmax,
                                        name='softmax')

    # get output of bi-taru-cnn shape=[batch * max_length, #label]
    prediction_train = lasagne.layers.get_output(network)
    prediction_eval = lasagne.layers.get_output(network, deterministic=True)
    final_prediction = T.argmax(prediction_eval, axis=1)

    loss_train = lasagne.objectives.categorical_crossentropy(
        prediction_train, target_var).mean()
    loss_eval = lasagne.objectives.categorical_crossentropy(
        prediction_eval, target_var).mean()

    corr_train = lasagne.objectives.categorical_accuracy(
        prediction_train, target_var).sum()
    corr_eval = lasagne.objectives.categorical_accuracy(
        prediction_eval, target_var).sum()

    params = lasagne.layers.get_all_params(network, trainable=True)
    updates = adam(loss_train,
                   params=params,
                   learning_rate=learning_rate,
                   beta1=0.9,
                   beta2=0.9)

    # Compile a function performing a training step on a mini-batch
    train_fn = theano.function([word_var, target_var, mask_var],
                               [loss_train, corr_train],
                               updates=updates)
    # Compile a second function evaluating the loss and accuracy of network
    eval_fn = theano.function([word_var, target_var, mask_var],
                              [corr_eval, final_prediction])

    # Finally, launch the training loop.
    logger.info("%s: (#data: %d, batch size: %d, clip: %.1f)" %
                (architec, num_data_train, batch_size, grad_clipping))

    num_batches = num_data_train / batch_size + 1
    dev_correct = 0.0
    best_epoch = 0
    test_correct = 0.0
    test_total = 0
    lr = learning_rate
    for epoch in range(1, num_epochs + 1):
        print 'Epoch %d (%s, learning rate=%.4f, decay rate=%.4f): ' % (
            epoch, architec, lr, decay_rate)
        train_err = 0.0
        train_corr = 0.0
        train_total = 0
        start_time = time.time()
        num_back = 0
        for batch in xrange(1, num_batches + 1):
            wids, tids, masks = get_batch(data_train, batch_size)
            num = wids.shape[0]
            err, corr = train_fn(wids, tids, masks)
            train_err += err * num
            train_corr += corr
            train_total += num
            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            sys.stdout.write("\b" * num_back)
            log_info = 'train: %d/%d loss: %.4f, acc: %.2f%%, time left (estimated): %.2fs' % (
                batch, num_batches, train_err / train_total,
                train_corr * 100 / train_total, time_left)
            sys.stdout.write(log_info)
            num_back = len(log_info)
        # update training log after each epoch
        assert train_total == num_batches * batch_size
        sys.stdout.write("\b" * num_back)
        print 'train: %d loss: %.4f, acc: %.2f%%, time: %.2fs' % (
            train_total, train_err / train_total,
            train_corr * 100 / train_total, time.time() - start_time)

        # evaluate performance on dev data
        dev_corr = 0.0
        dev_total = 0
        for batch in iterate_batch(data_dev, batch_size):
            wids, tids, masks = batch
            num = wids.shape[0]
            corr, predictions = eval_fn(wids, tids, masks)
            dev_corr += corr
            dev_total += num

        assert dev_total == num_data_dev
        print 'dev corr: %d, total: %d, acc: %.2f%%' % (
            dev_corr, dev_total, dev_corr * 100 / dev_total)

        if dev_correct <= dev_corr:
            dev_correct = dev_corr
            best_epoch = epoch

            # evaluate on test data when better performance detected
            test_corr = 0.0
            test_total = 0
            for batch in iterate_batch(data_test, batch_size):
                wids, tids, masks = batch
                num = wids.shape[0]
                corr, predictions = eval_fn(wids, tids, masks)
                test_corr += corr
                test_total += num

            assert test_total == num_data_test
            test_correct = test_corr
        print "best dev  corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % (
            dev_correct, dev_total, dev_correct * 100 / dev_total, best_epoch)
        print "best test corr: %d, total: %d, acc: %.2f%%(epoch: %d)" % (
            test_correct, test_total, test_correct * 100 / test_total,
            best_epoch)

        if epoch in schedule:
            lr = lr * decay_rate
            updates = adam(loss_train,
                           params=params,
                           learning_rate=lr,
                           beta1=0.9,
                           beta2=0.9)
            train_fn = theano.function([word_var, target_var, mask_var],
                                       [loss_train, corr_train],
                                       updates=updates)