Ejemplo n.º 1
0
def augment_with_extra_embedding(the_alphabet, extra_embed_file,
                                 extra_embed_src_file, test_file, logger):
    extra_embeds_arr = []
    if extra_embed_file is not None:
        # reopen the vocab
        the_alphabet.open()
        # read the embed
        extra_word_dict, _ = load_embedding_dict('word2vec', extra_embed_file)
        if extra_embed_src_file is not None:
            src_extra_word_dict, _ = load_embedding_dict(
                'word2vec', extra_embed_src_file)
        lang_id = guess_language_id(test_file)
        for one_sent in iter_file(test_file):
            for w in one_sent["word"]:
                already_spec = w.startswith("!en_")
                if already_spec:
                    normed_word = w
                else:
                    normed_word = DIGIT_RE.sub(b"0", w)
                    normed_word = lang_specific_word(normed_word,
                                                     lang_id=lang_id)
                #
                if normed_word in the_alphabet.instance2index:
                    continue
                # TODO: assume english is the source for run-translate
                if already_spec:
                    w = w[4:]
                    check_dict = src_extra_word_dict
                else:
                    check_dict = extra_word_dict
                #
                if w in check_dict:
                    new_embed_arr = check_dict[w]
                elif w.lower() in check_dict:
                    new_embed_arr = check_dict[w.lower()]
                else:
                    new_embed_arr = None
                if new_embed_arr is not None:
                    extra_embeds_arr.append(new_embed_arr)
                    the_alphabet.add(normed_word)
        # close the vocab
        the_alphabet.close()
    logger.info("Augmenting the vocab with new words of %s, now vocab is %s." %
                (len(extra_embeds_arr), the_alphabet.size()))
    return extra_embeds_arr
Ejemplo n.º 2
0
    def __init__(self, train, test, embeddings_filename, batch_size=1):
        self.train_path = train
        self.test_path = test
        self.mode = 'LSTM'
        self.dropout = 'std'
        self.num_epochs = 1
        self.batch_size = batch_size
        self.hidden_size = 256
        self.num_filters = 30
        self.learning_rate = 0.01
        self.momentum = 0.9
        self.decay_rate = 0.05
        self.gamma = 0.0
        self.schedule = 1
        self.p_rnn = tuple([0.33, 0.5])
        self.p_in = 0.33
        self.p_out = 0.5
        self.unk_replace = 0.0
        self.bigram = True
        self.embedding = 'glove'
        self.logger = get_logger("NERCRF")
        self.char_dim = 30
        self.window = 3
        self.num_layers = 1
        self.tag_space = 128
        self.initializer = nn.init.xavier_uniform

        self.use_gpu = torch.cuda.is_available()

        self.embedd_dict, self.embedd_dim = utils.load_embedding_dict(
            self.embedding, embeddings_filename)
        self.word_alphabet, self.char_alphabet, self.pos_alphabet, \
        self.chunk_alphabet, self.ner_alphabet = conll03_data.create_alphabets("data/alphabets/ner_crf/", self.train_path, data_paths=[self.test_path],
        embedd_dict=self.embedd_dict, max_vocabulary_size=50000)
        self.word_table = self.construct_word_embedding_table()

        self.logger.info("Word Alphabet Size: %d" % self.word_alphabet.size())
        self.logger.info("Character Alphabet Size: %d" %
                         self.char_alphabet.size())
        self.logger.info("POS Alphabet Size: %d" % self.pos_alphabet.size())
        self.logger.info("Chunk Alphabet Size: %d" %
                         self.chunk_alphabet.size())
        self.logger.info("NER Alphabet Size: %d" % self.ner_alphabet.size())
        self.num_labels = self.ner_alphabet.size()

        self.data_test = conll03_data.read_data_to_variable(
            self.test_path,
            self.word_alphabet,
            self.char_alphabet,
            self.pos_alphabet,
            self.chunk_alphabet,
            self.ner_alphabet,
            use_gpu=self.use_gpu,
            volatile=True)
        self.writer = CoNLL03Writer(self.word_alphabet, self.char_alphabet,
                                    self.pos_alphabet, self.chunk_alphabet,
                                    self.ner_alphabet)
Ejemplo n.º 3
0
def get_unseen_embedding(word_set, word_alphabet, embedding_path):
    prev_word_set = set()
    for _, word in word_alphabet.enumerate_items(1):
        prev_word_set.add(word)

    unseen_words = word_set - prev_word_set
    embedd_dict, _ = utils.load_embedding_dict('NNLM', embedding_path)

    unseen_emb_dict = get_embedding(unseen_words, embedd_dict)
    return unseen_emb_dict
Ejemplo n.º 4
0
def main(a=None):
    if a is None:
        a = sys.argv[1:]
    args = parse_cmd(a)
    # if output directory doesn't exist, create it
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    logger = get_logger("VocabBuilder", args.model_path + '/vocab.log.txt')
    logger.info('\ncommand-line params : {0}\n'.format(sys.argv[1:]))
    logger.info('{0}\n'.format(args))
    # load embeds
    logger.info("Load embeddings")
    word_dicts = []
    word_dim = None
    for one in args.word_paths:
        one_word_dict, one_word_dim = utils.load_embedding_dict(
            args.word_embedding, one)
        assert word_dim is None or word_dim == one_word_dim, "Embedding size not matched!"
        word_dicts.append(one_word_dict)
        word_dim = one_word_dim
    # combine embeds
    combined_word_dict, count_ins, count_repeats = combine_embeds(word_dicts)
    logger.info("Final embeddings size: %d." % len(combined_word_dict))
    for one_fname, one_count_ins, one_count_repeats in zip(
            args.word_paths, count_ins, count_repeats):
        logger.info("For embed-file %s, count-in: %d, repeat-discard: %d." %
                    (one_fname, one_count_ins, one_count_repeats))
    # create vocabs
    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(args.model_path, 'alphabets/')
    assert not os.path.exists(
        alphabet_path), "Alphabet path exists, please build with a new path."
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_stacked_data.create_alphabets(
        alphabet_path,
        args.train,
        data_paths=args.extra,
        max_vocabulary_size=100000,
        embedd_dict=combined_word_dict)
    # printing info
    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()
    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)
Ejemplo n.º 5
0
def main():
    uid = uuid.uuid4().hex[:6]

    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             default='FastLSTM')
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=10,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=32,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size',
                             type=int,
                             default=256,
                             help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of decoder RNN')
    args_parser.add_argument('--char_num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN(Character Level)')
    args_parser.add_argument('--eojul_num_filters',
                             type=int,
                             default=100,
                             help='Number of filters in CNN(Eojul Level)')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--eojul',
                             action='store_true',
                             help='use eojul embedding and CNN.')
    args_parser.add_argument('--word_dim',
                             type=int,
                             default=100,
                             help='Dimension of Word embeddings')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm',
                             default='adam')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             default=[0.33, 0.33],
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth',
                             type=float,
                             default=1.0,
                             help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect',
                             action='store_true',
                             help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar',
                             action='store_true',
                             help='use grand parent.')
    args_parser.add_argument('--sibling',
                             action='store_true',
                             help='use sibling.')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument('--schedule',
                             type=int,
                             default=20,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument(
        '--word_embedding',
        choices=['random', 'word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'word2vec'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument('--pos_embedding',
                             choices=['random', 'word2vec'],
                             help='Embedding for part of speeches',
                             required=True)
    args_parser.add_argument('--pos_path',
                             help='path for part of speech embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    args_parser.add_argument('--use_gpu',
                             action='store_true',
                             help='use the gpu')

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = "{}_{}".format(str(uid), args.model_name)
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    char_num_filters = args.char_num_filters
    eojul_num_filters = args.eojul_num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    use_gpu = args.use_gpu
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path
    pos_embedding = args.pos_embedding
    pos_path = args.pos_path

    use_pos = args.pos

    if word_embedding != 'random':
        word_dict, word_dim = utils.load_embedding_dict(
            word_embedding, word_path)
    else:
        word_dict = {}
        word_dim = args.word_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    else:
        if use_char:
            char_dict = {}
            char_dim = args.char_dim
        else:
            char_dict = None
    if pos_embedding != 'random':
        pos_dict, pos_dim = utils.load_embedding_dict(pos_embedding, pos_path)
    else:
        if use_pos:
            pos_dict = {}
            pos_dim = args.pos_dim
        else:
            pos_dict = None

    use_eojul = args.eojul

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = use_gpu

    data_train = conllx_stacked_data.read_stacked_data_to_variable(
        train_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    num_data = sum(data_train[1])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(
        dev_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_pos_embedding_table():
        if pos_dict is None:
            return None

        scale = np.sqrt(3.0 / pos_dim)
        table = np.empty([num_pos, pos_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, pos_dim]).astype(np.float32)
        oov = 0
        for pos, index in pos_alphabet.items():
            if pos in pos_dict:
                embedding = pos_dict[pos]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, pos_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('pos OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()
    pos_table = construct_pos_embedding_table()

    char_window = 3
    eojul_window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          char_num_filters,
                          char_window,
                          eojul_num_filters,
                          eojul_window,
                          mode,
                          input_size_decoder,
                          hidden_size,
                          encoder_layers,
                          decoder_layers,
                          num_types,
                          arc_space,
                          type_space,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          embedd_pos=pos_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          pos=use_pos,
                          char=use_char,
                          eojul=use_eojul,
                          prior_order=prior_order,
                          skipConnect=skipConnect,
                          grandPar=grandPar,
                          sibling=sibling)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            char_num_filters, char_window, eojul_num_filters, eojul_window,
            mode, input_size_decoder, hidden_size, encoder_layers,
            decoder_layers, num_types, arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'eojul': use_eojul,
            'prior_order': prior_order,
            'skipConnect': skipConnect,
            'grandPar': grandPar,
            'sibling': sibling
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("Char CNN: filter=%d, kernel=%d" %
                (char_num_filters, char_window))
    logger.info("Eojul CNN: filter=%d, kernel=%d" %
                (eojul_num_filters, eojul_window))
    logger.info(
        "RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d"
        % (mode, encoder_layers, decoder_layers, input_size_decoder,
           hidden_size, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' %
                (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d, use_gpu: %s' %
                (skipConnect, beam, use_gpu))
    logger.info(opt_info)

    num_batches = num_data // batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    decay = 0.
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d (%d, %d))): '
            % (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay,
               max_decay, double_schedule_decay))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train,
                batch_size,
                unk_replace=unk_replace,
                use_gpu=use_gpu)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling, stacked_types, label_smooth,
                                                            skip_connect=skip_connect, mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.item()  ##180809 data[0] --> item()
            num_non_leaf = num_non_leaf.item()  ##180809 data[0] --> item()

            train_err_arc_leaf += loss_arc_leaf.item(
            ) * num_leaf  ##180809 data[0] --> item()
            train_err_arc_non_leaf += loss_arc_non_leaf.item(
            ) * num_non_leaf  ##180809 data[0] --> item()

            train_err_type_leaf += loss_type_leaf.item(
            ) * num_leaf  ##180809 data[0] --> item()
            train_err_type_non_leaf += loss_type_non_leaf.item(
            ) * num_non_leaf  ##180809 data[0] --> item()

            train_err_cov += loss_cov.item() * (num_leaf + num_non_leaf
                                                )  ##180809 data[0] --> item()

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        print(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        #torch.save(network.state_dict(), model_name+"."+str(epoch))
        #continue

        # evaluate performance on dev data
        network.eval()
        tmp_root = 'tmp'
        if not os.path.isdir(tmp_root):
            logger.info('Creating temporary folder(%s)' % (tmp_root, ))
            os.makedirs(tmp_root)
        pred_filename = '%s/%spred_dev%d' % (tmp_root, str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = '%s/%sgold_dev%d' % (tmp_root, str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size, use_gpu=use_gpu):
            input_encoder, _, sentences = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(sentences,
                              word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(sentences,
                              word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size, use_gpu=use_gpu):
                input_encoder, _, sentences = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(sentences,
                                  word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(sentences,
                                  word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        if test_total_inst != 0 or test_total != 0:
            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            print(
                'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (test_ucorrect, test_lcorrect, test_total, test_ucorrect *
                   100 / test_total, test_lcorrect * 100 / test_total,
                   test_ucomlpete_match * 100 / test_total_inst,
                   test_lcomplete_match * 100 / test_total_inst, best_epoch))
            print(
                'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                %
                (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                 test_ucorrect_nopunc * 100 / test_total_nopunc,
                 test_lcorrect_nopunc * 100 / test_total_nopunc,
                 test_ucomlpete_match_nopunc * 100 / test_total_inst,
                 test_lcomplete_match_nopunc * 100 / test_total_inst,
                 best_epoch))
            print(
                'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (test_root_correct, test_total_root,
                   test_root_correct * 100 / test_total_root, best_epoch))
            print(
                '============================================================================================================================'
            )

        if decay == max_decay:
            break
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='Number of layers of RNN')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.015,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p_rnn',
                        nargs=2,
                        type=float,
                        required=True,
                        help='dropout rate for RNN')
    parser.add_argument('--p_in',
                        type=float,
                        default=0.33,
                        help='dropout rate for input embeddings')
    parser.add_argument('--p_out',
                        type=float,
                        default=0.33,
                        help='dropout rate for output layer')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, ner_alphabet = conll03_data.create_alphabets("data/alphabets/ner_crf/", train_path, data_paths=[dev_path, test_path],
                                                                 embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conll03_data.read_data_to_variable(train_path,
                                                    word_alphabet,
                                                    char_alphabet,
                                                    pos_alphabet,
                                                    chunk_alphabet,
                                                    ner_alphabet,
                                                    use_gpu=use_gpu)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data_to_variable(dev_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  chunk_alphabet,
                                                  ner_alphabet,
                                                  use_gpu=use_gpu,
                                                  volatile=True)
    data_test = conll03_data.read_data_to_variable(test_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   chunk_alphabet,
                                                   ner_alphabet,
                                                   use_gpu=use_gpu,
                                                   volatile=True)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     p_in=p_in,
                                     p_out=p_out,
                                     p_rnn=p_rnn,
                                     bigram=bigram,
                                     initializer=initializer)
    else:
        network = BiVarRecurrentConvCRF(embedd_dim,
                                        word_alphabet.size(),
                                        char_dim,
                                        char_alphabet.size(),
                                        num_filters,
                                        window,
                                        mode,
                                        hidden_size,
                                        num_layers,
                                        num_labels,
                                        tag_space=tag_space,
                                        embedd_word=word_table,
                                        p_in=p_in,
                                        p_out=p_out,
                                        p_rnn=p_rnn,
                                        bigram=bigram,
                                        initializer=initializer)

    if use_gpu:
        network.cuda()

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma,
                nesterov=True)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s"
        % (mode, num_layers, hidden_size, num_filters, tag_space,
           'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))

    num_batches = num_data / batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths = conll03_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        tmp_filename = 'tmp/%s_dev%d' % (str(uid), epoch)
        writer.start(tmp_filename)

        for batch in conll03_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, chunk, labels, masks, lengths = batch
            preds, _ = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         pos.data.cpu().numpy(),
                         chunk.data.cpu().numpy(),
                         preds.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        writer.close()
        acc, precision, recall, f1 = evaluate(tmp_filename)
        print(
            'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
            (acc, precision, recall, f1))

        if dev_f1 < f1:
            dev_f1 = f1
            dev_acc = acc
            dev_precision = precision
            dev_recall = recall
            best_epoch = epoch

            # evaluate on test data when better performance detected
            tmp_filename = 'tmp/%s_test%d' % (str(uid), epoch)
            writer.start(tmp_filename)

            for batch in conll03_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, pos, chunk, labels, masks, lengths = batch
                preds, _ = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            test_acc, test_precision, test_recall, test_f1 = evaluate(
                tmp_filename)

        print(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        print(
            "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (test_acc, test_precision, test_recall, test_f1, best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=1000,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("POSCRFTagger")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p = args.p
    unk_replace = args.unk_replace
    bigram = args.bigram

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        'glove', "data/glove/glove.6B/glove.6B.100d.gz")
    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_data.create_alphabets("data/alphabets/pos_crf/", train_path,
                                                 data_paths=[dev_path, test_path],
                                                 max_vocabulary_size=50000, embedd_dict=embedd_dict)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_data.read_data_to_variable(train_path,
                                                   word_alphabet,
                                                   char_alphabet,
                                                   pos_alphabet,
                                                   type_alphabet,
                                                   use_gpu=use_gpu)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])
    num_labels = pos_alphabet.size()

    data_dev = conllx_data.read_data_to_variable(dev_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 pos_alphabet,
                                                 type_alphabet,
                                                 use_gpu=use_gpu,
                                                 volatile=True)
    data_test = conllx_data.read_data_to_variable(test_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  type_alphabet,
                                                  use_gpu=use_gpu,
                                                  volatile=True)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = 1
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     embedd_word=word_table,
                                     p_rnn=p,
                                     bigram=bigram)
    else:
        raise NotImplementedError

    if use_gpu:
        network.cuda()

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma)
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, crf=%s" %
                (mode, num_layers, hidden_size, num_filters,
                 'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, dropout: %.2f, unk replace: %.2f)"
        % (gamma, num_data, batch_size, p, unk_replace))

    num_batches = num_data / batch_size + 1
    dev_correct = 0.0
    best_epoch = 0
    test_correct = 0.0
    test_total = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, labels, _, _, masks, lengths = conllx_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        dev_corr = 0.0
        dev_total = 0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, labels, _, _, masks, lengths = batch
            preds, corr = network.decode(
                word,
                char,
                target=labels,
                mask=masks,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            num_tokens = masks.data.sum()
            dev_corr += corr
            dev_total += num_tokens
        print('dev corr: %d, total: %d, acc: %.2f%%' %
              (dev_corr, dev_total, dev_corr * 100 / dev_total))

        if dev_correct < dev_corr:
            dev_correct = dev_corr
            best_epoch = epoch

            # evaluate on test data when better performance detected
            test_corr = 0.0
            test_total = 0
            for batch in conllx_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, labels, _, _, masks, lengths = batch
                preds, corr = network.decode(
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                num_tokens = masks.data.sum()
                test_corr += corr
                test_total += num_tokens
            test_correct = test_corr
        print("best dev  corr: %d, total: %d, acc: %.2f%% (epoch: %d)" %
              (dev_correct, dev_total, dev_correct * 100 / dev_total,
               best_epoch))
        print("best test corr: %d, total: %d, acc: %.2f%% (epoch: %d)" %
              (test_correct, test_total, test_correct * 100 / test_total,
               best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)
Ejemplo n.º 8
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=200,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adadelta'],
                             help='optimization algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.5,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument('--word_embedding',
                             choices=['glove', 'senna', 'sskip', 'polyglot'],
                             help='Embedding for words',
                             required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    rho = 0.9
    eps = 1e-6
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    beam = args.beam
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    logger.info("Creating Alphabets")

    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_stacked_data.read_stacked_data_to_variable(
        train_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        prior_order=prior_order)
    num_data = sum(data_train[1])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(
        dev_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)
    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          num_filters,
                          window,
                          mode,
                          hidden_size,
                          num_layers,
                          num_types,
                          arc_space,
                          type_space,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          prior_order=prior_order)

    if use_gpu:
        network.cuda()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adadelta':
            return Adadelta(params,
                            lr=lr,
                            rho=rho,
                            weight_decay=gamma,
                            eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adadelta':
        opt_info += 'rho=%.2f, eps=%.1e' % (rho, eps)

    logger.info("Embedding dim: word=%d, char=%d, pos=%d" %
                (word_dim, char_dim, pos_dim))
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, arc_space=%d, type_space=%d"
        % (mode, num_layers, hidden_size, num_filters, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, dropout(in, out, rnn): (%.2f, %.2f, %s), unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, p_in, p_out, p_rnn, unk_replace))
    logger.info('prior order: %s, beam: %d' % (prior_order, beam))
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, decay rate=%.2f (schedule=%d, patient=%d)): '
            % (epoch, mode, opt, lr, decay_rate, schedule, patient))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train, batch_size, unk_replace=unk_replace)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, stacked_types, masks_d, lengths_d = input_decoder
            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, stacked_heads, children, stacked_types,
                                                            mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.data[0]
            num_non_leaf = num_non_leaf.data[0]

            train_err_arc_leaf += loss_arc_leaf.data[0] * num_leaf
            train_err_arc_non_leaf += loss_arc_non_leaf.data[0] * num_non_leaf

            train_err_type_leaf += loss_type_leaf.data[0] * num_leaf
            train_err_type_non_leaf += loss_type_non_leaf.data[0] * num_non_leaf

            train_err_cov += loss_cov.data[0] * (num_leaf + num_non_leaf)

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        print(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size):
            input_encoder, _ = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc <= dev_ucorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            torch.save(network, model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size):
                input_encoder, _ = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if patient < schedule:
                patient += 1
            else:
                network = torch.load(model_name)
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        print(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (test_root_correct, test_total_root,
               test_root_correct * 100 / test_total_root, best_epoch))
        print(
            '============================================================================================================================'
        )
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn', required=True)
    parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Number of sentences in each batch')
    parser.add_argument('--hidden_size', type=int, default=128, help='Number of hidden units in RNN')
    parser.add_argument('--tag_space', type=int, default=0, help='Dimension of tag space')
    parser.add_argument('--num_filters', type=int, default=30, help='Number of filters in CNN')
    parser.add_argument('--char_dim', type=int, default=30, help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate', type=float, default=0.015, help='Learning rate')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha of rmsprop')
    parser.add_argument('--momentum', type=float, default=0, help='momentum')
    parser.add_argument('--lr_decay', type=float, default=0, help='Decay rate of learning rate')
    parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    parser.add_argument('--dropout', choices=['std', 'variational'], help='type of dropout', required=True)
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--bigram', action='store_true', help='bi-gram parameter for CRF')
    parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding', choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot', 'elmo'], help='Embedding for words', required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--elmo_option', help='path for ELMo option file')
    parser.add_argument('--elmo_weight', help='path for ELMo weight file')
    parser.add_argument('--elmo_cuda', help='assign GPU for ELMo embedding task')
    parser.add_argument('--attention', choices=['none', 'mlp', 'fine'], help='attetion mode', required=True)
    parser.add_argument('--data_reduce', help='data size reduce, value is keeping rate', default=1.0)
    parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    alpha = args.alpha
    momentum = args.momentum
    lr_decay = args.lr_decay
    gamma = args.gamma
    schedule = args.schedule
    p = args.p
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict
    elmo_option = args.elmo_option
    elmo_weight = args.elmo_weight
    elmo_cuda = int(args.elmo_cuda)
    attention_mode = args.attention
    data_reduce = float(args.data_reduce)

    embedd_dict, embedd_dim = utils.load_embedding_dict(embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
        chunk_alphabet, ner_alphabet = bionlp_data.create_alphabets(os.path.join(Path(train_path).parent.abspath(
        ), "alphabets"), train_path, data_paths=[dev_path, test_path], embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    if embedding == 'elmo':
        logger.info("Loading ELMo Embedder")
        ee = ElmoEmbedder(options_file=elmo_option, weight_file=elmo_weight, cuda_device=elmo_cuda)
    else:
        ee = None

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = bionlp_data.read_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet,
                                                    chunk_alphabet, ner_alphabet, use_gpu=use_gpu, 
                                                    elmo_ee=ee, data_reduce=data_reduce)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = bionlp_data.read_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet,
                                                  chunk_alphabet, ner_alphabet, use_gpu=use_gpu, volatile=True,
                                                  elmo_ee=ee)

    data_test = bionlp_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet,
                                                   chunk_alphabet, ner_alphabet, use_gpu=use_gpu, volatile=True,
                                                   elmo_ee=ee)

    writer = BioNLPWriter(word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[bionlp_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if not embedd_dict == None and word in embedd_dict:
                embedding = embedd_dict[word]
            elif not embedd_dict == None and word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = 1
    tag_space = args.tag_space
    if args.dropout == 'std':
        if attention_mode == 'none':
            network = BiRecurrentConvCRF(embedd_dim, word_alphabet.size(),
                                     char_dim, char_alphabet.size(),
                                     num_filters, window,
                                     mode, hidden_size, num_layers, num_labels,
                                     tag_space=tag_space, embedd_word=word_table, p_in=p, p_rnn=p, bigram=bigram, 
                                     elmo=(embedding == 'elmo'))
        else:
            network = BiRecurrentConvAttentionCRF(embedd_dim, word_alphabet.size(),
                                     char_dim, char_alphabet.size(),
                                     num_filters, window,
                                     mode, hidden_size, num_layers, num_labels,
                                     tag_space=tag_space, embedd_word=word_table, p_in=p, p_rnn=p, bigram=bigram,
                                     elmo=(embedding == 'elmo'), attention_mode=attention_mode)

    else:
        raise NotImplementedError

    if use_gpu:
        network.cuda()

    lr = learning_rate
    # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
    optim = RMSprop(network.parameters(), lr=lr, alpha=alpha, momentum=momentum, weight_decay=gamma)
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s" % (
        mode, num_layers, hidden_size, num_filters, tag_space, 'bigram' if bigram else 'unigram'))
    logger.info("training: l2: %f, (#training data: %d, batch: %d, dropout: %.2f, unk replace: %.2f)" % (
        gamma, num_data, batch_size, p, unk_replace))

    num_batches = num_data // batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): ' % (
            epoch, mode, args.dropout, lr, lr_decay, schedule))
        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, _, _, labels, masks, lengths, elmo_embedding = bionlp_data.get_batch_variable(data_train, batch_size,
                                                                                       unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(word, char, labels, mask=masks, elmo_word=elmo_embedding)
            loss.backward()
            clip_grad_norm(network.parameters(), 5.0)
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' % (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        tmp_filename = 'tmp/%s_dev%d' % (str(uid), epoch)
        writer.start(tmp_filename)

        for batch in bionlp_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
            preds, _ = network.decode(word, char, target=labels, mask=masks,
                                         leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS, elmo_word=elmo_embedding)
            writer.write(word.data.cpu().numpy(), pos.data.cpu().numpy(), chunk.data.cpu().numpy(),
                         preds.cpu().numpy(), labels.data.cpu().numpy(), lengths.cpu().numpy())
        writer.close()
        acc, precision, recall, f1 = evaluate(tmp_filename)
        print('dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' % (acc, precision, recall, f1))

        if dev_f1 < f1:
            dev_f1 = f1
            dev_acc = acc
            dev_precision = precision
            dev_recall = recall
            best_epoch = epoch

            # evaluate on test data when better performance detected
            tmp_filename = 'tmp/%s_test%d' % (str(uid), epoch)
            writer.start(tmp_filename)

            for batch in bionlp_data.iterate_batch_variable(data_test, batch_size):
                word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
                preds, _ = network.decode(word, char, target=labels, mask=masks,
                                          leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS, elmo_word=elmo_embedding)
                writer.write(word.data.cpu().numpy(), pos.data.cpu().numpy(), chunk.data.cpu().numpy(),
                             preds.cpu().numpy(), labels.data.cpu().numpy(), lengths.cpu().numpy())
            writer.close()
            test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename)

        print("best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
            dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        print("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
            test_acc, test_precision, test_recall, test_f1, best_epoch))

        if epoch % schedule == 0:
            # lr = learning_rate / (1.0 + epoch * lr_decay)
            # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
            lr = lr * lr_decay
            optim.param_groups[0]['lr'] = lr
