예제 #1
0
def train(args):
    # load data
    print('max_dec_len:', args['max_dec_len'])
    print("Loading data with batch size {}...".format(args['batch_size']))
    train_batch = DataLoader(args['train_file'],
                             args['batch_size'],
                             args,
                             evaluation=False)
    vocab = train_batch.vocab
    args['vocab_size'] = vocab.size
    dev_batch = DataLoader(args['eval_file'],
                           args['batch_size'],
                           args,
                           vocab=vocab,
                           evaluation=True)

    utils.ensure_dir(args['save_dir'])
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
            else '{}/{}_mwt_expander.pt'.format(args['save_dir'], args['shorthand'])

    # pred and gold path
    system_pred_file = args['output_file']
    gold_file = args['gold_file']

    # skip training if the language does not have training or dev data
    if len(train_batch) == 0 or len(dev_batch) == 0:
        print("Skip training because no data available...")
        sys.exit(0)

    # train a dictionary-based MWT expander
    trainer = Trainer(args=args, vocab=vocab, use_cuda=args['cuda'])
    print("Training dictionary-based MWT expander...")
    trainer.train_dict(train_batch.conll.get_mwt_expansions())
    print("Evaluating on dev set...")
    dev_preds = trainer.predict_dict(dev_batch.conll.get_mwt_expansion_cands())
    dev_batch.conll.write_conll_with_mwt_expansions(
        dev_preds, open(system_pred_file, 'w'))
    _, _, dev_f = scorer.score(system_pred_file, gold_file)
    print("Dev F1 = {:.2f}".format(dev_f * 100))

    if args.get('dict_only', False):
        # save dictionaries
        trainer.save(model_file)
    else:
        # train a seq2seq model
        print("Training seq2seq-based MWT expander...")
        global_step = 0
        max_steps = len(train_batch) * args['num_epoch']
        dev_score_history = []
        best_dev_preds = []
        current_lr = args['lr']
        global_start_time = time.time()
        format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'

        # start training
        for epoch in range(1, args['num_epoch'] + 1):
            train_loss = 0
            for i, batch in enumerate(train_batch):
                start_time = time.time()
                global_step += 1
                loss = trainer.update(batch, eval=False)  # update step
                train_loss += loss
                if global_step % args['log_step'] == 0:
                    duration = time.time() - start_time
                    print(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
                            max_steps, epoch, args['num_epoch'], loss, duration, current_lr))

            # eval on dev
            print("Evaluating on dev set...")
            dev_preds = []
            for i, batch in enumerate(dev_batch):
                preds = trainer.predict(batch)
                dev_preds += preds
            if args.get('ensemble_dict', False) and args.get(
                    'ensemble_early_stop', False):
                print("[Ensembling dict with seq2seq model...]")
                dev_preds = trainer.ensemble(
                    dev_batch.conll.get_mwt_expansion_cands(), dev_preds)
            dev_batch.conll.write_conll_with_mwt_expansions(
                dev_preds, open(system_pred_file, 'w'))
            _, _, dev_score = scorer.score(system_pred_file, gold_file)

            train_loss = train_loss / train_batch.num_examples * args[
                'batch_size']  # avg loss per batch
            print("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(
                epoch, train_loss, dev_score))

            # save best model
            if epoch == 1 or dev_score > max(dev_score_history):
                trainer.save(model_file)
                print("new best model saved.")
                best_dev_preds = dev_preds

            # lr schedule
            if epoch > args['decay_epoch'] and dev_score <= dev_score_history[
                    -1]:
                current_lr *= args['lr_decay']
                trainer.change_lr(current_lr)

            dev_score_history += [dev_score]
            print("")

        print("Training ended with {} epochs.".format(epoch))

        best_f, best_epoch = max(dev_score_history) * 100, np.argmax(
            dev_score_history) + 1
        print("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))

        # try ensembling with dict if necessary
        if args.get('ensemble_dict', False):
            print("[Ensembling dict with seq2seq model...]")
            dev_preds = trainer.ensemble(
                dev_batch.conll.get_mwt_expansion_cands(), best_dev_preds)
            dev_batch.conll.write_conll_with_mwt_expansions(
                dev_preds, open(system_pred_file, 'w'))
            _, _, dev_score = scorer.score(system_pred_file, gold_file)
            print("Ensemble dev F1 = {:.2f}".format(dev_score * 100))
            best_f = max(best_f, dev_score)
예제 #2
0
def train(args):
    utils.ensure_dir(args['save_dir'])
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
            else '{}/{}_tagger.pt'.format(args['save_dir'], args['shorthand'])

    # load pretrained vectors
    vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
    pretrain_file = '{}/{}.pretrain.pt'.format(args['save_dir'],
                                               args['shorthand'])
    pretrain = Pretrain(pretrain_file, vec_file)

    # load data
    print("Loading data with batch size {}...".format(args['batch_size']))
    train_batch = DataLoader(args['train_file'],
                             args['batch_size'],
                             args,
                             pretrain,
                             evaluation=False)
    vocab = train_batch.vocab
    dev_batch = DataLoader(args['eval_file'],
                           args['batch_size'],
                           args,
                           pretrain,
                           vocab=vocab,
                           evaluation=True)

    # pred and gold path
    system_pred_file = args['output_file']
    gold_file = args['gold_file']

    # skip training if the language does not have training or dev data
    if len(train_batch) == 0 or len(dev_batch) == 0:
        print("Skip training because no data available...")
        exit()

    print("Training tagger...")
    trainer = Trainer(args=args,
                      vocab=vocab,
                      pretrain=pretrain,
                      use_cuda=args['cuda'])

    global_step = 0
    max_steps = args['max_steps']
    dev_score_history = []
    best_dev_preds = []
    current_lr = args['lr']
    global_start_time = time.time()
    format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'

    if args['adapt_eval_interval']:
        args['eval_interval'] = utils.get_adaptive_eval_interval(
            dev_batch.num_examples, 2000, args['eval_interval'])
        print("Evaluating the model every {} steps...".format(
            args['eval_interval']))

    using_amsgrad = False
    last_best_step = 0
    # start training
    train_loss = 0
    while True:
        do_break = False
        for i, batch in enumerate(train_batch):
            start_time = time.time()
            global_step += 1
            loss = trainer.update(batch, eval=False)  # update step
            train_loss += loss
            if global_step % args['log_step'] == 0:
                duration = time.time() - start_time
                print(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
                        max_steps, loss, duration, current_lr))

            if global_step % args['eval_interval'] == 0:
                # eval on dev
                print("Evaluating on dev set...")
                dev_preds = []
                for batch in dev_batch:
                    preds = trainer.predict(batch)
                    dev_preds += preds
                dev_batch.conll.set(['upos', 'xpos', 'feats'],
                                    [y for x in dev_preds for y in x])
                dev_batch.conll.write_conll(system_pred_file)
                _, _, dev_score = scorer.score(system_pred_file, gold_file)

                train_loss = train_loss / args[
                    'eval_interval']  # avg loss per batch
                print(
                    "step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(
                        global_step, train_loss, dev_score))
                train_loss = 0

                # save best model
                if len(dev_score_history
                       ) == 0 or dev_score > max(dev_score_history):
                    last_best_step = global_step
                    trainer.save(model_file)
                    print("new best model saved.")
                    best_dev_preds = dev_preds

                dev_score_history += [dev_score]
                print("")

            if global_step - last_best_step >= args['max_steps_before_stop']:
                if not using_amsgrad:
                    print("Switching to AMSGrad")
                    last_best_step = global_step
                    using_amsgrad = True
                    trainer.optimizer = optim.Adam(trainer.model.parameters(),
                                                   amsgrad=True,
                                                   lr=args['lr'],
                                                   betas=(.9, args['beta2']),
                                                   eps=1e-6)
                else:
                    do_break = True
                    break

            if global_step >= args['max_steps']:
                do_break = True
                break

        if do_break: break

        train_batch.reshuffle()

    print("Training ended with {} steps.".format(global_step))

    best_f, best_eval = max(dev_score_history) * 100, np.argmax(
        dev_score_history) + 1
    print("Best dev F1 = {:.2f}, at iteration = {}".format(
        best_f, best_eval * args['eval_interval']))
