def decode_dataset(model, src, tgt, config):
    """Evaluate model."""
    inputs = []
    preds = []
    auxs = []
    ground_truths = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config['model']['model_type'],
            is_test=True)
        input_lines_src, output_lines_src, srclens, srcmask, indices = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        # TODO -- beam search
        tgt_pred = decode_minibatch(config['data']['max_len'],
                                    tgt['tok2id']['<s>'], model,
                                    input_lines_src, srclens, srcmask,
                                    input_ids_aux, auxlens, auxmask)

        # convert seqs to tokens
        def ids_to_toks(tok_seqs, id2tok):
            out = []
            # take off the gpu
            tok_seqs = tok_seqs.cpu().numpy()
            # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them)
            for line in tok_seqs:
                toks = [id2tok[x] for x in line]
                if '<s>' in toks:
                    toks.remove('<s>')
                cut_idx = toks.index('</s>') if '</s>' in toks else len(toks)
                out.append(toks[:cut_idx])
            # unsort
            out = data.unsort(out, indices)
            return out

        # convert inputs/preds/targets/aux to human-readable form
        inputs += ids_to_toks(output_lines_src, src['id2tok'])
        preds += ids_to_toks(tgt_pred, tgt['id2tok'])
        ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'])

        if config['model']['model_type'] == 'delete':
            auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy()
                     ]  # because of list comp in inference_metrics()
        elif config['model']['model_type'] == 'delete_retrieve':
            auxs += ids_to_toks(input_ids_aux, tgt['id2tok'])
        elif config['model']['model_type'] == 'seq2seq':
            auxs += ['None' for _ in range(len(tgt_pred))]

    return inputs, preds, ground_truths, auxs
def evaluate_lpp(model, src, tgt, config):
    """ evaluate log perplexity WITHOUT decoding
        (i.e., with teacher forcing)
    """
    weight_mask = torch.ones(len(tgt['tok2id']))
    if CUDA:
        weight_mask = weight_mask.cuda()
    weight_mask[tgt['tok2id']['<pad>']] = 0
    loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    if CUDA:
        loss_criterion = loss_criterion.cuda()

    losses = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config['model']['model_type'],
            is_test=True)
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        decoder_logit, decoder_probs = model(input_lines_src, input_lines_tgt,
                                             srcmask, srclens, input_ids_aux,
                                             auxlens, auxmask)

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, len(tgt['tok2id'])),
            output_lines_tgt.view(-1))
        losses.append(loss.data[0])

    return np.mean(losses)
Beispiel #3
0
            os.system("rm %s" % ckpt_path)
        # replace with new checkpoint
        torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch)

        best_metric = cur_metric
        best_epoch = epoch - 1

    losses = []
    for i in range(0, len(src['content']), batch_size):

        if args.overfit:
            i = 50

        batch_idx = i / batch_size

        input_content, input_aux, output = data.minibatch(
            src, tgt, i, batch_size, max_length, config['model']['model_type'])
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        
        decoder_logit, decoder_probs = model(
            input_lines_src, input_lines_tgt, srcmask, srclens,
            input_ids_aux, auxlens, auxmask)

        optimizer.zero_grad()

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, tgt_vocab_size),
            output_lines_tgt.view(-1)
        )