Ejemplo n.º 10
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=200,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    #args_parser.add_argument('--decoder_input_size', type=int, default=256, help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder RNN')
    #args_parser.add_argument('--decoder_layers', type=int, default=1, help='Number of layers of decoder RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    # NOTE: action='store_true' is just to set ON
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    # NOTE: arg MUST be one of choices(when specified)
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument('--clip',
                             type=float,
                             default=1.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth',
                             type=float,
                             default=1.0,
                             help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect',
                             action='store_true',
                             help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar',
                             action='store_true',
                             help='use grand parent.')
    args_parser.add_argument('--sibling',
                             action='store_true',
                             help='use sibling.')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument(
        '--word_embedding',
        choices=['glove', 'senna', 'sskip', 'polyglot', 'NNLM'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    # TODO: to include in logging process
    args_parser.add_argument('--pos_embedding',
                             choices=[1, 2, 4],
                             type=int,
                             help='Embedding method for korean POS tag',
                             default=2)
    args_parser.add_argument('--pos_path', help='path for pos embedding dict')
    args_parser.add_argument('--elmo',
                             action='store_true',
                             help='use elmo embedding.')
    args_parser.add_argument('--elmo_path',
                             help='path for elmo embedding model.')
    args_parser.add_argument('--elmo_dim',
                             type=int,
                             help='dimension for elmo embedding model')
    #args_parser.add_argument('--fine_tune_path', help='fine tune starting from this state_dict')
    args_parser.add_argument('--model_version',
                             help='previous model version to load')

    #hoon : bert
    args_parser.add_argument(
        '--bert', action='store_true',
        help='use elmo embedding.')  # true if use bert(hoon)
    args_parser.add_argument(
        '--etri_train',
        help='path for etri data of bert')  # etri train path(hoon)
    args_parser.add_argument(
        '--etri_dev', help='path for etri data of bert')  # etri dev path(hoon)
    args_parser.add_argument('--bert_path',
                             help='path for bert embedding model.')  # yjyj
    args_parser.add_argument('--bert_dim',
                             type=int,
                             help='dimension for bert embedding model')  # yjyj
    args_parser.add_argument('--bert_learning_rate',
                             type=float,
                             default=5e-5,
                             help='Bert Learning rate')

    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             help='decoding algorithm',
                             required=True)  #yj
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path + uid + '/'  # for numerous experiments
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    #input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    #decoder_layers = args.decoder_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    # QUESTION: pretrained vector for char?
    char_path = args.char_path

    use_pos = False
    pos_embedding = args.pos_embedding
    pos_path = args.pos_path
    pos_dict = None
    pos_dim = args.pos_dim  # NOTE pretrain 있을 경우 pos_dim은 그거 따라감
    if pos_path is not None:
        pos_dict, pos_dim = utils.load_embedding_dict(
            word_embedding,
            pos_path)  # NOTE 임시적으로 word_embedding(NNLM)이랑 같은 형식
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    use_elmo = args.elmo
    elmo_path = args.elmo_path
    elmo_dim = args.elmo_dim
    #fine_tune_path = args.fine_tune_path

    #bert(hoon)
    use_bert = args.bert
    #bert yj
    bert_path = args.bert_path
    bert_dim = args.bert_dim
    bert_lr = args.bert_learning_rate

    etri_train_path = args.etri_train
    etri_dev_path = args.etri_dev

    obj = args.objective
    decoding = args.decode

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    # min_occurence=1
    data_paths = [dev_path, test_path] if test_path else [dev_path]
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=data_paths,
        max_vocabulary_size=50000,
        pos_embedding=pos_embedding,
        embedd_dict=word_dict)

    num_words = word_alphabet.size()  # 30268
    num_chars = char_alphabet.size()  # 3545
    num_pos = pos_alphabet.size()  # 46
    num_types = type_alphabet.size()  # 39

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # data is a list of tuple containing tensors, etc ...
    data_train = conllx_stacked_data.read_stacked_data_to_variable(
        train_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=1,
        prior_order=prior_order,
        elmo=use_elmo,
        bert=use_bert,
        etri_path=etri_train_path)
    num_data = sum(data_train[2])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(
        dev_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        elmo=use_elmo,
        bert=use_bert,
        etri_path=etri_dev_path)
    if test_path:
        data_test = conllx_stacked_data.read_stacked_data_to_variable(
            test_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            pos_embedding,
            use_gpu=use_gpu,
            volatile=True,
            prior_order=prior_order,
            elmo=use_elmo)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        # NOTE: UNK 관리!
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                # NOTE: words not in pretrained are set to random
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        #for char, index, in char_alphabet.items():
        for char, index in list(char_alphabet.items()):
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_pos_embedding_table():
        if pos_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_pos, pos_dim], dtype=np.float32)
        for pos, index in list(pos_alphabet.items()):
            if pos in pos_dict:
                embedding = pos_dict[pos]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
            table[index, :] = embedding
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()
    pos_table = construct_pos_embedding_table()

    window = 3

    # yj 수정
    # network = StackPtrNet(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
    #                       mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
    #                       num_types, arc_space, type_space, pos_embedding,
    #                       embedd_word=word_table, embedd_char=char_table, embedd_pos=pos_table, p_in=p_in, p_out=p_out,
    #                       p_rnn=p_rnn, biaffine=True, pos=use_pos, char=use_char, elmo=use_elmo, prior_order=prior_order,
    #                       skipConnect=skipConnect, grandPar=grandPar, sibling=sibling, elmo_path=elmo_path, elmo_dim=elmo_dim,
    #                       bert = use_bert, bert_path=bert_path, bert_dim=bert_dim)

    network = BiRecurrentConvBiAffine(word_dim,
                                      num_words,
                                      char_dim,
                                      num_chars,
                                      pos_dim,
                                      num_pos,
                                      num_filters,
                                      window,
                                      mode,
                                      hidden_size,
                                      encoder_layers,
                                      num_types,
                                      arc_space,
                                      type_space,
                                      embedd_word=word_table,
                                      embedd_char=char_table,
                                      embedd_pos=pos_table,
                                      p_in=p_in,
                                      p_out=p_out,
                                      p_rnn=p_rnn,
                                      biaffine=True,
                                      pos=use_pos,
                                      char=use_char,
                                      elmo=use_elmo,
                                      elmo_path=elmo_path,
                                      elmo_dim=elmo_dim,
                                      bert=use_bert,
                                      bert_path=bert_path,
                                      bert_dim=bert_dim)

    # if fine_tune_path is not None:
    #     pretrained_dict = torch.load(fine_tune_path)
    #     model_dict = network.state_dict()
    #     # select
    #     #model_dict['pos_embedd.weight'] = pretrained_dict['pos_embedd.weight']
    #     model_dict['word_embedd.weight'] = pretrained_dict['word_embedd.weight']
    #     #model_dict['char_embedd.weight'] = pretrained_dict['char_embedd.weight']
    #     network.load_state_dict(model_dict)

    model_ver = args.model_version
    if model_ver is not None:
        savePath = args.model_path + model_ver + 'network.pt'
        network.load_state_dict(torch.load(savePath))
        logger.info('Load model: %s' % (model_ver))

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, encoder_layers, num_types,
            arc_space, type_space, pos_embedding
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'elmo': use_elmo,
            'bert': use_bert
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w', encoding="utf-8"),
                  indent=4)

        with open(arg_path + '.raw_args', 'w', encoding="utf-8") as f:
            f.write(str(args))

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)

    def generate_optimizer(opt, lr, params):
        # params = [param for name, param in params if param.requires_grad]
        params = [param for name, param in params]
        if True:
            return AdamW(params, lr=lr, betas=betas, weight_decay=gamma)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    # 우선 huggingface 기본 bert option으로 수정
    def generate_bert_optimizer(t_total, bert_lr, model):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            gamma
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters, lr=bert_lr, eps=1e-8)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=0,
                                         t_total=t_total)

        return scheduler, optimizer

    lr = learning_rate
    if use_bert:
        scheduler, optim = generate_bert_optimizer(
            len(data_train) * num_epochs, lr, network)
    #optim = generate_optimizer(opt, lr, network.named_parameters())

    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    #logger.info("RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, encoder_layers, decoder_layers, input_size_decoder, hidden_size, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' %
                (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d' % (skipConnect, beam))
    logger.info(opt_info)

    num_batches = int(num_data / batch_size + 1)  # kwon
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0

        network.train()
        for batch in range(1, num_batches + 1):
            # load data
            input_encoder, _ = conllx_stacked_data.get_batch_stacked_variable(
                data_train,
                batch_size,
                pos_embedding,
                unk_replace=unk_replace,
                elmo=use_elmo,
                bert=use_bert)

            word_elmo = None
            if use_elmo:
                word, char, pos, heads, types, masks, lengths, word_elmo, word_bert = input_encoder
            else:
                word, char, pos, heads, types, masks, lengths, word_bert = input_encoder

            #stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()

            # yjyj
            loss_arc, loss_type, bert_word_feature_ids, bert_morp_feature_ids = network.loss(
                word,
                char,
                pos,
                heads,
                types,
                mask=masks,
                length=lengths,
                input_word_bert=word_bert)

            # loss_arc_leaf, loss_arc_non_leaf, \
            # loss_type_leaf, loss_type_non_leaf, \
            # loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling, stacked_types, label_smooth, skip_connect=skip_connect, mask_e=masks_e, \
            #                                                 length_e=lengths_e, mask_d=masks_d, length_d=lengths_d, input_word_elmo = word_elmo, input_word_bert = word_bert)

            # loss_arc = loss_arc_leaf + loss_arc_non_leaf
            # loss_type = loss_type_leaf + loss_type_non_leaf
            # loss = loss_arc + loss_type + cov * loss_cov    # cov is set to 0 by default
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()
            if use_bert:
                pass
                #bert_optim.step()
                #scheduler.step()

            num_inst = word.size(
                0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.item() * num_inst
            train_err_arc += loss_arc.item() * num_inst
            train_err_type += loss_type.item() * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # yjyj
            # num_leaf = num_leaf.item()
            # num_non_leaf = num_non_leaf.item()
            # train_err_arc_leaf += loss_arc_leaf.item() * num_leaf
            # train_err_arc_non_leaf += loss_arc_non_leaf.item() * num_non_leaf
            #
            # train_err_type_leaf += loss_type_leaf.item() * num_leaf
            # train_err_type_non_leaf += loss_type_non_leaf.item() * num_non_leaf
            #
            # train_err_cov += loss_cov.item() * (num_leaf + num_non_leaf)
            # train_total_leaf += num_leaf
            # train_total_non_leaf += num_non_leaf
            #
            # time_ave = (time.time() - start_time) / batch
            # time_left = (num_batches - batch) * time_ave

            # update log

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (
                    batch, num_batches, train_err / train_total, train_err_arc
                    / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print(
            'train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' %
            (num_batches, train_err / train_total, train_err_arc / train_total,
             train_err_type / train_total, time.time() - start_time))
        # yjyj
        #     if batch % 10 == 0:
        #         sys.stdout.write("\b" * num_back)
        #         sys.stdout.write(" " * num_back)
        #         sys.stdout.write("\b" * num_back)
        #         err_arc_leaf = train_err_arc_leaf / train_total_leaf
        #         err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        #         err_arc = err_arc_leaf + err_arc_non_leaf
        #
        #         err_type_leaf = train_err_type_leaf / train_total_leaf
        #         err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        #         err_type = err_type_leaf + err_type_non_leaf
        #
        #         err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)
        #
        #         err = err_arc + err_type + cov * err_cov
        #         log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
        #             batch, num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf, err_type, err_type_leaf, err_type_non_leaf, err_cov, time_left)
        #         sys.stdout.write(log_info)
        #         sys.stdout.flush()
        #         num_back = len(log_info)
        #
        # sys.stdout.write("\b" * num_back)
        # sys.stdout.write(" " * num_back)
        # sys.stdout.write("\b" * num_back)
        # err_arc_leaf = train_err_arc_leaf / train_total_leaf
        # err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        # err_arc = err_arc_leaf + err_arc_non_leaf
        #
        # err_type_leaf = train_err_type_leaf / train_total_leaf
        # err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        # err_type = err_type_leaf + err_type_non_leaf
        #
        # err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)
        #
        # err = err_arc + err_type + cov * err_cov
        # print('train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs' % (
        #     num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf, err_type, err_type_leaf, err_type_non_leaf, err_cov, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = model_path + 'tmp/pred_dev%d' % (epoch)
        pred_writer.start(pred_filename)
        gold_filename = model_path + 'tmp/gold_dev%d' % (epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size, pos_embedding, type='dev',
                elmo=use_elmo):
            input_encoder, _ = batch
            #@TODO 여기 input word elmo랑 input word bert 처리
            if use_elmo:
                word, char, pos, heads, types, masks, lengths, word_elmo, word_bert = input_encoder
                heads_pred, types_pred = decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)
                # heads_pred, types_pred, _, _ = network.decode(word, char, pos, input_word_elmo=word_elmo, mask=masks,
                #                                               length=lengths, beam=beam,
                #                                               leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
            else:
                word, char, pos, heads, types, masks, lengths, word_bert = input_encoder
                heads_pred, types_pred, bert_word_feature_ids, bert_morp_feature_ids = decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS,
                    input_word_bert=word_bert)
                # heads_pred, types_pred, _, _ = network.decode(word, char, pos, mask=masks, length=lengths, beam=beam,
                #                                               leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser_bpe.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True,
                bert_word_feature_ids=bert_word_feature_ids,
                bert_morp_feature_ids=bert_morp_feature_ids)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc * 1.5 + dev_lcorrect_nopunc < dev_ucorr_nopunc * 1.5 + dev_lcorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)
            # save embedding to txt
            # FIXME format!
            #with open(model_path + 'embedding.txt', 'w') as f:
            #    for word, idx in word_alphabet.items():
            #        embedding = network.word_embedd.weight[idx, :]
            #        f.write('{}\t{}\n'.format(word, embedding))

            if test_path:
                pred_filename = model_path + 'tmp/%spred_test%d' % (str(uid),
                                                                    epoch)
                pred_writer.start(pred_filename)
                gold_filename = model_path + 'tmp/%sgold_test%d' % (str(uid),
                                                                    epoch)
                gold_writer.start(gold_filename)

                test_ucorrect = 0.0
                test_lcorrect = 0.0
                test_ucomlpete_match = 0.0
                test_lcomplete_match = 0.0
                test_total = 0

                test_ucorrect_nopunc = 0.0
                test_lcorrect_nopunc = 0.0
                test_ucomlpete_match_nopunc = 0.0
                test_lcomplete_match_nopunc = 0.0
                test_total_nopunc = 0
                test_total_inst = 0

                test_root_correct = 0.0
                test_total_root = 0
                for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                        data_test, batch_size, pos_embedding, type='dev'):
                    input_encoder, _ = batch
                    word, char, pos, heads, types, masks, lengths = input_encoder

                    # yjyj
                    # heads_pred, types_pred, _, _ = network.decode(word, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)
                    heads_pred, types_pred = decode(
                        word,
                        char,
                        pos,
                        mask=masks,
                        length=lengths,
                        leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)
                    word = word.data.cpu().numpy()
                    pos = pos.data.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.data.cpu().numpy()
                    types = types.data.cpu().numpy()

                    pred_writer.write(word,
                                      pos,
                                      heads_pred,
                                      types_pred,
                                      lengths,
                                      symbolic_root=True)
                    gold_writer.write(word,
                                      pos,
                                      heads,
                                      types,
                                      lengths,
                                      symbolic_root=True)

                    stats, stats_nopunc, stats_root, num_inst = parser_bpe.eval(
                        word,
                        pos,
                        heads_pred,
                        types_pred,
                        heads,
                        types,
                        word_alphabet,
                        pos_alphabet,
                        lengths,
                        punct_set=punct_set,
                        symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    test_ucorrect += ucorr
                    test_lcorrect += lcorr
                    test_total += total
                    test_ucomlpete_match += ucm
                    test_lcomplete_match += lcm

                    test_ucorrect_nopunc += ucorr_nopunc
                    test_lcorrect_nopunc += lcorr_nopunc
                    test_total_nopunc += total_nopunc
                    test_ucomlpete_match_nopunc += ucm_nopunc
                    test_lcomplete_match_nopunc += lcm_nopunc

                    test_root_correct += corr_root
                    test_total_root += total_root

                    test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                # = generate_optimizer(opt, lr, network.named_parameters())
                optim = generate_bert_optimizer(opt, lr, network)
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        if test_path:
            print(
                'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (test_ucorrect, test_lcorrect, test_total, test_ucorrect *
                   100 / test_total, test_lcorrect * 100 / test_total,
                   test_ucomlpete_match * 100 / test_total_inst,
                   test_lcomplete_match * 100 / test_total_inst, best_epoch))
            print(
                'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                %
                (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                 test_ucorrect_nopunc * 100 / test_total_nopunc,
                 test_lcorrect_nopunc * 100 / test_total_nopunc,
                 test_ucomlpete_match_nopunc * 100 / test_total_inst,
                 test_lcomplete_match_nopunc * 100 / test_total_inst,
                 best_epoch))
            print(
                'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (test_root_correct, test_total_root,
                   test_root_correct * 100 / test_total_root, best_epoch))
            print(
                '============================================================================================================================'
            )

        if decay == max_decay:
            break

    def save_result():
        result_path = model_name + '.result.txt'
        best_dev_Punc = 'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total,
            dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst,
            dev_lcomplete_match * 100 / dev_total_inst, best_epoch)
        best_dev_noPunc = 'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc,
            dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
            dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch)
        best_dev_Root = 'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root,
            dev_root_correct * 100 / dev_total_root, best_epoch)
        f = open(result_path, 'w')
        f.write(str(best_dev_Punc.encode('utf-8')) + '\n')
        f.write(str(best_dev_noPunc.encode('utf-8')) + '\n')
        f.write(str(best_dev_Root.encode('utf-8')))
        f.close()

    save_result()
Ejemplo n.º 11
0
def biaffine(model_path, model_name, pre_model_path, pre_model_name, use_gpu, logger, args):
    alphabet_path = os.path.join(pre_model_path, 'alphabets/')
    logger.info("Alphabet Path: %s" % alphabet_path)
    pre_model_name = os.path.join(pre_model_path, pre_model_name)
    model_name = os.path.join(model_path, model_name)

    # Load pre-created alphabets
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path, None, data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info('use gpu: %s' % (use_gpu))

    if args.test_lang:
        extra_embed = args.embed_dir + ("wiki.multi.%s.vec" % args.test_lang)
        extra_word_dict, _ = load_embedding_dict('word2vec', extra_embed)
        test_path = args.data_dir + args.test_lang + '_test.conllu'
        extra_embeds_arr = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
    else:
        extra_embeds_arr = []
        for language in args.langs:
            extra_embed = args.embed_dir + ("wiki.multi.%s.vec" % language)
            extra_word_dict, _ = load_embedding_dict('word2vec', extra_embed)

            test_path = args.data_dir + language + '_train.conllu'
            embeds_arr1 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            test_path = args.data_dir + language + '_dev.conllu'
            embeds_arr2 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            test_path = args.data_dir + language + '_test.conllu'
            embeds_arr3 = augment_with_extra_embedding(word_alphabet, extra_word_dict, test_path, logger)
            extra_embeds_arr.extend(embeds_arr1 + embeds_arr2 + embeds_arr3)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading model ------------------------------------- #

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = pre_model_name + '.arg.json'
    margs, kwargs = load_model_arguments_from_json()
    network = BiRecurrentConvBiAffine(use_gpu=use_gpu, *margs, **kwargs)
    network.load_state_dict(torch.load(pre_model_name))
    args.use_bert = kwargs.get('use_bert', False)

    #
    augment_network_embed(word_alphabet.size(), network, extra_embeds_arr)

    network.eval()
    logger.info('model: %s' % pre_model_name)

    # Freeze the network
    for p in network.parameters():
        p.requires_grad = False

    nclass = args.nclass
    classifier = nn.Sequential(
        nn.Linear(network.encoder.output_dim, 512),
        nn.Linear(512, nclass)
    )

    if use_gpu:
        network.cuda()
        classifier.cuda()
    else:
        network.cpu()
        classifier.cpu()

    batch_size = args.batch_size

    # ===== the reading
    def _read_one(path, is_train=False, max_size=None):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." % (path, lang_id))
        one_data = conllx_data.read_data_to_variable(path, word_alphabet, char_alphabet, pos_alphabet,
                                                     type_alphabet, use_gpu=use_gpu, volatile=(not is_train),
                                                     use_bert=args.use_bert, symbolic_root=True, lang_id=lang_id,
                                                     max_size=max_size)
        return one_data

    def compute_accuracy(data, lang_idx):
        total_corr, total = 0, 0
        classifier.eval()
        with torch.no_grad():
            for batch in conllx_data.iterate_batch_variable(data, batch_size):
                word, char, pos, _, _, masks, lengths, bert_inputs = batch
                if use_gpu:
                    word = word.cuda()
                    char = char.cuda()
                    pos = pos.cuda()
                    masks = masks.cuda()
                    lengths = lengths.cuda()
                    if bert_inputs[0] is not None:
                        bert_inputs[0] = bert_inputs[0].cuda()
                        bert_inputs[1] = bert_inputs[1].cuda()
                        bert_inputs[2] = bert_inputs[2].cuda()

                output = network.forward(word, char, pos, input_bert=bert_inputs,
                                         mask=masks, length=lengths, hx=None)
                output = output['output'].detach()

                if args.train_level == 'word':
                    output = classifier(output)
                    output = output.contiguous().view(-1, output.size(2))
                else:
                    output = torch.mean(output, dim=1)
                    output = classifier(output)

                preds = output.max(1)[1].cpu()
                labels = torch.LongTensor([lang_idx])
                labels = labels.expand(*preds.size())
                n_correct = preds.eq(labels).sum().item()
                total_corr += n_correct
                total += output.size(0)

            return {'total_corr': total_corr, 'total': total}

    if args.test_lang:
        classifier.load_state_dict(torch.load(model_name))
        path = args.data_dir + args.test_lang + '_train.conllu'
        test_data = _read_one(path)

        # TODO: fixed indexing is not GOOD
        lang_idx = 0 if args.test_lang == args.src_lang else 1
        result = compute_accuracy(test_data, lang_idx)
        accuracy = (result['total_corr'] * 100.0) / result['total']
        logger.info('[Classifier performance] Language: %s || accuracy: %.2f%%' % (args.test_lang, accuracy))

    else:
        # if output directory doesn't exist, create it
        if not os.path.exists(args.model_path):
            os.makedirs(args.model_path)

        # --------------------- Loading data -------------------------------------- #
        train_data = dict()
        dev_data = dict()
        test_data = dict()
        num_data = dict()
        lang_ids = dict()
        reverse_lang_ids = dict()

        # loading language data
        for language in args.langs:
            lang_ids[language] = len(lang_ids)
            reverse_lang_ids[lang_ids[language]] = language

            train_path = args.data_dir + language + '_train.conllu'
            # Utilize at most 10000 examples
            tmp_data = _read_one(train_path, max_size=10000)
            num_data[language] = sum(tmp_data[1])
            train_data[language] = tmp_data

            dev_path = args.data_dir + language + '_dev.conllu'
            tmp_data = _read_one(dev_path)
            dev_data[language] = tmp_data

            test_path = args.data_dir + language + '_test.conllu'
            tmp_data = _read_one(test_path)
            test_data[language] = tmp_data

        # ------------------------------------------------------------------------- #

        optim = torch.optim.Adam(classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        def compute_loss(lang_name, land_idx):
            word, char, pos, _, _, masks, lengths, bert_inputs = conllx_data.get_batch_variable(train_data[lang_name],
                                                                                                batch_size,
                                                                                                unk_replace=0.5)

            if use_gpu:
                word = word.cuda()
                char = char.cuda()
                pos = pos.cuda()
                masks = masks.cuda()
                lengths = lengths.cuda()
                if bert_inputs[0] is not None:
                    bert_inputs[0] = bert_inputs[0].cuda()
                    bert_inputs[1] = bert_inputs[1].cuda()
                    bert_inputs[2] = bert_inputs[2].cuda()

            output = network.forward(word, char, pos, input_bert=bert_inputs,
                                     mask=masks, length=lengths, hx=None)
            output = output['output'].detach()

            if args.train_level == 'word':
                output = classifier(output)
                output = output.contiguous().view(-1, output.size(2))
            else:
                output = torch.mean(output, dim=1)
                output = classifier(output)

            labels = torch.empty(output.size(0)).fill_(land_idx).type_as(output).long()
            loss = criterion(output, labels)
            return loss

        # ---------------------- Form the mini-batches -------------------------- #
        num_batches = 0
        batch_lang_labels = []
        for lang in args.langs:
            nbatches = num_data[lang] // batch_size + 1
            batch_lang_labels.extend([lang] * nbatches)
            num_batches += nbatches

        assert len(batch_lang_labels) == num_batches
        # ------------------------------------------------------------------------- #

        best_dev_accuracy = 0
        patience = 0
        for epoch in range(1, args.num_epochs + 1):
            # shuffling the data
            lang_in_batch = copy.copy(batch_lang_labels)
            random.shuffle(lang_in_batch)

            classifier.train()
            for batch in range(1, num_batches + 1):
                lang_name = lang_in_batch[batch - 1]
                lang_id = lang_ids.get(lang_name)

                loss = compute_loss(lang_name, lang_id)
                loss.backward()
                optim.step()

            # Validation
            avg_acc = dict()
            for dev_lang in dev_data.keys():
                lang_idx = lang_ids.get(dev_lang)
                result = compute_accuracy(dev_data[dev_lang], lang_idx)
                accuracy = (result['total_corr'] * 100.0) / result['total']
                avg_acc[dev_lang] = accuracy

            acc = ', '.join('%s: %.2f' % (key, val) for (key, val) in avg_acc.items())
            logger.info('Epoch: %d, Performance[%s]' % (epoch, acc))

            avg_acc = sum(avg_acc.values()) / len(avg_acc)
            if best_dev_accuracy < avg_acc:
                best_dev_accuracy = avg_acc
                patience = 0
                state_dict = classifier.state_dict()
                torch.save(state_dict, model_name)
            else:
                patience += 1

            if patience >= 5:
                break

        # Testing
        logger.info('Testing model %s' % pre_model_name)
        total_corr, total = 0, 0
        for test_lang in UD_languages:
            if test_lang in test_data:
                lang_idx = lang_ids.get(test_lang)
                result = compute_accuracy(test_data[test_lang], lang_idx)
                accuracy = (result['total_corr'] * 100.0) / result['total']
                print('[LANG]: %s, [ACC]: %.2f' % (test_lang.upper(), accuracy))
                total_corr += result['total_corr']
                total += result['total']
        print('[Avg. Performance]: %.2f' % ((total_corr * 100.0) / total))
Ejemplo n.º 12
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger,
             args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets(
        alphabet_path,
        None,
        pos_embedding,
        data_paths=[None, None],
        max_vocabulary_size=50000,
        embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    display_inst = args.display

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' %
                (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(
        test_path,
        word_alphabet,
        char_alphabet,
        pos_alphabet,
        type_alphabet,
        pos_embedding,
        use_gpu=use_gpu,
        volatile=True,
        prior_order=prior_order,
        is_test=True)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    # kwargs???�로??embedidng 추�?
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)

    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx, (word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    word_table = get_embedding_table()
    kwargs['embedd_word'] = word_table
    args[1] = len(word_dict)  # word_dim
    network = StackPtrNet(*args, **kwargs)
    # word_embedidng?� ??불러?�기
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({
        k: v
        for k, v in pretrained_dict.items() if k != 'word_embedd.weight'
    })

    network.load_state_dict(model_dict)

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/tmp/inference.txt')
    else:
        pred_writer.start(model_path + '/tmp/inference_ordered_temp.txt')
    sent = 0
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(
            data_test, 1, pos_embedding, type='dev'):
        sys.stdout.write('%d, ' % sent)
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks, lengths = input_encoder
        stacked_heads, children, siblings, stacked_types, skip_connect, mask_d, lengths_d = input_decoder
        heads_pred, types_pred, children_pred, stacked_types_pred = network.decode(
            word,
            char,
            pos,
            mask=masks,
            length=lengths,
            beam=beam,
            ordered=ordered,
            leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.write(word,
                          pos,
                          heads_pred,
                          types_pred,
                          lengths,
                          symbolic_root=True)
    pred_writer.close()
Ejemplo n.º 13
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=1000,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder.')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             default='mst',
                             help='decoding algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    # args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    #
    args_parser.add_argument('--no_word',
                             action='store_true',
                             help='do not use word embedding.')
    #
    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             action='store_true',
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    # Tansformer encoder
    args_parser.add_argument('--no_CoRNN',
                             action='store_true',
                             help='do not use context RNN.')
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--multi_head_attn',
                             action='store_true',
                             help='use multi-head-attention.')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    # - positional
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=50,
                             help='Dimension of Position embeddings.')
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    #
    args = args_parser.parse_args()

    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    #
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    # todo(warn): exactly same for loading vocabs
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)

    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # ===== the reading
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_data.read_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=(not is_train),
            symbolic_root=True,
            lang_id=lang_id,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    num_data = sum(data_train[1])

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    # =====

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(
            word_dim,
            num_words,
            char_dim,
            num_chars,
            pos_dim,
            num_pos,
            num_filters,
            window,
            mode,
            hidden_size,
            num_layers,
            num_types,
            arc_space,
            type_space,
            embedd_word=word_table,
            embedd_char=char_table,
            p_in=p_in,
            p_out=p_out,
            p_rnn=p_rnn,
            biaffine=True,
            pos=use_pos,
            char=use_char,
            train_position=args.train_position,
            use_con_rnn=(not args.no_CoRNN),
            trans_hid_size=args.trans_hid_size,
            d_k=args.d_k,
            d_v=args.d_v,
            multi_head_attn=args.multi_head_attn,
            num_head=args.num_head,
            enc_use_neg_dist=args.enc_use_neg_dist,
            enc_clip_dist=args.enc_clip_dist,
            position_dim=args.position_dim,
            max_sent_length=max_sent_length,
            use_gpu=use_gpu,
            no_word=args.no_word)

    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, num_layers, num_types,
            arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'train_position': args.train_position,
            'use_con_rnn': (not args.no_CoRNN),
            'trans_hid_size': args.trans_hid_size,
            'd_k': args.d_k,
            'd_v': args.d_v,
            'multi_head_attn': args.multi_head_attn,
            'num_head': args.num_head,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'enc_clip_dist': args.enc_clip_dist,
            'position_dim': args.position_dim,
            'max_sent_length': max_sent_length,
            'no_word': args.no_word
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" %
        (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info(
        "train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)"
        % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule
    warmup_factor = (lr + 0.) / num_batches

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))
    #

    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate
            #
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word,
                                               char,
                                               pos,
                                               heads,
                                               types,
                                               mask=masks,
                                               length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_inst = word.size(
                0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (
                    batch, num_batches, train_err / train_total, train_err_arc
                    / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print(
            'train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' %
            (num_batches, train_err / train_total, train_err_arc / train_total,
             train_err_type / train_total, time.time() - start_time))

        ################################################################################################
        if epoch % args.check_dev != 0:
            continue

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(
                    data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())

                if decoding == 'greedy':
                    decode = network.decode
                elif decoding == 'mst':
                    decode = network.decode_mst
                else:
                    raise ValueError('Unknown decoding algorithm: %s' %
                                     decoding)

                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        print(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (dev_root_correct, dev_total_root,
               dev_root_correct * 100 / dev_total_root, best_epoch))
        print(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        print(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        print(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
              (test_root_correct, test_total_root,
               test_root_correct * 100 / test_total_root, best_epoch))
        print(
            '============================================================================================================================'
        )

        if decay == max_decay:
            break
Ejemplo n.º 14
0
def stackptr(model_path, model_name, test_path, punct_set, use_gpu, logger, args):
    pos_embedding = args.pos_embedding
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_stacked_data.create_alphabets\
        (alphabet_path,None, pos_embedding,data_paths=[None, None], max_vocabulary_size=50000, embedd_dict=None)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    beam = args.beam
    ordered = args.ordered
    use_bert = args.bert
    bert_path = args.bert_path
    bert_feature_dim = args.bert_feature_dim
    if use_bert:
        etri_test_path = args.etri_test
    else:
        etri_test_path = None

    def load_model_arguments_from_json():
        arguments = json.load(open(arg_path, 'r'))
        return arguments['args'], arguments['kwargs']

    arg_path = model_name + '.arg.json'
    args, kwargs = load_model_arguments_from_json()

    prior_order = kwargs['prior_order']
    logger.info('use gpu: %s, beam: %d, order: %s (%s)' % (use_gpu, beam, prior_order, ordered))

    data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding,
                                                                  use_gpu=use_gpu, volatile=True, prior_order=prior_order, is_test=False,
                                                                  bert=use_bert, etri_path=etri_test_path)

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet, pos_embedding)

    logger.info('model: %s' % model_name)
    word_path = os.path.join(model_path, 'embedding.txt')
    word_dict, word_dim = utils.load_embedding_dict('NNLM', word_path)
    def get_embedding_table():
        table = np.empty([len(word_dict), word_dim])
        for idx,(word, embedding) in enumerate(word_dict.items()):
            try:
                table[idx, :] = embedding
            except:
                print(word)
        return torch.from_numpy(table)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in list(word_alphabet.items()):
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    # word_table = get_embedding_table()
    word_table = construct_word_embedding_table()
    # kwargs['embedd_word'] = word_table
    # args[1] = len(word_dict) # word_dim

    network = StackPtrNet(*args, **kwargs, bert=use_bert, bert_path=bert_path, bert_feature_dim=bert_feature_dim)
    network.load_state_dict(torch.load(model_name))
    """
    model_dict = network.state_dict()
    pretrained_dict = torch.load(model_name)
    model_dict.update({k:v for k,v in list(pretrained_dict.items())
        if k != 'word_embedd.weight'})
    
    network.load_state_dict(model_dict)
    """

    if use_gpu:
        network.cuda()
    else:
        network.cpu()

    network.eval()

    if not ordered:
        pred_writer.start(model_path + '/inference.txt')
    else:
        pred_writer.start(model_path + '/RL_B[test].txt')
    sent = 1

    dev_ucorr_nopunc = 0.0
    dev_lcorr_nopunc = 0.0
    dev_total_nopunc = 0
    dev_ucomlpete_nopunc = 0.0
    dev_lcomplete_nopunc = 0.0
    dev_total_inst = 0.0
    sys.stdout.write('Start!\n')
    start_time = time.time()
    for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, 1, pos_embedding, type='dev', bert=use_bert):
        if sent % 100 == 0:
            ####
            print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
                dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
                dev_lcomplete_nopunc * 100 / dev_total_inst))
            sys.stdout.write('[%d/%d]\n' %(sent, int(data_test[2][0])))
            ####
        sys.stdout.flush()
        sent += 1

        input_encoder, input_decoder = batch
        word, char, pos, heads, types, masks_e, lengths, word_bert = input_encoder
        stacked_heads, children, sibling, stacked_types, skip_connect, previous, nexts, masks_d, lengths_d = input_decoder
        heads_pred, types_pred, _, _ = network.decode(word, char, pos, previous, nexts, stacked_heads, mask_e=masks_e, mask_d=masks_d,
                                                              length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS, input_word_bert=word_bert)
        """
        stacked_heads = stacked_heads.data
        children = children.data
        stacked_types = stacked_types.data
        children_pred = torch.from_numpy(children_pred).long()
        stacked_types_pred = torch.from_numpy(stacked_types_pred).long()
        if use_gpu:
            children_pred = children_pred.cuda()
            stacked_types_pred = stacked_types_pred.cuda()
        """

        word = word.data.cpu().numpy()
        pos = pos.data.cpu().numpy()
        lengths = lengths.cpu().numpy()
        heads = heads.data.cpu().numpy()
        types = types.data.cpu().numpy()

        pred_writer.test_write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