예제 #3
0
def train(args):
    # load data
    print("[Loading data with batch size {}...]".format(args['batch_size']))
    train_batch = DataLoader(args['train_file'],
                             args['batch_size'],
                             args,
                             evaluation=False)
    vocab = train_batch.vocab
    args['vocab_size'] = vocab['char'].size
    args['pos_vocab_size'] = vocab['pos'].size
    dev_batch = DataLoader(args['eval_file'],
                           args['batch_size'],
                           args,
                           vocab=vocab,
                           evaluation=True)

    utils.ensure_dir(args['model_dir'])
    model_file = '{}/{}_lemmatizer.pt'.format(args['model_dir'],
                                              args['model_file'])

    # pred and gold path
    system_pred_file = args['output_file']
    gold_file = args['gold_file']

    utils.print_config(args)

    # skip training if the language does not have training or dev data
    if len(train_batch) == 0 or len(dev_batch) == 0:
        print("[Skip training because no data available...]")
        sys.exit(0)

    # start training
    # train a dictionary-based lemmatizer
    trainer = Trainer(args=args, vocab=vocab, use_cuda=args['cuda'])
    print("[Training dictionary-based lemmatizer...]")
    dict = train_batch.conll.get(['word', 'upos', 'feats', 'lemma'])
    dict = [(e[0].lower(), e[1], e[2], e[3]) for e in dict]
    if args.get('external_dict', None) is not None:
        extra_dict = []
        for line in open(args['external_dict']):
            word, lemma, upos, feats = line.rstrip('\r\n').split('\t')
            extra_dict.append((word.lower(), upos, feats, lemma))
        dict += extra_dict
    trainer.train_dict(dict)
    print("Evaluating on dev set...")
    dev_preds = trainer.predict_dict([
        (e[0].lower(), e[1], e[2])
        for e in dev_batch.conll.get(['word', 'upos', 'feats'])
    ])
    dev_batch.conll.write_conll_with_lemmas(dev_preds, system_pred_file)
    _, _, dev_f = scorer.score(system_pred_file, gold_file)
    print("Dev F1 = {:.2f}".format(dev_f * 100))

    if args.get('dict_only', False):
        # save dictionaries
        trainer.save(model_file)
    else:
        # train a seq2seq model
        print("[Training seq2seq-based lemmatizer...]")
        global_step = 0
        max_steps = len(train_batch) * args['num_epoch']
        dev_score_history = []
        best_dev_preds = []
        current_lr = args['lr']
        global_start_time = time.time()
        format_str = '{}: step {}/{} (epoch {}/{}), loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'

        # start training
        for epoch in range(1, args['num_epoch'] + 1):
            train_loss = 0
            for i, batch in enumerate(train_batch):
                start_time = time.time()
                global_step += 1
                loss = trainer.update(batch, eval=False)  # update step
                train_loss += loss
                if global_step % args['log_step'] == 0:
                    duration = time.time() - start_time
                    print(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
                            max_steps, epoch, args['num_epoch'], loss, duration, current_lr))

            # eval on dev
            print("Evaluating on dev set...")
            dev_preds = []
            dev_edits = []

            # try speeding up dev eval
            dict_preds = trainer.predict_dict([
                (e[0].lower(), e[1], e[2])
                for e in dev_batch.conll.get(['word', 'upos', 'feats'])
            ])
            if args.get('ensemble_dict', False):
                skip = trainer.skip_seq2seq([
                    (e[0].lower(), e[1], e[2])
                    for e in dev_batch.conll.get(['word', 'upos', 'feats'])
                ])
                seq2seq_batch = DataLoader(args['eval_file'],
                                           args['batch_size'],
                                           args,
                                           vocab=vocab,
                                           evaluation=True,
                                           skip=skip)
            else:
                seq2seq_batch = dev_batch
            for i, b in enumerate(seq2seq_batch):
                ps, es = trainer.predict(b, args['beam_size'])
                dev_preds += ps
                if es is not None:
                    dev_edits += es
            if args.get('ensemble_dict', False):
                dev_preds = trainer.postprocess([
                    x for x, y in zip(dev_batch.conll.get(['word']), skip)
                    if not y
                ],
                                                dev_preds,
                                                edits=dev_edits)
                print("[Ensembling dict with seq2seq lemmatizer...]")
                i = 0
                preds1 = []
                for s in skip:
                    if s:
                        preds1.append('')
                    else:
                        preds1.append(dev_preds[i])
                        i += 1
                dev_preds = trainer.ensemble(
                    [(e[0].lower(), e[1], e[2])
                     for e in dev_batch.conll.get(['word', 'upos', 'feats'])],
                    preds1)
            else:
                dev_preds = trainer.postprocess(dev_batch.conll.get(['word']),
                                                dev_preds,
                                                edits=dev_edits)

            dev_batch.conll.write_conll_with_lemmas(dev_preds,
                                                    system_pred_file)
            _, _, dev_score = scorer.score(system_pred_file, gold_file)

            train_loss = train_loss / train_batch.num_examples * args[
                'batch_size']  # avg loss per batch
            print("epoch {}: train_loss = {:.6f}, dev_score = {:.4f}".format(
                epoch, train_loss, dev_score))

            # save best model
            if epoch == 1 or dev_score > max(dev_score_history):
                trainer.save(model_file)
                print("new best model saved.")
                best_dev_preds = dev_preds

            # lr schedule
            if epoch > args['decay_epoch'] and dev_score <= dev_score_history[-1] and \
                    args['optim'] in ['sgd', 'adagrad']:
                current_lr *= args['lr_decay']
                trainer.update_lr(current_lr)

            dev_score_history += [dev_score]
            print("")

        print("Training ended with {} epochs.".format(epoch))

        best_f, best_epoch = max(dev_score_history) * 100, np.argmax(
            dev_score_history) + 1
        print("Best dev F1 = {:.2f}, at epoch = {}".format(best_f, best_epoch))
