예제 #1
0
def predict(args, trainer=None, pretrained=None, use_cuda=False):
    # load pretrained embeddings and model
    if not trainer:
        # load pretrained embeddings
        pretrained = Pretrain(from_pt=args.embeddings)
        # load model
        logger.info("Loading model from {}".format(args.model))
        trainer = Trainer(model_file=args.model, pretrain=pretrained, use_cuda=use_cuda)

    # load data
    logger.info("Loading prediction data...")
    doc = Document(from_file=args.test_data, read_positions=get_read_format_args(args), write_positions=get_write_format_args(args), copy_untouched=args.copy_untouched)
    data = DataLoader(doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=True)
    if len(data) == 0:
        raise RuntimeError("Cannot start prediction because no data is available")

    logger.info("Start prediction...")
    preds = []
    for batch in data:
        preds += trainer.predict(batch)[0]      # don't keep loss
    preds = unsort(preds, data.data_orig_idx)

    # write to file and score
    doc.add_predictions(preds)
    doc.write_to_file(args.test_data_out)
    display_results(doc, args.no_eval_feats, per_feature=True, report_oov=args.w_token_index >= 0)
 def ids_to_toks(tok_seqs, id2tok):
     out = []
     # take off the gpu
     tok_seqs = tok_seqs.cpu().numpy()
     # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them)
     for line in tok_seqs:
         toks = [id2tok[x] for x in line]
         if '<s>' in toks:
             toks.remove('<s>')
         cut_idx = toks.index('</s>') if '</s>' in toks else len(toks)
         out.append(toks[:cut_idx])
     # unsort
     out = data.unsort(out, indices)
     return out
예제 #3
0
    def predict(self, batch, unsort=True):
        inputs, orig_idx, word_orig_idx, sentlens, wordlens = unpack_batch(batch, self.use_cuda)
        word, word_mask, wordchars, wordchars_mask, pos, feats, pretrained = inputs

        self.model.eval()
        batch_size = word.size(0)
        loss, preds = self.model(word, word_mask, wordchars, wordchars_mask, pos, feats, pretrained, word_orig_idx, sentlens, wordlens)
        pos_seqs = [self.vocab['pos'].unmap(sent) for sent in preds[0].tolist()]
        feats_seqs = [self.vocab['feats'].unmap(sent) for sent in preds[1].tolist()]
        w_unk_seqs = [[(not self.model.use_word) or tokid == UNK_ID for tokid in sent] for sent in word]
        p_unk_seqs = [[(not self.model.use_pretrained) or tokid == UNK_ID for tokid in sent] for sent in pretrained]

        pred_tokens = [[[pos_seqs[i][j], feats_seqs[i][j], w_unk_seqs[i][j] and p_unk_seqs[i][j]] for j in range(sentlens[i])] for i in range(batch_size)]
        if unsort:
            pred_tokens = data.unsort(pred_tokens, orig_idx)
        return pred_tokens, loss.data.item()
예제 #4
0
def ids_to_toks(tok_seqs, id2tok, sort_indices, cuts=None, save_cuts=False):
    out = []
    cut_indices = []
    # take off the gpu
    if isinstance(tok_seqs, torch.Tensor):
        tok_seqs = tok_seqs.cpu().numpy()
    # convert to toks, cut off at </s>
    for i, line in enumerate(tok_seqs):
        toks = [id2tok[x] for x in line]
        if cuts is not None:
            cut_idx = cuts[i]
        elif '</s>' in toks:
            cut_idx = toks.index('</s>')
        else:
            cut_idx = len(toks)
        out.append(toks[:cut_idx])
        cut_indices += [cut_idx]
    # unsort
    out = data.unsort(out, sort_indices)

    if save_cuts:
        return out, cut_indices
    else:
        return out