###########
        stats, stats_nopunc, _, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                word_alphabet, pos_alphabet, lengths,
                                                                punct_set=punct_set, symbolic_root=True)

        ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
        dev_ucorr_nopunc += ucorr_nopunc
        dev_lcorr_nopunc += lcorr_nopunc
        dev_total_nopunc += total_nopunc
        dev_ucomlpete_nopunc += ucm_nopunc
        dev_lcomplete_nopunc += lcm_nopunc

        dev_total_inst += num_inst
    end_time = time.time()
################
    pred_writer.close()

    print('\nFINISHED!!\n', end_time - start_time)
    print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
        dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
        dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst,
        dev_lcomplete_nopunc * 100 / dev_total_inst))
Ejemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser(description='Tuning with bi-directional RNN-CNN')
    parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn', required=True)
    parser.add_argument('--cuda', action='store_true', help='using GPU')
    parser.add_argument('--num_epochs', type=int, default=1000, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Number of sentences in each batch')
    parser.add_argument('--hidden_size', type=int, default=128, help='Number of hidden units in RNN')
    parser.add_argument('--tag_space', type=int, default=0, help='Dimension of tag space')
    parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    parser.add_argument('--num_filters', type=int, default=30, help='Number of filters in CNN')
    parser.add_argument('--char_dim', type=int, default=30, help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate', type=float, default=0.1, help='Learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.1, help='Decay rate of learning rate')
    parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    parser.add_argument('--dropout', choices=['std', 'variational'], help='type of dropout', required=True)
    parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    logger = get_logger("POSTagger")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace

    embedding = args.embedding
    embedding_path = args.embedding_dict

    embedd_dict, embedd_dim = utils.load_embedding_dict(embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    type_alphabet = conllx_data.create_alphabets("data/alphabets/pos/", train_path, data_paths=[dev_path,test_path],
                                                 max_vocabulary_size=50000, embedd_dict=embedd_dict)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')

    data_train = conllx_data.read_data_to_tensor(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, device=device)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])
    num_labels = pos_alphabet.size()

    data_dev = conllx_data.read_data_to_tensor(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, device=device)
    data_test = conllx_data.read_data_to_tensor(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, device=device)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform_
    if args.dropout == 'std':
        network = BiRecurrentConv(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), num_filters, window, mode, hidden_size, num_layers, num_labels,
                                  tag_space=tag_space, embedd_word=word_table,  p_in=p_in, p_out=p_out, p_rnn=p_rnn, initializer=initializer)
    else:
        network = BiVarRecurrentConv(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), num_filters, window, mode, hidden_size, num_layers, num_labels,
                                     tag_space=tag_space, embedd_word=word_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn, initializer=initializer)

    network = network.to(device)

    lr = learning_rate
    # optim = Adam(network.parameters(), lr=lr, betas=(0.9, 0.9), weight_decay=gamma)
    optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d" % (mode, num_layers, hidden_size, num_filters, tag_space))
    logger.info("training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)" % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))

    num_batches = num_data / batch_size + 1
    dev_correct = 0.0
    best_epoch = 0
    test_correct = 0.0
    test_total = 0
    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): ' % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_corr = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, labels, _, _, masks, lengths = conllx_data.get_batch_tensor(data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss, corr, _ = network.loss(word, char, labels, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            loss.backward()
            optim.step()

            with torch.no_grad():
                num_tokens = masks.sum()
                train_err += loss * num_tokens
                train_corr += corr
                train_total += num_tokens

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, acc: %.2f%%, time left (estimated): %.2fs' % (batch, num_batches, train_err / train_total, train_corr * 100 / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, acc: %.2f%%, time: %.2fs' % (num_batches, train_err / train_total, train_corr * 100 / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            dev_corr = 0.0
            dev_total = 0
            for batch in conllx_data.iterate_batch_tensor(data_dev, batch_size):
                word, char, labels, _, _, masks, lengths = batch
                _, corr, preds = network.loss(word, char, labels, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                num_tokens = masks.sum()
                dev_corr += corr
                dev_total += num_tokens

            print('dev corr: %d, total: %d, acc: %.2f%%' % (dev_corr, dev_total, dev_corr * 100 / dev_total))

            if dev_correct < dev_corr:
                dev_correct = dev_corr
                best_epoch = epoch

                # evaluate on test data when better performance detected
                test_corr = 0.0
                test_total = 0
                for batch in conllx_data.iterate_batch_tensor(data_test, batch_size):
                    word, char, labels, _, _, masks, lengths = batch
                    _, corr, preds = network.loss(word, char, labels, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                    num_tokens = masks.sum()
                    test_corr += corr
                    test_total += num_tokens

                test_correct = test_corr
            print("best dev  corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % (dev_correct, dev_total, dev_correct * 100 / dev_total, best_epoch))
            print("best test corr: %d, total: %d, acc: %.2f%% (epoch: %d)" % (test_correct, test_total, test_correct * 100 / test_total, best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser(
        description='NER with bi-directional RNN-CNN')
    parser.add_argument('--config',
                        type=str,
                        help='config file',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--loss_type',
                        choices=['sentence', 'token'],
                        default='sentence',
                        help='loss type (default: sentence)')
    parser.add_argument('--optim',
                        choices=['sgd', 'adam'],
                        help='type of optimizer',
                        required=True)
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.1,
                        help='Learning rate')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.999995,
                        help='Decay rate of learning rate')
    parser.add_argument('--amsgrad', action='store_true', help='AMS Grad')
    parser.add_argument('--grad_clip',
                        type=float,
                        default=0,
                        help='max norm for gradient clip (default 0: no clip')
    parser.add_argument('--warmup_steps',
                        type=int,
                        default=0,
                        metavar='N',
                        help='number of steps to warm up (default: 0)')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0,
                        help='weight for l2 norm decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--train',
                        help='path for training file.',
                        required=True)
    parser.add_argument('--dev', help='path for dev file.', required=True)
    parser.add_argument('--test', help='path for test file.', required=True)
    parser.add_argument('--model_path',
                        help='path for saving model file.',
                        required=True)

    args = parser.parse_args()

    logger = get_logger("NER")

    args.cuda = torch.cuda.is_available()
    device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
    train_path = args.train
    dev_path = args.dev
    test_path = args.test

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    optim = args.optim
    learning_rate = args.learning_rate
    lr_decay = args.lr_decay
    amsgrad = args.amsgrad
    warmup_steps = args.warmup_steps
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip

    loss_ty_token = args.loss_type == 'token'
    unk_replace = args.unk_replace

    model_path = args.model_path
    model_name = os.path.join(model_path, 'model.pt')
    embedding = args.embedding
    embedding_path = args.embedding_dict

    print(args)

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets')
    word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet = conll03_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict,
        max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")

    data_train = conll03_data.read_bucketed_data(train_path, word_alphabet,
                                                 char_alphabet, pos_alphabet,
                                                 chunk_alphabet, ner_alphabet)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data(dev_path, word_alphabet, char_alphabet,
                                      pos_alphabet, chunk_alphabet,
                                      ner_alphabet)
    data_test = conll03_data.read_data(test_path, word_alphabet, char_alphabet,
                                       pos_alphabet, chunk_alphabet,
                                       ner_alphabet)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, pos_alphabet,
                           chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()

    logger.info("constructing network...")

    hyps = json.load(open(args.config, 'r'))
    json.dump(hyps,
              open(os.path.join(model_path, 'config.json'), 'w'),
              indent=2)
    dropout = hyps['dropout']
    crf = hyps['crf']
    bigram = hyps['bigram']
    assert embedd_dim == hyps['embedd_dim']
    char_dim = hyps['char_dim']
    mode = hyps['rnn_mode']
    hidden_size = hyps['hidden_size']
    out_features = hyps['out_features']
    num_layers = hyps['num_layers']
    p_in = hyps['p_in']
    p_out = hyps['p_out']
    p_rnn = hyps['p_rnn']
    activation = hyps['activation']

    if dropout == 'std':
        if crf:
            network = BiRecurrentConvCRF(embedd_dim,
                                         word_alphabet.size(),
                                         char_dim,
                                         char_alphabet.size(),
                                         mode,
                                         hidden_size,
                                         out_features,
                                         num_layers,
                                         num_labels,
                                         embedd_word=word_table,
                                         p_in=p_in,
                                         p_out=p_out,
                                         p_rnn=p_rnn,
                                         bigram=bigram,
                                         activation=activation)
        else:
            network = BiRecurrentConv(embedd_dim,
                                      word_alphabet.size(),
                                      char_dim,
                                      char_alphabet.size(),
                                      mode,
                                      hidden_size,
                                      out_features,
                                      num_layers,
                                      num_labels,
                                      embedd_word=word_table,
                                      p_in=p_in,
                                      p_out=p_out,
                                      p_rnn=p_rnn,
                                      activation=activation)
    elif dropout == 'variational':
        if crf:
            network = BiVarRecurrentConvCRF(embedd_dim,
                                            word_alphabet.size(),
                                            char_dim,
                                            char_alphabet.size(),
                                            mode,
                                            hidden_size,
                                            out_features,
                                            num_layers,
                                            num_labels,
                                            embedd_word=word_table,
                                            p_in=p_in,
                                            p_out=p_out,
                                            p_rnn=p_rnn,
                                            bigram=bigram,
                                            activation=activation)
        else:
            network = BiVarRecurrentConv(embedd_dim,
                                         word_alphabet.size(),
                                         char_dim,
                                         char_alphabet.size(),
                                         mode,
                                         hidden_size,
                                         out_features,
                                         num_layers,
                                         num_labels,
                                         embedd_word=word_table,
                                         p_in=p_in,
                                         p_out=p_out,
                                         p_rnn=p_rnn,
                                         activation=activation)
    else:
        raise ValueError('Unkown dropout type: {}'.format(dropout))

    network = network.to(device)

    optimizer, scheduler = get_optimizer(network.parameters(), optim,
                                         learning_rate, lr_decay, amsgrad,
                                         weight_decay, warmup_steps)
    model = "{}-CNN{}".format(mode, "-CRF" if crf else "")
    logger.info("Network: %s, num_layer=%d, hidden=%d, act=%s" %
                (model, num_layers, hidden_size, activation))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (weight_decay, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" %
                (dropout, p_in, p_out, p_rnn))
    print('# of Parameters: %d' %
          (sum([param.numel() for param in network.parameters()])))

    best_f1 = 0.0
    best_acc = 0.0
    best_precision = 0.0
    best_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    patient = 0
    num_batches = num_data // batch_size + 1
    result_path = os.path.join(model_path, 'tmp')
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = 0.
        num_insts = 0
        num_words = 0
        num_back = 0
        network.train()
        lr = scheduler.get_lr()[0]
        print('Epoch %d (%s, lr=%.6f, lr decay=%.6f, amsgrad=%s, l2=%.1e): ' %
              (epoch, optim, lr, lr_decay, amsgrad, weight_decay))
        if args.cuda:
            torch.cuda.empty_cache()
        gc.collect()
        for step, data in enumerate(
                iterate_data(data_train,
                             batch_size,
                             bucketed=True,
                             unk_replace=unk_replace,
                             shuffle=True)):
            optimizer.zero_grad()
            words = data['WORD'].to(device)
            chars = data['CHAR'].to(device)
            labels = data['NER'].to(device)
            masks = data['MASK'].to(device)

            nbatch = words.size(0)
            nwords = masks.sum().item()

            loss_total = network.loss(words, chars, labels, mask=masks).sum()
            if loss_ty_token:
                loss = loss_total.div(nwords)
            else:
                loss = loss_total.div(nbatch)
            loss.backward()
            if grad_clip > 0:
                clip_grad_norm_(network.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()

            with torch.no_grad():
                num_insts += nbatch
                num_words += nwords
                train_loss += loss_total.item()

            # update log
            if step % 100 == 0:
                torch.cuda.empty_cache()
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                curr_lr = scheduler.get_lr()[0]
                log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % (
                    step, num_batches, 100. * step / num_batches, curr_lr,
                    train_loss / num_insts, train_loss / num_words)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('total: %d (%d), loss: %.4f (%.4f), time: %.2fs' %
              (num_insts, num_words, train_loss / num_insts,
               train_loss / num_words, time.time() - start_time))
        print('-' * 100)

        # evaluate performance on dev data
        with torch.no_grad():
            outfile = os.path.join(result_path, 'pred_dev%d' % epoch)
            scorefile = os.path.join(result_path, "score_dev%d" % epoch)
            acc, precision, recall, f1 = eval(data_dev, network, writer,
                                              outfile, scorefile, device)
            print(
                'Dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))
            if best_f1 < f1:
                torch.save(network.state_dict(), model_name)
                best_f1 = f1
                best_acc = acc
                best_precision = precision
                best_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                outfile = os.path.join(result_path, 'pred_test%d' % epoch)
                scorefile = os.path.join(result_path, "score_test%d" % epoch)
                test_acc, test_precision, test_recall, test_f1 = eval(
                    data_test, network, writer, outfile, scorefile, device)
                print(
                    'test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                    % (test_acc, test_precision, test_recall, test_f1))
                patient = 0
            else:
                patient += 1
            print('-' * 100)

            print(
                "Best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d (%d))"
                % (best_acc, best_precision, best_recall, best_f1, best_epoch,
                   patient))
            print(
                "Best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d (%d))"
                % (test_acc, test_precision, test_recall, test_f1, best_epoch,
                   patient))
            print('=' * 100)

        if patient > 4:
            logger.info('reset optimizer momentums')
            scheduler.reset_state()
            patient = 0
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with Multitask bi-directional RNN-CNN-CRF')
    parser.add_argument('--config',
                        help='Config file (Python file format)',
                        default="config_multitask.py")
    parser.add_argument('--grid', help='Grid Search Options', default="{}")
    args = parser.parse_args()
    logger = get_logger("Multi-Task")
    use_gpu = torch.cuda.is_available()

    # Config Tensorboard Writer
    log_writer = SummaryWriter()

    # Load from config file
    spec = importlib.util.spec_from_file_location("config", args.config)
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)
    config = config_module.entries

    # Load options from grid search
    options = eval(args.grid)
    for k, v in options.items():
        if isinstance(v, six.string_types):
            cmd = "%s = \"%s\"" % (k, v)
        else:
            cmd = "%s = %s" % (k, v)
            log_writer.add_scalar(k, v, 1)
        exec(cmd)

    # Load embedding dict
    embedding = config.embedding.embedding_type
    embedding_path = config.embedding.embedding_dict
    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    # Collect data path
    data_dir = config.data.data_dir
    data_names = config.data.data_names
    train_paths = [
        os.path.join(data_dir, data_name, "train.tsv")
        for data_name in data_names
    ]
    dev_paths = [
        os.path.join(data_dir, data_name, "devel.tsv")
        for data_name in data_names
    ]
    test_paths = [
        os.path.join(data_dir, data_name, "test.tsv")
        for data_name in data_names
    ]

    # Create alphabets
    logger.info("Creating Alphabets")
    if not os.path.exists('tmp'):
        os.mkdir('tmp')
    word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet, ner_alphabet_task, label_reflect  = \
            bionlp_data.create_alphabets(os.path.join(Path(data_dir).abspath(), "alphabets", "_".join(data_names)), train_paths,
                    data_paths=dev_paths + test_paths, use_cache=True,
                    embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())
    logger.info(
        "NER Alphabet Size per Task: %s",
        str([task_alphabet.size() for task_alphabet in ner_alphabet_task]))

    #task_reflects = torch.LongTensor(reverse_reflect(label_reflect, ner_alphabet.size()))
    #if use_gpu:
    #    task_reflects = task_reflects.cuda()

    if embedding == 'elmo':
        logger.info("Loading ELMo Embedder")
        ee = ElmoEmbedder(options_file=config.embedding.elmo_option,
                          weight_file=config.embedding.elmo_weight,
                          cuda_device=config.embedding.elmo_cuda)
    else:
        ee = None

    logger.info("Reading Data")

    # Prepare dataset
    data_trains = [
        bionlp_data.read_data_to_variable(train_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          elmo_ee=ee)
        for task_id, train_path in enumerate(train_paths)
    ]
    num_data = [sum(data_train[1]) for data_train in data_trains]
    num_labels = ner_alphabet.size()
    num_labels_task = [task_item.size() for task_item in ner_alphabet_task]

    data_devs = [
        bionlp_data.read_data_to_variable(dev_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          volatile=True,
                                          elmo_ee=ee)
        for task_id, dev_path in enumerate(dev_paths)
    ]

    data_tests = [
        bionlp_data.read_data_to_variable(test_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          volatile=True,
                                          elmo_ee=ee)
        for task_id, test_path in enumerate(test_paths)
    ]

    writer = BioNLPWriter(word_alphabet, char_alphabet, pos_alphabet,
                          chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[bionlp_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if not embedd_dict == None and word in embedd_dict:
                embedding = embedd_dict[word]
            elif not embedd_dict == None and word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    # Construct network
    window = 3
    num_layers = 1
    mode = config.rnn.mode
    hidden_size = config.rnn.hidden_size
    char_dim = config.rnn.char_dim
    num_filters = config.rnn.num_filters
    tag_space = config.rnn.tag_space
    bigram = config.rnn.bigram
    attention_mode = config.rnn.attention
    if config.rnn.dropout == 'std':
        network = MultiTaskBiRecurrentCRF(
            len(data_trains),
            embedd_dim,
            word_alphabet.size(),
            char_dim,
            char_alphabet.size(),
            num_filters,
            window,
            mode,
            hidden_size,
            num_layers,
            num_labels,
            num_labels_task=num_labels_task,
            tag_space=tag_space,
            embedd_word=word_table,
            p_in=config.rnn.p,
            p_rnn=config.rnn.p,
            bigram=bigram,
            elmo=(embedding == 'elmo'),
            attention_mode=attention_mode,
            adv_loss_coef=config.multitask.adv_loss_coef,
            diff_loss_coef=config.multitask.diff_loss_coef,
            char_level_rnn=config.rnn.char_level_rnn)
    else:
        raise NotImplementedError

    if use_gpu:
        network.cuda()

    # Prepare training
    unk_replace = config.embedding.unk_replace
    num_epochs = config.training.num_epochs
    batch_size = config.training.batch_size
    lr = config.training.learning_rate
    momentum = config.training.momentum
    alpha = config.training.alpha
    lr_decay = config.training.lr_decay
    schedule = config.training.schedule
    gamma = config.training.gamma

    # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
    optim = RMSprop(network.parameters(),
                    lr=lr,
                    alpha=alpha,
                    momentum=momentum,
                    weight_decay=gamma)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s"
        % (mode, num_layers, hidden_size, num_filters, tag_space,
           'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %s, batch: %d, dropout: %.2f, unk replace: %.2f)"
        % (gamma, num_data, batch_size, config.rnn.p, unk_replace))

    num_batches = [x // batch_size + 1 for x in num_data]
    dev_f1 = [0.0 for x in num_data]
    dev_acc = [0.0 for x in num_data]
    dev_precision = [0.0 for x in num_data]
    dev_recall = [0.0 for x in num_data]
    test_f1 = [0.0 for x in num_data]
    test_acc = [0.0 for x in num_data]
    test_precision = [0.0 for x in num_data]
    test_recall = [0.0 for x in num_data]
    best_epoch = [0 for x in num_data]

    # Training procedure
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, config.rnn.dropout, lr, lr_decay, schedule))
        train_err = 0.
        train_total = 0.

        # Gradient decent on training data
        start_time = time.time()
        num_back = 0
        network.train()
        batch_count = 0
        for batch in range(1, 2 * num_batches[0] + 1):
            r = random.random()
            task_id = 0 if r <= 0.5 else random.randint(1, len(num_data) - 1)
            #if batch > num_batches[task_id]:
            #    batch = batch % num_batches[task_id] + 1
            batch_count += 1
            word, char, _, _, labels, masks, lengths, elmo_embedding = bionlp_data.get_batch_variable(
                data_trains[task_id], batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss, task_loss, adv_loss, diff_loss = network.loss(
                task_id,
                word,
                char,
                labels,
                mask=masks,
                elmo_word=elmo_embedding)
            #log_writer.add_scalars(
            #        'train_loss_task' + str(task_id),
            #        {'all_loss': loss, 'task_loss': task_loss, 'adv_loss': adv_loss, 'diff_loss': diff_loss},
            #        (epoch - 1) * (num_batches[task_id] + 1) + batch
            #)
            #log_writer.add_scalars(
            #        'train_loss_overview',
            #        {'all_loss': loss, 'task_loss': task_loss, 'adv_loss': adv_loss, 'diff_loss': diff_loss},
            #        (epoch - 1) * (sum(num_batches) + 1) + batch_count
            #)
            loss.backward()
            clip_grad_norm(network.parameters(), 5.0)
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (2 * num_batches[0] - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, 2 * num_batches[0], train_err / train_total,
                    time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (2 * num_batches[0], train_err / train_total,
               time.time() - start_time))

        # Evaluate performance on dev data
        network.eval()
        for task_id in range(len(num_batches)):
            tmp_filename = 'tmp/%s_dev%d%d' % (str(uid), epoch, task_id)
            writer.start(tmp_filename)

            for batch in bionlp_data.iterate_batch_variable(
                    data_devs[task_id], batch_size):
                word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
                preds, _ = network.decode(
                    task_id,
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS,
                    elmo_word=elmo_embedding)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            acc, precision, recall, f1 = evaluate(tmp_filename)
            log_writer.add_scalars(
                'dev_task' + str(task_id), {
                    'accuracy': acc,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1
                }, epoch)
            print(
                'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))

            if dev_f1[task_id] < f1:
                dev_f1[task_id] = f1
                dev_acc[task_id] = acc
                dev_precision[task_id] = precision
                dev_recall[task_id] = recall
                best_epoch[task_id] = epoch

                # Evaluate on test data when better performance detected
                tmp_filename = 'tmp/%s_test%d%d' % (str(uid), epoch, task_id)
                writer.start(tmp_filename)

                for batch in bionlp_data.iterate_batch_variable(
                        data_tests[task_id], batch_size):
                    word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
                    preds, _ = network.decode(
                        task_id,
                        word,
                        char,
                        target=labels,
                        mask=masks,
                        leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS,
                        elmo_word=elmo_embedding)
                    writer.write(word.data.cpu().numpy(),
                                 pos.data.cpu().numpy(),
                                 chunk.data.cpu().numpy(),
                                 preds.cpu().numpy(),
                                 labels.data.cpu().numpy(),
                                 lengths.cpu().numpy())
                writer.close()
                test_acc[task_id], test_precision[task_id], test_recall[
                    task_id], test_f1[task_id] = evaluate(tmp_filename)
                log_writer.add_scalars(
                    'test_task' + str(task_id), {
                        'accuracy': test_acc[task_id],
                        'precision': test_precision[task_id],
                        'recall': test_recall[task_id],
                        'f1': test_f1[task_id]
                    }, epoch)

            print(
                "================================================================================"
            )
            print("dataset: %s" % data_names[task_id])
            print(
                "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (dev_acc[task_id], dev_precision[task_id],
                   dev_recall[task_id], dev_f1[task_id], best_epoch[task_id]))
            print(
                "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                %
                (test_acc[task_id], test_precision[task_id],
                 test_recall[task_id], test_f1[task_id], best_epoch[task_id]))
            print(
                "================================================================================\n"
            )

            if epoch % schedule == 0:
                # lr = learning_rate / (1.0 + epoch * lr_decay)
                # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
                lr = lr * lr_decay
                optim.param_groups[0]['lr'] = lr

    # writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Ejemplo n.º 18
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with stack pointer parser')
    args_parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'], help='architecture of rnn', required=True)
    args_parser.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size', type=int, default=256, help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--type_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers', type=int, default=1, help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers', type=int, default=1, help='Number of layers of decoder RNN')
    args_parser.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
    args_parser.add_argument('--pos', action='store_true', help='use part-of-speech embedding.')
    args_parser.add_argument('--char', action='store_true', help='use character embedding and CNN.')
    args_parser.add_argument('--lemma', action='store_true', help='use lemma embedding.')
    args_parser.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
    args_parser.add_argument('--lemma_dim', type=int, default=50, help='Dimension of Lemma embeddings')
    args_parser.add_argument('--opt', choices=['adam', 'sgd', 'adamax'], help='optimization algorithm')
    args_parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.75, help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay', type=int, default=9, help='Number of decays before stop')
    args_parser.add_argument('--double_schedule_decay', type=int, default=5, help='Number of decays to double schedule')
    args_parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
    args_parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    args_parser.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage', type=float, default=0.0, help='weight for coverage loss')
    args_parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    args_parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth', type=float, default=1.0, help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect', action='store_true', help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar', action='store_true', help='use grand parent.')
    args_parser.add_argument('--sibling', action='store_true', help='use sibling.')
    args_parser.add_argument('--prior_order', choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'], help='prior order of children.', required=False)
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--beam', type=int, default=1, help='Beam size for decoding')
    args_parser.add_argument('--word_embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True)
    args_parser.add_argument('--word_path', help='path for word embedding dict')
    args_parser.add_argument('--freeze', action='store_true', help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters', required=True)
    args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--test2')
    args_parser.add_argument('--model_path', help='path for saving model file.', required=True)
    args_parser.add_argument('--model_name', help='name for saving model file.', required=True)

    args = args_parser.parse_args()

    logger = get_logger("PtrParser")
    print('SEMANTIC DEPENDENCY PARSER with POINTER NETWORKS')	
    print('CUDA?', torch.cuda.is_available())

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    test_path2 = args.test2
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim

    use_lemma = args.lemma
    lemma_dim = args.lemma_dim

    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, lemma_alphabet = conllx_stacked_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path, test_path2],
                                                                                                     max_vocabulary_size=50000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()
    num_lemmas = lemma_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)
    logger.info("LEMMA Alphabet Size: %d" % num_lemmas)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_stacked_data.read_stacked_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, lemma_alphabet, use_gpu=use_gpu, prior_order=prior_order)
    num_data = sum(data_train[1])

    data_dev = conllx_stacked_data.read_stacked_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, lemma_alphabet, use_gpu=use_gpu, volatile=True, prior_order=prior_order)
    data_test = conllx_stacked_data.read_stacked_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, lemma_alphabet, use_gpu=use_gpu, volatile=True, prior_order=prior_order)
    data_test2 = conllx_stacked_data.read_stacked_data_to_variable(test_path2, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, lemma_alphabet, use_gpu=use_gpu, volatile=True, prior_order=prior_order)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        #logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
	print(torch.__version__)
        return torch.from_numpy(table)
    
    def construct_lemma_embedding_table():
        scale = np.sqrt(3.0 / lemma_dim)
        table = np.empty([lemma_alphabet.size(), lemma_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, lemma_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, lemma_dim]).astype(np.float32)
        oov = 0
        for lemma, index in lemma_alphabet.items():
            if lemma in word_dict:
                embedding = word_dict[lemma]
            elif lemma.lower() in word_dict:
                embedding = word_dict[lemma.lower()]
            else:
                embedding = np.zeros([1, lemma_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, lemma_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('LEMMA OOV: %d' % oov)
	print(torch.__version__)
        return torch.from_numpy(table)
    

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()
    lemma_table = construct_lemma_embedding_table() 

    window = 3
    network = NewStackPtrNet(word_dim, num_words, lemma_dim, num_lemmas, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                          mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
                          num_types, arc_space, type_space,
                          embedd_word=word_table, embedd_char=char_table, embedd_lemma=lemma_table, p_in=p_in, p_out=p_out, p_rnn=p_rnn,
                          biaffine=True, pos=use_pos, char=use_char, lemma=use_lemma, prior_order=prior_order,
                          skipConnect=skipConnect, grandPar=grandPar, sibling=sibling)
    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [word_dim, num_words, lemma_dim, num_lemmas, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                     mode, input_size_decoder, hidden_size, encoder_layers, decoder_layers,
                     num_types, arc_space, type_space]
        kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char, 'lemma': use_lemma, 'prior_order': prior_order,
                  'skipConnect': skipConnect, 'grandPar': grandPar, 'sibling': sibling}
        json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w'), indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
	print('CUDA IS AVAILABLE')
        network.cuda()
    else:
	print('CUDA IS NOT AVAILABLE', use_gpu)

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, lemma_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, lemma_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        elif opt == 'sgd':
            return SGD(params, lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
        elif opt == 'adamax':
            return Adamax(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    lemma_status = 'enabled' if use_lemma else 'disabled'
    logger.info("Embedding dim: word=%d (%s), lemma=%d (%s) char=%d (%s), pos=%d (%s)" % (word_dim, word_status, lemma_dim, lemma_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info("RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, encoder_layers, decoder_layers, input_size_decoder, hidden_size, arc_space, type_space))
    logger.info("train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)" % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' % (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d' % (skipConnect, beam))
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    #dev_ucorrect = 0.0
	
    dev_bestLF1 = 0.0
    dev_bestUF1 = 0.0
    dev_bestUprecision = 0.0
    dev_bestLprecision = 0.0
    dev_bestUrecall = 0.0
    dev_bestLrecall = 0.0			


    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    #test_ucomlpete_match = 0.0
    #test_lcomplete_match = 0.0

    #test_ucorrect_nopunc = 0.0
    #test_lcorrect_nopunc = 0.0
    #test_ucomlpete_match_nopunc = 0.0
    #test_lcomplete_match_nopunc = 0.0
    #test_root_correct = 0.0
    test_total_pred = 0
    test_total_gold = 0
    #test_total_nopunc = 0
    test_total_inst = 0
    #test_total_root = 0

    test_LF1 = 0.0
    test_UF1 = 0.0
    test_Uprecision = 0.0
    test_Lprecision = 0.0
    test_Urecall = 0.0
    test_Lrecall = 0.0


    test2_ucorrect = 0.0
    test2_lcorrect = 0.0
    test2_total_pred = 0
    test2_total_gold = 0
    test2_total_inst = 0

    test2_LF1 = 0.0
    test2_UF1 = 0.0
    test2_Uprecision = 0.0
    test2_Lprecision = 0.0
    test2_Urecall = 0.0
    test2_Lrecall = 0.0

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay
    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d (%d, %d))): ' % (
            epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay, max_decay, double_schedule_decay))

        train_err_cov = 0.
	train_err_arc = 0.
	train_err_type = 0.
	train_total = 0.


        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
	
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(data_train, batch_size, unk_replace=unk_replace)
            word, lemma, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, sibling, stacked_types, skip_connect, previous, next, masks_d, lengths_d = input_decoder


		
	    #print('HEADSSS', heads)


            optim.zero_grad()


	    loss_arc, \
            loss_type, \
            loss_cov, num = network.loss(word, lemma, char, pos, heads, stacked_heads, children, sibling, stacked_types, previous, next, label_smooth,
                                                            skip_connect=skip_connect, mask_e=masks_e, length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)


            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

	    train_err_arc += loss_arc.data[0] * num

	    train_err_type += loss_type.data[0] * num

	    train_err_cov += loss_cov.data[0] * num


	    train_total += num

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave



        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
	err_arc = train_err_arc / train_total

	err_type = train_err_type / train_total

	err_cov = train_err_cov / train_total

        err = err_arc + err_type + cov * err_cov
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, coverage: %.4f, time: %.2fs' % (
            num_batches, err, err_arc, err_type, err_cov, time.time() - start_time))




        print('======EVALUATING PERFORMANCE ON DEV======')
        # evaluate performance on dev data
        network.eval()
        #pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
	pred_filename = '%spred_dev%d' % (str(uid), epoch)
	pred_filename = os.path.join(model_path, pred_filename)
        pred_writer.start(pred_filename)
        #gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
	gold_filename = '%sgold_dev%d' % (str(uid), epoch)
	gold_filename = os.path.join(model_path, gold_filename)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total_gold = 0
	dev_total_pred = 0
        dev_total_inst = 0.0
	start_time_dev = time.time()
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_dev, batch_size):
            input_encoder, _ = batch
            word, lemma, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(word, lemma, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
	    lemma = lemma.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word, lemma, pos, heads_pred, types_pred, lengths, symbolic_root=True)
	    gold_writer.write(word, lemma, pos, heads, types, lengths, symbolic_root=True)
	    
	    

            #stats, stats_nopunc, stats_root, num_inst = parser.evalF1(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            #ucorr, lcorr, total, ucm, lcm = stats
            #ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            #corr_root, total_root = stats_root
	    ucorr, lcorr, total_gold, total_pred, num_inst = parser.evalF1(word, lemma, pos, heads_pred, types_pred, heads, types, word_alphabet, lemma_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total_gold += total_gold
	    dev_total_pred += total_pred


            dev_total_inst += num_inst

	end_time_dev = time.time()
	lasted_time_dev=end_time_dev-start_time_dev
        pred_writer.close()
        gold_writer.close()

	dev_Uprecision=0.
	dev_Lprecision=0.
	if dev_total_pred!=0: 
		dev_Uprecision=dev_ucorr * 100 / dev_total_pred
		dev_Lprecision=dev_lcorr * 100 / dev_total_pred
	dev_Urecall=dev_ucorr * 100 / dev_total_gold
	dev_Lrecall=dev_lcorr * 100 / dev_total_gold
	if dev_Uprecision ==0. and dev_Urecall==0.: 
		dev_UF1=0
	else:
		dev_UF1=2*(dev_Uprecision*dev_Urecall)/(dev_Uprecision+dev_Urecall)
	if dev_Lprecision ==0. and dev_Lrecall==0.: 
		dev_LF1=0
	else:
		dev_LF1=2*(dev_Lprecision*dev_Lrecall)/(dev_Lprecision+dev_Lrecall)

	print('CUR DEV %d: ucorr: %d, lcorr: %d, tot_gold: %d, tot_pred: %d, Uprec: %.2f%%, Urec: %.2f%%, Lprec: %.2f%%, Lrec: %.2f%%, UF1: %.2f%%, LF1: %.2f%%' % (
            epoch, dev_ucorr, dev_lcorr, dev_total_gold, dev_total_pred, dev_Uprecision, dev_Urecall, dev_Lprecision, dev_Lrecall, dev_UF1, dev_LF1))






        #if dev_lcorrect_nopunc < dev_lcorr_nopunc or (dev_lcorrect_nopunc == dev_lcorr_nopunc and dev_ucorrect_nopunc < dev_ucorr_nopunc):
	if dev_bestLF1 < dev_LF1:
            dev_bestLF1 = dev_LF1
    	    dev_bestUF1 = dev_UF1
    	    dev_bestUprecision = dev_Uprecision
    	    dev_bestLprecision = dev_Lprecision
    	    dev_bestUrecall = dev_Urecall
    	    dev_bestLrecall = dev_Lrecall

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

	    print('======EVALUATING PERFORMANCE ON TEST======')
            #pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
	    pred_filename = '%spred_test%d' % (str(uid), epoch)
	    pred_filename = os.path.join(model_path, pred_filename)
            pred_writer.start(pred_filename)
            #gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
	    gold_filename = '%sgold_test%d' % (str(uid), epoch)
	    gold_filename = os.path.join(model_path, gold_filename)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_total_pred = 0
	    test_total_gold = 0
	    test_total_inst = 0

	    start_time_test = time.time()
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test, batch_size):
                input_encoder, _ = batch
                word, lemma, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(word, lemma, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
		lemma = lemma.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

		

                pred_writer.write(word, lemma, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                gold_writer.write(word, lemma, pos, heads, types, lengths, symbolic_root=True)

                ucorr, lcorr, total_gold, total_pred, num_inst = parser.evalF1(word, lemma, pos, heads_pred, types_pred, heads, types, word_alphabet, lemma_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                
                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total_gold += total_gold
		test_total_pred += total_pred
                
                test_total_inst += num_inst

	    end_time_test = time.time()
	    lasted_time_test=end_time_test-start_time_test
            pred_writer.close()
            gold_writer.close()

	    test_Uprecision=0.
	    test_Lprecision=0.
	    if test_total_pred!=0:
		    test_Uprecision=test_ucorrect * 100 / test_total_pred
		    test_Lprecision=test_lcorrect * 100 / test_total_pred
	    test_Urecall=test_ucorrect * 100 / test_total_gold
	    test_Lrecall=test_lcorrect * 100 / test_total_gold
	    if test_Uprecision ==0. and test_Urecall==0.: 
		test_UF1=0
	    else: 
	    	test_UF1=2*(test_Uprecision*test_Urecall)/(test_Uprecision+test_Urecall)
	    if test_Lprecision ==0. and test_Lrecall==0.: 
		test_LF1=0
	    else:
	    	test_LF1=2*(test_Lprecision*test_Lrecall)/(test_Lprecision+test_Lrecall)


	    print('======EVALUATING PERFORMANCE ON TEST 2======')
            #pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
	    pred_filename2 = '%spred_test_two%d' % (str(uid), epoch)
	    pred_filename2 = os.path.join(model_path, pred_filename2)
            pred_writer.start(pred_filename2)
            #gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
	    gold_filename2 = '%sgold_test_two%d' % (str(uid), epoch)
	    gold_filename2 = os.path.join(model_path, gold_filename2)
            gold_writer.start(gold_filename2)

            test2_ucorrect = 0.0
            test2_lcorrect = 0.0
            test2_total_pred = 0
	    test2_total_gold = 0
	    test2_total_inst = 0

	    start_time_test2 = time.time()
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(data_test2, batch_size):
                input_encoder, _ = batch
                word, lemma, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(word, lemma, char, pos, mask=masks, length=lengths, beam=beam, leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
		lemma = lemma.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

		

                pred_writer.write(word, lemma, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                gold_writer.write(word, lemma, pos, heads, types, lengths, symbolic_root=True)

                ucorr, lcorr, total_gold, total_pred, num_inst = parser.evalF1(word, lemma, pos, heads_pred, types_pred, heads, types, word_alphabet, lemma_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                
                test2_ucorrect += ucorr
                test2_lcorrect += lcorr
                test2_total_gold += total_gold
		test2_total_pred += total_pred
                
                test2_total_inst += num_inst

	    end_time_test2 = time.time()
	    lasted_time_test2=end_time_test2-start_time_test2
            pred_writer.close()
            gold_writer.close()	

	    test2_Uprecision=0.
	    test2_Lprecision=0.
	    if dev_total_pred!=0:
		    test2_Uprecision=test2_ucorrect * 100 / test2_total_pred
		    test2_Lprecision=test2_lcorrect * 100 / test2_total_pred
	    test2_Urecall=test2_ucorrect * 100 / test2_total_gold
	    test2_Lrecall=test2_lcorrect * 100 / test2_total_gold
	    if test2_Uprecision ==0. and test2_Urecall==0.: 
		test2_UF1=0.
	    else: 
		test2_UF1=2*(test2_Uprecision*test2_Urecall)/(test2_Uprecision+test2_Urecall)
	    if test2_Lprecision ==0 and test2_Lrecall==0: 
		test2_LF1=0.
	    else: 
	    	test2_LF1=2*(test2_Lprecision*test2_Lrecall)/(test2_Lprecision+test2_Lrecall)


        else:
            #if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
	    if dev_LF1 < dev_bestLF1 - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

	
        print('----------------------------------------------------------------------------------------------------------------------------')
	print('TIME DEV: ', lasted_time_dev, 'NUM SENTS DEV: ', dev_total_inst, 'SPEED DEV: ', dev_total_inst/lasted_time_dev)
	print('DEV: Uprec: %.2f%%, Urec: %.2f%%, Lprec: %.2f%%, Lrec: %.2f%%, UF1: %.2f%%, LF1: %.2f%% (epoch: %d)' % (
             dev_bestUprecision, dev_bestUrecall, dev_bestLprecision, dev_bestLrecall, dev_bestUF1, dev_bestLF1, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
	print('TIME TEST: ', lasted_time_test, 'NUM SENTS TEST: ', test_total_inst, 'SPEED TEST: ', test_total_inst/lasted_time_test)
        print('TEST: ucorr: %d, lcorr: %d, tot_gold: %d, tot_pred: %d, Uprec: %.2f%%, Urec: %.2f%%, Lprec: %.2f%%, Lrec: %.2f%%, UF1: %.2f%%, LF1: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total_gold, test_total_pred, test_Uprecision, test_Urecall, test_Lprecision, test_Lrecall, test_UF1, test_LF1, best_epoch))
	print('----------------------------------------------------------------------------------------------------------------------------')
	print('TIME TEST2: ', lasted_time_test2, 'NUM SENTS TEST: ', test2_total_inst, 'SPEED TEST2: ', test2_total_inst/lasted_time_test2)
        print('TEST2: ucorr: %d, lcorr: %d, tot_gold: %d, tot_pred: %d, Uprec: %.2f%%, Urec: %.2f%%, Lprec: %.2f%%, Lrec: %.2f%%, UF1: %.2f%%, LF1: %.2f%% (epoch: %d)' % (
            test2_ucorrect, test2_lcorrect, test2_total_gold, test2_total_pred, test2_Uprecision, test2_Urecall, test2_Lprecision, test2_Lrecall, test2_UF1, test2_LF1, best_epoch))
        print('============================================================================================================================')

	#exit(0)
        if decay == max_decay:
            break
Ejemplo n.º 19
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.register('type', 'bool', str2bool)

    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=1000,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder.')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode',
                             choices=['mst', 'greedy'],
                             default='mst',
                             help='decoding algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    # args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    # args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument('--data_dir', help='Data directory path')
    args_parser.add_argument(
        '--src_lang',
        required=True,
        help='Src language to train dependency parsing model')
    args_parser.add_argument('--aux_lang',
                             nargs='+',
                             help='Language names for adversarial training')
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    #
    args_parser.add_argument('--attn_on_rnn',
                             action='store_true',
                             help='use self-attention on top of context RNN.')
    args_parser.add_argument('--no_word',
                             type='bool',
                             default=False,
                             help='do not use word embedding.')
    args_parser.add_argument('--use_bert',
                             type='bool',
                             default=False,
                             help='use multilingual BERT.')
    #
    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             type='bool',
                             default=False,
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    # encoder selection
    args_parser.add_argument('--encoder_type',
                             choices=['Transformer', 'RNN', 'SelfAttn'],
                             default='RNN',
                             help='do not use context RNN.')
    args_parser.add_argument(
        '--pool_type',
        default='mean',
        choices=['max', 'mean', 'weight'],
        help='pool type to form fixed length vector from word embeddings')
    # Tansformer encoder
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    args_parser.add_argument(
        '--use_all_encoder_layers',
        type='bool',
        default=False,
        help='Use a weighted representations of all encoder layers')
    # - positional
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=50,
                             help='Dimension of Position embeddings.')
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')

    args_parser.add_argument('--input_concat_embeds',
                             action='store_true',
                             help="Concat input embeddings, otherwise add.")
    args_parser.add_argument('--input_concat_position',
                             action='store_true',
                             help="Concat position embeddings, otherwise add.")
    args_parser.add_argument(
        '--partitioned',
        type='bool',
        default=False,
        help=
        "Partition the content and positional attention for multi-head attention."
    )
    args_parser.add_argument(
        '--partition_type',
        choices=['content-position', 'lexical-delexical'],
        default='content-position',
        help="How to apply partition in the self-attention.")
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    #
    # regarding adversarial training
    args_parser.add_argument('--pre_model_path',
                             type=str,
                             default=None,
                             help='Path of the pretrained model.')
    args_parser.add_argument('--pre_model_name',
                             type=str,
                             default=None,
                             help='Name of the pretrained model.')
    args_parser.add_argument('--adv_training',
                             type='bool',
                             default=False,
                             help='Use adversarial training.')
    args_parser.add_argument(
        '--lambdaG',
        type=float,
        default=0.001,
        help='Scaling parameter to control generator loss.')
    args_parser.add_argument('--discriminator',
                             choices=['weak', 'not-so-weak', 'strong'],
                             default='weak',
                             help='architecture of the discriminator')
    args_parser.add_argument(
        '--delay',
        type=int,
        default=0,
        help='Number of epochs to be run first for the source task')
    args_parser.add_argument(
        '--n_critic',
        type=int,
        default=5,
        help='Number of training steps for discriminator per iter')
    args_parser.add_argument(
        '--clip_disc',
        type=float,
        default=5.0,
        help='Lower and upper clip value for disc. weights')
    args_parser.add_argument('--debug',
                             type='bool',
                             default=False,
                             help='Use debug portion of the training data')
    args_parser.add_argument('--train_level',
                             type=str,
                             default='word',
                             choices=['word', 'sent'],
                             help='Use X-level adversarial training')
    args_parser.add_argument('--train_type',
                             type=str,
                             default='GAN',
                             choices=['GR', 'GAN', 'WGAN'],
                             help='Type of adversarial training')
    #
    # regarding motivational training
    args_parser.add_argument(
        '--motivate',
        type='bool',
        default=False,
        help='This is opposite of the adversarial training')

    #
    args = args_parser.parse_args()

    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    # if output directory doesn't exist, create it
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    logger = get_logger("GraphParser")

    logger.info('\ncommand-line params : {0}\n'.format(sys.argv[1:]))
    logger.info('{0}\n'.format(args))

    logger.info("Visible GPUs: %s", str(os.environ["CUDA_VISIBLE_DEVICES"]))
    args.parallel = False
    if torch.cuda.device_count() > 1:
        args.parallel = True

    mode = args.mode
    obj = args.objective
    decoding = args.decode

    train_path = args.data_dir + args.src_lang + "_train.debug.1_10.conllu" \
        if args.debug else args.data_dir + args.src_lang + '_train.conllu'
    dev_path = args.data_dir + args.src_lang + "_dev.conllu"
    test_path = args.data_dir + args.src_lang + "_test.conllu"

    #
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path
    model_path = args.model_path
    model_name = args.model_name

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    use_word_emb = not args.no_word
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    attn_on_rnn = args.attn_on_rnn
    encoder_type = args.encoder_type
    if attn_on_rnn:
        assert encoder_type == 'RNN'

    t_types = (args.adv_training, args.motivate)
    t_count = sum(1 for tt in t_types if tt)
    if t_count > 1:
        assert False, "Only one of: adv_training or motivate can be true"

    # ------------------- Loading/initializing embeddings -------------------- #

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)

    # TODO (WARNING): must build vocabs previously
    assert os.path.isdir(alphabet_path), "should have build vocabs previously"
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)
    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading/building the model ------------------------ #

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table() if use_word_emb else None
    char_table = construct_char_embedding_table() if use_char else None

    def load_model_arguments_from_json():
        arguments = json.load(open(pre_model_path, 'r'))
        return arguments['args'], arguments['kwargs']

    window = 3
    if obj == 'cross_entropy':
        if args.pre_model_path and args.pre_model_name:
            pre_model_name = os.path.join(args.pre_model_path,
                                          args.pre_model_name)
            pre_model_path = pre_model_name + '.arg.json'
            model_args, kwargs = load_model_arguments_from_json()

            network = BiRecurrentConvBiAffine(use_gpu=use_gpu,
                                              *model_args,
                                              **kwargs)
            network.load_state_dict(torch.load(pre_model_name))
            logger.info("Model reloaded from %s" % pre_model_path)

            # Adjust the word embedding layer
            if network.embedder.word_embedd is not None:
                network.embedder.word_embedd = nn.Embedding(num_words,
                                                            word_dim,
                                                            _weight=word_table)

        else:
            network = BiRecurrentConvBiAffine(
                word_dim,
                num_words,
                char_dim,
                num_chars,
                pos_dim,
                num_pos,
                num_filters,
                window,
                mode,
                hidden_size,
                num_layers,
                num_types,
                arc_space,
                type_space,
                embedd_word=word_table,
                embedd_char=char_table,
                p_in=p_in,
                p_out=p_out,
                p_rnn=p_rnn,
                biaffine=True,
                pos=use_pos,
                char=use_char,
                train_position=args.train_position,
                encoder_type=encoder_type,
                trans_hid_size=args.trans_hid_size,
                d_k=args.d_k,
                d_v=args.d_v,
                num_head=args.num_head,
                enc_use_neg_dist=args.enc_use_neg_dist,
                enc_clip_dist=args.enc_clip_dist,
                position_dim=args.position_dim,
                max_sent_length=max_sent_length,
                use_gpu=use_gpu,
                use_word_emb=use_word_emb,
                input_concat_embeds=args.input_concat_embeds,
                input_concat_position=args.input_concat_position,
                attn_on_rnn=attn_on_rnn,
                partitioned=args.partitioned,
                partition_type=args.partition_type,
                use_all_encoder_layers=args.use_all_encoder_layers,
                use_bert=args.use_bert)

    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    # ------------------------------------------------------------------------- #
    # --------------------- Loading data -------------------------------------- #

    train_data = dict()
    dev_data = dict()
    test_data = dict()
    num_data = dict()
    lang_ids = dict()
    reverse_lang_ids = dict()

    # ===== the reading =============================================
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_data.read_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=False,
            volatile=(not is_train),
            symbolic_root=True,
            lang_id=lang_id,
            use_bert=args.use_bert,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    train_data[args.src_lang] = data_train
    num_data[args.src_lang] = sum(data_train[1])
    lang_ids[args.src_lang] = len(lang_ids)
    reverse_lang_ids[lang_ids[args.src_lang]] = args.src_lang

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    dev_data[args.src_lang] = data_dev
    test_data[args.src_lang] = data_test

    # ===============================================================

    # ===== reading data for adversarial training ===================
    if t_count > 0:
        for language in args.aux_lang:
            aux_train_path = args.data_dir + language + "_train.debug.1_10.conllu" \
                if args.debug else args.data_dir + language + '_train.conllu'
            aux_train_data = _read_one(aux_train_path, True)
            num_data[language] = sum(aux_train_data[1])
            train_data[language] = aux_train_data
            lang_ids[language] = len(lang_ids)
            reverse_lang_ids[lang_ids[language]] = language
    # ===============================================================

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, hidden_size, num_layers, num_types,
            arc_space, type_space
        ]
        kwargs = {
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'train_position': args.train_position,
            'encoder_type': args.encoder_type,
            'trans_hid_size': args.trans_hid_size,
            'd_k': args.d_k,
            'd_v': args.d_v,
            'num_head': args.num_head,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'enc_clip_dist': args.enc_clip_dist,
            'position_dim': args.position_dim,
            'max_sent_length': max_sent_length,
            'use_word_emb': use_word_emb,
            'input_concat_embeds': args.input_concat_embeds,
            'input_concat_position': args.input_concat_position,
            'attn_on_rnn': attn_on_rnn,
            'partitioned': args.partitioned,
            'partition_type': args.partition_type,
            'use_all_encoder_layers': args.use_all_encoder_layers,
            'use_bert': args.use_bert
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if use_word_emb and freeze:
        freeze_embedding(network.embedder.word_embedd)

    if args.parallel:
        network = torch.nn.DataParallel(network)

    if use_gpu:
        network = network.cuda()

    save_args()

    param_dict = {}
    encoder = network.module.encoder if args.parallel else network.encoder
    for name, param in encoder.named_parameters():
        if param.requires_grad:
            param_dict[name] = np.prod(param.size())

    total_params = np.sum(list(param_dict.values()))
    logger.info('Total Encoder Parameters = %d' % total_params)

    # ------------------------------------------------------------------------- #

    # =============================================
    if args.adv_training:
        disc_feat_size = network.module.encoder.output_dim if args.parallel else network.encoder.output_dim
        reverse_grad = args.train_type == 'GR'
        nclass = len(lang_ids) if args.train_type == 'GR' else 1

        kwargs = {
            'input_size': disc_feat_size,
            'disc_type': args.discriminator,
            'train_level': args.train_level,
            'train_type': args.train_type,
            'reverse_grad': reverse_grad,
            'soft_label': True,
            'nclass': nclass,
            'scale': args.lambdaG,
            'use_gpu': use_gpu,
            'opt': 'adam',
            'lr': 0.001,
            'betas': (0.9, 0.999),
            'gamma': 0,
            'eps': 1e-8,
            'momentum': 0,
            'clip_disc': args.clip_disc
        }
        AdvAgent = Adversarial(**kwargs)
        if use_gpu:
            AdvAgent.cuda()

    elif args.motivate:
        disc_feat_size = network.module.encoder.output_dim if args.parallel else network.encoder.output_dim
        nclass = len(lang_ids)

        kwargs = {
            'input_size': disc_feat_size,
            'disc_type': args.discriminator,
            'train_level': args.train_level,
            'nclass': nclass,
            'scale': args.lambdaG,
            'use_gpu': use_gpu,
            'opt': 'adam',
            'lr': 0.001,
            'betas': (0.9, 0.999),
            'gamma': 0,
            'eps': 1e-8,
            'momentum': 0,
            'clip_disc': args.clip_disc
        }
        MtvAgent = Motivator(**kwargs)
        if use_gpu:
            MtvAgent.cuda()

    # =============================================

    # --------------------- Initializing the optimizer ------------------------ #

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters(), betas, gamma,
                               eps, momentum)
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    # =============================================

    total_data = min(num_data.values())

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" %
        (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info(
        "train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)"
        % (obj, gamma, total_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    # ------------------------------------------------------------------------- #
    # --------------------- Form the mini-batches ----------------------------- #
    num_batches = total_data // batch_size + 1
    aux_lang = []
    if t_count > 0:
        for language in args.aux_lang:
            aux_lang.extend([language] * num_data[language])

        assert num_data[args.src_lang] <= len(aux_lang)
    # ------------------------------------------------------------------------- #

    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    if decoding == 'greedy':
        decode = network.module.decode if args.parallel else network.decode
    elif decoding == 'mst':
        decode = network.module.decode_mst if args.parallel else network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))

    skip_adv_tuning = 0
    loss_fn = network.module.loss if args.parallel else network.loss
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): '
            %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0

        skip_adv_tuning += 1
        loss_d_real, loss_d_fake = [], []
        acc_d_real, acc_d_fake, = [], []
        gen_loss, parsing_loss = [], []
        disent_loss = []

        if t_count > 0 and skip_adv_tuning > args.delay:
            batch_size = args.batch_size // 2
            num_batches = total_data // batch_size + 1

        # ---------------------- Sample the mini-batches -------------------------- #
        if t_count > 0:
            sampled_aux_lang = random.sample(aux_lang, num_batches)
            lang_in_batch = [(args.src_lang, sampled_aux_lang[k])
                             for k in range(num_batches)]
        else:
            lang_in_batch = [(args.src_lang, None) for _ in range(num_batches)]
        assert len(lang_in_batch) == num_batches
        # ------------------------------------------------------------------------- #

        network.train()
        warmup_factor = (lr + 0.) / num_batches
        for batch in range(1, num_batches + 1):
            update_generator = True
            update_discriminator = False

            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate

            # considering source language as real and auxiliary languages as fake
            real_lang, fake_lang = lang_in_batch[batch - 1]
            real_idx, fake_idx = lang_ids.get(real_lang), lang_ids.get(
                fake_lang, -1)

            #
            word, char, pos, heads, types, masks, lengths, bert_inputs = conllx_data.get_batch_variable(
                train_data[real_lang], batch_size, unk_replace=unk_replace)

            if use_gpu:
                word = word.cuda()
                char = char.cuda()
                pos = pos.cuda()
                heads = heads.cuda()
                types = types.cuda()
                masks = masks.cuda()
                lengths = lengths.cuda()
                if bert_inputs[0] is not None:
                    bert_inputs[0] = bert_inputs[0].cuda()
                    bert_inputs[1] = bert_inputs[1].cuda()
                    bert_inputs[2] = bert_inputs[2].cuda()

            real_enc = network(word,
                               char,
                               pos,
                               input_bert=bert_inputs,
                               mask=masks,
                               length=lengths,
                               hx=None)

            # ========== Update the discriminator ==========
            if t_count > 0 and skip_adv_tuning > args.delay:
                # fake examples = 0
                word_f, char_f, pos_f, heads_f, types_f, masks_f, lengths_f, bert_inputs = conllx_data.get_batch_variable(
                    train_data[fake_lang], batch_size, unk_replace=unk_replace)

                if use_gpu:
                    word_f = word_f.cuda()
                    char_f = char_f.cuda()
                    pos_f = pos_f.cuda()
                    heads_f = heads_f.cuda()
                    types_f = types_f.cuda()
                    masks_f = masks_f.cuda()
                    lengths_f = lengths_f.cuda()
                    if bert_inputs[0] is not None:
                        bert_inputs[0] = bert_inputs[0].cuda()
                        bert_inputs[1] = bert_inputs[1].cuda()
                        bert_inputs[2] = bert_inputs[2].cuda()

                fake_enc = network(word_f,
                                   char_f,
                                   pos_f,
                                   input_bert=bert_inputs,
                                   mask=masks_f,
                                   length=lengths_f,
                                   hx=None)

                # TODO: temporary crack
                if t_count > 0 and skip_adv_tuning > args.delay:
                    # skip discriminator training for '|n_critic|' iterations if 'n_critic' < 0
                    if args.n_critic > 0 or (batch - 1) % (-1 *
                                                           args.n_critic) == 0:
                        update_discriminator = True

            if update_discriminator:
                if args.adv_training:
                    real_loss, fake_loss, real_acc, fake_acc = AdvAgent.update(
                        real_enc['output'].detach(),
                        fake_enc['output'].detach(), real_idx, fake_idx)

                    loss_d_real.append(real_loss)
                    loss_d_fake.append(fake_loss)
                    acc_d_real.append(real_acc)
                    acc_d_fake.append(fake_acc)

                elif args.motivate:
                    real_loss, fake_loss, real_acc, fake_acc = MtvAgent.update(
                        real_enc['output'].detach(),
                        fake_enc['output'].detach(), real_idx, fake_idx)

                    loss_d_real.append(real_loss)
                    loss_d_fake.append(fake_loss)
                    acc_d_real.append(real_acc)
                    acc_d_fake.append(fake_acc)

                else:
                    raise NotImplementedError()

                if args.n_critic > 0 and (batch - 1) % args.n_critic != 0:
                    update_generator = False

            # ==============================================

            # =========== Update the generator =============
            if update_generator:
                others_loss = None
                if args.adv_training and skip_adv_tuning > args.delay:
                    # for GAN: L_G= L_parsing - (lambda_G * L_D)
                    # for GR : L_G= L_parsing +  L_D
                    others_loss = AdvAgent.gen_loss(real_enc['output'],
                                                    fake_enc['output'],
                                                    real_idx, fake_idx)
                    gen_loss.append(others_loss.item())

                elif args.motivate and skip_adv_tuning > args.delay:
                    others_loss = MtvAgent.gen_loss(real_enc['output'],
                                                    fake_enc['output'],
                                                    real_idx, fake_idx)
                    gen_loss.append(others_loss.item())

                optim.zero_grad()

                loss_arc, loss_type = loss_fn(real_enc['output'],
                                              heads,
                                              types,
                                              mask=masks,
                                              length=lengths)
                loss = loss_arc + loss_type

                num_inst = word.size(
                    0) if obj == 'crf' else masks.sum() - word.size(0)
                train_err += loss.item() * num_inst
                train_err_arc += loss_arc.item() * num_inst
                train_err_type += loss_type.item() * num_inst
                train_total += num_inst
                parsing_loss.append(loss.item())

                if others_loss is not None:
                    loss = loss + others_loss

                loss.backward()
                clip_grad_norm_(network.parameters(), clip)
                optim.step()

                time_ave = (time.time() - start_time) / batch
                time_left = (num_batches - batch) * time_ave

        if (args.adv_training
                or args.motivate) and skip_adv_tuning > args.delay:
            logger.info(
                'epoch: %d train: %d loss: %.4f, arc: %.4f, type: %.4f, dis_loss: (%.2f, %.2f), dis_acc: (%.2f, %.2f), '
                'gen_loss: %.2f, time: %.2fs' %
                (epoch, num_batches, train_err / train_total,
                 train_err_arc / train_total, train_err_type / train_total,
                 sum(loss_d_real) / len(loss_d_real), sum(loss_d_fake) /
                 len(loss_d_fake), sum(acc_d_real) / len(acc_d_real),
                 sum(acc_d_fake) / len(acc_d_fake),
                 sum(gen_loss) / len(gen_loss), time.time() - start_time))
        else:
            logger.info(
                'epoch: %d train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs'
                % (epoch, num_batches, train_err / train_total,
                   train_err_arc / train_total, train_err_type / train_total,
                   time.time() - start_time))

        ################# Validation on Dependency Parsing Only #################################
        if epoch % args.check_dev != 0:
            continue

        with torch.no_grad():
            # evaluate performance on dev data
            network.eval()

            dev_ucorr = 0.0
            dev_lcorr = 0.0
            dev_total = 0
            dev_ucomlpete = 0.0
            dev_lcomplete = 0.0
            dev_ucorr_nopunc = 0.0
            dev_lcorr_nopunc = 0.0
            dev_total_nopunc = 0
            dev_ucomlpete_nopunc = 0.0
            dev_lcomplete_nopunc = 0.0
            dev_root_corr = 0.0
            dev_total_root = 0.0
            dev_total_inst = 0.0

            for lang, data_dev in dev_data.items():
                for batch in conllx_data.iterate_batch_variable(
                        data_dev, batch_size):
                    word, char, pos, heads, types, masks, lengths, bert_inputs = batch

                    if use_gpu:
                        word = word.cuda()
                        char = char.cuda()
                        pos = pos.cuda()
                        heads = heads.cuda()
                        types = types.cuda()
                        masks = masks.cuda()
                        lengths = lengths.cuda()
                        if bert_inputs[0] is not None:
                            bert_inputs[0] = bert_inputs[0].cuda()
                            bert_inputs[1] = bert_inputs[1].cuda()
                            bert_inputs[2] = bert_inputs[2].cuda()

                    heads_pred, types_pred = decode(
                        word,
                        char,
                        pos,
                        input_bert=bert_inputs,
                        mask=masks,
                        length=lengths,
                        leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                    word = word.cpu().numpy()
                    pos = pos.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.cpu().numpy()
                    types = types.cpu().numpy()

                    stats, stats_nopunc, stats_root, num_inst = parser.eval(
                        word,
                        pos,
                        heads_pred,
                        types_pred,
                        heads,
                        types,
                        word_alphabet,
                        pos_alphabet,
                        lengths,
                        punct_set=punct_set,
                        symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    dev_ucorr += ucorr
                    dev_lcorr += lcorr
                    dev_total += total
                    dev_ucomlpete += ucm
                    dev_lcomplete += lcm

                    dev_ucorr_nopunc += ucorr_nopunc
                    dev_lcorr_nopunc += lcorr_nopunc
                    dev_total_nopunc += total_nopunc
                    dev_ucomlpete_nopunc += ucm_nopunc
                    dev_lcomplete_nopunc += lcm_nopunc

                    dev_root_corr += corr_root
                    dev_total_root += total_root
                    dev_total_inst += num_inst

            print(
                'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
                % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 /
                   dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete *
                   100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
            print(
                'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
                %
                (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
                 dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
                 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
                 dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
            print('Root: corr: %d, total: %d, acc: %.2f%%' %
                  (dev_root_corr, dev_total_root,
                   dev_root_corr * 100 / dev_total_root))

            if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                    dev_lcorrect_nopunc == dev_lcorr_nopunc
                    and dev_ucorrect_nopunc < dev_ucorr_nopunc):
                dev_ucorrect_nopunc = dev_ucorr_nopunc
                dev_lcorrect_nopunc = dev_lcorr_nopunc
                dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
                dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

                dev_ucorrect = dev_ucorr
                dev_lcorrect = dev_lcorr
                dev_ucomlpete_match = dev_ucomlpete
                dev_lcomplete_match = dev_lcomplete

                dev_root_correct = dev_root_corr

                best_epoch = epoch
                patient = 0

                state_dict = network.module.state_dict(
                ) if args.parallel else network.state_dict()
                torch.save(state_dict, model_name)

            else:
                if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                    state_dict = torch.load(model_name)
                    if args.parallel:
                        network.module.load_state_dict(state_dict)
                    else:
                        network.load_state_dict(state_dict)

                    lr = lr * decay_rate
                    optim = generate_optimizer(opt, lr, network.parameters(),
                                               betas, gamma, eps, momentum)

                    if decoding == 'greedy':
                        decode = network.module.decode if args.parallel else network.decode
                    elif decoding == 'mst':
                        decode = network.module.decode_mst if args.parallel else network.decode_mst
                    else:
                        raise ValueError('Unknown decoding algorithm: %s' %
                                         decoding)

                    patient = 0
                    decay += 1
                    if decay % double_schedule_decay == 0:
                        schedule *= 2
                else:
                    patient += 1

            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            print(
                'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 /
                   dev_total, dev_lcorrect * 100 / dev_total,
                   dev_ucomlpete_match * 100 / dev_total_inst,
                   dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
            print(
                'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
                   dev_ucorrect_nopunc * 100 / dev_total_nopunc,
                   dev_lcorrect_nopunc * 100 / dev_total_nopunc,
                   dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
                   dev_lcomplete_match_nopunc * 100 / dev_total_inst,
                   best_epoch))
            print(
                'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                % (dev_root_correct, dev_total_root,
                   dev_root_correct * 100 / dev_total_root, best_epoch))
            print(
                '----------------------------------------------------------------------------------------------------------------------------'
            )
            if decay == max_decay:
                break

        torch.cuda.empty_cache()  # release memory that can be released
Ejemplo n.º 20
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with stack pointer parser')
    args_parser.add_argument('--seed',
                             type=int,
                             default=1234,
                             help='random seed for reproducibility')
    args_parser.add_argument('--mode',
                             choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn',
                             required=True)
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--decoder_input_size',
                             type=int,
                             default=256,
                             help='Number of input units in decoder RNN.')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--type_space',
                             type=int,
                             default=128,
                             help='Dimension of tag space')
    args_parser.add_argument('--encoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of encoder RNN')
    args_parser.add_argument('--decoder_layers',
                             type=int,
                             default=1,
                             help='Number of layers of decoder RNN')
    args_parser.add_argument('--num_filters',
                             type=int,
                             default=50,
                             help='Number of filters in CNN')
    args_parser.add_argument(
        '--trans_hid_size',
        type=int,
        default=1024,
        help='#hidden units in point-wise feed-forward in transformer')
    args_parser.add_argument(
        '--d_k',
        type=int,
        default=64,
        help='d_k for multi-head-attention in transformer encoder')
    args_parser.add_argument(
        '--d_v',
        type=int,
        default=64,
        help='d_v for multi-head-attention in transformer encoder')
    args_parser.add_argument('--multi_head_attn',
                             action='store_true',
                             help='use multi-head-attention.')
    args_parser.add_argument('--num_head',
                             type=int,
                             default=8,
                             help='Value of h in multi-head attention')
    args_parser.add_argument(
        '--pool_type',
        default='mean',
        choices=['max', 'mean', 'weight'],
        help='pool type to form fixed length vector from word embeddings')
    args_parser.add_argument('--train_position',
                             action='store_true',
                             help='train positional encoding for transformer.')
    args_parser.add_argument('--no_word',
                             action='store_true',
                             help='do not use word embedding.')
    args_parser.add_argument('--pos',
                             action='store_true',
                             help='use part-of-speech embedding.')
    args_parser.add_argument('--char',
                             action='store_true',
                             help='use character embedding and CNN.')
    args_parser.add_argument('--no_CoRNN',
                             action='store_true',
                             help='do not use context RNN.')
    args_parser.add_argument('--pos_dim',
                             type=int,
                             default=50,
                             help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim',
                             type=int,
                             default=50,
                             help='Dimension of Character embeddings')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.001,
                             help='Learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--coverage',
                             type=float,
                             default=0.0,
                             help='weight for coverage loss')
    args_parser.add_argument('--p_rnn',
                             nargs='+',
                             type=float,
                             required=True,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--label_smooth',
                             type=float,
                             default=1.0,
                             help='weight of label smoothing method')
    args_parser.add_argument('--skipConnect',
                             action='store_true',
                             help='use skip connection for decoder RNN.')
    args_parser.add_argument('--grandPar',
                             action='store_true',
                             help='use grand parent.')
    args_parser.add_argument('--sibling',
                             action='store_true',
                             help='use sibling.')
    args_parser.add_argument(
        '--prior_order',
        choices=['inside_out', 'left2right', 'deep_first', 'shallow_first'],
        help='prior order of children.',
        required=True)
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation',
                             nargs='+',
                             type=str,
                             help='List of punctuations')
    args_parser.add_argument('--beam',
                             type=int,
                             default=1,
                             help='Beam size for decoding')
    args_parser.add_argument(
        '--word_embedding',
        choices=['word2vec', 'glove', 'senna', 'sskip', 'polyglot'],
        help='Embedding for words',
        required=True)
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding',
                             choices=['random', 'polyglot'],
                             help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path',
                             help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--vocab_path',
                             help='path for prebuilt alphabets.',
                             default=None)
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             required=True)
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             required=True)
    args_parser.add_argument(
        '--position_embed_num',
        type=int,
        default=200,
        help=
        'Minimum value of position embedding num, which usually is max-sent-length.'
    )
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=2000,
                             help='Number of training epochs')

    # lrate schedule with warmup in the first iter.
    args_parser.add_argument('--use_warmup_schedule',
                             action='store_true',
                             help="Use warmup lrate schedule.")
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.75,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--max_decay',
                             type=int,
                             default=9,
                             help='Number of decays before stop')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument('--double_schedule_decay',
                             type=int,
                             default=5,
                             help='Number of decays to double schedule')
    args_parser.add_argument(
        '--check_dev',
        type=int,
        default=5,
        help='Check development performance in every n\'th iteration')
    #
    # about decoder's bi-attention scoring with features (default is not using any)
    args_parser.add_argument(
        '--dec_max_dist',
        type=int,
        default=0,
        help=
        "The clamp range of decoder's distance feature, 0 means turning off.")
    args_parser.add_argument('--dec_dim_feature',
                             type=int,
                             default=10,
                             help="Dim for feature embed.")
    args_parser.add_argument(
        '--dec_use_neg_dist',
        action='store_true',
        help="Use negative distance for dec's distance feature.")
    args_parser.add_argument(
        '--dec_use_encoder_pos',
        action='store_true',
        help="Use pos feature combined with distance feature for child nodes.")
    args_parser.add_argument(
        '--dec_use_decoder_pos',
        action='store_true',
        help="Use pos feature combined with distance feature for head nodes.")
    args_parser.add_argument('--dec_drop_f_embed',
                             type=float,
                             default=0.2,
                             help="Dropout for dec feature embeddings.")
    #
    # about relation-aware self attention for the transformer encoder (default is not using any)
    # args_parser.add_argument('--rel_aware', action='store_true',
    #                          help="Enable relation-aware self-attention (multi_head_attn flag needs to be set).")
    args_parser.add_argument(
        '--enc_use_neg_dist',
        action='store_true',
        help="Use negative distance for enc's relational-distance embedding.")
    args_parser.add_argument(
        '--enc_clip_dist',
        type=int,
        default=0,
        help="The clipping distance for relative position features.")
    #
    # other options about how to combine multiple input features (have to make some dims fit if not concat)
    args_parser.add_argument('--input_concat_embeds',
                             action='store_true',
                             help="Concat input embeddings, otherwise add.")
    args_parser.add_argument('--input_concat_position',
                             action='store_true',
                             help="Concat position embeddings, otherwise add.")
    args_parser.add_argument('--position_dim',
                             type=int,
                             default=300,
                             help='Dimension of Position embeddings.')
    #
    args_parser.add_argument(
        '--train_len_thresh',
        type=int,
        default=100,
        help='In training, discard sentences longer than this.')

    args = args_parser.parse_args()

    # =====
    # fix data-prepare seed
    random.seed(1234)
    np.random.seed(1234)
    # model's seed
    torch.manual_seed(args.seed)

    # =====

    # if output directory doesn't exist, create it
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    logger = get_logger("PtrParser", args.model_path + 'log.txt')

    logger.info('\ncommand-line params : {0}\n'.format(sys.argv[1:]))
    logger.info('{0}\n'.format(args))

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    vocab_path = args.vocab_path if args.vocab_path is not None else args.model_path
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    input_size_decoder = args.decoder_input_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    encoder_layers = args.encoder_layers
    decoder_layers = args.decoder_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    cov = args.coverage
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    label_smooth = args.label_smooth
    unk_replace = args.unk_replace
    prior_order = args.prior_order
    skipConnect = args.skipConnect
    grandPar = args.grandPar
    sibling = args.sibling
    beam = args.beam
    punctuation = args.punctuation

    freeze = args.freeze
    use_word_emb = not args.no_word
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_con_rnn = not args.no_CoRNN

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(
        word_embedding, word_path) if use_word_emb else (None, 0)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(vocab_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)

    # todo(warn): should build vocabs previously
    assert os.path.isdir(alphabet_path), "should have build vocabs previously"
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = conllx_stacked_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        max_vocabulary_size=50000,
        embedd_dict=word_dict)
    # word_alphabet, char_alphabet, pos_alphabet, type_alphabet, max_sent_length = create_alphabets(alphabet_path,
    #     train_path, data_paths=[dev_path, test_path], max_vocabulary_size=50000, embedd_dict=word_dict)
    max_sent_length = max(max_sent_length, args.position_embed_num)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    # ===== the reading
    def _read_one(path, is_train):
        lang_id = guess_language_id(path)
        logger.info("Reading: guess that the language of file %s is %s." %
                    (path, lang_id))
        one_data = conllx_stacked_data.read_stacked_data_to_variable(
            path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            use_gpu=use_gpu,
            volatile=(not is_train),
            prior_order=prior_order,
            lang_id=lang_id,
            len_thresh=(args.train_len_thresh if is_train else 100000))
        return one_data

    data_train = _read_one(train_path, True)
    num_data = sum(data_train[1])

    data_dev = _read_one(dev_path, False)
    data_test = _read_one(test_path, False)
    # =====

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        logger.info('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_stacked_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        logger.info('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table() if use_word_emb else None
    char_table = construct_char_embedding_table()

    window = 3
    network = StackPtrNet(word_dim,
                          num_words,
                          char_dim,
                          num_chars,
                          pos_dim,
                          num_pos,
                          num_filters,
                          window,
                          mode,
                          input_size_decoder,
                          hidden_size,
                          encoder_layers,
                          decoder_layers,
                          num_types,
                          arc_space,
                          type_space,
                          args.pool_type,
                          args.multi_head_attn,
                          args.num_head,
                          max_sent_length,
                          args.trans_hid_size,
                          args.d_k,
                          args.d_v,
                          train_position=args.train_position,
                          embedd_word=word_table,
                          embedd_char=char_table,
                          p_in=p_in,
                          p_out=p_out,
                          p_rnn=p_rnn,
                          biaffine=True,
                          use_word_emb=use_word_emb,
                          pos=use_pos,
                          char=use_char,
                          prior_order=prior_order,
                          use_con_rnn=use_con_rnn,
                          skipConnect=skipConnect,
                          grandPar=grandPar,
                          sibling=sibling,
                          use_gpu=use_gpu,
                          dec_max_dist=args.dec_max_dist,
                          dec_use_neg_dist=args.dec_use_neg_dist,
                          dec_use_encoder_pos=args.dec_use_encoder_pos,
                          dec_use_decoder_pos=args.dec_use_decoder_pos,
                          dec_dim_feature=args.dec_dim_feature,
                          dec_drop_f_embed=args.dec_drop_f_embed,
                          enc_clip_dist=args.enc_clip_dist,
                          enc_use_neg_dist=args.enc_use_neg_dist,
                          input_concat_embeds=args.input_concat_embeds,
                          input_concat_position=args.input_concat_position,
                          position_dim=args.position_dim)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [
            word_dim, num_words, char_dim, num_chars, pos_dim, num_pos,
            num_filters, window, mode, input_size_decoder, hidden_size,
            encoder_layers, decoder_layers, num_types, arc_space, type_space,
            args.pool_type, args.multi_head_attn, args.num_head,
            max_sent_length, args.trans_hid_size, args.d_k, args.d_v
        ]
        kwargs = {
            'train_position': args.train_position,
            'use_word_emb': use_word_emb,
            'use_con_rnn': use_con_rnn,
            'p_in': p_in,
            'p_out': p_out,
            'p_rnn': p_rnn,
            'biaffine': True,
            'pos': use_pos,
            'char': use_char,
            'prior_order': prior_order,
            'skipConnect': skipConnect,
            'grandPar': grandPar,
            'sibling': sibling,
            'dec_max_dist': args.dec_max_dist,
            'dec_use_neg_dist': args.dec_use_neg_dist,
            'dec_use_encoder_pos': args.dec_use_encoder_pos,
            'dec_use_decoder_pos': args.dec_use_decoder_pos,
            'dec_dim_feature': args.dec_dim_feature,
            'dec_drop_f_embed': args.dec_drop_f_embed,
            'enc_clip_dist': args.enc_clip_dist,
            'enc_use_neg_dist': args.enc_use_neg_dist,
            'input_concat_embeds': args.input_concat_embeds,
            'input_concat_position': args.input_concat_position,
            'position_dim': args.position_dim
        }
        json.dump({
            'args': arguments,
            'kwargs': kwargs
        },
                  open(arg_path, 'w'),
                  indent=4)

    if use_word_emb and freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params,
                        lr=lr,
                        betas=betas,
                        weight_decay=gamma,
                        eps=eps)
        elif opt == 'sgd':
            return SGD(params,
                       lr=lr,
                       momentum=momentum,
                       weight_decay=gamma,
                       nesterov=True)
        elif opt == 'adamax':
            return Adamax(params,
                          lr=lr,
                          betas=betas,
                          weight_decay=gamma,
                          eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info(
        "Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" %
        (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info(
        "RNN: %s, num_layer=(%d, %d), input_dec=%d, hidden=%d, arc_space=%d, type_space=%d"
        % (mode, encoder_layers, decoder_layers, input_size_decoder,
           hidden_size, arc_space, type_space))
    logger.info(
        "train: cov: %.1f, (#data: %d, batch: %d, clip: %.2f, label_smooth: %.2f, unk_repl: %.2f)"
        % (cov, num_data, batch_size, clip, label_smooth, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))
    logger.info('prior order: %s, grand parent: %s, sibling: %s, ' %
                (prior_order, grandPar, sibling))
    logger.info('skip connect: %s, beam: %d' % (skipConnect, beam))
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    # lrate decay
    patient = 0
    decay = 0
    max_decay = args.max_decay
    double_schedule_decay = args.double_schedule_decay

    # lrate schedule
    step_num = 0
    use_warmup_schedule = args.use_warmup_schedule
    warmup_factor = (lr + 0.) / num_batches

    if use_warmup_schedule:
        logger.info("Use warmup lrate for the first epoch, from 0 up to %s." %
                    (lr, ))
    #
    for epoch in range(1, num_epochs + 1):
        logger.info(
            'Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f '
            '(schedule=%d, patient=%d, decay=%d (%d, %d))): ' %
            (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay,
             max_decay, double_schedule_decay))
        train_err_arc_leaf = 0.
        train_err_arc_non_leaf = 0.
        train_err_type_leaf = 0.
        train_err_type_non_leaf = 0.
        train_err_cov = 0.
        train_total_leaf = 0.
        train_total_non_leaf = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            # lrate schedule (before each step)
            step_num += 1
            if use_warmup_schedule and epoch <= 1:
                cur_lrate = warmup_factor * step_num
                # set lr
                for param_group in optim.param_groups:
                    param_group['lr'] = cur_lrate

            # train
            input_encoder, input_decoder = conllx_stacked_data.get_batch_stacked_variable(
                data_train, batch_size, unk_replace=unk_replace)
            word, char, pos, heads, types, masks_e, lengths_e = input_encoder
            stacked_heads, children, sibling, stacked_types, skip_connect, masks_d, lengths_d = input_decoder

            optim.zero_grad()
            loss_arc_leaf, loss_arc_non_leaf, \
            loss_type_leaf, loss_type_non_leaf, \
            loss_cov, num_leaf, num_non_leaf = network.loss(word, char, pos, heads, stacked_heads, children, sibling,
                                                            stacked_types, label_smooth,
                                                            skip_connect=skip_connect, mask_e=masks_e,
                                                            length_e=lengths_e, mask_d=masks_d, length_d=lengths_d)
            loss_arc = loss_arc_leaf + loss_arc_non_leaf
            loss_type = loss_type_leaf + loss_type_non_leaf
            loss = loss_arc + loss_type + cov * loss_cov
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_leaf = num_leaf.data[0]
            num_non_leaf = num_non_leaf.data[0]

            train_err_arc_leaf += loss_arc_leaf.data[0] * num_leaf
            train_err_arc_non_leaf += loss_arc_non_leaf.data[0] * num_non_leaf

            train_err_type_leaf += loss_type_leaf.data[0] * num_leaf
            train_err_type_non_leaf += loss_type_non_leaf.data[0] * num_non_leaf

            train_err_cov += loss_cov.data[0] * (num_leaf + num_non_leaf)

            train_total_leaf += num_leaf
            train_total_non_leaf += num_non_leaf

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                err_arc_leaf = train_err_arc_leaf / train_total_leaf
                err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
                err_arc = err_arc_leaf + err_arc_non_leaf

                err_type_leaf = train_err_type_leaf / train_total_leaf
                err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
                err_type = err_type_leaf + err_type_non_leaf

                err_cov = train_err_cov / (train_total_leaf +
                                           train_total_non_leaf)

                err = err_arc + err_type + cov * err_cov
                log_info = 'train: %d/%d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, err, err_arc, err_arc_leaf,
                    err_arc_non_leaf, err_type, err_type_leaf,
                    err_type_non_leaf, err_cov, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        err_arc_leaf = train_err_arc_leaf / train_total_leaf
        err_arc_non_leaf = train_err_arc_non_leaf / train_total_non_leaf
        err_arc = err_arc_leaf + err_arc_non_leaf

        err_type_leaf = train_err_type_leaf / train_total_leaf
        err_type_non_leaf = train_err_type_non_leaf / train_total_non_leaf
        err_type = err_type_leaf + err_type_non_leaf

        err_cov = train_err_cov / (train_total_leaf + train_total_non_leaf)

        err = err_arc + err_type + cov * err_cov
        logger.info(
            'train: %d loss (leaf, non_leaf): %.4f, arc: %.4f (%.4f, %.4f), type: %.4f (%.4f, %.4f), coverage: %.4f, time: %.2fs'
            % (num_batches, err, err_arc, err_arc_leaf, err_arc_non_leaf,
               err_type, err_type_leaf, err_type_non_leaf, err_cov,
               time.time() - start_time))

        ################################################################################################
        if epoch % args.check_dev != 0:
            continue

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                data_dev, batch_size):
            input_encoder, _ = batch
            word, char, pos, heads, types, masks, lengths = input_encoder
            heads_pred, types_pred, _, _ = network.decode(
                word,
                char,
                pos,
                mask=masks,
                length=lengths,
                beam=beam,
                leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word,
                              pos,
                              heads_pred,
                              types_pred,
                              lengths,
                              symbolic_root=True)
            gold_writer.write(word,
                              pos,
                              heads,
                              types,
                              lengths,
                              symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(
                word,
                pos,
                heads_pred,
                types_pred,
                heads,
                types,
                word_alphabet,
                pos_alphabet,
                lengths,
                punct_set=punct_set,
                symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print(
            'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total,
               dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 /
               dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print(
            'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%'
            % (dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc,
               dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc *
               100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 /
               dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %
              (dev_root_corr, dev_total_root,
               dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc < dev_lcorr_nopunc or (
                dev_lcorrect_nopunc == dev_lcorr_nopunc
                and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_stacked_data.iterate_batch_stacked_variable(
                    data_test, batch_size):
                input_encoder, _ = batch
                word, char, pos, heads, types, masks, lengths = input_encoder
                heads_pred, types_pred, _, _ = network.decode(
                    word,
                    char,
                    pos,
                    mask=masks,
                    length=lengths,
                    beam=beam,
                    leading_symbolic=conllx_stacked_data.NUM_SYMBOLIC_TAGS)

                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word,
                                  pos,
                                  heads_pred,
                                  types_pred,
                                  lengths,
                                  symbolic_root=True)
                gold_writer.write(word,
                                  pos,
                                  heads,
                                  types,
                                  lengths,
                                  symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(
                    word,
                    pos,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_alphabet,
                    pos_alphabet,
                    lengths,
                    punct_set=punct_set,
                    symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())
                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        logger.info(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        logger.info(
            'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect, dev_lcorrect, dev_total,
               dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
               dev_ucomlpete_match * 100 / dev_total_inst,
               dev_lcomplete_match * 100 / dev_total_inst, best_epoch))
        logger.info(
            'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
               dev_ucorrect_nopunc * 100 / dev_total_nopunc,
               dev_lcorrect_nopunc * 100 / dev_total_nopunc,
               dev_ucomlpete_match_nopunc * 100 / dev_total_inst,
               dev_lcomplete_match_nopunc * 100 / dev_total_inst, best_epoch))
        logger.info(
            'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
            (dev_root_correct, dev_total_root,
             dev_root_correct * 100 / dev_total_root, best_epoch))
        logger.info(
            '----------------------------------------------------------------------------------------------------------------------------'
        )
        logger.info(
            'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            % (test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 /
               test_total, test_lcorrect * 100 / test_total,
               test_ucomlpete_match * 100 / test_total_inst,
               test_lcomplete_match * 100 / test_total_inst, best_epoch))
        logger.info(
            'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
            %
            (test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
             test_ucorrect_nopunc * 100 / test_total_nopunc,
             test_lcorrect_nopunc * 100 / test_total_nopunc,
             test_ucomlpete_match_nopunc * 100 / test_total_inst,
             test_lcomplete_match_nopunc * 100 / test_total_inst, best_epoch))
        logger.info(
            'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' %
            (test_root_correct, test_total_root,
             test_root_correct * 100 / test_total_root, best_epoch))
        logger.info(
            '============================================================================================================================'
        )

        if decay == max_decay:
            break
Ejemplo n.º 21
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--num_layers',
                        type=int,
                        default=2,
                        help='Number of layers')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--bidirectional', default=True)
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.1,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--data_path')
    parser.add_argument('--modelname',
                        default="ASR_ERR_LSTM.json.pth.tar",
                        help='model name')
    parser.add_argument('--task',
                        default="MEDIA",
                        help='task name : MEDIA or ATIS')
    parser.add_argument('--optim',
                        default="SGD",
                        help=' Optimizer : SGD or ADAM')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"

    args = parser.parse_args()

    tim = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_file = '%s/log/log_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_optim_%s_lr_%f_tag_space_%s.txt' % (
        args.data_path, args.modelname, args.mode, args.num_epochs,
        args.batch_size, args.hidden_size, args.num_layers, args.optim,
        args.learning_rate, str(args.tag_space))
    logger = get_logger("SLU_BLSTM", log_file)

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    data_path = args.data_path
    bidirectional = args.bidirectional
    p = args.p
    unk_replace = args.unk_replace
    embedding = args.embedding
    embedding_path = args.embedding_dict
    out_path = args.data_path
    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)
    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, target_alphabet = slu_data.create_alphabets(
        '%s/data_dic' % (data_path),
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict,
        max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("Target Alphabet Size: %d" % target_alphabet.size())
    logger.info("Bidirectionnal %s" % bidirectional)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = slu_data.read_data_to_variable(train_path,
                                                word_alphabet,
                                                char_alphabet,
                                                target_alphabet,
                                                use_gpu=use_gpu)
    num_data = sum(data_train[1])
    num_labels = target_alphabet.size()
    print(" num_labels", num_labels)
    data_dev = slu_data.read_data_to_variable(dev_path,
                                              word_alphabet,
                                              char_alphabet,
                                              target_alphabet,
                                              use_gpu=use_gpu,
                                              volatile=True)
    data_test = slu_data.read_data_to_variable(test_path,
                                               word_alphabet,
                                               char_alphabet,
                                               target_alphabet,
                                               use_gpu=use_gpu,
                                               volatile=True)
    writer = SLUWriter(word_alphabet, char_alphabet, target_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[slu_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    print(" embedd_dim ", embedd_dim)
    if args.dropout == 'std':
        network = BiRecurrentConv2(embedd_dim,
                                   word_alphabet.size(),
                                   char_dim,
                                   char_alphabet.size(),
                                   num_filters,
                                   window,
                                   mode,
                                   hidden_size,
                                   num_layers,
                                   num_labels,
                                   tag_space=tag_space,
                                   embedd_word=word_table,
                                   p_rnn=p,
                                   bidirectional=bidirectional)
    else:
        network = BiVarRecurrentConv(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     num_filters,
                                     window,
                                     mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     p_rnn=p)
    print(network)
    if use_gpu:
        network.cuda()

    lr = learning_rate
    if args.optim == "SGD":
        optim = SGD(network.parameters(),
                    lr=lr,
                    momentum=momentum,
                    weight_decay=gamma,
                    nesterov=True)
    else:
        optim = Adam(network.parameters(),
                     lr=lr,
                     betas=(0.9, 0.9),
                     weight_decay=gamma)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d" %
        (mode, num_layers, hidden_size, num_filters, tag_space))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, dropout: %.2f, unk replace: %.2f)"
        % (gamma, num_data, batch_size, p, unk_replace))
    num_batches = num_data / batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    model_path = ""
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))
        train_err = 0.
        train_corr = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            #for batch_train in slu_data.iterate_batch_variable(data_train, batch_size):
            word, char, labels, masks, lengths = slu_data.get_batch_variable(
                data_train, batch_size, unk_replace=unk_replace)
            optim.zero_grad()
            loss, corr, _ = network.loss(
                word,
                char,
                labels,
                mask=masks,
                length=lengths,
                leading_symbolic=slu_data.NUM_SYMBOLIC_TAGS)
            loss.backward()
            optim.step()

            num_tokens = masks.data.sum()
            #train_err += loss.data * num_tokens
            train_err += loss.data[0] * num_tokens
            #train_corr += corr.data
            train_corr += corr.data[0]
            train_total += num_tokens
            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, acc: %.2f%%, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total,
                    train_corr * 100 / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)
            batch = batch + 1
        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))
        logger.info(
            'train: %d loss: %.4f, time: %.2fs' %
            (num_batches, train_err / train_total, time.time() - start_time))
        print('train: %d loss: %.4f, acc: %.2f%%, time: %.2fs' %
              (num_batches, train_err / train_total,
               train_corr * 100 / train_total, time.time() - start_time))
        loss_results = train_err / train_total
        # evaluate performance on dev data
        network.eval()
        tmp_filename = '%s/predictions/dev_%s_num_layers_%s_%s.txt' % (
            out_path, args.optim, str(args.num_layers), str(uid))
        writer.start(tmp_filename)
        all_target = []
        all_preds = []
        for batch in slu_data.iterate_batch_variable(data_dev, batch_size):
            word, char, labels, masks, lengths = batch
            _, _, preds = network.loss(
                word,
                char,
                labels,
                mask=masks,
                length=lengths,
                leading_symbolic=slu_data.NUM_SYMBOLIC_TAGS)
            writer.write(word.data.cpu().numpy(),
                         preds.data.cpu().numpy(),
                         labels.data.cpu().numpy(),
                         lengths.cpu().numpy())
        #    correct_tag, pred_tag=writer.tensor_to_list(preds.cpu().numpy(),labels.cpu().numpy(), lengths.cpu().numpy())
        #   all_target.extend(correct_tag)
        #  all_preds.extend(pred_tag)
        writer.close()
        # precision, recall,f1,acc=writer.evaluate(all_preds,all_target)
        acc, precision, recall, f1 = evaluate(tmp_filename, data_path, "dev",
                                              args.task, args.optim)
        print(
            'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
            (acc, precision, recall, f1))
        logger.info(
            'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' %
            (acc, precision, recall, f1))
        if dev_acc < acc:
            dev_f1 = f1
            dev_acc = acc
            dev_precision = precision
            dev_recall = recall
            best_epoch = epoch

            # save best model
            model_path = "%s/models/best_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_bestdevacc_%f_bestepoch_%d_optim_%s_lr_%f_tag_space_%s" % (
                args.data_path, args.modelname, mode, num_epochs, batch_size,
                hidden_size, args.num_layers, dev_acc, best_epoch, args.optim,
                args.learning_rate, str(tag_space))
            torch.save(network, model_path)

            # evaluate on test data when better performance detected
            """
            tmp_filename = '%s/tmp/%s_test%d' % (data_path,tim, epoch)
            writer.start(tmp_filename)

            for batch in slu_data.iterate_batch_variable(data_test, batch_size):
                word, features, sents, char, labels, masks, lengths = batch
                _, _, preds,probs = network.loss(features, char, labels, mask=masks, length=lengths,
                                              leading_symbolic=slu_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.data.cpu().numpy(),sents.data.cpu().numpy(),
                             preds.data.cpu().numpy(), probs.data.cpu().numpy(), labels.data.cpu().numpy(), lengths.cpu().numpy())
            writer.close()
            test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename, data_path,"test",tim)
            """

        logger.info(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        #        logger.info("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
        #           test_acc, test_precision, test_recall, test_f1, best_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            if args.optim == "SGD":
                optim = SGD(network.parameters(),
                            lr=lr,
                            momentum=momentum,
                            weight_decay=gamma,
                            nesterov=True)
            else:
                optim = Adam(network.parameters(),
                             lr=lr,
                             betas=(0.9, 0.9),
                             weight_decay=gamma)

    # end epoch
    # test evaluation
    # load model
    print("model path ", model_path)
    network = torch.load(model_path)
    if use_gpu:
        network.cuda()
    # mode eval
    network.eval()
    # evaluate on test dev when better performance detected
    tmp_filename = '%s/predictions/dev_best_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_bestdevacc_%f_bestF1_%f_bestepoch_%d_optim_%s_lr_%f_tag_space_%s' % (
        out_path, args.modelname, mode, num_epochs, batch_size, hidden_size,
        num_layers, dev_acc, dev_f1, best_epoch, args.optim,
        args.learning_rate, tag_space)

    #tmp_filename = '%s/predictions/dev_best_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_bestdevacc_%f_bestepoch_%d_optim_%s_lr_%f_tag_space_%s' % (out_path,args.modelname,mode,num_epochs,batch_size,hidden_size,num_layers,dev_acc,best_epoch,args.optim,args.learning_rate,tag_space)    #tmp_filename = '%s/predictions/dev_bestmodel_devacc_%f_epoch_%d' % (out_path,dev_acc, best_epoch)
    writer.start(tmp_filename)
    all_target = []
    all_preds = []
    for batch in slu_data.iterate_batch_variable(data_dev, batch_size):
        word, char, labels, masks, lengths = batch
        _, _, preds = network.loss(word,
                                   char,
                                   labels,
                                   mask=masks,
                                   length=lengths,
                                   leading_symbolic=slu_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     preds.data.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
    writer.close()
    dev_acc, dev_precision, dev_recall, dev_f1 = evaluate(
        tmp_filename, data_path, "dev", args.task, args.optim)

    # evaluate on test data when better performance detected
    tmp_filename = '%s/predictions/test_best_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_bestdevacc_%f_bestF1_%f_bestepoch_%d_optim_%s_lr_%f_tag_space_%s' % (
        out_path, args.modelname, mode, num_epochs, batch_size, hidden_size,
        num_layers, dev_acc, dev_f1, best_epoch, args.optim,
        args.learning_rate, tag_space)
    #    tmp_filename = '%s/predictions/test_best_model_%s_mode_%s_num_epochs_%d_batch_size_%d_hidden_size_%d_num_layers_%d_bestdevacc_%f_bestepoch_%d_optim_%s_lr_%f_tag_space_%s' % (out_path,args.modelname,mode,num_epochs,batch_size,hidden_size,num_layers,dev_acc,best_epoch, args.optim, args.learning_rate, tag_space)
    writer.start(tmp_filename)
    all_target = []
    all_preds = []
    for batch in slu_data.iterate_batch_variable(data_test, batch_size):
        word, char, labels, masks, lengths = batch
        _, _, preds = network.loss(word,
                                   char,
                                   labels,
                                   mask=masks,
                                   length=lengths,
                                   leading_symbolic=slu_data.NUM_SYMBOLIC_TAGS)
        writer.write(word.data.cpu().numpy(),
                     preds.data.cpu().numpy(),
                     labels.data.cpu().numpy(),
                     lengths.cpu().numpy())
    writer.close()
    test_acc, test_precision, test_recall, test_f1 = evaluate(
        tmp_filename, data_path, "test", args.task, args.optim)
    print(
        "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
        % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
    print(
        "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
        % (test_acc, test_precision, test_recall, test_f1, best_epoch))
    logger.info(
        "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
        % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
    logger.info(
        "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
        % (test_acc, test_precision, test_recall, test_f1, best_epoch))
Ejemplo n.º 22
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional RNN-CNN-CRF')
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        required=True)
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=128,
                        help='Number of hidden units in RNN')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_filters',
                        type=int,
                        default=30,
                        help='Number of filters in CNN')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--trig_dim',
                        type=int,
                        default=100,
                        help='Dimension of Trigger embeddings')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.015,
                        help='Learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--dropout',
                        choices=['std', 'variational'],
                        help='type of dropout',
                        required=True)
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')
    parser.add_argument('--schedule',
                        type=int,
                        help='schedule for learning rate decay')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'sskip', 'polyglot'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    # Arguments for provding where to get transfer learn data
    parser.add_argument('--t_train')
    parser.add_argument('--t_dev')
    parser.add_argument('--t_test')
    parser.add_argument(
        '--transfer',
        type=bool,
        default=True,
        help='Flag to activate transfer learning'
    )  # flag to either run the transfer learning module or not

    args = parser.parse_args()

    logger = get_logger("SRLCRF")

    mode = args.mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    transfer_train_path = args.t_train
    transfer_dev_path = args.t_dev
    transfer_test_path = args.t_test
    transfer = args.transfer
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p = args.p
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)
    ###############################################################################################################################
    # Load Data from CoNLL task for SRL and the Transfer Data
    # Create alphabets from BOTH SLR and Process Bank
    ###############################################################################################################################
    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, pos_alphabet, \
    chunk_alphabet, srl_alphabet, transfer_alphabet = srl_data.create_alphabets("data/alphabets/srl_crf/", train_path,
                                                                 data_paths=[dev_path, test_path],
                                                                 transfer_train_path = transfer_train_path,
                                                                 transfer_data_paths=
                                                                 [transfer_dev_path, transfer_test_path], transfer=transfer,
                                                                 embedd_dict=embedd_dict,
                                                                 max_vocabulary_size=55000
                                                                 )

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("SRL Alphabet Size: %d" % srl_alphabet.size())
    logger.info("Transfer Alphabet Size: %d" % transfer_alphabet.size())

    logger.info("Reading Data into Variables")
    use_gpu = torch.cuda.is_available()