예제 #4
0
def train(args):
    utils.ensure_dir(args['save_dir'])
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
        else '{}/{}_lm.pt'.format(args['save_dir'], args['shorthand'])

    # load pretrained vectors
    vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
    pretrain_file = '{}/{}.pretrain.pt'.format(args['save_dir'],
                                               args['shorthand'])
    pretrain = Pretrain(pretrain_file, vec_file)

    # load data
    print("Loading data with batch size {}...".format(args['batch_size']))
    train_batch = DataLoader(args['train_file'],
                             args['batch_size'],
                             args,
                             pretrain,
                             evaluation=False)
    vocab = train_batch.vocab
    # train_dev_batch = DataLoader(args['train_file'], args['batch_size'], args, pretrain, vocab=vocab, evaluation=True)
    dev_batch = DataLoader(args['eval_file'],
                           args['eval_batch_size'],
                           args,
                           pretrain,
                           vocab=vocab,
                           evaluation=True)

    # skip training if the language does not have training or dev data
    if len(train_batch) == 0 or len(dev_batch) == 0:
        print("Skip training because no data available...")
        sys.exit(0)

    print("Training language model...")
    trainer = Trainer(args=args,
                      vocab=vocab,
                      pretrain=pretrain,
                      use_cuda=args['cuda'])

    print()
    print('Parameters:')
    n_param = 0
    for p_name, p in trainer.model.named_parameters():
        if p.requires_grad == True:
            n_param += np.prod(list(p.size()))
            print('\t{:10}    {}'.format(p_name, p.size()))
    print('\tTotal paramamters: {}'.format(n_param))

    global_step = 0
    max_steps = args['max_steps']
    dev_score_history = []
    current_lr = args['lr']
    global_start_time = time.time()
    format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), ppl = {:.6f}, lr: {:.6f}'

    last_best_step = 0
    log_loss = 0
    train_loss = 0
    while True:
        do_break = False
        for i, batch in enumerate(train_batch):
            start_time = time.time()
            global_step += 1
            loss = trainer.update(batch, eval=False)  # update step
            log_loss += loss
            train_loss += loss
            if global_step % args['log_step'] == 0:
                duration = time.time() - start_time
                log_loss /= args['log_step']
                print(
                    format_str.format(
                        datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        global_step, max_steps, log_loss, duration,
                        np.exp(log_loss), current_lr))
                log_loss = 0

            if global_step % args['eval_interval'] == 0:
                # eval on dev
                print("Evaluating on dev set...")
                dev_loss = 0
                for batch in dev_batch:
                    dev_loss += trainer.update(batch, eval=True)
                dev_loss /= len(dev_batch)

                train_loss = train_loss / args[
                    'eval_interval']  # avg loss per batch
                print("step {}: train_ppl = {:.6f}, dev_ppl = {:.6f}".format(
                    global_step, np.exp(train_loss), np.exp(dev_loss)))
                train_loss = 0

                # save best model
                if len(dev_score_history
                       ) == 0 or dev_loss < min(dev_score_history):
                    last_best_step = global_step
                    trainer.save(model_file)
                    print("new best model saved.")
                dev_score_history.append(dev_loss)
                print()

            if global_step - last_best_step >= args['max_steps_before_stop']:
                do_break = True
                break

            if global_step >= args['max_steps']:
                do_break = True
                break

        if do_break:
            break

        train_batch.reshuffle()

    print("Training ended with {} steps.".format(global_step))

    best_ppl, best_eval = np.exp(
        min(dev_score_history)), np.argmin(dev_score_history) + 1
    print("Best dev ppl = {:.2f}, at iteration = {}".format(
        best_ppl, best_eval * args['eval_interval']))
