示例#1
0
def main():
    uid = uuid.uuid4().hex[:6]

    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             default='FastLSTM')
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=10,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=32,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size',
                             type=int,
                             default=256,
                             help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of decoder RNN')
    args_parser.add_argument('--char_num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN(Character Level)')
    args_parser.add_argument('--eojul_num_filters',
                             type=int,
                             default=100,
                             help='Number of filters in CNN(Eojul Level)')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--eojul',
                             action='store_true',
                             help='use eojul embedding and CNN.')
    args_parser.add_argument('--word_dim',
                             type=int,
                             default=100,
                             help='Dimension of Word embeddings')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm',
                             default='adam')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             default=[0.33, 0.33],
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth',
                             type=float,
                             default=1.0,
                             help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect',
                             action='store_true',
                             help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar',
                             action='store_true',
                             help='use grand parent.')
    args_parser.add_argument('--sibling',
                             action='store_true',
                             help='use sibling.')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument('--schedule',
                             type=int,
                             default=20,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument(
        '--word_embedding',
        choices=['random', 'word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'word2vec'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument('--pos_embedding',
                             choices=['random', 'word2vec'],
                             help='Embedding for part of speeches',
                             required=True)
    args_parser.add_argument('--pos_path',
                             help='path for part of speech embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    args_parser.add_argument('--use_gpu',
                             action='store_true',
                             help='use the gpu')

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = "{}_{}".format(str(uid), args.model_name)
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    char_num_filters = args.char_num_filters
    eojul_num_filters = args.eojul_num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    use_gpu = args.use_gpu
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path
    pos_embedding = args.pos_embedding
    pos_path = args.pos_path

    use_pos = args.pos

    if word_embedding != 'random':
        word_dict, word_dim = utils.load_embedding_dict(
            word_embedding, word_path)
    else:
        word_dict = {}
        word_dim = args.word_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    else:
        if use_char:
            char_dict = {}
            char_dim = args.char_dim
        else:
            char_dict = None
    if pos_embedding != 'random':
        pos_dict, pos_dim = utils.load_embedding_dict(pos_embedding, pos_path)
    else:
        if use_pos:
            pos_dict = {}
            pos_dim = args.pos_dim
        else:
            pos_dict = None

    use_eojul = args.eojul

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = use_gpu

    data_train = conllx_stacked_data.read_stacked_data_to_variable(
        train_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    num_data = sum(data_train[1])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(
        dev_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_pos_embedding_table():
        if pos_dict is None:
            return None

        scale = np.sqrt(3.0 / pos_dim)
        table = np.empty([num_pos, pos_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, pos_dim]).astype(np.float32)
        oov = 0
        for pos, index in pos_alphabet.items():
            if pos in pos_dict:
                embedding = pos_dict[pos]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, pos_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('pos OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()
    pos_table = construct_pos_embedding_table()

    char_window = 3
    eojul_window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          char_num_filters,
                          char_window,
                          eojul_num_filters,
                          eojul_window,
                          mode,
                          input_size_decoder,
                          hidden_size,
                          encoder_layers,
                          decoder_layers,
                          num_types,
                          arc_space,
                          type_space,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          embedd_pos=pos_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          pos=use_pos,
                          char=use_char,
                          eojul=use_eojul,
                          prior_order=prior_order,
                          skipConnect=skipConnect,
                          grandPar=grandPar,
                          sibling=sibling)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            char_num_filters, char_window, eojul_num_filters, eojul_window,
            mode, input_size_decoder, hidden_size, encoder_layers,
            decoder_layers, num_types, arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'eojul': use_eojul,
            'prior_order': prior_order,
            'skipConnect': skipConnect,
            'grandPar': grandPar,
            'sibling': sibling
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("Char CNN: filter=%d, kernel=%d" %
                (char_num_filters, char_window))
    logger.info("Eojul CNN: filter=%d, kernel=%d" %
                (eojul_num_filters, eojul_window))
    logger.info(
        "RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d"
        % (mode, encoder_layers, decoder_layers, input_size_decoder,
           hidden_size, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' %
                (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d, use_gpu: %s' %
                (skipConnect, beam, use_gpu))
    logger.info(opt_info)

    num_batches = num_data // batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    decay = 0.
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d (%d, %d))): '
            % (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay,
               max_decay, double_schedule_decay))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train,
                batch_size,
                unk_replace=unk_replace,
                use_gpu=use_gpu)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling, stacked_types, label_smooth,
                                                            skip_connect=skip_connect, mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.item()  ##180809 data[0] --> item()
            num_non_leaf = num_non_leaf.item()  ##180809 data[0] --> item()

            train_err_arc_leaf += loss_arc_leaf.item(
            ) * num_leaf  ##180809 data[0] --> item()
            train_err_arc_non_leaf += loss_arc_non_leaf.item(
            ) * num_non_leaf  ##180809 data[0] --> item()

            train_err_type_leaf += loss_type_leaf.item(
            ) * num_leaf  ##180809 data[0] --> item()
            train_err_type_non_leaf += loss_type_non_leaf.item(
            ) * num_non_leaf  ##180809 data[0] --> item()

            train_err_cov += loss_cov.item() * (num_leaf + num_non_leaf
                                                )  ##180809 data[0] --> item()

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        print(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        #torch.save(network.state_dict(), model_name+"."+str(epoch))
        #continue

        # evaluate performance on dev data
        network.eval()
        tmp_root = 'tmp'
        if not os.path.isdir(tmp_root):
            logger.info('Creating temporary folder(%s)' % (tmp_root, ))
            os.makedirs(tmp_root)
        pred_filename = '%s/%spred_dev%d' % (tmp_root, str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = '%s/%sgold_dev%d' % (tmp_root, str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size, use_gpu=use_gpu):
            input_encoder, _, sentences = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(sentences,
                              word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(sentences,
                              word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size, use_gpu=use_gpu):
                input_encoder, _, sentences = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(sentences,
                                  word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(sentences,
                                  word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        if test_total_inst != 0 or test_total != 0:
            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            print(
                'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (test_ucorrect, test_lcorrect, test_total, test_ucorrect *
                   100 / test_total, test_lcorrect * 100 / test_total,
                   test_ucomlpete_match * 100 / test_total_inst,
                   test_lcomplete_match * 100 / test_total_inst, best_epoch))
            print(
                'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                %
                (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                 test_ucorrect_nopunc * 100 / test_total_nopunc,
                 test_lcorrect_nopunc * 100 / test_total_nopunc,
                 test_ucomlpete_match_nopunc * 100 / test_total_inst,
                 test_lcomplete_match_nopunc * 100 / test_total_inst,
                 best_epoch))
            print(
                'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (test_root_correct, test_total_root,
                   test_root_correct * 100 / test_total_root, best_epoch))
            print(
                '============================================================================================================================'
            )

        if decay == max_decay:
            break
def main():
    args_parser = argparse.ArgumentParser(description='Testing with stack pointer parser')

    args_parser.add_argument('--model_path', help='path for parser model directory', required=True)
    args_parser.add_argument('--model_name', help='parser model file', required=True)
    args_parser.add_argument('--output_path', help='path for result with parser model', required=True)
    args_parser.add_argument('--test', required=True)
    args_parser.add_argument('--beam', type=int, default=1, help='Beam size for decoding')
    args_parser.add_argument('--use_gpu', action='store_true', help='use the gpu')
    args_parser.add_argument('--batch_size', type=int, default=32)

    args = args_parser.parse_args()

    logger = get_logger("PtrParser Decoding")
    model_path = args.model_path
    model_name = os.path.join(model_path, args.model_name)
    output_path = args.output_path
    beam = args.beam
    use_gpu = args.use_gpu
    test_path = args.test
    batch_size = args.batch_size

    def load_args():
        with open("{}.arg.json".format(model_name)) as f:
            key_parameters = json.loads(f.read())

        return key_parameters['args'], key_parameters['kwargs']

    # arguments = [word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, char_num_filters, char_window, eojul_num_filters, eojul_window,
    #              mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
    #              num_types, arc_space, type_space]
    # kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char, 'eojul': use_eojul, 'prior_order': prior_order,
    #           'skipConnect': skipConnect, 'grandPar': grandPar, 'sibling': sibling}
    arguments, kwarguments = load_args()
    mode = arguments[10]
    input_size_decoder = arguments[11]
    hidden_size = arguments[12]
    arc_space = arguments[16]
    type_space = arguments[17]
    encoder_layers = arguments[13]
    decoder_layers = arguments[14]
    char_num_filters = arguments[6]
    eojul_num_filters = arguments[8]
    p_rnn = kwarguments['p_rnn']
    p_in = kwarguments['p_in']
    p_out = kwarguments['p_out']
    prior_order = kwarguments['prior_order']
    skipConnect = kwarguments['skipConnect']
    grandPar = kwarguments['grandPar']
    sibling = kwarguments['sibling']
    use_char = kwarguments['char']
    use_pos = kwarguments['pos']
    use_eojul = kwarguments['eojul']

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.load_alphabets(alphabet_path)
    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")

    data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, prior_order=prior_order)
    num_data = sum(data_test[1])

    word_table = None
    word_dim = arguments[0]
    char_table = None
    char_dim = arguments[2]
    pos_table = None
    pos_dim = arguments[4]

    char_window = arguments[7]
    eojul_window = arguments[9]

    if arguments[1] != num_words:
        print("Mismatching number of word vocabulary({} != {})".format(arguments[1], num_words))
        exit()
    if arguments[3] != num_chars:
        print("Mismatching number of character vocabulary({} != {})".format(arguments[3], num_chars))
        exit()
    if arguments[5] != num_pos:
        print("Mismatching number of part-of-speech vocabulary({} != {})".format(arguments[5], num_pos))
        exit()
    if arguments[15] != num_types:
        print("Mismatching number types of vocabulary({} != {})".format(arguments[14], num_types))
        exit()

    network = StackPtrNet(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, char_num_filters, char_window, eojul_num_filters, eojul_window,
                          mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
                          num_types, arc_space, type_space,
                          embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn,
                          biaffine=True, pos=use_pos, char=use_char, eojul=use_eojul, prior_order=prior_order,
                          skipConnect=skipConnect, grandPar=grandPar, sibling=sibling)

    if use_gpu:
        network.cuda()

    print("loading model: {}".format(model_name))
    if use_gpu:
        network.load_state_dict(torch.load(model_name))
    else:
        network.load_state_dict(torch.load(model_name, map_location='cpu'))

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    logger.info("Embedding dim: word=%d, char=%d, pos=%d" % (word_dim, char_dim, pos_dim))
    logger.info("Char CNN: filter=%d, kernel=%d" % (char_num_filters, char_window))
    logger.info("Eojul CNN: filter=%d, kernel=%d" % (eojul_num_filters, eojul_window))
    logger.info("RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d" % (
    mode, encoder_layers, decoder_layers, input_size_decoder, hidden_size, arc_space, type_space))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' % (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d, use_gpu: %s' % (skipConnect, beam, use_gpu))

    network.eval()

    pred_filename = '%s/pred_test.txt' % (output_path, )
    pred_writer.start(pred_filename)
    gold_filename = '%s/gold_test.txt' % (output_path, )
    gold_writer.start(gold_filename)

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_total = 0

    test_total_inst = 0

    test_root_correct = 0.0
    test_total_root = 0
    num_back = 0
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, batch_size, use_gpu=use_gpu):
        input_encoder, _, sentences = batch
        word, char, pos, heads, types, masks, lengths = input_encoder

        heads_pred, types_pred, _, _ = network.decode(word, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(sentences, word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
        gold_writer.write(sentences, word, pos, heads, types, lengths, symbolic_root=True)

        stats, _, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=None,
                                                                symbolic_root=True)
        ucorr, lcorr, total, _, _ = stats
        corr_root, total_root = stats_root

        test_ucorrect += ucorr
        test_lcorrect += lcorr
        test_total += total

        test_root_correct += corr_root
        test_total_root += total_root

        test_total_inst += num_inst

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)

        log_info = "({:.1f}%){}/{}".format(test_total_inst * 100 / num_data, test_total_inst, num_data)

        sys.stdout.write(log_info)
        sys.stdout.flush()
        num_back = len(log_info)

    pred_writer.close()
    gold_writer.close()

    sys.stdout.write("\b" * num_back)
    sys.stdout.write(" " * num_back)
    sys.stdout.write("\b" * num_back)

    print('----------------------------------------------------------------------------------------------------------------------------')
    print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%' % (
        test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total))
    print('best test Root: corr: %d, total: %d, acc: %.2f%%' % (test_root_correct, test_total_root, test_root_correct * 100 / test_total_root))
    print('============================================================================================================================')
