Exemplo n.º 1
0
def evaluate(args):
    mwt_dict = load_mwt_dict(args['mwt_json_file'])
    use_cuda = args['cuda'] and not args['cpu']
    trainer = Trainer(model_file=args['load_name'] or args['save_name'],
                      use_cuda=use_cuda)
    loaded_args, vocab = trainer.args, trainer.vocab

    for k in loaded_args:
        if not k.endswith('_file') and k not in [
                'cuda', 'mode', 'save_dir', 'load_name', 'save_name'
        ]:
            args[k] = loaded_args[k]

    eval_input_files = {'txt': args['txt_file'], 'label': args['label_file']}

    batches = DataLoader(args,
                         input_files=eval_input_files,
                         vocab=vocab,
                         evaluation=True,
                         dictionary=trainer.dictionary)

    oov_count, N, _, _ = output_predictions(args['conll_file'], trainer,
                                            batches, vocab, mwt_dict,
                                            args['max_seqlen'])

    logger.info("OOV rate: {:6.3f}% ({:6d}/{:6d})".format(
        oov_count / N * 100, oov_count, N))
Exemplo n.º 2
0
def train(args):
    if args['use_dictionary']:
        #load lexicon
        lexicon, args['num_dict_feat'] = load_lexicon(args)
        #create the dictionary
        dictionary = create_dictionary(lexicon)
        #adjust the feat_dim
        args['feat_dim'] += args['num_dict_feat'] * 2
    else:
        args['num_dict_feat'] = 0
        lexicon = None
        dictionary = None

    mwt_dict = load_mwt_dict(args['mwt_json_file'])

    train_input_files = {'txt': args['txt_file'], 'label': args['label_file']}
    train_batches = DataLoader(args,
                               input_files=train_input_files,
                               dictionary=dictionary)
    vocab = train_batches.vocab

    args['vocab_size'] = len(vocab)

    dev_input_files = {
        'txt': args['dev_txt_file'],
        'label': args['dev_label_file']
    }
    dev_batches = DataLoader(args,
                             input_files=dev_input_files,
                             vocab=vocab,
                             evaluation=True,
                             dictionary=dictionary)

    if args['use_mwt'] is None:
        args['use_mwt'] = train_batches.has_mwt()
        logger.info(
            "Found {}mwts in the training data.  Setting use_mwt to {}".format(
                ("" if args['use_mwt'] else "no "), args['use_mwt']))

    trainer = Trainer(args=args,
                      vocab=vocab,
                      lexicon=lexicon,
                      dictionary=dictionary,
                      use_cuda=args['cuda'])

    if args['load_name'] is not None:
        load_name = os.path.join(args['save_dir'], args['load_name'])
        trainer.load(load_name)
    trainer.change_lr(args['lr0'])

    N = len(train_batches)
    steps = args['steps'] if args['steps'] is not None else int(
        N * args['epochs'] / args['batch_size'] + .5)
    lr = args['lr0']

    prev_dev_score = -1
    best_dev_score = -1
    best_dev_step = -1

    for step in range(1, steps + 1):
        batch = train_batches.next(unit_dropout=args['unit_dropout'],
                                   feat_unit_dropout=args['feat_unit_dropout'])

        loss = trainer.update(batch)
        if step % args['report_steps'] == 0:
            logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(
                step, steps, loss))

        if args['shuffle_steps'] > 0 and step % args['shuffle_steps'] == 0:
            train_batches.shuffle()

        if step % args['eval_steps'] == 0:
            dev_score = eval_model(args, trainer, dev_batches, vocab, mwt_dict)
            reports = ['Dev score: {:6.3f}'.format(dev_score * 100)]
            if step >= args['anneal_after'] and dev_score < prev_dev_score:
                reports += [
                    'lr: {:.6f} -> {:.6f}'.format(lr, lr * args['anneal'])
                ]
                lr *= args['anneal']
                trainer.change_lr(lr)

            prev_dev_score = dev_score

            if dev_score > best_dev_score:
                reports += ['New best dev score!']
                best_dev_score = dev_score
                best_dev_step = step
                trainer.save(args['save_name'])
            elif best_dev_step > 0 and step - best_dev_step > args[
                    'max_steps_before_stop']:
                reports += [
                    'Stopping training after {} steps with no improvement'.
                    format(step - best_dev_step)
                ]
                logger.info('\t'.join(reports))
                break

            logger.info('\t'.join(reports))

    if best_dev_step > -1:
        logger.info('Best dev score={} at step {}'.format(
            best_dev_score, best_dev_step))
    else:
        logger.info('Dev set never evaluated.  Saving final model')
        trainer.save(args['save_name'])