def compute_loss(lang_name, land_idx):
            word, char, pos, _, _, masks, lengths, bert_inputs = conllx_data.get_batch_variable(train_data[lang_name],
                                                                                                batch_size,
                                                                                                unk_replace=0.5)

            if use_gpu:
                word = word.cuda()
                char = char.cuda()
                pos = pos.cuda()
                masks = masks.cuda()
                lengths = lengths.cuda()
                if bert_inputs[0] is not None:
                    bert_inputs[0] = bert_inputs[0].cuda()
                    bert_inputs[1] = bert_inputs[1].cuda()
                    bert_inputs[2] = bert_inputs[2].cuda()

            output = network.forward(word, char, pos, input_bert=bert_inputs,
                                     mask=masks, length=lengths, hx=None)
            output = output['output'].detach()

            if args.train_level == 'word':
                output = classifier(output)
                output = output.contiguous().view(-1, output.size(2))
            else:
                output = torch.mean(output, dim=1)
                output = classifier(output)

            labels = torch.empty(output.size(0)).fill_(land_idx).type_as(output).long()
            loss = criterion(output, labels)
            return loss
Beispiel #2
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with graph-based parsing')
    args_parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn', required=True)
    args_parser.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--type_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    args_parser.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
    args_parser.add_argument('--pos', action='store_true', help='use part-of-speech embedding.')
    args_parser.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
    args_parser.add_argument('--objective', choices=['cross_entropy', 'crf'], default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode', choices=['mst', 'greedy'], help='decoding algorithm', required=True)
    args_parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    args_parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    args_parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0.,
                             help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--word_embedding', choices=['glove', 'senna', 'sskip', 'polyglot'],
                             help='Embedding for words', required=True)
    args_parser.add_argument('--word_path', help='path for word embedding dict')
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path', help='path for saving model file.', required=True)

    args = args_parser.parse_args()

    print("*** Model UID: %s ***" % uid)

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    betas = (0.9, 0.9)
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path],
                                                                                             max_vocabulary_size=50000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_data.read_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, symbolic_root=True)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])

    data_dev = conllx_data.read_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)
    data_test = conllx_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim, num_words,
                                          char_dim, num_chars,
                                          pos_dim, num_pos,
                                          num_filters, window,
                                          mode, hidden_size, num_layers,
                                          num_types, arc_space, type_space,
                                          embedd_word=word_table, embedd_char=char_table,
                                          p_in=p_in, p_out=p_out, p_rnn=p_rnn, biaffine=True, pos=use_pos)
    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    if use_gpu:
        network.cuda()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    adam_epochs = 50
    adam_rate = 0.001
    if adam_epochs > 0:
        lr = adam_rate
        opt = 'adam'
        optim = Adam(network.parameters(), lr=adam_rate, betas=betas, weight_decay=gamma)
    else:
        opt = 'sgd'
        lr = learning_rate
        optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)

    logger.info("Embedding dim: word=%d, char=%d, pos=%d (%s)" % (word_dim, char_dim, pos_dim, use_pos))
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, arc_space=%d, type_space=%d" % (
        mode, num_layers, hidden_size, num_filters, arc_space, type_space))
    logger.info("train: obj: %s, l2: %f, (#data: %d, batch: %d, dropout(in, out, rnn): (%.2f, %.2f, %s), unk replace: %.2f)" % (
        obj, gamma, num_data, batch_size, p_in, p_out, p_rnn, unk_replace))
    logger.info("decoding algorithm: %s" % decoding)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s, optim: %s, learning rate=%.4f, decay rate=%.4f (schedule=%d)): ' % (
            epoch, mode, opt, lr, decay_rate, schedule))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(data_train, batch_size,
                                                                                           unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word, char, pos, heads, types, mask=masks, length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            optim.step()

            num_inst = word.size(0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total,
                    train_err_arc / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' % (
            num_batches, train_err / train_total, train_err_arc / train_total, train_err_type / train_total,
            time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        print('[%s] Epoch %d complete' % (time.strftime("%Y-%m-%d %H:%M:%S"), epoch))

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
            gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total,
            dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %(
            dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc <= dev_ucorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()

        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
            test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
            best_epoch))
        print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
            test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
            test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
            best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
        print('============================================================================================================================')

        if epoch % schedule == 0:
            # lr = lr * decay_rate
            if epoch < adam_epochs:
                opt = 'adam'
                lr = adam_rate / (1.0 + epoch * decay_rate)
                optim = Adam(network.parameters(), lr=lr, betas=betas, weight_decay=gamma)
            else:
                opt = 'sgd'
                lr = learning_rate / (1.0 + (epoch - adam_epochs) * decay_rate)
                optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=1000,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("POSCRFTagger")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p = args.p
    unk_replace = args.unk_replace
    bigram = args.bigram

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        'glove', "data/glove/glove.6B/glove.6B.100d.gz")
    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_data.create_alphabets("data/alphabets/pos_crf/", train_path,
                                                 data_paths=[dev_path, test_path],
                                                 max_vocabulary_size=50000, embedd_dict=embedd_dict)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_data.read_data_to_variable(train_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   type_alphabet,
                                                   use_gpu=use_gpu)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])
    num_labels = pos_alphabet.size()

    data_dev = conllx_data.read_data_to_variable(dev_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 pos_alphabet,
                                                 type_alphabet,
                                                 use_gpu=use_gpu,
                                                 volatile=True)
    data_test = conllx_data.read_data_to_variable(test_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  type_alphabet,
                                                  use_gpu=use_gpu,
                                                  volatile=True)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = 1
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     embedd_word=word_table,
                                     p_rnn=p,
                                     bigram=bigram)
    else:
        raise NotImplementedError

    if use_gpu:
        network.cuda()

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma)
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, crf=%s" %
                (mode, num_layers, hidden_size, num_filters,
                 'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, dropout: %.2f, unk replace: %.2f)"
        % (gamma, num_data, batch_size, p, unk_replace))

    num_batches = num_data / batch_size + 1
    dev_correct = 0.0
    best_epoch = 0
    test_correct = 0.0
    test_total = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, labels, _, _, masks, lengths = conllx_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        dev_corr = 0.0
        dev_total = 0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, labels, _, _, masks, lengths = batch
            preds, corr = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            num_tokens = masks.data.sum()
            dev_corr += corr
            dev_total += num_tokens
        print('dev corr: %d, total: %d, acc: %.2f%%' %
              (dev_corr, dev_total, dev_corr * 100 / dev_total))

        if dev_correct < dev_corr:
            dev_correct = dev_corr
            best_epoch = epoch

            # evaluate on test data when better performance detected
            test_corr = 0.0
            test_total = 0
            for batch in conllx_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, labels, _, _, masks, lengths = batch
                preds, corr = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                num_tokens = masks.data.sum()
                test_corr += corr
                test_total += num_tokens
            test_correct = test_corr
        print("best dev  corr: %d, total: %d, acc: %.2f%% (epoch: %d)" %
              (dev_correct, dev_total, dev_correct * 100 / dev_total,
               best_epoch))
        print("best test corr: %d, total: %d, acc: %.2f%% (epoch: %d)" %
              (test_correct, test_total, test_correct * 100 / test_total,
               best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.register('type', 'bool', str2bool)

    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=1000,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder.')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             default='mst',
                             help='decoding algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    # args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument('--data_dir', help='Data directory path')
    args_parser.add_argument(
        '--src_lang',
        required=True,
        help='Src language to train dependency parsing model')
    args_parser.add_argument('--aux_lang',
                             nargs='+',
                             help='Language names for adversarial training')
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    #
    args_parser.add_argument('--attn_on_rnn',
                             action='store_true',
                             help='use self-attention on top of context RNN.')
    args_parser.add_argument('--no_word',
                             type='bool',
                             default=False,
                             help='do not use word embedding.')
    args_parser.add_argument('--use_bert',
                             type='bool',
                             default=False,
                             help='use multilingual BERT.')
    #
    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             type='bool',
                             default=False,
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    # encoder selection
    args_parser.add_argument('--encoder_type',
                             choices=['Transformer', 'RNN', 'SelfAttn'],
                             default='RNN',
                             help='do not use context RNN.')
    args_parser.add_argument(
        '--pool_type',
        default='mean',
        choices=['max', 'mean', 'weight'],
        help='pool type to form fixed length vector from word embeddings')
    # Tansformer encoder
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    args_parser.add_argument(
        '--use_all_encoder_layers',
        type='bool',
        default=False,
        help='Use a weighted representations of all encoder layers')
    # - positional
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=50,
                             help='Dimension of Position embeddings.')
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')

    args_parser.add_argument('--input_concat_embeds',
                             action='store_true',
                             help="Concat input embeddings, otherwise add.")
    args_parser.add_argument('--input_concat_position',
                             action='store_true',
                             help="Concat position embeddings, otherwise add.")
    args_parser.add_argument(
        '--partitioned',
        type='bool',
        default=False,
        help=
        "Partition the content and positional attention for multi-head attention."
    )
    args_parser.add_argument(
        '--partition_type',
        choices=['content-position', 'lexical-delexical'],
        default='content-position',
        help="How to apply partition in the self-attention.")
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    #
    # regarding adversarial training
    args_parser.add_argument('--pre_model_path',
                             type=str,
                             default=None,
                             help='Path of the pretrained model.')
    args_parser.add_argument('--pre_model_name',
                             type=str,
                             default=None,
                             help='Name of the pretrained model.')
    args_parser.add_argument('--adv_training',
                             type='bool',
                             default=False,
                             help='Use adversarial training.')
    args_parser.add_argument(
        '--lambdaG',
        type=float,
        default=0.001,
        help='Scaling parameter to control generator loss.')
    args_parser.add_argument('--discriminator',
                             choices=['weak', 'not-so-weak', 'strong'],
                             default='weak',
                             help='architecture of the discriminator')
    args_parser.add_argument(
        '--delay',
        type=int,
        default=0,
        help='Number of epochs to be run first for the source task')
    args_parser.add_argument(
        '--n_critic',
        type=int,
        default=5,
        help='Number of training steps for discriminator per iter')
    args_parser.add_argument(
        '--clip_disc',
        type=float,
        default=5.0,
        help='Lower and upper clip value for disc. weights')
    args_parser.add_argument('--debug',
                             type='bool',
                             default=False,
                             help='Use debug portion of the training data')
    args_parser.add_argument('--train_level',
                             type=str,
                             default='word',
                             choices=['word', 'sent'],
                             help='Use X-level adversarial training')
    args_parser.add_argument('--train_type',
                             type=str,
                             default='GAN',
                             choices=['GR', 'GAN', 'WGAN'],
                             help='Type of adversarial training')
    #
    # regarding motivational training
    args_parser.add_argument(
        '--motivate',
        type='bool',
        default=False,
        help='This is opposite of the adversarial training')

    #
    args = args_parser.parse_args()

    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    # if output directory doesn't exist, create it
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    logger = get_logger("GraphParser")

    logger.info('\ncommand-line params : {0}\n'.format(sys.argv[1:]))
    logger.info('{0}\n'.format(args))

    logger.info("Visible GPUs: %s", str(os.environ["CUDA_VISIBLE_DEVICES"]))
    args.parallel = False
    if torch.cuda.device_count() > 1:
        args.parallel = True

    mode = args.mode
    obj = args.objective
    decoding = args.decode

    train_path = args.data_dir + args.src_lang + "_train.debug.1_10.conllu" \
        if args.debug else args.data_dir + args.src_lang + '_train.conllu'
    dev_path = args.data_dir + args.src_lang + "_dev.conllu"
    test_path = args.data_dir + args.src_lang + "_test.conllu"

    #
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path
    model_path = args.model_path
    model_name = args.model_name

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    use_word_emb = not args.no_word
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    attn_on_rnn = args.attn_on_rnn
    encoder_type = args.encoder_type
    if attn_on_rnn:
        assert encoder_type == 'RNN'

    t_types = (args.adv_training, args.motivate)
    t_count = sum(1 for tt in t_types if tt)
    if t_count > 1:
        assert False, "Only one of: adv_training or motivate can be true"

    # ------------------- Loading/initializing embeddings -------------------- #

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)

    # TODO (WARNING): must build vocabs previously
    assert os.path.isdir(alphabet_path), "should have build vocabs previously"
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)
    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading/building the model ------------------------ #

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table() if use_word_emb else None
    char_table = construct_char_embedding_table() if use_char else None

    def load_model_arguments_from_json():
        arguments = json.load(open(pre_model_path, 'r'))
        return arguments['args'], arguments['kwargs']

    window = 3
    if obj == 'cross_entropy':
        if args.pre_model_path and args.pre_model_name:
            pre_model_name = os.path.join(args.pre_model_path,
                                          args.pre_model_name)
            pre_model_path = pre_model_name + '.arg.json'
            model_args, kwargs = load_model_arguments_from_json()

            network = BiRecurrentConvBiAffine(use_gpu=use_gpu,
                                              *model_args,
                                              **kwargs)
            network.load_state_dict(torch.load(pre_model_name))
            logger.info("Model reloaded from %s" % pre_model_path)

            # Adjust the word embedding layer
            if network.embedder.word_embedd is not None:
                network.embedder.word_embedd = nn.Embedding(num_words,
                                                            word_dim,
                                                            _weight=word_table)

        else:
            network = BiRecurrentConvBiAffine(
                word_dim,
                num_words,
                char_dim,
                num_chars,
                pos_dim,
                num_pos,
                num_filters,
                window,
                mode,
                hidden_size,
                num_layers,
                num_types,
                arc_space,
                type_space,
                embedd_word=word_table,
                embedd_char=char_table,
                p_in=p_in,
                p_out=p_out,
                p_rnn=p_rnn,
                biaffine=True,
                pos=use_pos,
                char=use_char,
                train_position=args.train_position,
                encoder_type=encoder_type,
                trans_hid_size=args.trans_hid_size,
                d_k=args.d_k,
                d_v=args.d_v,
                num_head=args.num_head,
                enc_use_neg_dist=args.enc_use_neg_dist,
                enc_clip_dist=args.enc_clip_dist,
                position_dim=args.position_dim,
                max_sent_length=max_sent_length,
                use_gpu=use_gpu,
                use_word_emb=use_word_emb,
                input_concat_embeds=args.input_concat_embeds,
                input_concat_position=args.input_concat_position,
                attn_on_rnn=attn_on_rnn,
                partitioned=args.partitioned,
                partition_type=args.partition_type,
                use_all_encoder_layers=args.use_all_encoder_layers,
                use_bert=args.use_bert)

    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading data -------------------------------------- #

    train_data = dict()
    dev_data = dict()
    test_data = dict()
    num_data = dict()
    lang_ids = dict()
    reverse_lang_ids = dict()

    # ===== the reading =============================================
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_data.read_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=False,
            volatile=(not is_train),
            symbolic_root=True,
            lang_id=lang_id,
            use_bert=args.use_bert,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    train_data[args.src_lang] = data_train
    num_data[args.src_lang] = sum(data_train[1])
    lang_ids[args.src_lang] = len(lang_ids)
    reverse_lang_ids[lang_ids[args.src_lang]] = args.src_lang

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    dev_data[args.src_lang] = data_dev
    test_data[args.src_lang] = data_test

    # ===============================================================

    # ===== reading data for adversarial training ===================
    if t_count > 0:
        for language in args.aux_lang:
            aux_train_path = args.data_dir + language + "_train.debug.1_10.conllu" \
                if args.debug else args.data_dir + language + '_train.conllu'
            aux_train_data = _read_one(aux_train_path, True)
            num_data[language] = sum(aux_train_data[1])
            train_data[language] = aux_train_data
            lang_ids[language] = len(lang_ids)
            reverse_lang_ids[lang_ids[language]] = language
    # ===============================================================

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, num_layers, num_types,
            arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'train_position': args.train_position,
            'encoder_type': args.encoder_type,
            'trans_hid_size': args.trans_hid_size,
            'd_k': args.d_k,
            'd_v': args.d_v,
            'num_head': args.num_head,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'enc_clip_dist': args.enc_clip_dist,
            'position_dim': args.position_dim,
            'max_sent_length': max_sent_length,
            'use_word_emb': use_word_emb,
            'input_concat_embeds': args.input_concat_embeds,
            'input_concat_position': args.input_concat_position,
            'attn_on_rnn': attn_on_rnn,
            'partitioned': args.partitioned,
            'partition_type': args.partition_type,
            'use_all_encoder_layers': args.use_all_encoder_layers,
            'use_bert': args.use_bert
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if use_word_emb and freeze:
        freeze_embedding(network.embedder.word_embedd)

    if args.parallel:
        network = torch.nn.DataParallel(network)

    if use_gpu:
        network = network.cuda()

    save_args()

    param_dict = {}
    encoder = network.module.encoder if args.parallel else network.encoder
    for name, param in encoder.named_parameters():
        if param.requires_grad:
            param_dict[name] = np.prod(param.size())

    total_params = np.sum(list(param_dict.values()))
    logger.info('Total Encoder Parameters = %d' % total_params)

    # ------------------------------------------------------------------------- #

    # =============================================
    if args.adv_training:
        disc_feat_size = network.module.encoder.output_dim if args.parallel else network.encoder.output_dim
        reverse_grad = args.train_type == 'GR'
        nclass = len(lang_ids) if args.train_type == 'GR' else 1

        kwargs = {
            'input_size': disc_feat_size,
            'disc_type': args.discriminator,
            'train_level': args.train_level,
            'train_type': args.train_type,
            'reverse_grad': reverse_grad,
            'soft_label': True,
            'nclass': nclass,
            'scale': args.lambdaG,
            'use_gpu': use_gpu,
            'opt': 'adam',
            'lr': 0.001,
            'betas': (0.9, 0.999),
            'gamma': 0,
            'eps': 1e-8,
            'momentum': 0,
            'clip_disc': args.clip_disc
        }
        AdvAgent = Adversarial(**kwargs)
        if use_gpu:
            AdvAgent.cuda()

    elif args.motivate:
        disc_feat_size = network.module.encoder.output_dim if args.parallel else network.encoder.output_dim
        nclass = len(lang_ids)

        kwargs = {
            'input_size': disc_feat_size,
            'disc_type': args.discriminator,
            'train_level': args.train_level,
            'nclass': nclass,
            'scale': args.lambdaG,
            'use_gpu': use_gpu,
            'opt': 'adam',
            'lr': 0.001,
            'betas': (0.9, 0.999),
            'gamma': 0,
            'eps': 1e-8,
            'momentum': 0,
            'clip_disc': args.clip_disc
        }
        MtvAgent = Motivator(**kwargs)
        if use_gpu:
            MtvAgent.cuda()

    # =============================================

    # --------------------- Initializing the optimizer ------------------------ #

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters(), betas, gamma,
                               eps, momentum)
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    # =============================================

    total_data = min(num_data.values())

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" %
        (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info(
        "train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)"
        % (obj, gamma, total_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    # ------------------------------------------------------------------------- #
    # --------------------- Form the mini-batches ----------------------------- #
    num_batches = total_data // batch_size + 1
    aux_lang = []
    if t_count > 0:
        for language in args.aux_lang:
            aux_lang.extend([language] * num_data[language])

        assert num_data[args.src_lang] <= len(aux_lang)
    # ------------------------------------------------------------------------- #

    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    if decoding == 'greedy':
        decode = network.module.decode if args.parallel else network.decode
    elif decoding == 'mst':
        decode = network.module.decode_mst if args.parallel else network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))

    skip_adv_tuning = 0
    loss_fn = network.module.loss if args.parallel else network.loss
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0

        skip_adv_tuning += 1
        loss_d_real, loss_d_fake = [], []
        acc_d_real, acc_d_fake, = [], []
        gen_loss, parsing_loss = [], []
        disent_loss = []

        if t_count > 0 and skip_adv_tuning > args.delay:
            batch_size = args.batch_size // 2
            num_batches = total_data // batch_size + 1

        # ---------------------- Sample the mini-batches -------------------------- #
        if t_count > 0:
            sampled_aux_lang = random.sample(aux_lang, num_batches)
            lang_in_batch = [(args.src_lang, sampled_aux_lang[k])
                             for k in range(num_batches)]
        else:
            lang_in_batch = [(args.src_lang, None) for _ in range(num_batches)]
        assert len(lang_in_batch) == num_batches
        # ------------------------------------------------------------------------- #

        network.train()
        warmup_factor = (lr + 0.) / num_batches
        for batch in range(1, num_batches + 1):
            update_generator = True
            update_discriminator = False

            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate

            # considering source language as real and auxiliary languages as fake
            real_lang, fake_lang = lang_in_batch[batch - 1]
            real_idx, fake_idx = lang_ids.get(real_lang), lang_ids.get(
                fake_lang, -1)

            #
            word, char, pos, heads, types, masks, lengths, bert_inputs = conllx_data.get_batch_variable(
                train_data[real_lang], batch_size, unk_replace=unk_replace)

            if use_gpu:
                word = word.cuda()
                char = char.cuda()
                pos = pos.cuda()
                heads = heads.cuda()
                types = types.cuda()
                masks = masks.cuda()
                lengths = lengths.cuda()
                if bert_inputs[0] is not None:
                    bert_inputs[0] = bert_inputs[0].cuda()
                    bert_inputs[1] = bert_inputs[1].cuda()
                    bert_inputs[2] = bert_inputs[2].cuda()

            real_enc = network(word,
                               char,
                               pos,
                               input_bert=bert_inputs,
                               mask=masks,
                               length=lengths,
                               hx=None)

            # ========== Update the discriminator ==========
            if t_count > 0 and skip_adv_tuning > args.delay:
                # fake examples = 0
                word_f, char_f, pos_f, heads_f, types_f, masks_f, lengths_f, bert_inputs = conllx_data.get_batch_variable(
                    train_data[fake_lang], batch_size, unk_replace=unk_replace)

                if use_gpu:
                    word_f = word_f.cuda()
                    char_f = char_f.cuda()
                    pos_f = pos_f.cuda()
                    heads_f = heads_f.cuda()
                    types_f = types_f.cuda()
                    masks_f = masks_f.cuda()
                    lengths_f = lengths_f.cuda()
                    if bert_inputs[0] is not None:
                        bert_inputs[0] = bert_inputs[0].cuda()
                        bert_inputs[1] = bert_inputs[1].cuda()
                        bert_inputs[2] = bert_inputs[2].cuda()

                fake_enc = network(word_f,
                                   char_f,
                                   pos_f,
                                   input_bert=bert_inputs,
                                   mask=masks_f,
                                   length=lengths_f,
                                   hx=None)

                # TODO: temporary crack
                if t_count > 0 and skip_adv_tuning > args.delay:
                    # skip discriminator training for '|n_critic|' iterations if 'n_critic' < 0
                    if args.n_critic > 0 or (batch - 1) % (-1 *
                                                           args.n_critic) == 0:
                        update_discriminator = True

            if update_discriminator:
                if args.adv_training:
                    real_loss, fake_loss, real_acc, fake_acc = AdvAgent.update(
                        real_enc['output'].detach(),
                        fake_enc['output'].detach(), real_idx, fake_idx)

                    loss_d_real.append(real_loss)
                    loss_d_fake.append(fake_loss)
                    acc_d_real.append(real_acc)
                    acc_d_fake.append(fake_acc)

                elif args.motivate:
                    real_loss, fake_loss, real_acc, fake_acc = MtvAgent.update(
                        real_enc['output'].detach(),
                        fake_enc['output'].detach(), real_idx, fake_idx)

                    loss_d_real.append(real_loss)
                    loss_d_fake.append(fake_loss)
                    acc_d_real.append(real_acc)
                    acc_d_fake.append(fake_acc)

                else:
                    raise NotImplementedError()

                if args.n_critic > 0 and (batch - 1) % args.n_critic != 0:
                    update_generator = False

            # ==============================================

            # =========== Update the generator =============
            if update_generator:
                others_loss = None
                if args.adv_training and skip_adv_tuning > args.delay:
                    # for GAN: L_G= L_parsing - (lambda_G * L_D)
                    # for GR : L_G= L_parsing +  L_D
                    others_loss = AdvAgent.gen_loss(real_enc['output'],
                                                    fake_enc['output'],
                                                    real_idx, fake_idx)
                    gen_loss.append(others_loss.item())

                elif args.motivate and skip_adv_tuning > args.delay:
                    others_loss = MtvAgent.gen_loss(real_enc['output'],
                                                    fake_enc['output'],
                                                    real_idx, fake_idx)
                    gen_loss.append(others_loss.item())

                optim.zero_grad()

                loss_arc, loss_type = loss_fn(real_enc['output'],
                                              heads,
                                              types,
                                              mask=masks,
                                              length=lengths)
                loss = loss_arc + loss_type

                num_inst = word.size(
                    0) if obj == 'crf' else masks.sum() - word.size(0)
                train_err += loss.item() * num_inst
                train_err_arc += loss_arc.item() * num_inst
                train_err_type += loss_type.item() * num_inst
                train_total += num_inst
                parsing_loss.append(loss.item())

                if others_loss is not None:
                    loss = loss + others_loss

                loss.backward()
                clip_grad_norm_(network.parameters(), clip)
                optim.step()

                time_ave = (time.time() - start_time) / batch
                time_left = (num_batches - batch) * time_ave

        if (args.adv_training
                or args.motivate) and skip_adv_tuning > args.delay:
            logger.info(
                'epoch: %d train: %d loss: %.4f, arc: %.4f, type: %.4f, dis_loss: (%.2f, %.2f), dis_acc: (%.2f, %.2f), '
                'gen_loss: %.2f, time: %.2fs' %
                (epoch, num_batches, train_err / train_total,
                 train_err_arc / train_total, train_err_type / train_total,
                 sum(loss_d_real) / len(loss_d_real), sum(loss_d_fake) /
                 len(loss_d_fake), sum(acc_d_real) / len(acc_d_real),
                 sum(acc_d_fake) / len(acc_d_fake),
                 sum(gen_loss) / len(gen_loss), time.time() - start_time))
        else:
            logger.info(
                'epoch: %d train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs'
                % (epoch, num_batches, train_err / train_total,
                   train_err_arc / train_total, train_err_type / train_total,
                   time.time() - start_time))

        ################# Validation on Dependency Parsing Only #################################
        if epoch % args.check_dev != 0:
            continue

        with torch.no_grad():
            # evaluate performance on dev data
            network.eval()

            dev_ucorr = 0.0
            dev_lcorr = 0.0
            dev_total = 0
            dev_ucomlpete = 0.0
            dev_lcomplete = 0.0
            dev_ucorr_nopunc = 0.0
            dev_lcorr_nopunc = 0.0
            dev_total_nopunc = 0
            dev_ucomlpete_nopunc = 0.0
            dev_lcomplete_nopunc = 0.0
            dev_root_corr = 0.0
            dev_total_root = 0.0
            dev_total_inst = 0.0

            for lang, data_dev in dev_data.items():
                for batch in conllx_data.iterate_batch_variable(
                        data_dev, batch_size):
                    word, char, pos, heads, types, masks, lengths, bert_inputs = batch

                    if use_gpu:
                        word = word.cuda()
                        char = char.cuda()
                        pos = pos.cuda()
                        heads = heads.cuda()
                        types = types.cuda()
                        masks = masks.cuda()
                        lengths = lengths.cuda()
                        if bert_inputs[0] is not None:
                            bert_inputs[0] = bert_inputs[0].cuda()
                            bert_inputs[1] = bert_inputs[1].cuda()
                            bert_inputs[2] = bert_inputs[2].cuda()

                    heads_pred, types_pred = decode(
                        word,
                        char,
                        pos,
                        input_bert=bert_inputs,
                        mask=masks,
                        length=lengths,
                        leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                    word = word.cpu().numpy()
                    pos = pos.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.cpu().numpy()
                    types = types.cpu().numpy()

                    stats, stats_nopunc, stats_root, num_inst = parser.eval(
                        word,
                        pos,
                        heads_pred,
                        types_pred,
                        heads,
                        types,
                        word_alphabet,
                        pos_alphabet,
                        lengths,
                        punct_set=punct_set,
                        symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    dev_ucorr += ucorr
                    dev_lcorr += lcorr
                    dev_total += total
                    dev_ucomlpete += ucm
                    dev_lcomplete += lcm

                    dev_ucorr_nopunc += ucorr_nopunc
                    dev_lcorr_nopunc += lcorr_nopunc
                    dev_total_nopunc += total_nopunc
                    dev_ucomlpete_nopunc += ucm_nopunc
                    dev_lcomplete_nopunc += lcm_nopunc

                    dev_root_corr += corr_root
                    dev_total_root += total_root
                    dev_total_inst += num_inst

            print(
                'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
                % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 /
                   dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete *
                   100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
            print(
                'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
                %
                (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
                 dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
                 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
                 dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
            print('Root: corr: %d, total: %d, acc: %.2f%%' %
                  (dev_root_corr, dev_total_root,
                   dev_root_corr * 100 / dev_total_root))

            if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                    dev_lcorrect_nopunc == dev_lcorr_nopunc
                    and dev_ucorrect_nopunc < dev_ucorr_nopunc):
                dev_ucorrect_nopunc = dev_ucorr_nopunc
                dev_lcorrect_nopunc = dev_lcorr_nopunc
                dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
                dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

                dev_ucorrect = dev_ucorr
                dev_lcorrect = dev_lcorr
                dev_ucomlpete_match = dev_ucomlpete
                dev_lcomplete_match = dev_lcomplete

                dev_root_correct = dev_root_corr

                best_epoch = epoch
                patient = 0

                state_dict = network.module.state_dict(
                ) if args.parallel else network.state_dict()
                torch.save(state_dict, model_name)

            else:
                if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                    state_dict = torch.load(model_name)
                    if args.parallel:
                        network.module.load_state_dict(state_dict)
                    else:
                        network.load_state_dict(state_dict)

                    lr = lr * decay_rate
                    optim = generate_optimizer(opt, lr, network.parameters(),
                                               betas, gamma, eps, momentum)

                    if decoding == 'greedy':
                        decode = network.module.decode if args.parallel else network.decode
                    elif decoding == 'mst':
                        decode = network.module.decode_mst if args.parallel else network.decode_mst
                    else:
                        raise ValueError('Unknown decoding algorithm: %s' %
                                         decoding)

                    patient = 0
                    decay += 1
                    if decay % double_schedule_decay == 0:
                        schedule *= 2
                else:
                    patient += 1

            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            print(
                'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 /
                   dev_total, dev_lcorrect * 100 / dev_total,
                   dev_ucomlpete_match * 100 / dev_total_inst,
                   dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
            print(
                'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
                   dev_ucorrect_nopunc * 100 / dev_total_nopunc,
                   dev_lcorrect_nopunc * 100 / dev_total_nopunc,
                   dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
                   dev_lcomplete_match_nopunc * 100 / dev_total_inst,
                   best_epoch))
            print(
                'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (dev_root_correct, dev_total_root,
                   dev_root_correct * 100 / dev_total_root, best_epoch))
            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            if decay == max_decay:
                break

        torch.cuda.empty_cache()  # release memory that can be released
Beispiel #5
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=200,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             help='decoding algorithm',
                             required=True)
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.05,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument(
        '--word_embedding',
        choices=['glove', 'senna', 'sskip', 'polyglot', 'NNLM'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)

    args_parser.add_argument('--pos_embedding',
                             choices=[1, 2, 4],
                             type=int,
                             help='Embedding method for korean POS tag',
                             default=2)

    args = args_parser.parse_args()

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_embedding = args.pos_embedding
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    data_paths = [dev_path, test_path] if test_path else [dev_path]
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=data_paths,
        max_vocabulary_size=50000,
        pos_embedding=pos_embedding,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_data.read_data_to_variable(train_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   type_alphabet,
                                                   pos_embedding,
                                                   use_gpu=use_gpu,
                                                   symbolic_root=True)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])

    data_dev = conllx_data.read_data_to_variable(dev_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 pos_alphabet,
                                                 type_alphabet,
                                                 pos_embedding,
                                                 use_gpu=use_gpu,
                                                 volatile=True,
                                                 symbolic_root=True)
    if test_path:
        data_test = conllx_data.read_data_to_variable(test_path,
                                                      word_alphabet,
                                                      char_alphabet,
                                                      pos_alphabet,
                                                      type_alphabet,
                                                      pos_embedding,
                                                      use_gpu=use_gpu,
                                                      volatile=True,
                                                      symbolic_root=True)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in list(char_alphabet.items()):
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim,
                                          num_words,
                                          char_dim,
                                          num_chars,
                                          pos_dim,
                                          num_pos,
                                          num_filters,
                                          window,
                                          mode,
                                          hidden_size,
                                          num_layers,
                                          num_types,
                                          arc_space,
                                          type_space,
                                          embedd_word=word_table,
                                          embedd_char=char_table,
                                          p_in=p_in,
                                          p_out=p_out,
                                          p_rnn=p_rnn,
                                          biaffine=True,
                                          pos=use_pos,
                                          char=use_char)
    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, num_layers, num_types,
            arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)

    def generate_optimizer(opt, lr, params):
        params = [param for param in params if param.requires_grad]
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" %
        (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info(
        "train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)"
        % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = 9
    double_schedule_decay = 5
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(
                data_train, batch_size, pos_embedding, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word,
                                               char,
                                               pos,
                                               heads,
                                               types,
                                               mask=masks,
                                               length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()

            num_inst = word.size(
                0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (
                    batch, num_batches, train_err / train_total, train_err_arc
                    / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print(
            'train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' %
            (num_batches, train_err / train_total, train_err_arc / train_total,
             train_err_type / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_data.iterate_batch_variable(
                data_dev, batch_size, pos_embedding):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser_bpe.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            if test_path:
                pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
                pred_writer.start(pred_filename)
                gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
                gold_writer.start(gold_filename)

                test_ucorrect = 0.0
                test_lcorrect = 0.0
                test_ucomlpete_match = 0.0
                test_lcomplete_match = 0.0
                test_total = 0

                test_ucorrect_nopunc = 0.0
                test_lcorrect_nopunc = 0.0
                test_ucomlpete_match_nopunc = 0.0
                test_lcomplete_match_nopunc = 0.0
                test_total_nopunc = 0
                test_total_inst = 0

                test_root_correct = 0.0
                test_total_root = 0
                for batch in conllx_data.iterate_batch_variable(
                        data_test, batch_size, pos_embedding):
                    word, char, pos, heads, types, masks, lengths = batch
                    heads_pred, types_pred = decode(
                        word,
                        char,
                        pos,
                        mask=masks,
                        length=lengths,
                        leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                    word = word.data.cpu().numpy()
                    pos = pos.data.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.data.cpu().numpy()
                    types = types.data.cpu().numpy()

                    pred_writer.write(word,
                                      pos,
                                      heads_pred,
                                      types_pred,
                                      lengths,
                                      symbolic_root=True)
                    gold_writer.write(word,
                                      pos,
                                      heads,
                                      types,
                                      lengths,
                                      symbolic_root=True)

                    stats, stats_nopunc, stats_root, num_inst = parser_bpe.eval(
                        word,
                        pos,
                        heads_pred,
                        types_pred,
                        heads,
                        types,
                        word_alphabet,
                        pos_alphabet,
                        lengths,
                        punct_set=punct_set,
                        symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    test_ucorrect += ucorr
                    test_lcorrect += lcorr
                    test_total += total
                    test_ucomlpete_match += ucm
                    test_lcomplete_match += lcm

                    test_ucorrect_nopunc += ucorr_nopunc
                    test_lcorrect_nopunc += lcorr_nopunc
                    test_total_nopunc += total_nopunc
                    test_ucomlpete_match_nopunc += ucm_nopunc
                    test_lcomplete_match_nopunc += lcm_nopunc

                    test_root_correct += corr_root
                    test_total_root += total_root

                    test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())

                if decoding == 'greedy':
                    decode = network.decode
                elif decoding == 'mst':
                    decode = network.decode_mst
                else:
                    raise ValueError('Unknown decoding algorithm: %s' %
                                     decoding)

                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        if test_path:
            print(
                'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (test_ucorrect, test_lcorrect, test_total, test_ucorrect *
                   100 / test_total, test_lcorrect * 100 / test_total,
                   test_ucomlpete_match * 100 / test_total_inst,
                   test_lcomplete_match * 100 / test_total_inst, best_epoch))
            print(
                'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                %
                (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                 test_ucorrect_nopunc * 100 / test_total_nopunc,
                 test_lcorrect_nopunc * 100 / test_total_nopunc,
                 test_ucomlpete_match_nopunc * 100 / test_total_inst,
                 test_lcomplete_match_nopunc * 100 / test_total_inst,
                 best_epoch))
            print(
                'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (test_root_correct, test_total_root,
                   test_root_correct * 100 / test_total_root, best_epoch))
            print(
                '============================================================================================================================'
            )

        if decay == max_decay:
            break

    def save_result():
        result_path = model_name + '.result.txt'
        best_dev_Punc = 'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total,
            dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst,
            dev_lcomplete_match * 100 / dev_total_inst, best_epoch)
        best_dev_noPunc = 'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc,
            dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
            dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch)
        best_dev_Root = 'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root,
            dev_root_correct * 100 / dev_total_root, best_epoch)
        f = open(result_path, 'w')
        f.write(best_dev_Punc.encode('utf-8') + '\n')
        f.write(best_dev_noPunc.encode('utf-8') + '\n')
        f.write(best_dev_Root.encode('utf-8'))
        f.close()

    save_result()
Beispiel #6
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=1000,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder.')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             default='mst',
                             help='decoding algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    # args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    #
    args_parser.add_argument('--no_word',
                             action='store_true',
                             help='do not use word embedding.')
    #
    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             action='store_true',
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    # Tansformer encoder
    args_parser.add_argument('--no_CoRNN',
                             action='store_true',
                             help='do not use context RNN.')
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--multi_head_attn',
                             action='store_true',
                             help='use multi-head-attention.')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    # - positional
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=50,
                             help='Dimension of Position embeddings.')
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    #
    args = args_parser.parse_args()

    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    #
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    # todo(warn): exactly same for loading vocabs
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # ===== the reading
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_data.read_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=(not is_train),
            symbolic_root=True,
            lang_id=lang_id,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    num_data = sum(data_train[1])

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    # =====

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(
            word_dim,
            num_words,
            char_dim,
            num_chars,
            pos_dim,
            num_pos,
            num_filters,
            window,
            mode,
            hidden_size,
            num_layers,
            num_types,
            arc_space,
            type_space,
            embedd_word=word_table,
            embedd_char=char_table,
            p_in=p_in,
            p_out=p_out,
            p_rnn=p_rnn,
            biaffine=True,
            pos=use_pos,
            char=use_char,
            train_position=args.train_position,
            use_con_rnn=(not args.no_CoRNN),
            trans_hid_size=args.trans_hid_size,
            d_k=args.d_k,
            d_v=args.d_v,
            multi_head_attn=args.multi_head_attn,
            num_head=args.num_head,
            enc_use_neg_dist=args.enc_use_neg_dist,
            enc_clip_dist=args.enc_clip_dist,
            position_dim=args.position_dim,
            max_sent_length=max_sent_length,
            use_gpu=use_gpu,
            no_word=args.no_word)

    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, num_layers, num_types,
            arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'train_position': args.train_position,
            'use_con_rnn': (not args.no_CoRNN),
            'trans_hid_size': args.trans_hid_size,
            'd_k': args.d_k,
            'd_v': args.d_v,
            'multi_head_attn': args.multi_head_attn,
            'num_head': args.num_head,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'enc_clip_dist': args.enc_clip_dist,
            'position_dim': args.position_dim,
            'max_sent_length': max_sent_length,
            'no_word': args.no_word
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" %
        (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info(
        "train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)"
        % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule
    warmup_factor = (lr + 0.) / num_batches

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))
    #

    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate
            #
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word,
                                               char,
                                               pos,
                                               heads,
                                               types,
                                               mask=masks,
                                               length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_inst = word.size(
                0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (
                    batch, num_batches, train_err / train_total, train_err_arc
                    / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print(
            'train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' %
            (num_batches, train_err / train_total, train_err_arc / train_total,
             train_err_type / train_total, time.time() - start_time))

        ################################################################################################
        if epoch % args.check_dev != 0:
            continue

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())

                if decoding == 'greedy':
                    decode = network.decode
                elif decoding == 'mst':
                    decode = network.decode_mst
                else:
                    raise ValueError('Unknown decoding algorithm: %s' %
                                     decoding)

                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        print(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (test_root_correct, test_total_root,
               test_root_correct * 100 / test_total_root, best_epoch))
        print(
            '============================================================================================================================'
        )

        if decay == max_decay:
            break
Beispiel #7
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with graph-based parsing')
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--freeze', action='store_true', help='frozen the word embedding (disable fine-tuning).')
    

    args = args_parser.parse_args()

    logger = get_logger("GraphParser")

    mode = "FastLSTM" #fast lstm here
    obj = "cross_entropy"
    decoding = "mst" #mst decode here 
    train_path = "data/train.stanford.conll"
    dev_path = "data/dev.stanford.conll"
    test_path = "data/test.stanford.conll"
    model_path = "models/parsing/biaffine/"
    model_name = 'network.pt'
    num_epochs = 80
    batch_size = 32
    hidden_size = 512
    arc_space = 512
    type_space = 128
    num_layers = 10
    num_filters = 1
    learning_rate = 0.001
    opt = "adam" #default adam
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = 1e-4
    decay_rate = 0.75
    clip = 5 #what is clip
    gamma = 0
    schedule = 10 #?What is this?
    p_rnn = (0.05,0.05)
    p_in = 0.33
    p_out = 0.33
    unk_replace = args.unk_replace# ?what is this?
    punctuation = ['.','``', "''", ':', ',']

    freeze = args.freeze
    word_embedding = 'glove'
    word_path = "data/glove.6B.100d.txt"

    use_char = False
    char_embedding = None
    #char_path = args.char_path

    use_pos = True
    pos_dim = 100
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = 0
    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path],
                                                                                             max_vocabulary_size=50000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()
    #print(word_alphabet.instance2index)

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()
    print(use_gpu)

    data_train = conllx_data.read_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, symbolic_root=True)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])
    """
    print("bucket_size")
    print(data_train[1])
    print("___________________________________data_train")
    print(data_train[0])
	"""
    data_dev = conllx_data.read_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)
    data_test = conllx_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                                          mode, hidden_size, num_layers, num_types, arc_space, type_space,
                                          embedd_word=word_table, embedd_char=None,
                                          p_in=p_in, p_out=p_out, p_rnn=p_rnn, biaffine=True, pos=use_pos, char=use_char)
    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                     mode, hidden_size, num_layers, num_types, arc_space, type_space]
        kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char}
        json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w'), indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()
    
    #pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    #gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    
    ##print parameters:
    print("number of parameters")

    num_param = sum([param.nelement() for param in network.parameters()])
    print(num_param)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info("Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" % (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info("RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info("train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)" % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    #logger.info("Attention")

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = 9
    double_schedule_decay = 5

    f = open("testout.csv", "wt")
    writer = csv.writer(f)
    writer.writerow(('train', 'dev'))

    for epoch in range(1, num_epochs + 1):
        print(epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay)
        print('Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): ' % (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word, char, pos, heads, types, mask=masks, length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_inst = word.size(0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst
            #bp()

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (batch, num_batches, train_err / train_total,
                                                                                                 train_err_arc / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' % (num_batches, train_err / train_total, train_err_arc / train_total, train_err_type / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        t_ucorr = 0.0
        t_lcorr = 0.0
        t_total = 0
        t_ucomlpete = 0.0
        t_lcomplete = 0.0
        t_ucorr_nopunc = 0.0
        t_lcorr_nopunc = 0.0
        t_total_nopunc = 0
        t_ucomlpete_nopunc = 0.0
        t_lcomplete_nopunc = 0.0
        t_root_corr = 0.0
        t_total_root = 0.0
        t_total_inst = 0.0

        list_iter = iter(conllx_data.iterate_batch_variable(data_train, batch_size))
        for batch in list_iter:
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root
            #print(t_ucorr)
            t_ucorr += ucorr
            t_lcorr += lcorr
            t_total += total
            t_ucomlpete += ucm
            t_lcomplete += lcm

            t_ucorr_nopunc += ucorr_nopunc
            t_lcorr_nopunc += lcorr_nopunc
            t_total_nopunc += total_nopunc
            t_ucomlpete_nopunc += ucm_nopunc
            t_lcomplete_nopunc += lcm_nopunc

            t_root_corr += corr_root
            t_total_root += total_root

            t_total_inst += num_inst
            for _ in range(10):   
                next(list_iter, None)

        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        writer.writerow((t_ucorr_nopunc*100/t_total_nopunc,dev_ucorr_nopunc*100/dev_total_nopunc)) 
        f.flush()
        #pred_writer.close()
        #gold_writer.close()
        print('Train Wo Punct:%.2f%%'% (t_ucorr_nopunc*100/t_total_nopunc))
        print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %(dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc< dev_lcorr_nopunc or (dev_lcorrect_nopunc == dev_lcorr_nopunc and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            #pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            #pred_writer.start(pred_filename)
            #gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            #gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                #pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                #gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            #pred_writer.close()
            #gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())

                if decoding == 'greedy':
                    decode = network.decode
                elif decoding == 'mst':
                    decode = network.decode_mst
                else:
                    raise ValueError('Unknown decoding algorithm: %s' % decoding)

                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
            test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
            best_epoch))
        print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
            test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
            test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
            best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
        print('============================================================================================================================')
        


        if decay == max_decay:
            break