示例#3
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, None, pos_embedding,data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        is_test=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)
    #gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    logger.info('model: %s' % model_name)
    network = StackPtrNet(*args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    pred_writer.start(model_path + 'tmp/analyze_pred')
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        mask_d = mask_d.data
        mask_leaf = torch.eq(children, stacked_heads).float()
        mask_non_leaf = (1.0 - mask_leaf)
        mask_leaf = mask_leaf * mask_d
        mask_non_leaf = mask_non_leaf * mask_d
        num_leaf = mask_leaf.sum()
        num_non_leaf = mask_non_leaf.sum()

        # ------------------------------------------------------------------------------------------------

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)

    pred_writer.close()
示例#4
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_tensor(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    logger.info('model: %s' % model_name)
    network = StackPtrNet(*args, **kwargs)
    network.load_state_dict(torch.load(model_name))

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0
    test_total = 0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_total_nopunc = 0
    test_total_inst = 0

    test_root_correct = 0.0
    test_total_root = 0

    test_ucorrect_stack_leaf = 0.0
    test_ucorrect_stack_non_leaf = 0.0

    test_lcorrect_stack_leaf = 0.0
    test_lcorrect_stack_non_leaf = 0.0

    test_leaf = 0
    test_non_leaf = 0

    pred_writer.start('tmp/analyze_pred_%s' % str(uid))
    gold_writer.start('tmp/analyze_gold_%s' % str(uid))
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        mask_d = mask_d.data
        mask_leaf = torch.eq(children, stacked_heads).float()
        mask_non_leaf = (1.0 - mask_leaf)
        mask_leaf = mask_leaf * mask_d
        mask_non_leaf = mask_non_leaf * mask_d
        num_leaf = mask_leaf.sum()
        num_non_leaf = mask_non_leaf.sum()

        ucorr_stack = torch.eq(children_pred, children).float()
        lcorr_stack = ucorr_stack * torch.eq(stacked_types_pred,
                                             stacked_types).float()
        ucorr_stack_leaf = (ucorr_stack * mask_leaf).sum()
        ucorr_stack_non_leaf = (ucorr_stack * mask_non_leaf).sum()

        lcorr_stack_leaf = (lcorr_stack * mask_leaf).sum()
        lcorr_stack_non_leaf = (lcorr_stack * mask_non_leaf).sum()

        test_ucorrect_stack_leaf += ucorr_stack_leaf
        test_ucorrect_stack_non_leaf += ucorr_stack_non_leaf
        test_lcorrect_stack_leaf += lcorr_stack_leaf
        test_lcorrect_stack_non_leaf += lcorr_stack_non_leaf

        test_leaf += num_leaf
        test_non_leaf += num_non_leaf

        # ------------------------------------------------------------------------------------------------

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
        gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

        stats, stats_nopunc, stats_root, num_inst = parser.eval(
            word,
            pos,
            heads_pred,
            types_pred,
            heads,
            types,
            word_alphabet,
            pos_alphabet,
            lengths,
            punct_set=punct_set,
            symbolic_root=True)
        ucorr, lcorr, total, ucm, lcm = stats
        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        corr_root, total_root = stats_root

        test_ucorrect += ucorr
        test_lcorrect += lcorr
        test_total += total
        test_ucomlpete_match += ucm
        test_lcomplete_match += lcm

        test_ucorrect_nopunc += ucorr_nopunc
        test_lcorrect_nopunc += lcorr_nopunc
        test_total_nopunc += total_nopunc
        test_ucomlpete_match_nopunc += ucm_nopunc
        test_lcomplete_match_nopunc += lcm_nopunc

        test_root_correct += corr_root
        test_total_root += total_root

        test_total_inst += num_inst

    pred_writer.close()
    gold_writer.close()

    print('\ntime: %.2fs' % (time.time() - start_time))

    print(
        'test W. Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
         test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match *
         100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst))
    print(
        'test Wo Punct:  ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
        %
        (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
         test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc *
         100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 /
         test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst))
    print('test Root: corr: %d, total: %d, acc: %.2f%%' %
          (test_root_correct, test_total_root,
           test_root_correct * 100 / test_total_root))
    print(
        '============================================================================================================================'
    )

    print(
        'Stack leaf:     ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%'
        % (test_ucorrect_stack_leaf, test_lcorrect_stack_leaf, test_leaf,
           test_ucorrect_stack_leaf * 100 / test_leaf,
           test_lcorrect_stack_leaf * 100 / test_leaf))
    print(
        'Stack non_leaf: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%'
        % (test_ucorrect_stack_non_leaf, test_lcorrect_stack_non_leaf,
           test_non_leaf, test_ucorrect_stack_non_leaf * 100 / test_non_leaf,
           test_lcorrect_stack_non_leaf * 100 / test_non_leaf))
    print(
        '============================================================================================================================'
    )

    def analyze():
        np.set_printoptions(linewidth=100000)
        pred_path = 'tmp/analyze_pred_%s' % str(uid)
        data_gold = conllx_stacked_data.read_stacked_data_to_tensor(
            test_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=True,
            prior_order=prior_order)
        data_pred = conllx_stacked_data.read_stacked_data_to_tensor(
            pred_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=True,
            prior_order=prior_order)

        gold_iter = conllx_stacked_data.iterate_batch_stacked_variable(
            data_gold, 1)
        test_iter = conllx_stacked_data.iterate_batch_stacked_variable(
            data_pred, 1)
        model_err = 0
        search_err = 0
        type_err = 0
        for gold, pred in zip(gold_iter, test_iter):
            gold_encoder, gold_decoder = gold
            word, char, pos, gold_heads, gold_types, masks, lengths = gold_encoder
            gold_stacked_heads, gold_children, gold_siblings, gold_stacked_types, gold_skip_connect, gold_mask_d, gold_lengths_d = gold_decoder

            pred_encoder, pred_decoder = pred
            _, _, _, pred_heads, pred_types, _, _ = pred_encoder
            pred_stacked_heads, pred_children, pred_siblings, pred_stacked_types, pred_skip_connect, pred_mask_d, pred_lengths_d = pred_decoder

            assert gold_heads.size() == pred_heads.size(
            ), 'sentence dis-match.'

            ucorr_stack = torch.eq(pred_children, gold_children).float()
            lcorr_stack = ucorr_stack * torch.eq(pred_stacked_types,
                                                 gold_stacked_types).float()
            ucorr_stack = (ucorr_stack * gold_mask_d).data.sum()
            lcorr_stack = (lcorr_stack * gold_mask_d).data.sum()
            num_stack = gold_mask_d.data.sum()

            if lcorr_stack < num_stack:
                loss_pred, loss_pred_arc, loss_pred_type = calc_loss(
                    network, word, char, pos, pred_heads, pred_stacked_heads,
                    pred_children, pred_siblings, pred_stacked_types,
                    pred_skip_connect, masks, lengths, pred_mask_d,
                    pred_lengths_d)

                loss_gold, loss_gold_arc, loss_gold_type = calc_loss(
                    network, word, char, pos, gold_heads, gold_stacked_heads,
                    gold_children, gold_siblings, gold_stacked_types,
                    gold_skip_connect, masks, lengths, gold_mask_d,
                    gold_lengths_d)

                if display_inst:
                    print('%d, %d, %d' % (ucorr_stack, lcorr_stack, num_stack))
                    print(
                        'pred(arc, type): %.4f (%.4f, %.4f), gold(arc, type): %.4f (%.4f, %.4f)'
                        % (loss_pred, loss_pred_arc, loss_pred_type, loss_gold,
                           loss_gold_arc, loss_gold_type))
                    word = word[0].data.cpu().numpy()
                    pos = pos[0].data.cpu().numpy()
                    head_gold = gold_heads[0].data.cpu().numpy()
                    type_gold = gold_types[0].data.cpu().numpy()
                    head_pred = pred_heads[0].data.cpu().numpy()
                    type_pred = pred_types[0].data.cpu().numpy()
                    display(word, pos, head_gold, type_gold, head_pred,
                            type_pred, lengths[0], word_alphabet, pos_alphabet,
                            type_alphabet)

                    length_dec = gold_lengths_d[0]
                    gold_display = np.empty([3, length_dec])
                    gold_display[0] = gold_stacked_types.data[0].cpu().numpy(
                    )[:length_dec]
                    gold_display[1] = gold_children.data[0].cpu().numpy(
                    )[:length_dec]
                    gold_display[2] = gold_stacked_heads.data[0].cpu().numpy(
                    )[:length_dec]
                    print(gold_display)
                    print(
                        '--------------------------------------------------------'
                    )
                    pred_display = np.empty([3,
                                             pred_lengths_d[0]])[:length_dec]
                    pred_display[0] = pred_stacked_types.data[0].cpu().numpy(
                    )[:length_dec]
                    pred_display[1] = pred_children.data[0].cpu().numpy(
                    )[:length_dec]
                    pred_display[2] = pred_stacked_heads.data[0].cpu().numpy(
                    )[:length_dec]
                    print(pred_display)
                    print(
                        '========================================================'
                    )
                    raw_input()

                if ucorr_stack == num_stack:
                    type_err += 1
                elif loss_pred < loss_gold:
                    model_err += 1
                else:
                    search_err += 1
        print('type   errors: %d' % type_err)
        print('model  errors: %d' % model_err)
        print('search errors: %d' % search_err)

    analyze()