예제 #5
0
def train(args, use_cuda=False):
    logger.info("Loading training data...")
    train_doc = Document(from_file=args.training_data, read_positions=get_read_format_args(args), sample_ratio=args.sample_train)
    if args.augment_nopunct:
        train_doc.augment_punct(args.augment_nopunct, args.punct_tag)

    # continue training existing model
    if args.model:
        pretrained = None
        if args.embeddings:
            pretrained = Pretrain(from_pt=args.embeddings)
            if args.embeddings_save:
                pretrained.save_to_pt(args.embeddings_save)
        
        logger.info("Loading model from {}".format(args.model))
        trainer = Trainer(model_file=args.model, pretrain=pretrained, args=vars(args), use_cuda=use_cuda)
        train_data = DataLoader(train_doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=False)

    # create new model from scratch and start training
    else:
        pretrained = None
        if args.embeddings:
            pretrained = Pretrain(from_pt=args.embeddings)
        elif args.emb_data:
            pretrained = Pretrain(from_text=args.emb_data, max_vocab=args.emb_max_vocab)
        if pretrained and args.embeddings_save:
            pretrained.save_to_pt(args.embeddings_save)

        logger.info("Creating new model...")
        train_data = DataLoader(train_doc, args.batch_size, vocab=None, pretrain=pretrained, evaluation=False, word_cutoff=args.w_token_min_freq)
        trainer = Trainer(vocab=train_data.vocab, pretrain=pretrained, args=vars(args), use_cuda=use_cuda)

    if len(train_data) == 0:
        raise RuntimeError("Cannot start training because no training data is available")

    if args.dev_data:
        logger.info("Loading development data...")
        dev_doc = Document(from_file=args.dev_data, read_positions=get_read_format_args(args), write_positions=get_write_format_args(args), copy_untouched=args.copy_untouched, cut_first=args.cut_dev)
        dev_data = DataLoader(dev_doc, args.batch_size, vocab=trainer.vocab, pretrain=pretrained, evaluation=True)
    else:
        dev_doc = None
        dev_data = []

    if not args.eval_interval:
        args.eval_interval = get_adaptive_eval_interval(len(train_data), len(dev_data))
    if len(dev_data) > 0:
        logger.info("Evaluating the model every {} steps".format(args.eval_interval))
    else:
        logger.info("No dev data given, not evaluating the model")

    if not args.log_interval:
        args.log_interval = get_adaptive_log_interval(args.batch_size, max_interval=args.eval_interval, gpu=use_cuda)
    logger.info("Showing log every {} steps".format(args.log_interval))

    if args.scores_out:
        scores_file = open(args.scores_out, "w")
        scores_file.write("Step\tEpoch\tTrainLoss\tDevLoss\tDevScore\tNewBest\n")
        scores_file.flush()
    else:
        scores_file = None


    global_step = 0
    epoch = 0
    dev_score_history = []
    last_best_step = 0
    max_steps = args.max_steps
    current_lr = args.lr
    global_start_time = time.time()
    format_str = 'Finished step {}/{}, loss = {:.6f}, {:.3f} sec/batch, lr: {:.6f}'

    # start training
    logger.info("Start training...")
    using_amsgrad = False
    train_loss = 0
    while True:
        epoch += 1
        epoch_start_time = time.time()
        do_break = False
        for batch in train_data:
            start_time = time.time()
            global_step += 1
            loss = trainer.update(batch, eval=False) # update step
            train_loss += loss
            if global_step % args.log_interval == 0:
                duration = time.time() - start_time
                logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))

            if global_step % args.eval_interval == 0:
                new_best = ""
                dev_loss = 0.0
                dev_score = 0.0

                if len(dev_data) > 0:
                    logger.info("Evaluating on dev set...")
                    dev_preds = []
                    dev_loss = 0.0
                    for dev_batch in dev_data:
                        preds, loss = trainer.predict(dev_batch)
                        dev_preds += preds
                        dev_loss += float(loss)
                    dev_preds = unsort(dev_preds, dev_data.data_orig_idx)
                    dev_loss = dev_loss / len(dev_data)
                    dev_doc.add_predictions(dev_preds)
                    dev_doc.write_to_file(args.dev_data_out)
                    dev_score = display_results(dev_doc, args.no_eval_feats, report_oov=args.w_token_index >= 0)

                    # save best model
                    if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
                        logger.info("New best model found")
                        new_best = "*"
                        last_best_step = global_step
                        if args.model_save:
                            trainer.save(args.model_save)
                    dev_score_history += [dev_score]

                train_loss = train_loss / args.eval_interval # avg loss per batch
                logger.info("Step {}/{}: train_loss = {:.6f}, dev_loss = {:.6f}, dev_score = {:.4f}".format(global_step, max_steps, train_loss, dev_loss, dev_score))
                if scores_file:
                    scores_file.write("{}\t{}\t{:.6f}\t{:.6f}\t{:.4f}\t{}\n".format(global_step, epoch, train_loss, dev_loss, dev_score, new_best))
                    scores_file.flush()
                train_loss = 0

            if args.max_steps_before_stop > 0 and global_step - last_best_step >= args.max_steps_before_stop:
                if args.optim == 'adam' and not using_amsgrad:
                    logger.info("Switching to AMSGrad")
                    last_best_step = global_step
                    using_amsgrad = True
                    trainer.set_optimizer('amsgrad', lr=args.lr, betas=(.9, args.beta2), eps=1e-6)
                else:
                    logger.info("Early stopping: dev_score has not improved in {} steps".format(args.max_steps_before_stop))
                    do_break = True
                    break

            if global_step >= args.max_steps:
                do_break = True
                break

        if do_break: break

        epoch_duration = time.time() - epoch_start_time
        logger.info("Finished epoch {} after step {}, {:.3f} sec/epoch".format(epoch, global_step, epoch_duration))
        train_data.reshuffle()

    logger.info("Training ended with {} steps in epoch {}".format(global_step, epoch))

    if len(dev_score_history) > 0:
        best_score, best_step = max(dev_score_history), np.argmax(dev_score_history)+1
        logger.info("Best dev score = {:.2f} at step {}".format(best_score*100, best_step * args.eval_interval))
    elif args.model_save:
        logger.info("Dev set never evaluated, saving final model")
        trainer.save(args.model_save)
    return trainer, pretrained