Ejemplo n.º 23
0
def train(args):
    logger = get_logger("Parsing")

    args.cuda = torch.cuda.is_available()
    device = torch.device('cuda', 0) if args.cuda else torch.device('cpu')
    train_path = args.train
    dev_path = args.dev
    test_path = args.test

    num_epochs = args.num_epochs
    batch_size = args.batch_size
    optim = args.optim
    learning_rate = args.learning_rate
    lr_decay = args.lr_decay
    amsgrad = args.amsgrad
    eps = args.eps
    betas = (args.beta1, args.beta2)
    warmup_steps = args.warmup_steps
    weight_decay = args.weight_decay
    grad_clip = args.grad_clip

    loss_ty_token = args.loss_type == 'token'
    unk_replace = args.unk_replace
    freeze = args.freeze

    model_path = args.model_path
    model_name = os.path.join(model_path, 'model.pt')
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    print(args)

    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(
            char_embedding, char_path)
    else:
        char_dict = None
        char_dim = None

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets')
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(
        alphabet_path,
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=word_dict,
        max_vocabulary_size=200000)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    result_path = os.path.join(model_path, 'tmp')
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" %
                    (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(
            np.float32) if freeze else np.random.uniform(
                -scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(
                    np.float32) if freeze else np.random.uniform(
                        -scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale,
                                              [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    logger.info("constructing network...")

    hyps = json.load(open(args.config, 'r'))
    json.dump(hyps,
              open(os.path.join(model_path, 'config.json'), 'w'),
              indent=2)
    model_type = hyps['model']
    assert model_type in ['DeepBiAffine', 'NeuroMST', 'StackPtr']
    assert word_dim == hyps['word_dim']
    if char_dim is not None:
        assert char_dim == hyps['char_dim']
    else:
        char_dim = hyps['char_dim']
    use_pos = hyps['pos']
    pos_dim = hyps['pos_dim']
    mode = hyps['rnn_mode']
    hidden_size = hyps['hidden_size']
    arc_space = hyps['arc_space']
    type_space = hyps['type_space']
    p_in = hyps['p_in']
    p_out = hyps['p_out']
    p_rnn = hyps['p_rnn']
    activation = hyps['activation']
    prior_order = None

    alg = 'transition' if model_type == 'StackPtr' else 'graph'
    if model_type == 'DeepBiAffine':
        num_layers = hyps['num_layers']
        network = DeepBiAffine(word_dim,
                               num_words,
                               char_dim,
                               num_chars,
                               pos_dim,
                               num_pos,
                               mode,
                               hidden_size,
                               num_layers,
                               num_types,
                               arc_space,
                               type_space,
                               embedd_word=word_table,
                               embedd_char=char_table,
                               p_in=p_in,
                               p_out=p_out,
                               p_rnn=p_rnn,
                               pos=use_pos,
                               activation=activation)
    elif model_type == 'NeuroMST':
        num_layers = hyps['num_layers']
        network = NeuroMST(word_dim,
                           num_words,
                           char_dim,
                           num_chars,
                           pos_dim,
                           num_pos,
                           mode,
                           hidden_size,
                           num_layers,
                           num_types,
                           arc_space,
                           type_space,
                           embedd_word=word_table,
                           embedd_char=char_table,
                           p_in=p_in,
                           p_out=p_out,
                           p_rnn=p_rnn,
                           pos=use_pos,
                           activation=activation)
    elif model_type == 'StackPtr':
        encoder_layers = hyps['encoder_layers']
        decoder_layers = hyps['decoder_layers']
        num_layers = (encoder_layers, decoder_layers)
        prior_order = hyps['prior_order']
        grandPar = hyps['grandPar']
        sibling = hyps['sibling']
        network = StackPtrNet(word_dim,
                              num_words,
                              char_dim,
                              num_chars,
                              pos_dim,
                              num_pos,
                              mode,
                              hidden_size,
                              encoder_layers,
                              decoder_layers,
                              num_types,
                              arc_space,
                              type_space,
                              embedd_word=word_table,
                              embedd_char=char_table,
                              prior_order=prior_order,
                              activation=activation,
                              p_in=p_in,
                              p_out=p_out,
                              p_rnn=p_rnn,
                              pos=use_pos,
                              grandPar=grandPar,
                              sibling=sibling)
    else:
        raise RuntimeError('Unknown model type: %s' % model_type)

    if freeze:
        freeze_embedding(network.word_embed)

    network = network.to(device)
    model = "{}-{}".format(model_type, mode)
    logger.info("Network: %s, num_layer=%s, hidden=%d, act=%s" %
                (model, num_layers, hidden_size, activation))
    logger.info("dropout(in, out, rnn): %s(%.2f, %.2f, %s)" %
                ('variational', p_in, p_out, p_rnn))
    logger.info('# of Parameters: %d' %
                (sum([param.numel() for param in network.parameters()])))

    logger.info("Reading Data")
    if alg == 'graph':
        data_train = conllx_data.read_bucketed_data(train_path,
                                                    word_alphabet,
                                                    char_alphabet,
                                                    pos_alphabet,
                                                    type_alphabet,
                                                    symbolic_root=True)
        data_dev = conllx_data.read_data(dev_path,
                                         word_alphabet,
                                         char_alphabet,
                                         pos_alphabet,
                                         type_alphabet,
                                         symbolic_root=True)
        data_test = conllx_data.read_data(test_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          type_alphabet,
                                          symbolic_root=True)
    else:
        data_train = conllx_stacked_data.read_bucketed_data(
            train_path,
            word_alphabet,
            char_alphabet,
            pos_alphabet,
            type_alphabet,
            prior_order=prior_order)
        data_dev = conllx_stacked_data.read_data(dev_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 pos_alphabet,
                                                 type_alphabet,
                                                 prior_order=prior_order)
        data_test = conllx_stacked_data.read_data(test_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  pos_alphabet,
                                                  type_alphabet,
                                                  prior_order=prior_order)
    num_data = sum(data_train[1])
    logger.info("training: #training data: %d, batch: %d, unk replace: %.2f" %
                (num_data, batch_size, unk_replace))

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet,
                               type_alphabet)
    optimizer, scheduler = get_optimizer(network.parameters(), optim,
                                         learning_rate, lr_decay, betas, eps,
                                         amsgrad, weight_decay, warmup_steps)

    best_ucorrect = 0.0
    best_lcorrect = 0.0
    best_ucomlpete = 0.0
    best_lcomplete = 0.0

    best_ucorrect_nopunc = 0.0
    best_lcorrect_nopunc = 0.0
    best_ucomlpete_nopunc = 0.0
    best_lcomplete_nopunc = 0.0
    best_root_correct = 0.0
    best_total = 0
    best_total_nopunc = 0
    best_total_inst = 0
    best_total_root = 0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete = 0.0
    test_lcomplete = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_nopunc = 0.0
    test_lcomplete_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    patient = 0
    beam = args.beam
    reset = args.reset
    num_batches = num_data // batch_size + 1
    if optim == 'adam':
        opt_info = 'adam, betas=(%.1f, %.3f), eps=%.1e, amsgrad=%s' % (
            betas[0], betas[1], eps, amsgrad)
    else:
        opt_info = 'sgd, momentum=0.9, nesterov=True'
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = 0.
        train_arc_loss = 0.
        train_type_loss = 0.
        num_insts = 0
        num_words = 0
        num_back = 0
        num_nans = 0
        network.train()
        lr = scheduler.get_lr()[0]
        print(
            'Epoch %d (%s, lr=%.6f, lr decay=%.6f, grad clip=%.1f, l2=%.1e): '
            % (epoch, opt_info, lr, lr_decay, grad_clip, weight_decay))
        if args.cuda:
            torch.cuda.empty_cache()
        gc.collect()
        with torch.autograd.set_detect_anomaly(True):
            for step, data in enumerate(
                    iterate_data(data_train,
                                 batch_size,
                                 bucketed=True,
                                 unk_replace=unk_replace,
                                 shuffle=True)):
                optimizer.zero_grad()
                bert_words = data["BERT_WORD"].to(device)
                sub_word_idx = data["SUB_IDX"].to(device)
                words = data['WORD'].to(device)
                chars = data['CHAR'].to(device)
                postags = data['POS'].to(device)
                heads = data['HEAD'].to(device)
                nbatch = words.size(0)
                if alg == 'graph':
                    types = data['TYPE'].to(device)
                    masks = data['MASK'].to(device)
                    nwords = masks.sum() - nbatch
                    BERT = True
                    if BERT:
                        loss_arc, loss_type = network.loss(bert_words,
                                                           sub_word_idx,
                                                           words,
                                                           chars,
                                                           postags,
                                                           heads,
                                                           types,
                                                           mask=masks)
                    else:
                        loss_arc, loss_type = network.loss(words,
                                                           chars,
                                                           postags,
                                                           heads,
                                                           types,
                                                           mask=masks)
                else:
                    masks_enc = data['MASK_ENC'].to(device)
                    masks_dec = data['MASK_DEC'].to(device)
                    stacked_heads = data['STACK_HEAD'].to(device)
                    children = data['CHILD'].to(device)
                    siblings = data['SIBLING'].to(device)
                    stacked_types = data['STACK_TYPE'].to(device)
                    nwords = masks_enc.sum() - nbatch
                    loss_arc, loss_type = network.loss(words,
                                                       chars,
                                                       postags,
                                                       heads,
                                                       stacked_heads,
                                                       children,
                                                       siblings,
                                                       stacked_types,
                                                       mask_e=masks_enc,
                                                       mask_d=masks_dec)
                loss_arc = loss_arc.sum()
                loss_type = loss_type.sum()
                loss_total = loss_arc + loss_type

                # print("loss", loss_arc, loss_type, loss_total)
                if loss_ty_token:
                    loss = loss_total.div(nwords)
                else:
                    loss = loss_total.div(nbatch)
                loss.backward()
                if grad_clip > 0:
                    grad_norm = clip_grad_norm_(network.parameters(),
                                                grad_clip)
                else:
                    grad_norm = total_grad_norm(network.parameters())

                if math.isnan(grad_norm):
                    num_nans += 1
                else:
                    optimizer.step()
                    scheduler.step()

                    with torch.no_grad():
                        num_insts += nbatch
                        num_words += nwords
                        train_loss += loss_total.item()
                        train_arc_loss += loss_arc.item()
                        train_type_loss += loss_type.item()

                # update log
                if step % 100 == 0:
                    torch.cuda.empty_cache()
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    curr_lr = scheduler.get_lr()[0]
                    num_insts = max(num_insts, 1)
                    num_words = max(num_words, 1)
                    log_info = '[%d/%d (%.0f%%) lr=%.6f (%d)] loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f)' % (
                        step, num_batches, 100. * step / num_batches, curr_lr,
                        num_nans, train_loss / num_insts,
                        train_loss / num_words, train_arc_loss / num_insts,
                        train_arc_loss / num_words, train_type_loss /
                        num_insts, train_type_loss / num_words)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)

            sys.stdout.write("\b" * num_back)
            sys.stdout.write(" " * num_back)
            sys.stdout.write("\b" * num_back)
            print(
                'total: %d (%d), loss: %.4f (%.4f), arc: %.4f (%.4f), type: %.4f (%.4f), time: %.2fs'
                % (num_insts, num_words, train_loss / num_insts,
                   train_loss / num_words, train_arc_loss / num_insts,
                   train_arc_loss / num_words, train_type_loss / num_insts,
                   train_type_loss / num_words, time.time() - start_time))
            print('-' * 125)

            # evaluate performance on dev data
            with torch.no_grad():
                pred_filename = os.path.join(result_path, 'pred_dev%d' % epoch)
                pred_writer.start(pred_filename)
                gold_filename = os.path.join(result_path, 'gold_dev%d' % epoch)
                gold_writer.start(gold_filename)

                print('Evaluating dev:')
                dev_stats, dev_stats_nopunct, dev_stats_root = eval(
                    alg,
                    data_dev,
                    network,
                    pred_writer,
                    gold_writer,
                    punct_set,
                    word_alphabet,
                    pos_alphabet,
                    device,
                    beam=beam)

                pred_writer.close()
                gold_writer.close()

                dev_ucorr, dev_lcorr, dev_ucomlpete, dev_lcomplete, dev_total = dev_stats
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc, dev_total_nopunc = dev_stats_nopunct
                dev_root_corr, dev_total_root, dev_total_inst = dev_stats_root

                if best_ucorrect_nopunc + best_lcorrect_nopunc < dev_ucorr_nopunc + dev_lcorr_nopunc:
                    best_ucorrect_nopunc = dev_ucorr_nopunc
                    best_lcorrect_nopunc = dev_lcorr_nopunc
                    best_ucomlpete_nopunc = dev_ucomlpete_nopunc
                    best_lcomplete_nopunc = dev_lcomplete_nopunc

                    best_ucorrect = dev_ucorr
                    best_lcorrect = dev_lcorr
                    best_ucomlpete = dev_ucomlpete
                    best_lcomplete = dev_lcomplete

                    best_root_correct = dev_root_corr
                    best_total = dev_total
                    best_total_nopunc = dev_total_nopunc
                    best_total_root = dev_total_root
                    best_total_inst = dev_total_inst

                    best_epoch = epoch
                    patient = 0
                    torch.save(network.state_dict(), model_name)

                    pred_filename = os.path.join(result_path,
                                                 'pred_test%d' % epoch)
                    pred_writer.start(pred_filename)
                    gold_filename = os.path.join(result_path,
                                                 'gold_test%d' % epoch)
                    gold_writer.start(gold_filename)

                    print('Evaluating test:')
                    test_stats, test_stats_nopunct, test_stats_root = eval(
                        alg,
                        data_test,
                        network,
                        pred_writer,
                        gold_writer,
                        punct_set,
                        word_alphabet,
                        pos_alphabet,
                        device,
                        beam=beam)

                    test_ucorrect, test_lcorrect, test_ucomlpete, test_lcomplete, test_total = test_stats
                    test_ucorrect_nopunc, test_lcorrect_nopunc, test_ucomlpete_nopunc, test_lcomplete_nopunc, test_total_nopunc = test_stats_nopunct
                    test_root_correct, test_total_root, test_total_inst = test_stats_root

                    pred_writer.close()
                    gold_writer.close()
                else:
                    patient += 1

                print('-' * 125)
                print(
                    'best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (best_ucorrect, best_lcorrect, best_total,
                       best_ucorrect * 100 / best_total, best_lcorrect * 100 /
                       best_total, best_ucomlpete * 100 / dev_total_inst,
                       best_lcomplete * 100 / dev_total_inst, best_epoch))
                print(
                    'best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (best_ucorrect_nopunc, best_lcorrect_nopunc,
                       best_total_nopunc, best_ucorrect_nopunc * 100 /
                       best_total_nopunc, best_lcorrect_nopunc * 100 /
                       best_total_nopunc, best_ucomlpete_nopunc * 100 /
                       best_total_inst, best_lcomplete_nopunc * 100 /
                       best_total_inst, best_epoch))
                print(
                    'best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                    % (best_root_correct, best_total_root,
                       best_root_correct * 100 / best_total_root, best_epoch))
                print('-' * 125)
                print(
                    'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (test_ucorrect, test_lcorrect, test_total,
                       test_ucorrect * 100 / test_total, test_lcorrect * 100 /
                       test_total, test_ucomlpete * 100 / test_total_inst,
                       test_lcomplete * 100 / test_total_inst, best_epoch))
                print(
                    'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)'
                    % (test_ucorrect_nopunc, test_lcorrect_nopunc,
                       test_total_nopunc, test_ucorrect_nopunc * 100 /
                       test_total_nopunc, test_lcorrect_nopunc * 100 /
                       test_total_nopunc, test_ucomlpete_nopunc * 100 /
                       test_total_inst, test_lcomplete_nopunc * 100 /
                       test_total_inst, best_epoch))
                print(
                    'best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)'
                    % (test_root_correct, test_total_root,
                       test_root_correct * 100 / test_total_root, best_epoch))
                print('=' * 125)

                if patient >= reset:
                    logger.info('reset optimizer momentums')
                    network.load_state_dict(
                        torch.load(model_name, map_location=device))
                    scheduler.reset_state()
                    patient = 0