示例#5
0
def train(args):
    logger = get_logger("Parsing")

    args.cuda = torch.cuda.is_available()
    device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
    train_path = args.train
    dev_path = args.dev
    test_path = args.test

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    optim = args.optim
    learning_rate = args.learning_rate
    lr_decay = args.lr_decay
    amsgrad = args.amsgrad
    eps = args.eps
    betas = (args.beta1, args.beta2)
    warmup_steps = args.warmup_steps
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip

    loss_ty_token = args.loss_type == 'token'
    unk_replace = args.unk_replace
    freeze = args.freeze

    model_path = args.model_path
    model_name = os.path.join(model_path, 'model.pt')
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    print(args)

    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    else:
        char_dict = None
        char_dim = None

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets')
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=word_dict,
        max_vocabulary_size=200000)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    result_path = os.path.join(model_path, 'tmp')
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    logger.info("constructing network...")

    hyps = json.load(open(args.config, 'r'))
    json.dump(hyps,
              open(os.path.join(model_path, 'config.json'), 'w'),
              indent=2)
    model_type = hyps['model']
    assert model_type in ['DeepBiAffine', 'NeuroMST', 'StackPtr']
    assert word_dim == hyps['word_dim']
    if char_dim is not None:
        assert char_dim == hyps['char_dim']
    else:
        char_dim = hyps['char_dim']
    use_pos = hyps['pos']
    pos_dim = hyps['pos_dim']
    mode = hyps['rnn_mode']
    hidden_size = hyps['hidden_size']
    arc_space = hyps['arc_space']
    type_space = hyps['type_space']
    p_in = hyps['p_in']
    p_out = hyps['p_out']
    p_rnn = hyps['p_rnn']
    activation = hyps['activation']
    prior_order = None

    alg = 'transition' if model_type == 'StackPtr' else 'graph'
    if model_type == 'DeepBiAffine':
        num_layers = hyps['num_layers']
        network = DeepBiAffine(word_dim,
                               num_words,
                               char_dim,
                               num_chars,
                               pos_dim,
                               num_pos,
                               mode,
                               hidden_size,
                               num_layers,
                               num_types,
                               arc_space,
                               type_space,
                               embedd_word=word_table,
                               embedd_char=char_table,
                               p_in=p_in,
                               p_out=p_out,
                               p_rnn=p_rnn,
                               pos=use_pos,
                               activation=activation)
    elif model_type == 'NeuroMST':
        num_layers = hyps['num_layers']
        network = NeuroMST(word_dim,
                           num_words,
                           char_dim,
                           num_chars,
                           pos_dim,
                           num_pos,
                           mode,
                           hidden_size,
                           num_layers,
                           num_types,
                           arc_space,
                           type_space,
                           embedd_word=word_table,
                           embedd_char=char_table,
                           p_in=p_in,
                           p_out=p_out,
                           p_rnn=p_rnn,
                           pos=use_pos,
                           activation=activation)
    elif model_type == 'StackPtr':
        encoder_layers = hyps['encoder_layers']
        decoder_layers = hyps['decoder_layers']
        num_layers = (encoder_layers, decoder_layers)
        prior_order = hyps['prior_order']
        grandPar = hyps['grandPar']
        sibling = hyps['sibling']
        network = StackPtrNet(word_dim,
                              num_words,
                              char_dim,
                              num_chars,
                              pos_dim,
                              num_pos,
                              mode,
                              hidden_size,
                              encoder_layers,
                              decoder_layers,
                              num_types,
                              arc_space,
                              type_space,
                              embedd_word=word_table,
                              embedd_char=char_table,
                              prior_order=prior_order,
                              activation=activation,
                              p_in=p_in,
                              p_out=p_out,
                              p_rnn=p_rnn,
                              pos=use_pos,
                              grandPar=grandPar,
                              sibling=sibling)
    else:
        raise RuntimeError('Unknown model type: %s' % model_type)

    if freeze:
        freeze_embedding(network.word_embed)

    network = network.to(device)
    model = "{}-{}".format(model_type, mode)
    logger.info("Network: %s, num_layer=%s, hidden=%d, act=%s" %
                (model, num_layers, hidden_size, activation))
    logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" %
                ('variational', p_in, p_out, p_rnn))
    logger.info('# of Parameters: %d' %
                (sum([param.numel() for param in network.parameters()])))

    logger.info("Reading Data")
    if alg == 'graph':
        data_train = conllx_data.read_bucketed_data(train_path,
                                                    word_alphabet,
                                                    char_alphabet,
                                                    pos_alphabet,
                                                    type_alphabet,
                                                    symbolic_root=True)
        data_dev = conllx_data.read_data(dev_path,
                                         word_alphabet,
                                         char_alphabet,
                                         pos_alphabet,
                                         type_alphabet,
                                         symbolic_root=True)
        data_test = conllx_data.read_data(test_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          type_alphabet,
                                          symbolic_root=True)
    else:
        data_train = conllx_stacked_data.read_bucketed_data(
            train_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            prior_order=prior_order)
        data_dev = conllx_stacked_data.read_data(dev_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 pos_alphabet,
                                                 type_alphabet,
                                                 prior_order=prior_order)
        data_test = conllx_stacked_data.read_data(test_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  type_alphabet,
                                                  prior_order=prior_order)
    num_data = sum(data_train[1])
    logger.info("training: #training data: %d, batch: %d, unk replace: %.2f" %
                (num_data, batch_size, unk_replace))

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    optimizer, scheduler = get_optimizer(network.parameters(), optim,
                                         learning_rate, lr_decay, betas, eps,
                                         amsgrad, weight_decay, warmup_steps)

    best_ucorrect = 0.0
    best_lcorrect = 0.0
    best_ucomlpete = 0.0
    best_lcomplete = 0.0

    best_ucorrect_nopunc = 0.0
    best_lcorrect_nopunc = 0.0
    best_ucomlpete_nopunc = 0.0
    best_lcomplete_nopunc = 0.0
    best_root_correct = 0.0
    best_total = 0
    best_total_nopunc = 0
    best_total_inst = 0
    best_total_root = 0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete = 0.0
    test_lcomplete = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_nopunc = 0.0
    test_lcomplete_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    beam = args.beam
    reset = args.reset
    num_batches = num_data // batch_size + 1
    if optim == 'adam':
        opt_info = 'adam, betas=(%.1f, %.3f), eps=%.1e, amsgrad=%s' % (
            betas[0], betas[1], eps, amsgrad)
    else:
        opt_info = 'sgd, momentum=0.9, nesterov=True'
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = 0.
        train_arc_loss = 0.
        train_type_loss = 0.
        num_insts = 0
        num_words = 0
        num_back = 0
        num_nans = 0
        network.train()
        lr = scheduler.get_lr()[0]
        print(
            'Epoch %d (%s, lr=%.6f, lr decay=%.6f, grad clip=%.1f, l2=%.1e): '
            % (epoch, opt_info, lr, lr_decay, grad_clip, weight_decay))
        if args.cuda:
            torch.cuda.empty_cache()
        gc.collect()
        with torch.autograd.set_detect_anomaly(True):
            for step, data in enumerate(
                    iterate_data(data_train,
                                 batch_size,
                                 bucketed=True,
                                 unk_replace=unk_replace,
                                 shuffle=True)):
                optimizer.zero_grad()
                bert_words = data["BERT_WORD"].to(device)
                sub_word_idx = data["SUB_IDX"].to(device)
                words = data['WORD'].to(device)
                chars = data['CHAR'].to(device)
                postags = data['POS'].to(device)
                heads = data['HEAD'].to(device)
                nbatch = words.size(0)
                if alg == 'graph':
                    types = data['TYPE'].to(device)
                    masks = data['MASK'].to(device)
                    nwords = masks.sum() - nbatch
                    BERT = True
                    if BERT:
                        loss_arc, loss_type = network.loss(bert_words,
                                                           sub_word_idx,
                                                           words,
                                                           chars,
                                                           postags,
                                                           heads,
                                                           types,
                                                           mask=masks)
                    else:
                        loss_arc, loss_type = network.loss(words,
                                                           chars,
                                                           postags,
                                                           heads,
                                                           types,
                                                           mask=masks)
                else:
                    masks_enc = data['MASK_ENC'].to(device)
                    masks_dec = data['MASK_DEC'].to(device)
                    stacked_heads = data['STACK_HEAD'].to(device)
                    children = data['CHILD'].to(device)
                    siblings = data['SIBLING'].to(device)
                    stacked_types = data['STACK_TYPE'].to(device)
                    nwords = masks_enc.sum() - nbatch
                    loss_arc, loss_type = network.loss(words,
                                                       chars,
                                                       postags,
                                                       heads,
                                                       stacked_heads,
                                                       children,
                                                       siblings,
                                                       stacked_types,
                                                       mask_e=masks_enc,
                                                       mask_d=masks_dec)
                loss_arc = loss_arc.sum()
                loss_type = loss_type.sum()
                loss_total = loss_arc + loss_type

                # print("loss", loss_arc, loss_type, loss_total)
                if loss_ty_token:
                    loss = loss_total.div(nwords)
                else:
                    loss = loss_total.div(nbatch)
                loss.backward()
                if grad_clip > 0:
                    grad_norm = clip_grad_norm_(network.parameters(),
                                                grad_clip)
                else:
                    grad_norm = total_grad_norm(network.parameters())

                if math.isnan(grad_norm):
                    num_nans += 1
                else:
                    optimizer.step()
                    scheduler.step()

                    with torch.no_grad():
                        num_insts += nbatch
                        num_words += nwords
                        train_loss += loss_total.item()
                        train_arc_loss += loss_arc.item()
                        train_type_loss += loss_type.item()

                # update log
                if step % 100 == 0:
                    torch.cuda.empty_cache()
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    curr_lr = scheduler.get_lr()[0]
                    num_insts = max(num_insts, 1)
                    num_words = max(num_words, 1)
                    log_info = '[%d/%d (%.0f%%) lr=%.6f (%d)] loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f)' % (
                        step, num_batches, 100. * step / num_batches, curr_lr,
                        num_nans, train_loss / num_insts,
                        train_loss / num_words, train_arc_loss / num_insts,
                        train_arc_loss / num_words, train_type_loss /
                        num_insts, train_type_loss / num_words)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)

            sys.stdout.write("\b" * num_back)
            sys.stdout.write(" " * num_back)
            sys.stdout.write("\b" * num_back)
            print(
                'total: %d (%d), loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f), time: %.2fs'
                % (num_insts, num_words, train_loss / num_insts,
                   train_loss / num_words, train_arc_loss / num_insts,
                   train_arc_loss / num_words, train_type_loss / num_insts,
                   train_type_loss / num_words, time.time() - start_time))
            print('-' * 125)

            # evaluate performance on dev data
            with torch.no_grad():
                pred_filename = os.path.join(result_path, 'pred_dev%d' % epoch)
                pred_writer.start(pred_filename)
                gold_filename = os.path.join(result_path, 'gold_dev%d' % epoch)
                gold_writer.start(gold_filename)

                print('Evaluating dev:')
                dev_stats, dev_stats_nopunct, dev_stats_root = eval(
                    alg,
                    data_dev,
                    network,
                    pred_writer,
                    gold_writer,
                    punct_set,
                    word_alphabet,
                    pos_alphabet,
                    device,
                    beam=beam)

                pred_writer.close()
                gold_writer.close()

                dev_ucorr, dev_lcorr, dev_ucomlpete, dev_lcomplete, dev_total = dev_stats
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc, dev_total_nopunc = dev_stats_nopunct
                dev_root_corr, dev_total_root, dev_total_inst = dev_stats_root

                if best_ucorrect_nopunc + best_lcorrect_nopunc < dev_ucorr_nopunc + dev_lcorr_nopunc:
                    best_ucorrect_nopunc = dev_ucorr_nopunc
                    best_lcorrect_nopunc = dev_lcorr_nopunc
                    best_ucomlpete_nopunc = dev_ucomlpete_nopunc
                    best_lcomplete_nopunc = dev_lcomplete_nopunc

                    best_ucorrect = dev_ucorr
                    best_lcorrect = dev_lcorr
                    best_ucomlpete = dev_ucomlpete
                    best_lcomplete = dev_lcomplete

                    best_root_correct = dev_root_corr
                    best_total = dev_total
                    best_total_nopunc = dev_total_nopunc
                    best_total_root = dev_total_root
                    best_total_inst = dev_total_inst

                    best_epoch = epoch
                    patient = 0
                    torch.save(network.state_dict(), model_name)

                    pred_filename = os.path.join(result_path,
                                                 'pred_test%d' % epoch)
                    pred_writer.start(pred_filename)
                    gold_filename = os.path.join(result_path,
                                                 'gold_test%d' % epoch)
                    gold_writer.start(gold_filename)

                    print('Evaluating test:')
                    test_stats, test_stats_nopunct, test_stats_root = eval(
                        alg,
                        data_test,
                        network,
                        pred_writer,
                        gold_writer,
                        punct_set,
                        word_alphabet,
                        pos_alphabet,
                        device,
                        beam=beam)

                    test_ucorrect, test_lcorrect, test_ucomlpete, test_lcomplete, test_total = test_stats
                    test_ucorrect_nopunc, test_lcorrect_nopunc, test_ucomlpete_nopunc, test_lcomplete_nopunc, test_total_nopunc = test_stats_nopunct
                    test_root_correct, test_total_root, test_total_inst = test_stats_root

                    pred_writer.close()
                    gold_writer.close()
                else:
                    patient += 1

                print('-' * 125)
                print(
                    'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (best_ucorrect, best_lcorrect, best_total,
                       best_ucorrect * 100 / best_total, best_lcorrect * 100 /
                       best_total, best_ucomlpete * 100 / dev_total_inst,
                       best_lcomplete * 100 / dev_total_inst, best_epoch))
                print(
                    'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (best_ucorrect_nopunc, best_lcorrect_nopunc,
                       best_total_nopunc, best_ucorrect_nopunc * 100 /
                       best_total_nopunc, best_lcorrect_nopunc * 100 /
                       best_total_nopunc, best_ucomlpete_nopunc * 100 /
                       best_total_inst, best_lcomplete_nopunc * 100 /
                       best_total_inst, best_epoch))
                print(
                    'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                    % (best_root_correct, best_total_root,
                       best_root_correct * 100 / best_total_root, best_epoch))
                print('-' * 125)
                print(
                    'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (test_ucorrect, test_lcorrect, test_total,
                       test_ucorrect * 100 / test_total, test_lcorrect * 100 /
                       test_total, test_ucomlpete * 100 / test_total_inst,
                       test_lcomplete * 100 / test_total_inst, best_epoch))
                print(
                    'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (test_ucorrect_nopunc, test_lcorrect_nopunc,
                       test_total_nopunc, test_ucorrect_nopunc * 100 /
                       test_total_nopunc, test_lcorrect_nopunc * 100 /
                       test_total_nopunc, test_ucomlpete_nopunc * 100 /
                       test_total_inst, test_lcomplete_nopunc * 100 /
                       test_total_inst, best_epoch))
                print(
                    'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                    % (test_root_correct, test_total_root,
                       test_root_correct * 100 / test_total_root, best_epoch))
                print('=' * 125)

                if patient >= reset:
                    logger.info('reset optimizer momentums')
                    network.load_state_dict(
                        torch.load(model_name, map_location=device))
                    scheduler.reset_state()
                    patient = 0