예제 #6
0
def decode_dataset(model, src, tgt, config, k=20):
    """Evaluate model."""
    inputs = []
    preds = []
    top_k_preds = []
    auxs = []
    ground_truths = []
    raw_srcs = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output, side_info, raw_src = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config,
            is_test=True)
        input_lines_src, output_lines_src, srclens, srcmask, indices = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        _, raw_src, _, _, _ = raw_src
        side_info, _, _, _, _ = side_info

        # TODO -- beam search
        tgt_pred_top_k = decode_minibatch(config['data']['max_len'],
                                          tgt['tok2id']['<s>'],
                                          model,
                                          input_lines_src,
                                          srclens,
                                          srcmask,
                                          input_ids_aux,
                                          auxlens,
                                          auxmask,
                                          side_info,
                                          k=k)

        # convert inputs/preds/targets/aux to human-readable form
        inputs += ids_to_toks(output_lines_src, src['id2tok'], indices)
        ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'], indices)
        raw_srcs += ids_to_toks(raw_src, src['id2tok'], indices)

        # TODO -- refactor this stuff!! it's shitty
        # get the "offical" predictions from the model
        pred_toks, pred_lens = ids_to_toks(tgt_pred_top_k[:, :, 0],
                                           tgt['id2tok'],
                                           indices,
                                           save_cuts=True)
        preds += pred_toks
        # now get all the other top-k prediction levels
        top_k_pred = [pred_toks]
        for i in range(k - 1):
            top_k_pred.append(
                ids_to_toks(tgt_pred_top_k[:, :, i + 1],
                            tgt['id2tok'],
                            indices,
                            cuts=pred_lens))
        # top_k_pred is [k, batch, length] where length is ragged
        # but we want it in [batch, length, k]. Manual transpose b/c ragged :(
        batch_size = len(top_k_pred[0])  # could be variable at test time
        pred_lens = data.unsort(pred_lens, indices)
        top_k_pred_transposed = [[] for _ in range(batch_size)]
        for bi in range(batch_size):
            for ti in range(pred_lens[bi]):
                top_k_pred_transposed[bi] += [[
                    top_k_pred[ki][bi][ti] for ki in range(k)
                ]]
        top_k_preds += top_k_pred_transposed

        if config['model']['model_type'] == 'delete':
            auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy()
                     ]  # because of list comp in inference_metrics()
        elif config['model']['model_type'] == 'delete_retrieve':
            auxs += ids_to_toks(input_ids_aux, tgt['id2tok'], indices)
        elif config['model']['model_type'] == 'seq2seq':
            auxs += ['None' for _ in range(batch_size)]

    return inputs, preds, top_k_preds, ground_truths, auxs, raw_srcs