Exemplo n.º 1
0
def predict(args, trainer=None, pretrained=None, use_cuda=False):
    # load pretrained embeddings and model
    if not trainer:
        # load pretrained embeddings
        pretrained = Pretrain(from_pt=args.embeddings)
        # load model
        logger.info("Loading model from {}".format(args.model))
        trainer = Trainer(model_file=args.model, pretrain=pretrained, use_cuda=use_cuda)

    # load data
    logger.info("Loading prediction data...")
    doc = Document(from_file=args.test_data, read_positions=get_read_format_args(args), write_positions=get_write_format_args(args), copy_untouched=args.copy_untouched)
    data = DataLoader(doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=True)
    if len(data) == 0:
        raise RuntimeError("Cannot start prediction because no data is available")

    logger.info("Start prediction...")
    preds = []
    for batch in data:
        preds += trainer.predict(batch)[0]      # don't keep loss
    preds = unsort(preds, data.data_orig_idx)

    # write to file and score
    doc.add_predictions(preds)
    doc.write_to_file(args.test_data_out)
    display_results(doc, args.no_eval_feats, per_feature=True, report_oov=args.w_token_index >= 0)
Exemplo n.º 2
0
def train(args, use_cuda=False):
    logger.info("Loading training data...")
    train_doc = Document(from_file=args.training_data, read_positions=get_read_format_args(args), sample_ratio=args.sample_train)
    if args.augment_nopunct:
        train_doc.augment_punct(args.augment_nopunct, args.punct_tag)

    # continue training existing model
    if args.model:
        pretrained = None
        if args.embeddings:
            pretrained = Pretrain(from_pt=args.embeddings)
            if args.embeddings_save:
                pretrained.save_to_pt(args.embeddings_save)
        
        logger.info("Loading model from {}".format(args.model))
        trainer = Trainer(model_file=args.model, pretrain=pretrained, args=vars(args), use_cuda=use_cuda)
        train_data = DataLoader(train_doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=False)

    # create new model from scratch and start training
    else:
        pretrained = None
        if args.embeddings:
            pretrained = Pretrain(from_pt=args.embeddings)
        elif args.emb_data:
            pretrained = Pretrain(from_text=args.emb_data, max_vocab=args.emb_max_vocab)
        if pretrained and args.embeddings_save:
            pretrained.save_to_pt(args.embeddings_save)

        logger.info("Creating new model...")
        train_data = DataLoader(train_doc, args.batch_size, vocab=None, pretrain=pretrained, evaluation=False, word_cutoff=args.w_token_min_freq)
        trainer = Trainer(vocab=train_data.vocab, pretrain=pretrained, args=vars(args), use_cuda=use_cuda)

    if len(train_data) == 0:
        raise RuntimeError("Cannot start training because no training data is available")

    if args.dev_data:
        logger.info("Loading development data...")
        dev_doc = Document(from_file=args.dev_data, read_positions=get_read_format_args(args), write_positions=get_write_format_args(args), copy_untouched=args.copy_untouched, cut_first=args.cut_dev)
        dev_data = DataLoader(dev_doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=True)
    else:
        dev_doc = None
        dev_data = []

    if not args.eval_interval:
        args.eval_interval = get_adaptive_eval_interval(len(train_data), len(dev_data))
    if len(dev_data) > 0:
        logger.info("Evaluating the model every {} steps".format(args.eval_interval))
    else:
        logger.info("No dev data given, not evaluating the model")

    if not args.log_interval:
        args.log_interval = get_adaptive_log_interval(args.batch_size, max_interval=args.eval_interval, gpu=use_cuda)
    logger.info("Showing log every {} steps".format(args.log_interval))

    if args.scores_out:
        scores_file = open(args.scores_out, "w")
        scores_file.write("Step\tEpoch\tTrainLoss\tDevLoss\tDevScore\tNewBest\n")
        scores_file.flush()
    else:
        scores_file = None


    global_step = 0
    epoch = 0
    dev_score_history = []
    last_best_step = 0
    max_steps = args.max_steps
    current_lr = args.lr
    global_start_time = time.time()
    format_str = 'Finished step {}/{}, loss = {:.6f}, {:.3f} sec/batch, lr: {:.6f}'

    # start training
    logger.info("Start training...")
    using_amsgrad = False
    train_loss = 0
    while True:
        epoch += 1
        epoch_start_time = time.time()
        do_break = False
        for batch in train_data:
            start_time = time.time()
            global_step += 1
            loss = trainer.update(batch, eval=False) # update step
            train_loss += loss
            if global_step % args.log_interval == 0:
                duration = time.time() - start_time
                logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))

            if global_step % args.eval_interval == 0:
                new_best = ""
                dev_loss = 0.0
                dev_score = 0.0

                if len(dev_data) > 0:
                    logger.info("Evaluating on dev set...")
                    dev_preds = []
                    dev_loss = 0.0
                    for dev_batch in dev_data:
                        preds, loss = trainer.predict(dev_batch)
                        dev_preds += preds
                        dev_loss += float(loss)
                    dev_preds = unsort(dev_preds, dev_data.data_orig_idx)
                    dev_loss = dev_loss / len(dev_data)
                    dev_doc.add_predictions(dev_preds)
                    dev_doc.write_to_file(args.dev_data_out)
                    dev_score = display_results(dev_doc, args.no_eval_feats, report_oov=args.w_token_index >= 0)

                    # save best model
                    if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
                        logger.info("New best model found")
                        new_best = "*"
                        last_best_step = global_step
                        if args.model_save:
                            trainer.save(args.model_save)
                    dev_score_history += [dev_score]

                train_loss = train_loss / args.eval_interval # avg loss per batch
                logger.info("Step {}/{}: train_loss = {:.6f}, dev_loss = {:.6f}, dev_score = {:.4f}".format(global_step, max_steps, train_loss, dev_loss, dev_score))
                if scores_file:
                    scores_file.write("{}\t{}\t{:.6f}\t{:.6f}\t{:.4f}\t{}\n".format(global_step, epoch, train_loss, dev_loss, dev_score, new_best))
                    scores_file.flush()
                train_loss = 0

            if args.max_steps_before_stop > 0 and global_step - last_best_step >= args.max_steps_before_stop:
                if args.optim == 'adam' and not using_amsgrad:
                    logger.info("Switching to AMSGrad")
                    last_best_step = global_step
                    using_amsgrad = True
                    trainer.set_optimizer('amsgrad', lr=args.lr, betas=(.9, args.beta2), eps=1e-6)
                else:
                    logger.info("Early stopping: dev_score has not improved in {} steps".format(args.max_steps_before_stop))
                    do_break = True
                    break

            if global_step >= args.max_steps:
                do_break = True
                break

        if do_break: break

        epoch_duration = time.time() - epoch_start_time
        logger.info("Finished epoch {} after step {}, {:.3f} sec/epoch".format(epoch, global_step, epoch_duration))
        train_data.reshuffle()

    logger.info("Training ended with {} steps in epoch {}".format(global_step, epoch))

    if len(dev_score_history) > 0:
        best_score, best_step = max(dev_score_history), np.argmax(dev_score_history)+1
        logger.info("Best dev score = {:.2f} at step {}".format(best_score*100, best_step * args.eval_interval))
    elif args.model_save:
        logger.info("Dev set never evaluated, saving final model")
        trainer.save(args.model_save)
    return trainer, pretrained