示例#6
0
def parse(args):
    logger = get_logger("Parsing")
    args.cuda = torch.cuda.is_available()
    device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
    test_path = args.test

    model_path = args.model_path
    model_name = os.path.join(model_path, 'model.pt')
    punctuation = args.punctuation
    print(args)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets')
    assert os.path.exists(alphabet_path)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(
        alphabet_path, None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    result_path = os.path.join(model_path, 'tmp')
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    logger.info("loading network...")
    hyps = json.load(open(os.path.join(model_path, 'config.json'), 'r'))
    model_type = hyps['model']
    assert model_type in ['DeepBiAffine', 'NeuroMST', 'StackPtr']
    word_dim = hyps['word_dim']
    char_dim = hyps['char_dim']
    use_pos = hyps['pos']
    pos_dim = hyps['pos_dim']
    mode = hyps['rnn_mode']
    hidden_size = hyps['hidden_size']
    arc_space = hyps['arc_space']
    type_space = hyps['type_space']
    p_in = hyps['p_in']
    p_out = hyps['p_out']
    p_rnn = hyps['p_rnn']
    activation = hyps['activation']
    prior_order = None

    alg = 'transition' if model_type == 'StackPtr' else 'graph'
    if model_type == 'DeepBiAffine':
        num_layers = hyps['num_layers']
        network = DeepBiAffine(word_dim,
                               num_words,
                               char_dim,
                               num_chars,
                               pos_dim,
                               num_pos,
                               mode,
                               hidden_size,
                               num_layers,
                               num_types,
                               arc_space,
                               type_space,
                               p_in=p_in,
                               p_out=p_out,
                               p_rnn=p_rnn,
                               pos=use_pos,
                               activation=activation)
    elif model_type == 'NeuroMST':
        num_layers = hyps['num_layers']
        network = NeuroMST(word_dim,
                           num_words,
                           char_dim,
                           num_chars,
                           pos_dim,
                           num_pos,
                           mode,
                           hidden_size,
                           num_layers,
                           num_types,
                           arc_space,
                           type_space,
                           p_in=p_in,
                           p_out=p_out,
                           p_rnn=p_rnn,
                           pos=use_pos,
                           activation=activation)
    elif model_type == 'StackPtr':
        encoder_layers = hyps['encoder_layers']
        decoder_layers = hyps['decoder_layers']
        num_layers = (encoder_layers, decoder_layers)
        prior_order = hyps['prior_order']
        grandPar = hyps['grandPar']
        sibling = hyps['sibling']
        network = StackPtrNet(word_dim,
                              num_words,
                              char_dim,
                              num_chars,
                              pos_dim,
                              num_pos,
                              mode,
                              hidden_size,
                              encoder_layers,
                              decoder_layers,
                              num_types,
                              arc_space,
                              type_space,
                              prior_order=prior_order,
                              activation=activation,
                              p_in=p_in,
                              p_out=p_out,
                              p_rnn=p_rnn,
                              pos=use_pos,
                              grandPar=grandPar,
                              sibling=sibling)
    else:
        raise RuntimeError('Unknown model type: %s' % model_type)

    network = network.to(device)
    network.load_state_dict(torch.load(model_name, map_location=device))
    model = "{}-{}".format(model_type, mode)
    logger.info("Network: %s, num_layer=%s, hidden=%d, act=%s" %
                (model, num_layers, hidden_size, activation))

    logger.info("Reading Data")
    if alg == 'graph':
        data_test = conllx_data.read_data(test_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          type_alphabet,
                                          symbolic_root=True)
    else:
        data_test = conllx_stacked_data.read_data(test_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  type_alphabet,
                                                  prior_order=prior_order)

    beam = args.beam
    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    pred_filename = os.path.join(result_path, 'pred.txt')
    pred_writer.start(pred_filename)
    gold_filename = os.path.join(result_path, 'gold.txt')
    gold_writer.start(gold_filename)

    with torch.no_grad():
        print('Parsing...')
        start_time = time.time()
        eval(alg,
             data_test,
             network,
             pred_writer,
             gold_writer,
             punct_set,
             word_alphabet,
             pos_alphabet,
             device,
             beam,
             batch_size=args.test_batch_size)
        print('Time: %.2fs' % (time.time() - start_time))

    pred_writer.close()
    gold_writer.close()
示例#7
0
    data_path = sys.argv[2]
    embedding_path = sys.argv[3]

    arg_path = os.path.join(path, 'network.pt.arg.json')
    model_path = os.path.join(path, 'network.pt')
    alphabet_path = os.path.join(path, 'alphabets/')
    embedding_path = embedding_path

    #'data/embedding/Korean_POS_lap.nnlm.c10.neg30.w3.h100.wc1e-4.i5.a0.025.bin.txt'


    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    args, kwargs = load_model_arguments_from_json()
    model = StackPtrNet(*args, **kwargs)

    # words from test set
    word_set = make_word_set(data_path)
    # words from train and dev set
    word_alphabet, _, _, _ = conllx_stacked_data.create_alphabets(
        alphabet_path, None, pos_embedding=4, data_paths=[None, None])

    # get 'TARGET' words
    unseen_emb_dict = get_unseen_embedding(word_set, word_alphabet,
                                           embedding_path)

    # save to a text file
    save_to_file(word_alphabet, model, unseen_emb_dict)
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=200,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adadelta'],
                             help='optimization algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.5,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument('--word_embedding',
                             choices=['glove', 'senna', 'sskip', 'polyglot'],
                             help='Embedding for words',
                             required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    rho = 0.9
    eps = 1e-6
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    beam = args.beam
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    logger.info("Creating Alphabets")

    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_stacked_data.read_stacked_data_to_variable(
        train_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    num_data = sum(data_train[1])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(
        dev_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)
    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          num_filters,
                          window,
                          mode,
                          hidden_size,
                          num_layers,
                          num_types,
                          arc_space,
                          type_space,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          prior_order=prior_order)

    if use_gpu:
        network.cuda()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adadelta':
            return Adadelta(params,
                            lr=lr,
                            rho=rho,
                            weight_decay=gamma,
                            eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adadelta':
        opt_info += 'rho=%.2f, eps=%.1e' % (rho, eps)

    logger.info("Embedding dim: word=%d, char=%d, pos=%d" %
                (word_dim, char_dim, pos_dim))
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, arc_space=%d, type_space=%d"
        % (mode, num_layers, hidden_size, num_filters, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, dropout(in, out, rnn): (%.2f, %.2f, %s), unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, p_in, p_out, p_rnn, unk_replace))
    logger.info('prior order: %s, beam: %d' % (prior_order, beam))
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, decay rate=%.2f (schedule=%d, patient=%d)): '
            % (epoch, mode, opt, lr, decay_rate, schedule, patient))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train, batch_size, unk_replace=unk_replace)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, stacked_types, masks_d, lengths_d = input_decoder
            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, stacked_heads, children, stacked_types,
                                                            mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.data[0]
            num_non_leaf = num_non_leaf.data[0]

            train_err_arc_leaf += loss_arc_leaf.data[0] * num_leaf
            train_err_arc_non_leaf += loss_arc_non_leaf.data[0] * num_non_leaf

            train_err_type_leaf += loss_type_leaf.data[0] * num_leaf
            train_err_type_non_leaf += loss_type_non_leaf.data[0] * num_non_leaf

            train_err_cov += loss_cov.data[0] * (num_leaf + num_non_leaf)

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        print(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size):
            input_encoder, _ = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc <= dev_ucorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            torch.save(network, model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size):
                input_encoder, _ = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if patient < schedule:
                patient += 1
            else:
                network = torch.load(model_name)
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        print(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (test_root_correct, test_total_root,
               test_root_correct * 100 / test_total_root, best_epoch))
        print(
            '============================================================================================================================'
        )