Ejemplo n.º 24
0
def main():
    # Arguments parser
    parser = argparse.ArgumentParser(
        description='Tuning with DNN Model for NER')
    # Model Hyperparameters
    parser.add_argument('--mode',
                        choices=['RNN', 'LSTM', 'GRU'],
                        help='architecture of rnn',
                        default='LSTM')
    parser.add_argument('--encoder_mode',
                        choices=['cnn', 'lstm'],
                        help='Encoder type for sentence encoding',
                        default='lstm')
    parser.add_argument('--char_method',
                        choices=['cnn', 'lstm'],
                        help='Method to create character-level embeddings',
                        required=True)
    parser.add_argument(
        '--hidden_size',
        type=int,
        default=128,
        help='Number of hidden units in RNN for sentence level')
    parser.add_argument('--char_hidden_size',
                        type=int,
                        default=30,
                        help='Output character-level embeddings size')
    parser.add_argument('--char_dim',
                        type=int,
                        default=30,
                        help='Dimension of Character embeddings')
    parser.add_argument('--tag_space',
                        type=int,
                        default=0,
                        help='Dimension of tag space')
    parser.add_argument('--num_layers',
                        type=int,
                        default=1,
                        help='Number of layers of RNN')
    parser.add_argument('--dropout',
                        choices=['std', 'weight_drop'],
                        help='Dropout method',
                        default='weight_drop')
    parser.add_argument('--p_em',
                        type=float,
                        default=0.33,
                        help='dropout rate for input embeddings')
    parser.add_argument('--p_in',
                        type=float,
                        default=0.33,
                        help='dropout rate for input of RNN model')
    parser.add_argument('--p_rnn',
                        nargs=2,
                        type=float,
                        required=True,
                        help='dropout rate for RNN')
    parser.add_argument('--p_out',
                        type=float,
                        default=0.33,
                        help='dropout rate for output layer')
    parser.add_argument('--bigram',
                        action='store_true',
                        help='bi-gram parameter for CRF')

    # Data loading and storing params
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--dataset_name',
                        type=str,
                        default='alexa',
                        help='Which dataset to use')
    parser.add_argument('--train',
                        type=str,
                        required=True,
                        help='Path of train set')
    parser.add_argument('--dev',
                        type=str,
                        required=True,
                        help='Path of dev set')
    parser.add_argument('--test',
                        type=str,
                        required=True,
                        help='Path of test set')
    parser.add_argument('--results_folder',
                        type=str,
                        default='results',
                        help='The folder to store results')
    parser.add_argument('--tmp_folder',
                        type=str,
                        default='tmp',
                        help='The folder to store tmp files')
    parser.add_argument('--alphabets_folder',
                        type=str,
                        default='data/alphabets',
                        help='The folder to store alphabets files')
    parser.add_argument('--result_file_name',
                        type=str,
                        default='hyperparameters_tuning',
                        help='File name to store some results')
    parser.add_argument('--result_file_path',
                        type=str,
                        default='results/hyperparameters_tuning',
                        help='File name to store some results')

    # Training parameters
    parser.add_argument('--cuda',
                        action='store_true',
                        help='whether using GPU')
    parser.add_argument('--num_epochs',
                        type=int,
                        default=100,
                        help='Number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of sentences in each batch')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.001,
                        help='Base learning rate')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.95,
                        help='Decay rate of learning rate')
    parser.add_argument('--schedule',
                        type=int,
                        default=3,
                        help='schedule for learning rate decay')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for l2 regularization')
    parser.add_argument('--max_norm',
                        type=float,
                        default=1.,
                        help='Max norm for gradients')
    parser.add_argument('--gpu_id',
                        type=int,
                        nargs='+',
                        required=True,
                        help='which gpu to use for training')

    # Misc
    parser.add_argument('--embedding',
                        choices=['glove', 'senna', 'alexa'],
                        help='Embedding for words',
                        required=True)
    parser.add_argument('--restore',
                        action='store_true',
                        help='whether restore from stored parameters')
    parser.add_argument('--save_checkpoint',
                        type=str,
                        default='',
                        help='the path to save the model')
    parser.add_argument('--o_tag',
                        type=str,
                        default='O',
                        help='The default tag for outside tag')
    parser.add_argument('--unk_replace',
                        type=float,
                        default=0.,
                        help='The rate to replace a singleton word with UNK')
    parser.add_argument('--evaluate_raw_format',
                        action='store_true',
                        help='The tagging format for evaluation')

    args = parser.parse_args()

    logger = get_logger("NERCRF")

    # rename the parameters
    mode = args.mode
    encoder_mode = args.encoder_mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    char_hidden_size = args.char_hidden_size
    char_method = args.char_method
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    max_norm = args.max_norm
    schedule = args.schedule
    dropout = args.dropout
    p_em = args.p_em
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict
    dataset_name = args.dataset_name
    result_file_name = args.result_file_name
    evaluate_raw_format = args.evaluate_raw_format
    o_tag = args.o_tag
    restore = args.restore
    save_checkpoint = args.save_checkpoint
    gpu_id = args.gpu_id
    results_folder = args.results_folder
    tmp_folder = args.tmp_folder
    alphabets_folder = args.alphabets_folder
    use_elmo = False
    p_em_vec = 0.
    result_file_path = args.result_file_path

    score_file = "%s/score_gpu_%s" % (tmp_folder, '-'.join(map(str, gpu_id)))

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)
    if not os.path.exists(tmp_folder):
        os.makedirs(tmp_folder)
    if not os.path.exists(alphabets_folder):
        os.makedirs(alphabets_folder)

    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, ner_alphabet = conll03_data.create_alphabets(
        "{}/{}/".format(alphabets_folder, dataset_name),
        train_path,
        data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict,
        max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    print(device)

    data_train = conll03_data.read_data_to_tensor(train_path,
                                                  word_alphabet,
                                                  char_alphabet,
                                                  ner_alphabet,
                                                  device=device)
    num_data = sum(data_train[1])
    num_labels = ner_alphabet.size()

    data_dev = conll03_data.read_data_to_tensor(dev_path,
                                                word_alphabet,
                                                char_alphabet,
                                                ner_alphabet,
                                                device=device)
    data_test = conll03_data.read_data_to_tensor(test_path,
                                                 word_alphabet,
                                                 char_alphabet,
                                                 ner_alphabet,
                                                 device=device)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform_
    if args.dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim,
                                     word_alphabet.size(),
                                     char_dim,
                                     char_alphabet.size(),
                                     char_hidden_size,
                                     window,
                                     mode,
                                     encoder_mode,
                                     hidden_size,
                                     num_layers,
                                     num_labels,
                                     tag_space=tag_space,
                                     embedd_word=word_table,
                                     use_elmo=use_elmo,
                                     p_em_vec=p_em_vec,
                                     p_em=p_em,
                                     p_in=p_in,
                                     p_out=p_out,
                                     p_rnn=p_rnn,
                                     bigram=bigram,
                                     initializer=initializer)
    elif args.dropout == 'var':
        network = BiVarRecurrentConvCRF(embedd_dim,
                                        word_alphabet.size(),
                                        char_dim,
                                        char_alphabet.size(),
                                        char_hidden_size,
                                        window,
                                        mode,
                                        encoder_mode,
                                        hidden_size,
                                        num_layers,
                                        num_labels,
                                        tag_space=tag_space,
                                        embedd_word=word_table,
                                        use_elmo=use_elmo,
                                        p_em_vec=p_em_vec,
                                        p_em=p_em,
                                        p_in=p_in,
                                        p_out=p_out,
                                        p_rnn=p_rnn,
                                        bigram=bigram,
                                        initializer=initializer)
    else:
        network = BiWeightDropRecurrentConvCRF(embedd_dim,
                                               word_alphabet.size(),
                                               char_dim,
                                               char_alphabet.size(),
                                               char_hidden_size,
                                               window,
                                               mode,
                                               encoder_mode,
                                               hidden_size,
                                               num_layers,
                                               num_labels,
                                               tag_space=tag_space,
                                               embedd_word=word_table,
                                               p_em=p_em,
                                               p_in=p_in,
                                               p_out=p_out,
                                               p_rnn=p_rnn,
                                               bigram=bigram,
                                               initializer=initializer)

    network = network.to(device)

    lr = learning_rate
    optim = SGD(network.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=gamma,
                nesterov=True)
    # optim = Adam(network.parameters(), lr=lr, weight_decay=gamma, amsgrad=True)
    nn.utils.clip_grad_norm_(network.parameters(), max_norm)
    logger.info("Network: %s, encoder_mode=%s, num_layer=%d, hidden=%d, char_hidden_size=%d, char_method=%s, tag_space=%d, crf=%s" % \
        (mode, encoder_mode, num_layers, hidden_size, char_hidden_size, char_method, tag_space, 'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)"
        % (gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" %
                (p_in, p_out, p_rnn))

    num_batches = num_data // batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    best_test_f1 = 0.0
    best_test_acc = 0.0
    best_test_precision = 0.0
    best_test_recall = 0.0
    best_test_epoch = 0.0
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, args.dropout, lr, decay_rate, schedule))

        train_err = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            _, word, char, labels, masks, lengths = conll03_data.get_batch_tensor(
                data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss = network.loss(_, word, char, labels, mask=masks)
            loss.backward()
            optim.step()

            with torch.no_grad():
                num_inst = word.size(0)
                train_err += loss * num_inst
                train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (num_batches, train_err / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            tmp_filename = '%s/gpu_%s_dev' % (tmp_folder, '-'.join(
                map(str, gpu_id)))
            writer.start(tmp_filename)

            for batch in conll03_data.iterate_batch_tensor(
                    data_dev, batch_size):
                _, word, char, labels, masks, lengths = batch
                preds, _ = network.decode(
                    _,
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(word.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            acc, precision, recall, f1 = evaluate(tmp_filename, score_file,
                                                  evaluate_raw_format, o_tag)
            print(
                'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))

            if dev_f1 < f1:
                dev_f1 = f1
                dev_acc = acc
                dev_precision = precision
                dev_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                tmp_filename = '%s/gpu_%s_test' % (tmp_folder, '-'.join(
                    map(str, gpu_id)))
                writer.start(tmp_filename)

                for batch in conll03_data.iterate_batch_tensor(
                        data_test, batch_size):
                    _, word, char, labels, masks, lengths = batch
                    preds, _ = network.decode(
                        _,
                        word,
                        char,
                        target=labels,
                        mask=masks,
                        leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                    writer.write(word.cpu().numpy(),
                                 preds.cpu().numpy(),
                                 labels.cpu().numpy(),
                                 lengths.cpu().numpy())
                writer.close()
                test_acc, test_precision, test_recall, test_f1 = evaluate(
                    tmp_filename, score_file, evaluate_raw_format, o_tag)
                if best_test_f1 < test_f1:
                    best_test_acc, best_test_precision, best_test_recall, best_test_f1 = test_acc, test_precision, test_recall, test_f1
                    best_test_epoch = epoch

            print(
                "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
            print(
                "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (test_acc, test_precision, test_recall, test_f1, best_epoch))
            print(
                "overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (best_test_acc, best_test_precision, best_test_recall,
                   best_test_f1, best_test_epoch))

        if epoch % schedule == 0:
            lr = learning_rate / (1.0 + epoch * decay_rate)
            optim = SGD(network.parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=gamma,
                        nesterov=True)

    with open(result_file_path, 'a') as ofile:
        ofile.write(
            "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n"
            % (dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        ofile.write(
            "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n"
            % (test_acc, test_precision, test_recall, test_f1, best_epoch))
        ofile.write(
            "overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n\n"
            % (best_test_acc, best_test_precision, best_test_recall,
               best_test_f1, best_test_epoch))
    print('Training finished!')
Ejemplo n.º 25
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with graph-based parsing')
    args_parser.add_argument('--test_phase', action='store_true', help='Load trained model and run testing phase.')
    args_parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'], help='architecture of rnn', required=True)
    args_parser.add_argument('--cuda', action='store_true', help='using GPU')
    args_parser.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--type_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    args_parser.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
    args_parser.add_argument('--pos', action='store_true', help='use part-of-speech embedding.')
    args_parser.add_argument('--char', action='store_true', help='use character embedding and CNN.')
    args_parser.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
    args_parser.add_argument('--opt', choices=['adam', 'sgd', 'adamax'], help='optimization algorithm')
    args_parser.add_argument('--objective', choices=['cross_entropy', 'crf'], default='cross_entropy', help='objective function of training procedure.')
    args_parser.add_argument('--decode', choices=['mst', 'greedy'], help='decoding algorithm', required=True)
    args_parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
    args_parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    args_parser.add_argument('--epsilon', type=float, default=1e-8, help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    args_parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--word_embedding', choices=['glove', 'senna', 'sskip', 'polyglot'], help='Embedding for words', required=True)
    args_parser.add_argument('--word_path', help='path for word embedding dict')
    args_parser.add_argument('--freeze', action='store_true', help='frozen the word embedding (disable fine-tuning).')
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters', required=True)
    args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path', help='path for saving model file.', required=True)
    args_parser.add_argument('--model_name', help='name for saving model file.', required=True)

    args = args_parser.parse_args()

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    model_name = args.model_name
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    opt = args.opt
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = args.epsilon
    decay_rate = args.decay_rate
    clip = args.clip
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    freeze = args.freeze
    word_embedding = args.word_embedding
    word_path = args.word_path

    use_char = args.char
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path],
                                                                                             max_vocabulary_size=100000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')

    data_train = conllx_data.read_data_to_tensor(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True, device=device)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])

    data_dev = conllx_data.read_data_to_tensor(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True, device=device)
    data_test = conllx_data.read_data_to_tensor(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, symbolic_root=True, device=device)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                                          mode, hidden_size, num_layers, num_types, arc_space, type_space,
                                          embedd_word=word_table, embedd_char=char_table,
                                          p_in=p_in, p_out=p_out, p_rnn=p_rnn, biaffine=True, pos=use_pos, char=use_char)
    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                     mode, hidden_size, num_layers, num_types, arc_space, type_space]
        kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char}
        json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w'), indent=4)

    if freeze:
        freeze_embedding(network.word_embedd)

    network = network.to(device)

    save_args()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        elif opt == 'sgd':
            return SGD(params, lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
        elif opt == 'adamax':
            return Adamax(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)
        else:
            raise ValueError('Unknown optimization algorithm: %s' % opt)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)
    elif opt == 'sgd':
        opt_info += 'momentum=%.2f' % momentum
    elif opt == 'adamax':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info("Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" % (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info("RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info("train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)" % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = 9
    double_schedule_decay = 5

    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): ' % (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_tensor(data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word, char, pos, heads, types, mask=masks, length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm_(network.parameters(), clip)
            optim.step()

            with torch.no_grad():
                num_inst = word.size(0) if obj == 'crf' else masks.sum() - word.size(0)
                train_err += loss * num_inst
                train_err_arc += loss_arc * num_inst
                train_err_type += loss_type * num_inst
                train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (batch, num_batches, train_err / train_total,
                                                                                                 train_err_arc / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' % (num_batches, train_err / train_total,
                                                                            train_err_arc / train_total, train_err_type / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            dev_ucorr = 0.0
            dev_lcorr = 0.0
            dev_total = 0
            dev_ucomlpete = 0.0
            dev_lcomplete = 0.0
            dev_ucorr_nopunc = 0.0
            dev_lcorr_nopunc = 0.0
            dev_total_nopunc = 0
            dev_ucomlpete_nopunc = 0.0
            dev_lcomplete_nopunc = 0.0
            dev_root_corr = 0.0
            dev_total_root = 0.0
            dev_total_inst = 0.0
            for batch in conllx_data.iterate_batch_tensor(data_dev, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.cpu().numpy()
                pos = pos.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.cpu().numpy()
                types = types.cpu().numpy()

                pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                        word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                dev_ucorr += ucorr
                dev_lcorr += lcorr
                dev_total += total
                dev_ucomlpete += ucm
                dev_lcomplete += lcm

                dev_ucorr_nopunc += ucorr_nopunc
                dev_lcorr_nopunc += lcorr_nopunc
                dev_total_nopunc += total_nopunc
                dev_ucomlpete_nopunc += ucm_nopunc
                dev_lcomplete_nopunc += lcm_nopunc

                dev_root_corr += corr_root
                dev_total_root += total_root

                dev_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()
            print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total,
                dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
            print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
                dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
                dev_lcorr_nopunc * 100 / dev_total_nopunc,
                dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
            print('Root: corr: %d, total: %d, acc: %.2f%%' %(dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

            if dev_lcorrect_nopunc < dev_lcorr_nopunc or (dev_lcorrect_nopunc == dev_lcorr_nopunc and dev_ucorrect_nopunc < dev_ucorr_nopunc):
                dev_ucorrect_nopunc = dev_ucorr_nopunc
                dev_lcorrect_nopunc = dev_lcorr_nopunc
                dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
                dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

                dev_ucorrect = dev_ucorr
                dev_lcorrect = dev_lcorr
                dev_ucomlpete_match = dev_ucomlpete
                dev_lcomplete_match = dev_lcomplete

                dev_root_correct = dev_root_corr

                best_epoch = epoch
                patient = 0
                # torch.save(network, model_name)
                torch.save(network.state_dict(), model_name)

                pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
                pred_writer.start(pred_filename)
                gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
                gold_writer.start(gold_filename)

                test_ucorrect = 0.0
                test_lcorrect = 0.0
                test_ucomlpete_match = 0.0
                test_lcomplete_match = 0.0
                test_total = 0

                test_ucorrect_nopunc = 0.0
                test_lcorrect_nopunc = 0.0
                test_ucomlpete_match_nopunc = 0.0
                test_lcomplete_match_nopunc = 0.0
                test_total_nopunc = 0
                test_total_inst = 0

                test_root_correct = 0.0
                test_total_root = 0
                for batch in conllx_data.iterate_batch_tensor(data_test, batch_size):
                    word, char, pos, heads, types, masks, lengths = batch
                    heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                    word = word.cpu().numpy()
                    pos = pos.cpu().numpy()
                    lengths = lengths.cpu().numpy()
                    heads = heads.cpu().numpy()
                    types = types.cpu().numpy()

                    pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                    gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                    stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types,
                                                                            word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                    ucorr, lcorr, total, ucm, lcm = stats
                    ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                    corr_root, total_root = stats_root

                    test_ucorrect += ucorr
                    test_lcorrect += lcorr
                    test_total += total
                    test_ucomlpete_match += ucm
                    test_lcomplete_match += lcm

                    test_ucorrect_nopunc += ucorr_nopunc
                    test_lcorrect_nopunc += lcorr_nopunc
                    test_total_nopunc += total_nopunc
                    test_ucomlpete_match_nopunc += ucm_nopunc
                    test_lcomplete_match_nopunc += lcm_nopunc

                    test_root_correct += corr_root
                    test_total_root += total_root

                    test_total_inst += num_inst

                pred_writer.close()
                gold_writer.close()
            else:
                if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                    # network = torch.load(model_name)
                    network.load_state_dict(torch.load(model_name))
                    lr = lr * decay_rate
                    optim = generate_optimizer(opt, lr, network.parameters())

                    if decoding == 'greedy':
                        decode = network.decode
                    elif decoding == 'mst':
                        decode = network.decode_mst
                    else:
                        raise ValueError('Unknown decoding algorithm: %s' % decoding)

                    patient = 0
                    decay += 1
                    if decay % double_schedule_decay == 0:
                        schedule *= 2
                else:
                    patient += 1

            print('----------------------------------------------------------------------------------------------------------------------------')
            print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
                dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
                best_epoch))
            print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
                dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
                dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
                best_epoch))
            print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
                dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
            print('----------------------------------------------------------------------------------------------------------------------------')
            print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
                test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
                best_epoch))
            print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
                test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
                test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
                test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
                best_epoch))
            print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
                test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
            print('============================================================================================================================')

            if decay == max_decay:
                break
Ejemplo n.º 26
0
def main():
    # Arguments parser
    parser = argparse.ArgumentParser(description='Tuning with DNN Model for NER')
    # Model Hyperparameters
    parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU'], help='architecture of rnn', default='LSTM')
    parser.add_argument('--encoder_mode', choices=['cnn', 'lstm'], help='Encoder type for sentence encoding',
                        default='lstm')
    parser.add_argument('--char_method', choices=['cnn', 'lstm'], help='Method to create character-level embeddings',
                        required=True)
    parser.add_argument('--hidden_size', type=int, default=128, help='Number of hidden units in RNN for sentence level')
    parser.add_argument('--char_hidden_size', type=int, default=30, help='Output character-level embeddings size')
    parser.add_argument('--char_dim', type=int, default=30, help='Dimension of Character embeddings')
    parser.add_argument('--tag_space', type=int, default=0, help='Dimension of tag space')
    parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    parser.add_argument('--dropout', choices=['std', 'gcn'], help='Dropout method',
                        default='gcn')
    parser.add_argument('--p_em', type=float, default=0.33, help='dropout rate for input embeddings')
    parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input of RNN model')
    parser.add_argument('--p_rnn', nargs=3, type=float, required=True, help='dropout rate for RNN')
    parser.add_argument('--p_tag', type=float, default=0.33, help='dropout rate for output layer')
    parser.add_argument('--bigram', action='store_true', help='bi-gram parameter for CRF')

    parser.add_argument('--adj_attn', choices=['cossim', 'flex_cossim', 'flex_cossim2', 'concat', '', 'multihead'],
                        default='')

    # Data loading and storing params
    parser.add_argument('--embedding_dict', help='path for embedding dict')
    parser.add_argument('--dataset_name', type=str, default='alexa', help='Which dataset to use')
    parser.add_argument('--train', type=str, required=True, help='Path of train set')
    parser.add_argument('--dev', type=str, required=True, help='Path of dev set')
    parser.add_argument('--test', type=str, required=True, help='Path of test set')
    parser.add_argument('--results_folder', type=str, default='results', help='The folder to store results')
    parser.add_argument('--alphabets_folder', type=str, default='data/alphabets',
                        help='The folder to store alphabets files')

    # Training parameters
    parser.add_argument('--cuda', action='store_true', help='whether using GPU')
    parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Number of sentences in each batch')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Base learning rate')
    parser.add_argument('--decay_rate', type=float, default=0.95, help='Decay rate of learning rate')
    parser.add_argument('--schedule', type=int, default=3, help='schedule for learning rate decay')
    parser.add_argument('--gamma', type=float, default=0.0, help='weight for l2 regularization')
    parser.add_argument('--max_norm', type=float, default=1., help='Max norm for gradients')
    parser.add_argument('--gpu_id', type=int, nargs='+', required=True, help='which gpu to use for training')

    parser.add_argument('--learning_rate_gcn', type=float, default=5e-4, help='Base learning rate')
    parser.add_argument('--gcn_warmup', type=int, default=200, help='Base learning rate')
    parser.add_argument('--pretrain_lstm', type=float, default=10, help='Base learning rate')

    parser.add_argument('--adj_loss_lambda', type=float, default=0.)
    parser.add_argument('--lambda1', type=float, default=1.)
    parser.add_argument('--lambda2', type=float, default=0.)
    parser.add_argument('--seed', type=int, default=None)

    # Misc
    parser.add_argument('--embedding', choices=['glove', 'senna', 'alexa'], help='Embedding for words', required=True)
    parser.add_argument('--restore', action='store_true', help='whether restore from stored parameters')
    parser.add_argument('--save_checkpoint', type=str, default='', help='the path to save the model')
    parser.add_argument('--o_tag', type=str, default='O', help='The default tag for outside tag')
    parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    parser.add_argument('--evaluate_raw_format', action='store_true', help='The tagging format for evaluation')


    parser.add_argument('--eval_type', type=str, default="micro_f1",choices=['micro_f1', 'acc'])
    parser.add_argument('--show_network', action='store_true', help='whether to display the network structure')
    parser.add_argument('--smooth', action='store_true', help='whether to skip all pdb break points')

    parser.add_argument('--uid', type=str, default='temp')
    parser.add_argument('--misc', type=str, default='')

    args = parser.parse_args()
    show_var(['args'])

    uid = args.uid
    results_folder = args.results_folder
    dataset_name = args.dataset_name
    use_tensorboard = True

    save_dset_dir = '{}../dset/{}/graph'.format(results_folder, dataset_name)
    result_file_path = '{}/{dataset}_{uid}_result'.format(results_folder, dataset=dataset_name, uid=uid)

    save_loss_path = '{}/{dataset}_{uid}_loss'.format(results_folder, dataset=dataset_name, uid=uid)
    save_lr_path = '{}/{dataset}_{uid}_lr'.format(results_folder, dataset=dataset_name, uid='temp')
    save_tb_path = '{}/tensorboard/'.format(results_folder)

    logger = get_logger("NERCRF")
    loss_recorder = LossRecorder(uid=uid)
    record = TensorboardLossRecord(use_tensorboard, save_tb_path, uid=uid)

    # rename the parameters
    mode = args.mode
    encoder_mode = args.encoder_mode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    char_hidden_size = args.char_hidden_size
    char_method = args.char_method
    learning_rate = args.learning_rate
    momentum = 0.9
    decay_rate = args.decay_rate
    gamma = args.gamma
    max_norm = args.max_norm
    schedule = args.schedule
    dropout = args.dropout
    p_em = args.p_em
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_tag = args.p_tag
    unk_replace = args.unk_replace
    bigram = args.bigram
    embedding = args.embedding
    embedding_path = args.embedding_dict
    evaluate_raw_format = args.evaluate_raw_format
    o_tag = args.o_tag
    restore = args.restore
    save_checkpoint = args.save_checkpoint
    alphabets_folder = args.alphabets_folder
    use_elmo = False
    p_em_vec = 0.
    graph_model = 'gnn'
    coref_edge_filt = ''

    learning_rate_gcn = args.learning_rate_gcn
    gcn_warmup = args.gcn_warmup
    pretrain_lstm = args.pretrain_lstm

    adj_loss_lambda = args.adj_loss_lambda
    lambda1 = args.lambda1
    lambda2 = args.lambda2

    if args.smooth:
        import pdb
        pdb.set_trace = lambda: None

    misc = "{}".format(str(args.misc))

    score_file = "{}/{dataset}_{uid}_score".format(results_folder, dataset=dataset_name, uid=uid)

    for folder in [results_folder, alphabets_folder, save_dset_dir]:
        if not os.path.exists(folder):
            os.makedirs(folder)

    def set_seed(seed):
        if not seed:
            seed = int(show_time())
        print("[Info] seed set to: {}".format(seed))
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    set_seed(args.seed)

    embedd_dict, embedd_dim = utils.load_embedding_dict(embedding, embedding_path)

    logger.info("Creating Alphabets")
    word_alphabet, char_alphabet, ner_alphabet = conll03_data.create_alphabets(
        "{}/{}/".format(alphabets_folder, dataset_name), train_path, data_paths=[dev_path, test_path],
        embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())

    logger.info("Reading Data")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    print(device)

    data_train = conll03_data.read_data(train_path, word_alphabet, char_alphabet,
                                        ner_alphabet,
                                        graph_model, batch_size, ori_order=False,
                                        total_batch="{}x".format(num_epochs + 1),
                                        unk_replace=unk_replace, device=device,
                                        save_path=save_dset_dir + '/train', coref_edge_filt=coref_edge_filt
                                        )
    # , shuffle=True,
    num_data = data_train.data_len
    num_labels = ner_alphabet.size()
    graph_types = data_train.meta_info['graph_types']

    data_dev = conll03_data.read_data(dev_path, word_alphabet, char_alphabet,
                                      ner_alphabet,
                                      graph_model, batch_size, ori_order=True, unk_replace=unk_replace, device=device,
                                      save_path=save_dset_dir + '/dev',
                                      coref_edge_filt=coref_edge_filt)

    data_test = conll03_data.read_data(test_path, word_alphabet, char_alphabet,
                                       ner_alphabet,
                                       graph_model, batch_size, ori_order=True, unk_replace=unk_replace, device=device,
                                       save_path=save_dset_dir + '/test',
                                       coref_edge_filt=coref_edge_filt)

    writer = CoNLL03Writer(word_alphabet, char_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[conll03_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                embedding = embedd_dict[word]
            elif word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    char_dim = args.char_dim
    window = 3
    num_layers = args.num_layers
    tag_space = args.tag_space
    initializer = nn.init.xavier_uniform_

    p_gcn = [0.5, 0.5]

    d_graph = 256
    d_out = 256
    d_inner_hid = 128
    d_k = 32
    d_v = 32
    n_head = 4
    n_gcn_layer = 1

    p_rnn2 = [0.0, 0.5, 0.5]

    adj_attn = args.adj_attn
    mask_singles = True
    post_lstm = 1
    position_enc_mode = 'none'

    adj_memory = False

    if dropout == 'gcn':
        network = BiRecurrentConvGraphCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(),
                                          char_hidden_size, window, mode, encoder_mode, hidden_size, num_layers,
                                          num_labels,
                                          graph_model, n_head, d_graph, d_inner_hid, d_k, d_v, p_gcn, n_gcn_layer,
                                          d_out, post_lstm=post_lstm, mask_singles=mask_singles,
                                          position_enc_mode=position_enc_mode, adj_attn=adj_attn,
                                          adj_loss_lambda=adj_loss_lambda,
                                          tag_space=tag_space, embedd_word=word_table,
                                          use_elmo=use_elmo, p_em_vec=p_em_vec, p_em=p_em, p_in=p_in, p_tag=p_tag,
                                          p_rnn=p_rnn, p_rnn2=p_rnn2,
                                          bigram=bigram, initializer=initializer)

    elif dropout == 'std':
        network = BiRecurrentConvCRF(embedd_dim, word_alphabet.size(), char_dim, char_alphabet.size(), char_hidden_size,
                                     window, mode, encoder_mode, hidden_size, num_layers, num_labels,
                                     tag_space=tag_space, embedd_word=word_table, use_elmo=use_elmo, p_em_vec=p_em_vec,
                                     p_em=p_em, p_in=p_in, p_tag=p_tag, p_rnn=p_rnn, bigram=bigram,
                                     initializer=initializer)

    # whether restore from trained model
    if restore:
        network.load_state_dict(torch.load(save_checkpoint + '_best.pth'))  # load trained model

    logger.info("cuda()ing network...")

    network = network.to(device)

    if dataset_name == 'conll03' and data_dev.data_len > 26:
        sample = data_dev.pad_batch(data_dev.dataset[25:26])
    else:
        sample = data_dev.pad_batch(data_dev.dataset[:1])
    plot_att_change(sample, network, record, save_tb_path + 'att/', uid='temp', epoch=0, device=device,
                    word_alphabet=word_alphabet, show_net=args.show_network,
                    graph_types=data_train.meta_info['graph_types'])

    logger.info("finished cuda()ing network...")

    lr = learning_rate
    lr_gcn = learning_rate_gcn
    optim = Optimizer('sgd', 'adam', network, dropout, lr=learning_rate,
                      lr_gcn=learning_rate_gcn,
                      wd=0., wd_gcn=0., momentum=momentum, lr_decay=decay_rate, schedule=schedule,
                      gcn_warmup=gcn_warmup,
                      pretrain_lstm=pretrain_lstm)
    nn.utils.clip_grad_norm_(network.parameters(), max_norm)
    logger.info(
        "Network: %s, encoder_mode=%s, num_layer=%d, hidden=%d, char_hidden_size=%d, char_method=%s, tag_space=%d, crf=%s" % \
        (mode, encoder_mode, num_layers, hidden_size, char_hidden_size, char_method, tag_space,
         'bigram' if bigram else 'unigram'))
    logger.info("training: l2: %f, (#training data: %d, batch: %d, unk replace: %.2f)" % (
        gamma, num_data, batch_size, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_tag, p_rnn))

    num_batches = num_data // batch_size + 1
    dev_f1 = 0.0
    dev_acc = 0.0
    dev_precision = 0.0
    dev_recall = 0.0
    test_f1 = 0.0
    test_acc = 0.0
    test_precision = 0.0
    test_recall = 0.0
    best_epoch = 0
    best_test_f1 = 0.0
    best_test_acc = 0.0
    best_test_precision = 0.0
    best_test_recall = 0.0
    best_test_epoch = 0.0

    loss_recorder.start(save_loss_path, mode='w', misc=misc)
    fwrite('', save_lr_path)
    fwrite(json.dumps(vars(args)) + '\n', result_file_path)

    for epoch in range(1, num_epochs + 1):
        show_var(['misc'])

        lr_state = 'Epoch %d (uid=%s, lr=%.2E, lr_gcn=%.2E, decay rate=%.4f): ' % (
            epoch, uid, Decimal(optim.curr_lr), Decimal(optim.curr_lr_gcn), decay_rate)
        print(lr_state)
        fwrite(lr_state[:-2] + '\n', save_lr_path, mode='a')

        train_err = 0.
        train_err2 = 0.
        train_total = 0.

        start_time = time.time()
        num_back = 0
        network.train()
        for batch_i in range(1, num_batches + 1):

            batch_doc = data_train.next()
            char, word, posi, labels, feats, adjs, words_en = [batch_doc[i] for i in [
                "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]

            sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                word, char, labels)

            optim.zero_grad()

            adjs_into_model = adjs if adj_memory else adjs.clone()

            loss, (ner_loss, adj_loss) = network.loss(None, word, char, adjs_into_model, labels,
                                                      graph_types=graph_types, lambda1=lambda1, lambda2=lambda2)

            # loss = network.loss(_, sent_word, sent_char, sent_labels, mask=sent_mask)
            loss.backward()
            optim.step()

            with torch.no_grad():
                num_inst = sent_mask.size(0)
                train_err += ner_loss * num_inst
                train_err2 += adj_loss * num_inst
                train_total += num_inst

            time_ave = (time.time() - start_time) / batch_i
            time_left = (num_batches - batch_i) * time_ave

            # update log
            if batch_i % 20 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss1: %.4f, loss2: %.4f, time left (estimated): %.2fs' % (
                    batch_i, num_batches, train_err / train_total, train_err2 / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

            optim.update(epoch, batch_i, num_batches, network)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, loss2: %.4f, time: %.2fs' % (
            num_batches, train_err / train_total, train_err2 / train_total, time.time() - start_time))

        # evaluate performance on dev data
        with torch.no_grad():
            network.eval()
            tmp_filename = "{}/{dataset}_{uid}_output_dev".format(results_folder, dataset=dataset_name, uid=uid)

            writer.start(tmp_filename)

            for batch in data_dev:
                char, word, posi, labels, feats, adjs, words_en = [batch[i] for i in [
                    "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]
                sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                    word, char, labels)

                preds, _ = network.decode(
                    None, word, char, adjs.clone(), target=labels, leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS,
                    graph_types=graph_types)
                # preds, _ = network.decode(_, sent_word, sent_char, target=sent_labels, mask=sent_mask,
                #                           leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)
                writer.write(sent_word.cpu().numpy(), preds.cpu().numpy(), sent_labels.cpu().numpy(),
                             sent_length.cpu().numpy())
            writer.close()


            if args.eval_type == "acc":
                acc, precision, recall, f1 =evaluate_tokenacc(tmp_filename)
                f1 = acc
            else:
                acc, precision, recall, f1 = evaluate(tmp_filename, score_file, evaluate_raw_format, o_tag)

            print('dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%' % (acc, precision, recall, f1))

            # plot loss and attention
            record.plot_loss(epoch, train_err / train_total, f1)

            plot_att_change(sample, network, record, save_tb_path + 'att/', uid="{}_{:03d}".format(uid, epoch),
                            epoch=epoch, device=device,
                            word_alphabet=word_alphabet, show_net=False, graph_types=graph_types)

            if dev_f1 < f1:
                dev_f1 = f1
                dev_acc = acc
                dev_precision = precision
                dev_recall = recall
                best_epoch = epoch

                # evaluate on test data when better performance detected
                tmp_filename = "{}/{dataset}_{uid}_output_test".format(results_folder, dataset=dataset_name,
                                                                       uid=uid)
                writer.start(tmp_filename)

                for batch in data_test:
                    char, word, posi, labels, feats, adjs, words_en = [batch[i] for i in [
                        "chars", "word_ids", "posi", "ner_ids", "feat_ids", "adjs", "words_en"]]
                    sent_word, sent_char, sent_labels, sent_mask, sent_length, _ = network._doc2sent(
                        word, char, labels)

                    preds, _ = network.decode(
                        None, word, char, adjs.clone(), target=labels, leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS,
                        graph_types=graph_types)
                    # preds, _ = network.decode(_, sent_word, sent_char, target=sent_labels, mask=sent_mask,
                    #                           leading_symbolic=conll03_data.NUM_SYMBOLIC_TAGS)

                    writer.write(sent_word.cpu().numpy(), preds.cpu().numpy(), sent_labels.cpu().numpy(),
                                 sent_length.cpu().numpy())
                writer.close()

                if args.eval_type == "acc":
                    test_acc, test_precision, test_recall, test_f1 = evaluate_tokenacc(tmp_filename)
                    test_f1 = test_acc
                else:
                    test_acc, test_precision, test_recall, test_f1 = evaluate(tmp_filename, score_file, evaluate_raw_format,																		  o_tag)

                if best_test_f1 < test_f1:
                    best_test_acc, best_test_precision, best_test_recall, best_test_f1 = test_acc, test_precision, test_recall, test_f1
                    best_test_epoch = epoch

                # save the model parameters
                if save_checkpoint:
                    torch.save(network.state_dict(), save_checkpoint + '_best.pth')

            print("best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
            print("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                test_acc, test_precision, test_recall, test_f1, best_epoch))
            print("overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)" % (
                best_test_acc, best_test_precision, best_test_recall, best_test_f1, best_test_epoch))

        # optim.update(epoch, 1, num_batches, network)
        loss_recorder.write(epoch, train_err / train_total, train_err2 / train_total,
                            Decimal(optim.curr_lr), Decimal(optim.curr_lr_gcn), f1, best_test_f1, test_f1)
    with open(result_file_path, 'a') as ofile:
        ofile.write("best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n" % (
            dev_acc, dev_precision, dev_recall, dev_f1, best_epoch))
        ofile.write("best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n" % (
            test_acc, test_precision, test_recall, test_f1, best_epoch))
        ofile.write("overall best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)\n\n" % (
            best_test_acc, best_test_precision, best_test_recall, best_test_f1, best_test_epoch))

    record.close()

    print('Training finished!')
Ejemplo n.º 27
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with graph-based parsing')
    args_parser.add_argument('--mode', choices=['RNN', 'LSTM', 'GRU', 'FastLSTM'],
                             help='architecture of rnn', required=True)
    args_parser.add_argument('--num_epochs', type=int, default=200, help='Number of training epochs')
    args_parser.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size', type=int, default=256, help='Number of hidden units in RNN')
    args_parser.add_argument('--arc_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--type_space', type=int, default=128, help='Dimension of tag space')
    args_parser.add_argument('--num_layers', type=int, default=1, help='Number of layers of RNN')
    args_parser.add_argument('--num_filters', type=int, default=50, help='Number of filters in CNN')
    args_parser.add_argument('--pos', action='store_true', help='use part-of-speech embedding.')
    args_parser.add_argument('--pos_dim', type=int, default=50, help='Dimension of POS embeddings')
    args_parser.add_argument('--char_dim', type=int, default=50, help='Dimension of Character embeddings')
    args_parser.add_argument('--objective', choices=['cross_entropy', 'crf'], default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--decode', choices=['mst', 'greedy'], help='decoding algorithm', required=True)
    args_parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    args_parser.add_argument('--decay_rate', type=float, default=0.05, help='Decay rate of learning rate')
    args_parser.add_argument('--gamma', type=float, default=0.0, help='weight for regularization')
    args_parser.add_argument('--p_rnn', nargs=2, type=float, required=True, help='dropout rate for RNN')
    args_parser.add_argument('--p_in', type=float, default=0.33, help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out', type=float, default=0.33, help='dropout rate for output layer')
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0.,
                             help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--word_embedding', choices=['glove', 'senna', 'sskip', 'polyglot'],
                             help='Embedding for words', required=True)
    args_parser.add_argument('--word_path', help='path for word embedding dict')
    args_parser.add_argument('--char_embedding', choices=['random', 'polyglot'], help='Embedding for characters',
                             required=True)
    args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument('--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument('--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument('--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path', help='path for saving model file.', required=True)

    args = args_parser.parse_args()

    print("*** Model UID: %s ***" % uid)

    logger = get_logger("GraphParser")

    mode = args.mode
    obj = args.objective
    decoding = args.decode
    train_path = args.train
    dev_path = args.dev
    test_path = args.test
    model_path = args.model_path
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    arc_space = args.arc_space
    type_space = args.type_space
    num_layers = args.num_layers
    num_filters = args.num_filters
    learning_rate = args.learning_rate
    momentum = 0.9
    betas = (0.9, 0.9)
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule
    p_rnn = tuple(args.p_rnn)
    p_in = args.p_in
    p_out = args.p_out
    unk_replace = args.unk_replace
    punctuation = args.punctuation

    word_embedding = args.word_embedding
    word_path = args.word_path
    char_embedding = args.char_embedding
    char_path = args.char_path

    use_pos = args.pos
    pos_dim = args.pos_dim
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = args.char_dim
    if char_embedding != 'random':
        char_dict, char_dim = utils.load_embedding_dict(char_embedding, char_path)

    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path],
                                                                                             max_vocabulary_size=50000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()

    data_train = conllx_data.read_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, symbolic_root=True)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])

    data_dev = conllx_data.read_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)
    data_test = conllx_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    def construct_char_embedding_table():
        if char_dict is None:
            return None

        scale = np.sqrt(3.0 / char_dim)
        table = np.empty([num_chars, char_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
        oov = 0
        for char, index, in char_alphabet.items():
            if char in char_dict:
                embedding = char_dict[char]
            else:
                embedding = np.random.uniform(-scale, scale, [1, char_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('character OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    char_table = construct_char_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim, num_words,
                                          char_dim, num_chars,
                                          pos_dim, num_pos,
                                          num_filters, window,
                                          mode, hidden_size, num_layers,
                                          num_types, arc_space, type_space,
                                          embedd_word=word_table, embedd_char=char_table,
                                          p_in=p_in, p_out=p_out, p_rnn=p_rnn, biaffine=True, pos=use_pos)
    elif obj == 'crf':
        raise NotImplementedError
    else:
        raise RuntimeError('Unknown objective: %s' % obj)

    if use_gpu:
        network.cuda()

    pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)

    adam_epochs = 50
    adam_rate = 0.001
    if adam_epochs > 0:
        lr = adam_rate
        opt = 'adam'
        optim = Adam(network.parameters(), lr=adam_rate, betas=betas, weight_decay=gamma)
    else:
        opt = 'sgd'
        lr = learning_rate
        optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)

    logger.info("Embedding dim: word=%d, char=%d, pos=%d (%s)" % (word_dim, char_dim, pos_dim, use_pos))
    logger.info("Network: %s, num_layer=%d, hidden=%d, filter=%d, arc_space=%d, type_space=%d" % (
        mode, num_layers, hidden_size, num_filters, arc_space, type_space))
    logger.info("train: obj: %s, l2: %f, (#data: %d, batch: %d, dropout(in, out, rnn): (%.2f, %.2f, %s), unk replace: %.2f)" % (
        obj, gamma, num_data, batch_size, p_in, p_out, p_rnn, unk_replace))
    logger.info("decoding algorithm: %s" % decoding)

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    for epoch in range(1, num_epochs + 1):
        print('Epoch %d (%s, optim: %s, learning rate=%.4f, decay rate=%.4f (schedule=%d)): ' % (
            epoch, mode, opt, lr, decay_rate, schedule))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(data_train, batch_size,
                                                                                           unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word, char, pos, heads, types, mask=masks, length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            optim.step()

            num_inst = word.size(0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left (estimated): %.2fs' % (
                    batch, num_batches, train_err / train_total,
                    train_err_arc / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' % (
            num_batches, train_err / train_total, train_err_arc / train_total, train_err_type / train_total,
            time.time() - start_time))

        # evaluate performance on dev data
        network.eval()
        pred_filename = 'tmp/%spred_dev%d' % (str(uid), epoch)
        pred_writer.start(pred_filename)
        gold_filename = 'tmp/%sgold_dev%d' % (str(uid), epoch)
        gold_writer.start(gold_filename)

        print('[%s] Epoch %d complete' % (time.strftime("%Y-%m-%d %H:%M:%S"), epoch))

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
            gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        pred_writer.close()
        gold_writer.close()
        print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total,
            dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %(
            dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        if dev_ucorrect_nopunc <= dev_ucorr_nopunc:
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch

            pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            pred_writer.start(pred_filename)
            gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            pred_writer.close()
            gold_writer.close()

        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
            test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
            best_epoch))
        print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
            test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
            test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
            best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
        print('============================================================================================================================')

        if epoch % schedule == 0:
            # lr = lr * decay_rate
            if epoch < adam_epochs:
                opt = 'adam'
                lr = adam_rate / (1.0 + epoch * decay_rate)
                optim = Adam(network.parameters(), lr=lr, betas=betas, weight_decay=gamma)
            else:
                opt = 'sgd'
                lr = learning_rate / (1.0 + (epoch - adam_epochs) * decay_rate)
                optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
Ejemplo n.º 28
0
def main():
    args_parser = argparse.ArgumentParser(description='Tuning with graph-based parsing')
    args_parser.add_argument('--schedule', type=int, help='schedule for learning rate decay')
    args_parser.add_argument('--unk_replace', type=float, default=0., help='The rate to replace a singleton word with UNK')
    args_parser.add_argument('--freeze', action='store_true', help='frozen the word embedding (disable fine-tuning).')
    

    args = args_parser.parse_args()

    logger = get_logger("GraphParser")

    mode = "FastLSTM" #fast lstm here
    obj = "cross_entropy"
    decoding = "mst" #mst decode here 
    train_path = "data/train.stanford.conll"
    dev_path = "data/dev.stanford.conll"
    test_path = "data/test.stanford.conll"
    model_path = "models/parsing/biaffine/"
    model_name = 'network.pt'
    num_epochs = 80
    batch_size = 32
    hidden_size = 512
    arc_space = 512
    type_space = 128
    num_layers = 10
    num_filters = 1
    learning_rate = 0.001
    opt = "adam" #default adam
    momentum = 0.9
    betas = (0.9, 0.9)
    eps = 1e-4
    decay_rate = 0.75
    clip = 5 #what is clip
    gamma = 0
    schedule = 10 #?What is this?
    p_rnn = (0.05,0.05)
    p_in = 0.33
    p_out = 0.33
    unk_replace = args.unk_replace# ?what is this?
    punctuation = ['.','``', "''", ':', ',']

    freeze = args.freeze
    word_embedding = 'glove'
    word_path = "data/glove.6B.100d.txt"

    use_char = False
    char_embedding = None
    #char_path = args.char_path

    use_pos = True
    pos_dim = 100
    word_dict, word_dim = utils.load_embedding_dict(word_embedding, word_path)
    char_dict = None
    char_dim = 0
    logger.info("Creating Alphabets")
    alphabet_path = os.path.join(model_path, 'alphabets/')
    model_name = os.path.join(model_path, model_name)
    word_alphabet, char_alphabet, pos_alphabet, type_alphabet = conllx_data.create_alphabets(alphabet_path, train_path, data_paths=[dev_path, test_path],
                                                                                             max_vocabulary_size=50000, embedd_dict=word_dict)

    num_words = word_alphabet.size()
    num_chars = char_alphabet.size()
    num_pos = pos_alphabet.size()
    num_types = type_alphabet.size()
    #print(word_alphabet.instance2index)

    logger.info("Word Alphabet Size: %d" % num_words)
    logger.info("Character Alphabet Size: %d" % num_chars)
    logger.info("POS Alphabet Size: %d" % num_pos)
    logger.info("Type Alphabet Size: %d" % num_types)

    logger.info("Reading Data")
    use_gpu = torch.cuda.is_available()
    print(use_gpu)

    data_train = conllx_data.read_data_to_variable(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, symbolic_root=True)
    # data_train = conllx_data.read_data(train_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    # num_data = sum([len(bucket) for bucket in data_train])
    num_data = sum(data_train[1])
    """
    print("bucket_size")
    print(data_train[1])
    print("___________________________________data_train")
    print(data_train[0])
	"""
    data_dev = conllx_data.read_data_to_variable(dev_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)
    data_test = conllx_data.read_data_to_variable(test_path, word_alphabet, char_alphabet, pos_alphabet, type_alphabet, use_gpu=use_gpu, volatile=True, symbolic_root=True)

    punct_set = None
    if punctuation is not None:
        punct_set = set(punctuation)
        logger.info("punctuations(%d): %s" % (len(punct_set), ' '.join(punct_set)))

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / word_dim)
        table = np.empty([word_alphabet.size(), word_dim], dtype=np.float32)
        table[conllx_data.UNK_ID, :] = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if word in word_dict:
                embedding = word_dict[word]
            elif word.lower() in word_dict:
                embedding = word_dict[word.lower()]
            else:
                embedding = np.zeros([1, word_dim]).astype(np.float32) if freeze else np.random.uniform(-scale, scale, [1, word_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('word OOV: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()

    window = 3
    if obj == 'cross_entropy':
        network = BiRecurrentConvBiAffine(word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                                          mode, hidden_size, num_layers, num_types, arc_space, type_space,
                                          embedd_word=word_table, embedd_char=None,
                                          p_in=p_in, p_out=p_out, p_rnn=p_rnn, biaffine=True, pos=use_pos, char=use_char)
    def save_args():
        arg_path = model_name + '.arg.json'
        arguments = [word_dim, num_words, char_dim, num_chars, pos_dim, num_pos, num_filters, window,
                     mode, hidden_size, num_layers, num_types, arc_space, type_space]
        kwargs = {'p_in': p_in, 'p_out': p_out, 'p_rnn': p_rnn, 'biaffine': True, 'pos': use_pos, 'char': use_char}
        json.dump({'args': arguments, 'kwargs': kwargs}, open(arg_path, 'w'), indent=4)

    if freeze:
        network.word_embedd.freeze()

    if use_gpu:
        network.cuda()

    save_args()
    
    #pred_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    #gold_writer = CoNLLXWriter(word_alphabet, char_alphabet, pos_alphabet, type_alphabet)
    
    ##print parameters:
    print("number of parameters")

    num_param = sum([param.nelement() for param in network.parameters()])
    print(num_param)

    def generate_optimizer(opt, lr, params):
        params = filter(lambda param: param.requires_grad, params)
        if opt == 'adam':
            return Adam(params, lr=lr, betas=betas, weight_decay=gamma, eps=eps)

    lr = learning_rate
    optim = generate_optimizer(opt, lr, network.parameters())
    opt_info = 'opt: %s, ' % opt
    if opt == 'adam':
        opt_info += 'betas=%s, eps=%.1e' % (betas, eps)

    word_status = 'frozen' if freeze else 'fine tune'
    char_status = 'enabled' if use_char else 'disabled'
    pos_status = 'enabled' if use_pos else 'disabled'
    logger.info("Embedding dim: word=%d (%s), char=%d (%s), pos=%d (%s)" % (word_dim, word_status, char_dim, char_status, pos_dim, pos_status))
    logger.info("CNN: filter=%d, kernel=%d" % (num_filters, window))
    logger.info("RNN: %s, num_layer=%d, hidden=%d, arc_space=%d, type_space=%d" % (mode, num_layers, hidden_size, arc_space, type_space))
    logger.info("train: obj: %s, l2: %f, (#data: %d, batch: %d, clip: %.2f, unk replace: %.2f)" % (obj, gamma, num_data, batch_size, clip, unk_replace))
    logger.info("dropout(in, out, rnn): (%.2f, %.2f, %s)" % (p_in, p_out, p_rnn))
    logger.info("decoding algorithm: %s" % decoding)
    logger.info(opt_info)

    #logger.info("Attention")

    num_batches = num_data / batch_size + 1
    dev_ucorrect = 0.0
    dev_lcorrect = 0.0
    dev_ucomlpete_match = 0.0
    dev_lcomplete_match = 0.0

    dev_ucorrect_nopunc = 0.0
    dev_lcorrect_nopunc = 0.0
    dev_ucomlpete_match_nopunc = 0.0
    dev_lcomplete_match_nopunc = 0.0
    dev_root_correct = 0.0

    best_epoch = 0

    test_ucorrect = 0.0
    test_lcorrect = 0.0
    test_ucomlpete_match = 0.0
    test_lcomplete_match = 0.0

    test_ucorrect_nopunc = 0.0
    test_lcorrect_nopunc = 0.0
    test_ucomlpete_match_nopunc = 0.0
    test_lcomplete_match_nopunc = 0.0
    test_root_correct = 0.0
    test_total = 0
    test_total_nopunc = 0
    test_total_inst = 0
    test_total_root = 0

    if decoding == 'greedy':
        decode = network.decode
    elif decoding == 'mst':
        decode = network.decode_mst
    else:
        raise ValueError('Unknown decoding algorithm: %s' % decoding)

    patient = 0
    decay = 0
    max_decay = 9
    double_schedule_decay = 5

    f = open("testout.csv", "wt")
    writer = csv.writer(f)
    writer.writerow(('train', 'dev'))

    for epoch in range(1, num_epochs + 1):
        print(epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay)
        print('Epoch %d (%s, optim: %s, learning rate=%.6f, eps=%.1e, decay rate=%.2f (schedule=%d, patient=%d, decay=%d)): ' % (epoch, mode, opt, lr, eps, decay_rate, schedule, patient, decay))
        train_err = 0.
        train_err_arc = 0.
        train_err_type = 0.
        train_total = 0.
        start_time = time.time()
        num_back = 0
        network.train()
        for batch in range(1, num_batches + 1):
            word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_variable(data_train, batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss_arc, loss_type = network.loss(word, char, pos, heads, types, mask=masks, length=lengths)
            loss = loss_arc + loss_type
            loss.backward()
            clip_grad_norm(network.parameters(), clip)
            optim.step()

            num_inst = word.size(0) if obj == 'crf' else masks.data.sum() - word.size(0)
            train_err += loss.data[0] * num_inst
            train_err_arc += loss_arc.data[0] * num_inst
            train_err_type += loss_type.data[0] * num_inst
            train_total += num_inst
            #bp()

            time_ave = (time.time() - start_time) / batch
            time_left = (num_batches - batch) * time_ave

            # update log
            if batch % 10 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, arc: %.4f, type: %.4f, time left: %.2fs' % (batch, num_batches, train_err / train_total,
                                                                                                 train_err_arc / train_total, train_err_type / train_total, time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, arc: %.4f, type: %.4f, time: %.2fs' % (num_batches, train_err / train_total, train_err_arc / train_total, train_err_type / train_total, time.time() - start_time))

        # evaluate performance on dev data
        network.eval()

        dev_ucorr = 0.0
        dev_lcorr = 0.0
        dev_total = 0
        dev_ucomlpete = 0.0
        dev_lcomplete = 0.0
        dev_ucorr_nopunc = 0.0
        dev_lcorr_nopunc = 0.0
        dev_total_nopunc = 0
        dev_ucomlpete_nopunc = 0.0
        dev_lcomplete_nopunc = 0.0
        dev_root_corr = 0.0
        dev_total_root = 0.0
        dev_total_inst = 0.0
        t_ucorr = 0.0
        t_lcorr = 0.0
        t_total = 0
        t_ucomlpete = 0.0
        t_lcomplete = 0.0
        t_ucorr_nopunc = 0.0
        t_lcorr_nopunc = 0.0
        t_total_nopunc = 0
        t_ucomlpete_nopunc = 0.0
        t_lcomplete_nopunc = 0.0
        t_root_corr = 0.0
        t_total_root = 0.0
        t_total_inst = 0.0

        list_iter = iter(conllx_data.iterate_batch_variable(data_train, batch_size))
        for batch in list_iter:
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root
            #print(t_ucorr)
            t_ucorr += ucorr
            t_lcorr += lcorr
            t_total += total
            t_ucomlpete += ucm
            t_lcomplete += lcm

            t_ucorr_nopunc += ucorr_nopunc
            t_lcorr_nopunc += lcorr_nopunc
            t_total_nopunc += total_nopunc
            t_ucomlpete_nopunc += ucm_nopunc
            t_lcomplete_nopunc += lcm_nopunc

            t_root_corr += corr_root
            t_total_root += total_root

            t_total_inst += num_inst
            for _ in range(10):   
                next(list_iter, None)

        for batch in conllx_data.iterate_batch_variable(data_dev, batch_size):
            word, char, pos, heads, types, masks, lengths = batch
            heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
            word = word.data.cpu().numpy()
            pos = pos.data.cpu().numpy()
            lengths = lengths.cpu().numpy()
            heads = heads.data.cpu().numpy()
            types = types.data.cpu().numpy()

            stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
            ucorr, lcorr, total, ucm, lcm = stats
            ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
            corr_root, total_root = stats_root

            dev_ucorr += ucorr
            dev_lcorr += lcorr
            dev_total += total
            dev_ucomlpete += ucm
            dev_lcomplete += lcm

            dev_ucorr_nopunc += ucorr_nopunc
            dev_lcorr_nopunc += lcorr_nopunc
            dev_total_nopunc += total_nopunc
            dev_ucomlpete_nopunc += ucm_nopunc
            dev_lcomplete_nopunc += lcm_nopunc

            dev_root_corr += corr_root
            dev_total_root += total_root

            dev_total_inst += num_inst

        writer.writerow((t_ucorr_nopunc*100/t_total_nopunc,dev_ucorr_nopunc*100/dev_total_nopunc)) 
        f.flush()
        #pred_writer.close()
        #gold_writer.close()
        print('Train Wo Punct:%.2f%%'% (t_ucorr_nopunc*100/t_total_nopunc))
        print('W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst))
        print('Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % (
            dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc,
            dev_lcorr_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst))
        print('Root: corr: %d, total: %d, acc: %.2f%%' %(dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root))

        if dev_lcorrect_nopunc< dev_lcorr_nopunc or (dev_lcorrect_nopunc == dev_lcorr_nopunc and dev_ucorrect_nopunc < dev_ucorr_nopunc):
            dev_ucorrect_nopunc = dev_ucorr_nopunc
            dev_lcorrect_nopunc = dev_lcorr_nopunc
            dev_ucomlpete_match_nopunc = dev_ucomlpete_nopunc
            dev_lcomplete_match_nopunc = dev_lcomplete_nopunc

            dev_ucorrect = dev_ucorr
            dev_lcorrect = dev_lcorr
            dev_ucomlpete_match = dev_ucomlpete
            dev_lcomplete_match = dev_lcomplete

            dev_root_correct = dev_root_corr

            best_epoch = epoch
            patient = 0
            # torch.save(network, model_name)
            torch.save(network.state_dict(), model_name)

            #pred_filename = 'tmp/%spred_test%d' % (str(uid), epoch)
            #pred_writer.start(pred_filename)
            #gold_filename = 'tmp/%sgold_test%d' % (str(uid), epoch)
            #gold_writer.start(gold_filename)

            test_ucorrect = 0.0
            test_lcorrect = 0.0
            test_ucomlpete_match = 0.0
            test_lcomplete_match = 0.0
            test_total = 0

            test_ucorrect_nopunc = 0.0
            test_lcorrect_nopunc = 0.0
            test_ucomlpete_match_nopunc = 0.0
            test_lcomplete_match_nopunc = 0.0
            test_total_nopunc = 0
            test_total_inst = 0

            test_root_correct = 0.0
            test_total_root = 0
            for batch in conllx_data.iterate_batch_variable(data_test, batch_size):
                word, char, pos, heads, types, masks, lengths = batch
                heads_pred, types_pred = decode(word, char, pos, mask=masks, length=lengths, leading_symbolic=conllx_data.NUM_SYMBOLIC_TAGS)
                word = word.data.cpu().numpy()
                pos = pos.data.cpu().numpy()
                lengths = lengths.cpu().numpy()
                heads = heads.data.cpu().numpy()
                types = types.data.cpu().numpy()

                #pred_writer.write(word, pos, heads_pred, types_pred, lengths, symbolic_root=True)
                #gold_writer.write(word, pos, heads, types, lengths, symbolic_root=True)

                stats, stats_nopunc, stats_root, num_inst = parser.eval(word, pos, heads_pred, types_pred, heads, types, word_alphabet, pos_alphabet, lengths, punct_set=punct_set, symbolic_root=True)
                ucorr, lcorr, total, ucm, lcm = stats
                ucorr_nopunc, lcorr_nopunc, total_nopunc, ucm_nopunc, lcm_nopunc = stats_nopunc
                corr_root, total_root = stats_root

                test_ucorrect += ucorr
                test_lcorrect += lcorr
                test_total += total
                test_ucomlpete_match += ucm
                test_lcomplete_match += lcm

                test_ucorrect_nopunc += ucorr_nopunc
                test_lcorrect_nopunc += lcorr_nopunc
                test_total_nopunc += total_nopunc
                test_ucomlpete_match_nopunc += ucm_nopunc
                test_lcomplete_match_nopunc += lcm_nopunc

                test_root_correct += corr_root
                test_total_root += total_root

                test_total_inst += num_inst

            #pred_writer.close()
            #gold_writer.close()
        else:
            if dev_ucorr_nopunc * 100 / dev_total_nopunc < dev_ucorrect_nopunc * 100 / dev_total_nopunc - 5 or patient >= schedule:
                # network = torch.load(model_name)
                network.load_state_dict(torch.load(model_name))
                lr = lr * decay_rate
                optim = generate_optimizer(opt, lr, network.parameters())

                if decoding == 'greedy':
                    decode = network.decode
                elif decoding == 'mst':
                    decode = network.decode_mst
                else:
                    raise ValueError('Unknown decoding algorithm: %s' % decoding)

                patient = 0
                decay += 1
                if decay % double_schedule_decay == 0:
                    schedule *= 2
            else:
                patient += 1

        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best dev  W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect, dev_lcorrect, dev_total, dev_ucorrect * 100 / dev_total, dev_lcorrect * 100 / dev_total,
            dev_ucomlpete_match * 100 / dev_total_inst, dev_lcomplete_match * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            dev_ucorrect_nopunc, dev_lcorrect_nopunc, dev_total_nopunc,
            dev_ucorrect_nopunc * 100 / dev_total_nopunc, dev_lcorrect_nopunc * 100 / dev_total_nopunc,
            dev_ucomlpete_match_nopunc * 100 / dev_total_inst, dev_lcomplete_match_nopunc * 100 / dev_total_inst,
            best_epoch))
        print('best dev  Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            dev_root_correct, dev_total_root, dev_root_correct * 100 / dev_total_root, best_epoch))
        print('----------------------------------------------------------------------------------------------------------------------------')
        print('best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total,
            test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst,
            best_epoch))
        print('best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% (epoch: %d)' % (
            test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc,
            test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc,
            test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst,
            best_epoch))
        print('best test Root: corr: %d, total: %d, acc: %.2f%% (epoch: %d)' % (
            test_root_correct, test_total_root, test_root_correct * 100 / test_total_root, best_epoch))
        print('============================================================================================================================')
        


        if decay == max_decay:
            break