예제 #5
0
def train(args):
    utils.ensure_dir(args['save_dir'])
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
            else '{}/{}_nertagger.pt'.format(args['save_dir'], args['shorthand'])

    # load pretrained vectors
    vec_file = args['wordvec_file']
    pretrain_file = '{}/{}.pretrain.pt'.format(args['save_dir'],
                                               args['save_name'])
    pretrain = Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
    """
    if len(args['wordvec_file']) == 0:
        vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
    else:
        vec_file = args['wordvec_file']
    # do not save pretrained embeddings individually
    pretrain = Pretrain(None, vec_file, args['pretrain_max_vocab'], save_to_file=False)
    """
    if args['charlm']:
        if args['charlm_shorthand'] is None:
            print(
                "CharLM Shorthand is required for loading pretrained CharLM model..."
            )
            sys.exit(0)
        print('Use pretrained contextualized char embedding')
        args['charlm_forward_file'] = '{}/{}_forward_charlm.pt'.format(
            args['charlm_save_dir'], args['charlm_shorthand'])
        args['charlm_backward_file'] = '{}/{}_backward_charlm.pt'.format(
            args['charlm_save_dir'], args['charlm_shorthand'])

    # load data
    print("Loading data with batch size {}...".format(args['batch_size']))
    train_doc = Document(json.load(open(args['train_file'])))
    train_batch = DataLoader(train_doc,
                             args['batch_size'],
                             args,
                             pretrain,
                             evaluation=False)
    vocab = train_batch.vocab
    dev_doc = Document(json.load(open(args['eval_file'])))
    dev_batch = DataLoader(dev_doc,
                           args['batch_size'],
                           args,
                           pretrain,
                           vocab=vocab,
                           evaluation=True)
    dev_gold_tags = dev_batch.tags

    # skip training if the language does not have training or dev data
    if len(train_batch) == 0 or len(dev_batch) == 0:
        print("Skip training because no data available...")
        sys.exit(0)

    print("Training tagger...")
    trainer = Trainer(args=args,
                      vocab=vocab,
                      pretrain=pretrain,
                      use_cuda=args['cuda'])
    print(trainer.model)

    global_step = 0
    max_steps = args['max_steps']
    dev_score_history = []
    best_dev_preds = []
    current_lr = trainer.optimizer.param_groups[0]['lr']
    global_start_time = time.time()
    format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'

    # LR scheduling
    if args['lr_decay'] > 0:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='max', factor=args['lr_decay'], \
            patience=args['patience'], verbose=True, min_lr=args['min_lr'])
    else:
        scheduler = None

    # start training
    train_loss = 0
    while True:
        should_stop = False
        for i, batch in enumerate(train_batch):
            start_time = time.time()
            global_step += 1
            loss = trainer.update(batch, eval=False)  # update step
            train_loss += loss
            if global_step % args['log_step'] == 0:
                duration = time.time() - start_time
                print(format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), global_step,\
                        max_steps, loss, duration, current_lr))

            if global_step % args['eval_interval'] == 0:
                # eval on dev
                print("Evaluating on dev set...")
                dev_preds = []
                for batch in dev_batch:
                    preds = trainer.predict(batch)
                    dev_preds += preds
                _, _, dev_score = scorer.score_by_entity(
                    dev_preds, dev_gold_tags)

                train_loss = train_loss / args[
                    'eval_interval']  # avg loss per batch
                print(
                    "step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(
                        global_step, train_loss, dev_score))
                train_loss = 0

                # save best model
                if len(dev_score_history
                       ) == 0 or dev_score > max(dev_score_history):
                    trainer.save(model_file)
                    print("New best model saved.")
                    best_dev_preds = dev_preds

                dev_score_history += [dev_score]
                print("")

                # lr schedule
                if scheduler is not None:
                    scheduler.step(dev_score)

            # check stopping
            current_lr = trainer.optimizer.param_groups[0]['lr']
            if global_step >= args['max_steps'] or current_lr <= args['min_lr']:
                should_stop = True
                break

        if should_stop:
            break

        train_batch.reshuffle()

    print("Training ended with {} steps.".format(global_step))

    best_f, best_eval = max(dev_score_history) * 100, np.argmax(
        dev_score_history) + 1
    print("Best dev F1 = {:.2f}, at iteration = {}".format(
        best_f, best_eval * args['eval_interval']))