示例#9
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size',
                             type=int,
                             default=256,
                             help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of decoder RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--multi_head_attn',
                             action='store_true',
                             help='use multi-head-attention.')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    args_parser.add_argument(
        '--pool_type',
        default='mean',
        choices=['max', 'mean', 'weight'],
        help='pool type to form fixed length vector from word embeddings')
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')
    args_parser.add_argument('--no_word',
                             action='store_true',
                             help='do not use word embedding.')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--no_CoRNN',
                             action='store_true',
                             help='do not use context RNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth',
                             type=float,
                             default=1.0,
                             help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect',
                             action='store_true',
                             help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar',
                             action='store_true',
                             help='use grand parent.')
    args_parser.add_argument('--sibling',
                             action='store_true',
                             help='use sibling.')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=2000,
                             help='Number of training epochs')

    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             action='store_true',
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    #
    # about decoder's bi-attention scoring with features (default is not using any)
    args_parser.add_argument(
        '--dec_max_dist',
        type=int,
        default=0,
        help=
        "The clamp range of decoder's distance feature, 0 means turning off.")
    args_parser.add_argument('--dec_dim_feature',
                             type=int,
                             default=10,
                             help="Dim for feature embed.")
    args_parser.add_argument(
        '--dec_use_neg_dist',
        action='store_true',
        help="Use negative distance for dec's distance feature.")
    args_parser.add_argument(
        '--dec_use_encoder_pos',
        action='store_true',
        help="Use pos feature combined with distance feature for child nodes.")
    args_parser.add_argument(
        '--dec_use_decoder_pos',
        action='store_true',
        help="Use pos feature combined with distance feature for head nodes.")
    args_parser.add_argument('--dec_drop_f_embed',
                             type=float,
                             default=0.2,
                             help="Dropout for dec feature embeddings.")
    #
    # about relation-aware self attention for the transformer encoder (default is not using any)
    # args_parser.add_argument('--rel_aware', action='store_true',
    #                          help="Enable relation-aware self-attention (multi_head_attn flag needs to be set).")
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    #
    # other options about how to combine multiple input features (have to make some dims fit if not concat)
    args_parser.add_argument('--input_concat_embeds',
                             action='store_true',
                             help="Concat input embeddings, otherwise add.")
    args_parser.add_argument('--input_concat_position',
                             action='store_true',
                             help="Concat position embeddings, otherwise add.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=300,
                             help='Dimension of Position embeddings.')
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    args = args_parser.parse_args()

    # =====
    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    # =====

    # if output directory doesn't exist, create it
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    logger = get_logger("PtrParser", args.model_path + 'log.txt')

    logger.info('\ncommand-line params : {0}\n'.format(sys.argv[1:]))
    logger.info('{0}\n'.format(args))

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    use_word_emb = not args.no_word
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_con_rnn = not args.no_CoRNN

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(
        word_embedding, word_path) if use_word_emb else (None, 0)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)

    # todo(warn): should build vocabs previously
    assert os.path.isdir(alphabet_path), "should have build vocabs previously"
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)
    # word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = create_alphabets(alphabet_path,
    #     train_path, data_paths=[dev_path, test_path], max_vocabulary_size=50000, embedd_dict=word_dict)
    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # ===== the reading
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_stacked_data.read_stacked_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=(not is_train),
            prior_order=prior_order,
            lang_id=lang_id,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    num_data = sum(data_train[1])

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    # =====

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        logger.info('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        logger.info('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table() if use_word_emb else None
    char_table = construct_char_embedding_table()

    window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          num_filters,
                          window,
                          mode,
                          input_size_decoder,
                          hidden_size,
                          encoder_layers,
                          decoder_layers,
                          num_types,
                          arc_space,
                          type_space,
                          args.pool_type,
                          args.multi_head_attn,
                          args.num_head,
                          max_sent_length,
                          args.trans_hid_size,
                          args.d_k,
                          args.d_v,
                          train_position=args.train_position,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          use_word_emb=use_word_emb,
                          pos=use_pos,
                          char=use_char,
                          prior_order=prior_order,
                          use_con_rnn=use_con_rnn,
                          skipConnect=skipConnect,
                          grandPar=grandPar,
                          sibling=sibling,
                          use_gpu=use_gpu,
                          dec_max_dist=args.dec_max_dist,
                          dec_use_neg_dist=args.dec_use_neg_dist,
                          dec_use_encoder_pos=args.dec_use_encoder_pos,
                          dec_use_decoder_pos=args.dec_use_decoder_pos,
                          dec_dim_feature=args.dec_dim_feature,
                          dec_drop_f_embed=args.dec_drop_f_embed,
                          enc_clip_dist=args.enc_clip_dist,
                          enc_use_neg_dist=args.enc_use_neg_dist,
                          input_concat_embeds=args.input_concat_embeds,
                          input_concat_position=args.input_concat_position,
                          position_dim=args.position_dim)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, input_size_decoder, hidden_size,
            encoder_layers, decoder_layers, num_types, arc_space, type_space,
            args.pool_type, args.multi_head_attn, args.num_head,
            max_sent_length, args.trans_hid_size, args.d_k, args.d_v
        ]
        kwargs = {
            'train_position': args.train_position,
            'use_word_emb': use_word_emb,
            'use_con_rnn': use_con_rnn,
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'prior_order': prior_order,
            'skipConnect': skipConnect,
            'grandPar': grandPar,
            'sibling': sibling,
            'dec_max_dist': args.dec_max_dist,
            'dec_use_neg_dist': args.dec_use_neg_dist,
            'dec_use_encoder_pos': args.dec_use_encoder_pos,
            'dec_use_decoder_pos': args.dec_use_decoder_pos,
            'dec_dim_feature': args.dec_dim_feature,
            'dec_drop_f_embed': args.dec_drop_f_embed,
            'enc_clip_dist': args.enc_clip_dist,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'input_concat_embeds': args.input_concat_embeds,
            'input_concat_position': args.input_concat_position,
            'position_dim': args.position_dim
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if use_word_emb and freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d"
        % (mode, encoder_layers, decoder_layers, input_size_decoder,
           hidden_size, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' %
                (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d' % (skipConnect, beam))
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    # lrate decay
    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule
    warmup_factor = (lr + 0.) / num_batches

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))
    #
    for epoch in range(1, num_epochs + 1):
        logger.info(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f '
            '(schedule=%d, patient=%d, decay=%d (%d, %d))): ' %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay,
             max_decay, double_schedule_decay))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate

            # train
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train, batch_size, unk_replace=unk_replace)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling,
                                                            stacked_types, label_smooth,
                                                            skip_connect=skip_connect, mask_e=masks_e,
                                                            length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.data[0]
            num_non_leaf = num_non_leaf.data[0]

            train_err_arc_leaf += loss_arc_leaf.data[0] * num_leaf
            train_err_arc_non_leaf += loss_arc_non_leaf.data[0] * num_non_leaf

            train_err_type_leaf += loss_type_leaf.data[0] * num_leaf
            train_err_type_non_leaf += loss_type_non_leaf.data[0] * num_non_leaf

            train_err_cov += loss_cov.data[0] * (num_leaf + num_non_leaf)

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        logger.info(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        ################################################################################################
        if epoch % args.check_dev != 0:
            continue

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size):
            input_encoder, _ = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size):
                input_encoder, _ = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        logger.info(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        logger.info(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        logger.info(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        logger.info(
            'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
            (dev_root_correct, dev_total_root,
             dev_root_correct * 100 / dev_total_root, best_epoch))
        logger.info(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        logger.info(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        logger.info(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        logger.info(
            'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
            (test_root_correct, test_total_root,
             test_root_correct * 100 / test_total_root, best_epoch))
        logger.info(
            '============================================================================================================================'
        )

        if decay == max_decay:
            break
示例#10
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'], help='architecture of rnn', required=True)
    args_parser.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size', type=int, default=256, help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--type_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers', type=int, default=1, help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers', type=int, default=1, help='Number of layers of decoder RNN')
    args_parser.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
    # NOTE: action='store_true' is just to set ON
    args_parser.add_argument('--pos', action='store_true', help='use part-of-speech embedding.')
    args_parser.add_argument('--char', action='store_true', help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
    # NOTE: arg MUST be one of choices(when specified)
    args_parser.add_argument('--opt', choices=['adam', 'sgd', 'adamax'], help='optimization algorithm')
    args_parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.75, help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay', type=int, default=9, help='Number of decays before stop')
    args_parser.add_argument('--double_schedule_decay', type=int, default=5, help='Number of decays to double schedule')
    args_parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
    args_parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    args_parser.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage', type=float, default=0.0, help='weight for coverage loss')
    args_parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    args_parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth', type=float, default=1.0, help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect', action='store_true', help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar', action='store_true', help='use grand parent.')
    args_parser.add_argument('--sibling', action='store_true', help='use sibling.')
    args_parser.add_argument('--prior_order', choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'], help='prior order of children.', required=True)
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--beam', type=int, default=1, help='Beam size for decoding')
    args_parser.add_argument('--word_embedding', choices=['glove', 'senna', 'sskip', 'polyglot', 'NNLM'], help='Embedding for words', required=True)
    args_parser.add_argument('--word_path', help='path for word embedding dict')
    args_parser.add_argument('--freeze', action='store_true', help='frozen the word embedding (disable fine-tuning).')    
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters', required=True)
    args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path', help='path for saving model file.', required=True)
    args_parser.add_argument('--model_name', help='name for saving model file.', required=True)
    # TODO: to include in logging process
    args_parser.add_argument('--pos_embedding', choices=[1,2,4], type=int, help='Embedding method for korean POS tag', default=2)
    args_parser.add_argument('--pos_path', help='path for pos embedding dict')
    args_parser.add_argument('--elmo', action='store_true', help='use elmo embedding.')
    args_parser.add_argument('--elmo_path', help='path for elmo embedding model.')
    args_parser.add_argument('--elmo_dim', type=int, help='dimension for elmo embedding model')
    #args_parser.add_argument('--fine_tune_path', help='fine tune starting from this state_dict')
    args_parser.add_argument('--model_version', help='previous model version to load')
    #bert2020_boychaboy
    args_parser.add_argument('--bert', action='store_true', help='use elmo embedding.')  # true if use bert(hoon)
    args_parser.add_argument('--etri_train', help='path for etri data of bert')  # etri train path(hoon)
    args_parser.add_argument('--etri_dev', help='path for etri data of bert')  # etri dev path(hoon)
    args_parser.add_argument('--bert_path', help='path for bert embedding model.')
    args_parser.add_argument('--bert_feature_dim', type=int, help='dimension for bert feature embedding')

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path + uid + '/'   # for numerous experiments
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    # QUESTION: pretrained vector for char?
    char_path = args.char_path

    use_pos = args.pos    
    pos_embedding = args.pos_embedding
    pos_path = args.pos_path
    pos_dict = None
    pos_dim = args.pos_dim    # NOTE pretrain 있을 경우 pos_dim은 그거 따라감
    if pos_path is not None:
        pos_dict, pos_dim = utils.load_embedding_dict(word_embedding, pos_path)  # NOTE 임시적으로 word_embedding(NNLM)이랑 같은 형식
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(char_embedding, char_path)    

    use_elmo = args.elmo
    elmo_path = args.elmo_path
    elmo_dim = args.elmo_dim
    #fine_tune_path = args.fine_tune_path

    #bert 2020(boychaboy)
    use_bert = args.bert
    bert_path = args.bert_path
    bert_feature_dim = args.bert_feature_dim

    if use_bert:
        etri_train_path = args.etri_train
        etri_dev_path = args.etri_dev
    else:
        etri_train_path = None
        etri_dev_path = None

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    # min_occurence=1
    data_paths = [dev_path, test_path] if test_path else [dev_path]
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, train_path, data_paths=data_paths,
                                                                                                      max_vocabulary_size=50000, pos_embedding=pos_embedding, embedd_dict=word_dict)

    num_words = word_alphabet.size() # 30268
    num_chars = char_alphabet.size() # 3545
    num_pos = pos_alphabet.size() # 46
    num_types = type_alphabet.size()  # 39

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # data is a list of tuple containing tensors, etc ...
    data_train = conllx_stacked_data.read_stacked_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding, use_gpu=1, prior_order=prior_order, elmo = use_elmo, bert=use_bert, etri_path=etri_train_path)
    num_data = sum(data_train[2])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding, use_gpu=use_gpu, volatile=True, prior_order=prior_order, elmo = use_elmo, bert=use_bert, etri_path=etri_dev_path)
    if test_path:
        data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding, use_gpu=use_gpu, volatile=True, prior_order=prior_order, elmo = use_elmo)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        # NOTE: UNK 관리!
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                # NOTE: words not in pretrained are set to random
                embedding = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        #for char, index, in char_alphabet.items():
        for char, index in list(char_alphabet.items()):
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_pos_embedding_table():
        if pos_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_pos, pos_dim], dtype=np.float32)
        for pos, index in list(pos_alphabet.items()):
            if pos in pos_dict:
                embedding = pos_dict[pos]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
            table[index, :] = embedding
        return torch.from_numpy(table)
    
    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()
    pos_table = construct_pos_embedding_table()

    window = 3
    network = StackPtrNet(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                          mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
                          num_types, arc_space, type_space, pos_embedding,
                          embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table, p_in=p_in, p_out=p_out,
                          p_rnn=p_rnn, biaffine=True, pos=use_pos, char=use_char, elmo=use_elmo, prior_order=prior_order,
                          skipConnect=skipConnect, grandPar=grandPar, sibling=sibling, elmo_path=elmo_path, elmo_dim=elmo_dim,
                          bert=use_bert, bert_path=bert_path, bert_feature_dim=bert_feature_dim)


    # if fine_tune_path is not None:
    #     pretrained_dict = torch.load(fine_tune_path)
    #     model_dict = network.state_dict()
    #     # select
    #     #model_dict['pos_embedd.weight'] = pretrained_dict['pos_embedd.weight']
    #     model_dict['word_embedd.weight'] = pretrained_dict['word_embedd.weight']
    #     #model_dict['char_embedd.weight'] = pretrained_dict['char_embedd.weight']
    #     network.load_state_dict(model_dict)

    model_ver = args.model_version
    if model_ver is not None:
        savePath = args.model_path + model_ver + 'network.pt'
        network.load_state_dict(torch.load(savePath))
        logger.info('Load model: %s' % (model_ver))

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                     mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
                     num_types, arc_space, type_space, pos_embedding]
        kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char, 'elmo': use_elmo, 'prior_order': prior_order,
                 'skipConnect': skipConnect, 'grandPar': grandPar, 'sibling': sibling}
        json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w', encoding="utf-8"), indent=4)

        with open(arg_path + '.raw_args', 'w', encoding="utf-8") as f:
            f.write(str(args))

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding)

    def generate_optimizer(opt, lr, params):
        params = [param for param in params if param.requires_grad]
        if opt == 'adam':
            return Adam(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        elif opt == 'sgd':
            return SGD(params, lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
        elif opt == 'adamax':
            return Adamax(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    def generate_differentlr_bert_optimizer(lr, bert_lr, model):
        no_decay = ['bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if 'bert_model' not in n]
             }  # ,
            # {'params': [p for n, p in model.named_parameters() if 'bert_' in n],
            # 'lr': bert_lr}
        ]
        '''
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if 'bert_model' not in n]},
            #{'params': model.bert_model.parameters(), 'lr': bert_lr}
            {'params': model.bert_morp_feature_embedd.parameters(), 'lr': bert_lr},
            {'params': model.bert_word_feature_embedd.parameters(), 'lr': bert_lr}
        ]
        '''
        for n in optimizer_grouped_parameters:
            print(n)
        # optimizer=Adam(optimizer_grouped_parameters, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        optimizer = BertAdam(optimizer_grouped_parameters, lr=lr, e=1e-8)
        # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0, t_total=t_total)
        return optimizer

    def generate_old_bert_optimizer(t_total, bert_lr, model):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': gamma},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = BertAdam(optimizer_grouped_parameters, lr=bert_lr, e=1e-8)
        # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0, t_total=t_total)
        return optimizer

    lr = learning_rate
    # bert_lr = learning_rate
    #  optim = generate_optimizer(opt, lr, network.parameters())
    if use_bert:
        # optim =generate_differentlr_bert_optimizer(lr, lr, network)
        optim = generate_old_bert_optimizer(len(data_train) * num_epochs, lr, network)

    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info("Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" % (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info("RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, encoder_layers, decoder_layers, input_size_decoder, hidden_size, arc_space, type_space))
    logger.info("train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)" % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' % (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d' % (skipConnect, beam))
    logger.info(opt_info)

    num_batches = int(num_data / batch_size + 1)  # kwon
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay
    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d (%d, %d))): ' % (
            epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay, max_decay, double_schedule_decay))
        train_err_arc_leaf = 0.    # QUESTION: leaf and non-leaf?
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0

        network.train()
        for batch in range(1, num_batches + 1):
            # load data #bert2020 [boychaboy]
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(data_train, batch_size, pos_embedding, unk_replace=unk_replace, elmo = use_elmo, bert=use_bert)

            if use_elmo:
                word, char, pos, heads, types, masks_e, lengths_e, word_elmo, word_bert = input_encoder
            else:
                word, char, pos, heads, types, masks_e, lengths_e, word_bert = input_encoder

            stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()

            if use_elmo:
                loss_arc_leaf, loss_arc_non_leaf, \
                loss_type_leaf, loss_type_non_leaf, \
                loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling, stacked_types, label_smooth, skip_connect=skip_connect, mask_e=masks_e, \
                                                                length_e=lengths_e, mask_d=masks_d, length_d=lengths_d, input_word_elmo = word_elmo, input_word_bert = word_bert)
            else:
                loss_arc_leaf, loss_arc_non_leaf, \
                loss_type_leaf, loss_type_non_leaf, \
                loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling, stacked_types, label_smooth, \
                                                            skip_connect=skip_connect, mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d, input_word_bert=word_bert)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov    # cov is set to 0 by default
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.item()
            num_non_leaf = num_non_leaf.item()
            train_err_arc_leaf += loss_arc_leaf.item() * num_leaf
            train_err_arc_non_leaf += loss_arc_non_leaf.item() * num_non_leaf

            train_err_type_leaf += loss_type_leaf.item() * num_leaf
            train_err_type_non_leaf += loss_type_non_leaf.item() * num_non_leaf

            train_err_cov += loss_cov.item() * (num_leaf + num_non_leaf)
            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf, err_type, err_type_leaf, err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        print('train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs' % (
            num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf, err_type, err_type_leaf, err_type_non_leaf, err_cov, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = model_path + 'tmp/pred_dev%d' % (epoch)
        pred_writer.start(pred_filename)
        gold_filename = model_path + 'tmp/gold_dev%d' % (epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_dev, batch_size, pos_embedding, type='dev', elmo=use_elmo, bert=use_bert):
            input_encoder, _ = batch
            if use_elmo:
                word, char, pos, heads, types, masks, lengths, word_elmo, word_bert = input_encoder
                heads_pred, types_pred, _, _ = network.decode(word, char, pos, input_word_elmo=word_elmo, mask=masks,
                                                              length=lengths, beam=beam,
                                                              leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
            else:
                word, char, pos, heads, types, masks, lengths, word_bert = input_encoder
                heads_pred, types_pred, _, _ = network.decode(word, char, pos, mask=masks, length=lengths, beam=beam,
                                                              leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
            gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' % (dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc * 1.5 + dev_lcorrect_nopunc < dev_ucorr_nopunc * 1.5 + dev_lcorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)
            # save embedding to txt
            # FIXME format!
            #with open(model_path + 'embedding.txt', 'w') as f:
            #    for word, idx in word_alphabet.items():
            #        embedding = network.word_embedd.weight[idx, :]
            #        f.write('{}\t{}\n'.format(word, embedding))

            if test_path:
                pred_filename = model_path + 'tmp/%spred_test%d' % (str(uid), epoch)
                pred_writer.start(pred_filename)
                gold_filename = model_path + 'tmp/%sgold_test%d' % (str(uid), epoch)
                gold_writer.start(gold_filename)

                test_ucorrect = 0.0
                test_lcorrect = 0.0
                test_ucomlpete_match = 0.0
                test_lcomplete_match = 0.0
                test_total = 0

                test_ucorrect_nopunc = 0.0
                test_lcorrect_nopunc = 0.0
                test_ucomlpete_match_nopunc = 0.0
                test_lcomplete_match_nopunc = 0.0
                test_total_nopunc = 0
                test_total_inst = 0

                test_root_correct = 0.0
                test_total_root = 0
                for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, batch_size, pos_embedding, type='dev'):
                    input_encoder, _ = batch
                    word, char, pos, heads, types, masks, lengths = input_encoder
                    heads_pred, types_pred, _, _ = network.decode(word, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                    word = word.data.cpu().numpy()
                    pos = pos.data.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.data.cpu().numpy()
                    types = types.data.cpu().numpy()

                    pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                    gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                    stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    test_ucorrect += ucorr
                    test_lcorrect += lcorr
                    test_total += total
                    test_ucomlpete_match += ucm
                    test_lcomplete_match += lcm

                    test_ucorrect_nopunc += ucorr_nopunc
                    test_lcorrect_nopunc += lcorr_nopunc
                    test_total_nopunc += total_nopunc
                    test_ucomlpete_match_nopunc += ucm_nopunc
                    test_lcomplete_match_nopunc += lcm_nopunc

                    test_root_correct += corr_root
                    test_total_root += total_root

                    test_total_inst += num_inst

                pred_writer.close()
                gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                if use_bert:
                    # optim = generate_differentlr_bert_optimizer(lr, lr, network)
                    optim = generate_old_bert_optimizer(len(data_train) * num_epochs, lr, network)
                else:
                    optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
        if test_path:
            print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
                test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
                best_epoch))
            print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
                test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
                best_epoch))
            print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
            print('============================================================================================================================')

        if decay == max_decay:
            break

    def save_result():
        result_path = model_name + '.result.txt'
        best_dev_Punc = 'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch)
        best_dev_noPunc = 'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch)
        best_dev_Root = 'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch)
        f = open(result_path, 'w', encoding="utf-8")
        f.write(best_dev_Punc + '\n')
        f.write(best_dev_noPunc + '\n')
        f.write(best_dev_Root)
        f.close()

    save_result()
示例#11
0
    def __init__(self, device, word_table, char_table, model_id, args):
        super(third_party_parser, self).__init__()
        # mode = args.mode
        # if model_id==0 and args.treebank == 'ptb':
        if args.treebank == 'ptb':
            model_path = "models/parsing/stack_ptr/"  # args.model_path
        elif args.treebank == 'ctb':
            model_path = "ctb_models/parsing/stack_ptr/"  # args.model_path
        model_name = 'network.pt'  # args.model_name

        model_name = os.path.join(model_path, model_name)

        # data_test = conllx_stacked_data.read_stacked_data_to_tensor(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, prior_order=prior_order, device=device)

        # save_args()
        arg_path = model_name + '.arg.json'
        # json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w'), indent=4)
        [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, input_size_decoder, hidden_size,
            encoder_layers, decoder_layers, num_types, arc_space, type_space
        ] = json.load(open(arg_path, "r"))['args']
        parameters = json.load(open(arg_path, "r"))['kwargs']
        p_in = parameters['p_in']
        p_out = parameters['p_out']
        p_rnn = parameters['p_rnn']
        True = parameters['biaffine']
        use_pos = False  #parameters['pos']
        use_char = False  #parameters['char']
        prior_order = parameters['prior_order']
        skipConnect = parameters['skipConnect']
        grandPar = parameters['grandPar']
        sibling = parameters['sibling']

        window = 3
        self.network = StackPtrNet(word_dim,
                                   num_words,
                                   char_dim,
                                   num_chars,
                                   pos_dim,
                                   num_pos,
                                   num_filters,
                                   window,
                                   mode,
                                   input_size_decoder,
                                   hidden_size,
                                   encoder_layers,
                                   decoder_layers,
                                   num_types,
                                   arc_space,
                                   type_space,
                                   embedd_word=word_table,
                                   embedd_char=char_table,
                                   p_in=p_in,
                                   p_out=p_out,
                                   p_rnn=p_rnn,
                                   biaffine=True,
                                   pos=use_pos,
                                   char=use_char,
                                   prior_order=prior_order,
                                   skipConnect=skipConnect,
                                   grandPar=grandPar,
                                   sibling=sibling)
        # if True:
        #     freeze_embedding(network.word_embedd)

        self.network = self.network.to(device)

        ####################
        self.network.load_state_dict(torch.load(model_name))
        self.network = self.network.to(device)
        self.network.eval()