예제 #6
0
def train(args):
    utils.ensure_dir(args['save_dir'])
    model_file = args['save_dir'] + '/' + args['save_name'] if args['save_name'] is not None \
        else '{}/{}_parser.pt'.format(args['save_dir'], args['shorthand'])

    pretrain_file = '{}/{}.pretrain.pt'.format(args['save_dir'],
                                               args['shorthand'])
    vec_file = utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
    pretrain = Pretrain(pretrain_file, vec_file)
    use_cuda = args['cuda'] and not args['cpu']

    lm_train_batch = LMDataLoader(args['lm_file'],
                                  args['lm_batch_size'],
                                  args,
                                  pretrain,
                                  vocab=None,
                                  evaluation=False,
                                  cutoff=args['vocab_cutoff'])
    vocab = lm_train_batch.vocab
    dp_train_batch = DPDataLoader(args['train_file'],
                                  args['batch_size'],
                                  args,
                                  pretrain,
                                  vocab=None,
                                  evaluation=False,
                                  cutoff=args['vocab_cutoff'])
    vocab['deprel'] = dp_train_batch.vocab['deprel']
    dp_train_batch = DPDataLoader(args['train_file'],
                                  args['batch_size'],
                                  args,
                                  pretrain,
                                  vocab=vocab,
                                  evaluation=False,
                                  cutoff=args['vocab_cutoff'])
    train_dev_batch = DPDataLoader(args['train_file'],
                                   args['batch_size'],
                                   args,
                                   pretrain,
                                   vocab=vocab,
                                   evaluation=True)
    dev_batch = DPDataLoader(args['eval_file'],
                             args['batch_size'],
                             args,
                             pretrain,
                             vocab=vocab,
                             evaluation=True)

    lm_train_iter = iter(lm_train_batch)
    dp_train_iter = iter(dp_train_batch)

    # pred and gold path
    system_pred_file = args['output_file']
    gold_file = args['gold_file']

    print("Training parser...")
    trainer = Trainer(args=args,
                      vocab=vocab,
                      pretrain=pretrain,
                      use_cuda=args['cuda'],
                      weight_decay=args['wdecay'])
    print()
    print('Parameters that require grad:')
    for p_name, p in trainer.model.named_parameters():
        if p.requires_grad == True:
            print('\t{:10}    {}'.format(p_name, p.size()))

    global_step = 0
    max_steps = args['max_steps']
    dev_score_history = []
    best_dev_preds = []
    current_lr = args['lr']
    global_start_time = time.time()
    format_str = '{}: step {}/{}, loss = {:.6f} ({:.3f} sec/batch), dp_loss = {:.4f}, ppl = {:.2f}, lr: {:.6f}'

    using_amsgrad = False
    last_best_step = 0
    # start training
    log_loss = np.zeros(3)
    train_loss = np.zeros(3)
    while True:
        do_break = False

        try:
            lm_batch = next(lm_train_iter)
        except StopIteration:
            lm_train_iter = iter(lm_train_batch)
            lm_batch = next(lm_train_iter)
        try:
            dp_batch = next(dp_train_iter)
        except StopIteration:
            dp_train_iter = iter(dp_train_batch)
            dp_batch = next(dp_train_iter)

        start_time = time.time()
        global_step += 1
        dp_loss, lm_loss, loss = trainer.update(dp_batch, lm_batch,
                                                eval=False)  # update step
        log_loss += np.array([lm_loss, dp_loss, loss])
        train_loss += np.array([lm_loss, dp_loss, loss])
        if global_step % args['log_step'] == 0:
            duration = time.time() - start_time
            log_loss = log_loss / args['log_step']
            print(
                format_str.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                                  global_step, max_steps,
                                  log_loss[2], duration, log_loss[1],
                                  np.exp(log_loss[0]), current_lr))
            log_loss[:] = 0

        if global_step % args['eval_interval'] == 0:
            # eval on train
            train_preds = []
            for batch in train_dev_batch:
                preds = trainer.predict(batch)
                train_preds += preds

            train_dev_batch.conll.set(['head', 'deprel'],
                                      [y for x in train_preds for y in x])
            train_dev_batch.conll.write_conll(system_pred_file)
            _, _, train_score = scorer.score(system_pred_file,
                                             args['train_file'])

            # eval on dev
            print("Evaluating on dev set...")
            dev_preds = []
            for batch in dev_batch:
                preds = trainer.predict(batch)
                dev_preds += preds

            dev_batch.conll.set(['head', 'deprel'],
                                [y for x in dev_preds for y in x])
            dev_batch.conll.write_conll(system_pred_file)
            _, _, dev_score = scorer.score(system_pred_file, gold_file)

            train_loss = train_loss / args[
                'eval_interval']  # avg loss per batch
            print("step {}: train_score = {:.4f}, dev_score = {:.4f}".format(
                global_step, train_score, dev_score))
            train_loss[:] = 0

            # save best model
            if len(dev_score_history
                   ) == 0 or dev_score > max(dev_score_history):
                last_best_step = global_step
                trainer.save(model_file)
                print("new best model saved.")
                best_dev_preds = dev_preds

            dev_score_history += [dev_score]
            print("")

        if global_step - last_best_step >= args['max_steps_before_stop']:
            if not using_amsgrad:
                print("Switching to AMSGrad")
                last_best_step = global_step
                using_amsgrad = True
                trainer.optimizer = optim.Adam(trainer.model.parameters(),
                                               amsgrad=True,
                                               lr=args['lr'],
                                               betas=(.9, args['beta2']),
                                               eps=1e-6)
            else:
                do_break = True
                break

        if global_step >= args['max_steps']:
            do_break = True
            break

        if do_break:
            break

        # train_batch.reshuffle()

    print("Training ended with {} steps.".format(global_step))

    best_f, best_eval = max(dev_score_history) * 100, np.argmax(
        dev_score_history) + 1
    print("Best dev F1 = {:.2f}, at iteration = {}".format(
        best_f, best_eval * args['eval_interval']))