示例#12
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        None,
        pos_embedding,
        data_paths=[None, None],
        max_vocabulary_size=50000,
        embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        is_test=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    # kwargs???�로??embedidng 추�?
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)

    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx, (word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    word_table = get_embedding_table()
    kwargs['embedd_word'] = word_table
    args[1] = len(word_dict)  # word_dim
    network = StackPtrNet(*args, **kwargs)
    # word_embedidng?� ??불러?�기
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({
        k: v
        for k, v in pretrained_dict.items() if k != 'word_embedd.weight'
    })

    network.load_state_dict(model_dict)

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/tmp/inference.txt')
    else:
        pred_writer.start(model_path + '/tmp/inference_ordered_temp.txt')
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1, pos_embedding, type='dev'):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
    pred_writer.close()
示例#13
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger, args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets\
        (alphabet_path,None, pos_embedding,data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    use_bert = args.bert
    bert_path = args.bert_path
    bert_feature_dim = args.bert_feature_dim
    if use_bert:
        etri_test_path = args.etri_test
    else:
        etri_test_path = None

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' % (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding,
                                                                  use_gpu=use_gpu, volatile=True, prior_order=prior_order, is_test=False,
                                                                  bert=use_bert, etri_path=etri_test_path)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)
    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx,(word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    # word_table = get_embedding_table()
    word_table = construct_word_embedding_table()
    # kwargs['embedd_word'] = word_table
    # args[1] = len(word_dict) # word_dim

    network = StackPtrNet(*args, **kwargs, bert=use_bert, bert_path=bert_path, bert_feature_dim=bert_feature_dim)
    network.load_state_dict(torch.load(model_name))
    """
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({k:v for k,v in list(pretrained_dict.items())
        if k != 'word_embedd.weight'})
    
    network.load_state_dict(model_dict)
    """

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/inference.txt')
    else:
        pred_writer.start(model_path + '/RL_B[test].txt')
    sent = 1

    dev_ucorr_nopunc = 0.0
    dev_lcorr_nopunc = 0.0
    dev_total_nopunc = 0
    dev_ucomlpete_nopunc = 0.0
    dev_lcomplete_nopunc = 0.0
    dev_total_inst = 0.0
    sys.stdout.write('Start!\n')
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, 1, pos_embedding, type='dev', bert=use_bert):
        if sent % 100 == 0:
            ####
            print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
                dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
                dev_lcomplete_nopunc * 100 / dev_total_inst))
            sys.stdout.write('[%d/%d]\n' %(sent, int(data_test[2][0])))
            ####
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks_e, lengths, word_bert = input_encoder
        stacked_heads, children, sibling, stacked_types, skip_connect, previous, nexts, masks_d, lengths_d = input_decoder
        heads_pred, types_pred, _, _ = network.decode(word, char, pos, previous, nexts, stacked_heads, mask_e=masks_e, mask_d=masks_d,
                                                              length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
        """
        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        """

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.test_write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
###########
        stats, stats_nopunc, _, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                word_alphabet, pos_alphabet, lengths,
                                                                punct_set=punct_set, symbolic_root=True)

        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        dev_ucorr_nopunc += ucorr_nopunc
        dev_lcorr_nopunc += lcorr_nopunc
        dev_total_nopunc += total_nopunc
        dev_ucomlpete_nopunc += ucm_nopunc
        dev_lcomplete_nopunc += lcm_nopunc

        dev_total_inst += num_inst
    end_time = time.time()
################
    pred_writer.close()

    print('\nFINISHED!!\n', end_time - start_time)
    print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
        dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
        dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
        dev_lcomplete_nopunc * 100 / dev_total_inst))