示例#1
0
def evaluate_r2a_batch(model, task, batch, args, writer=None):
    '''
        Evaluate the network on a batch of examples for the target task

        model: a dictionary of networks
        task: the name of the task
        batch: a batch of examples for the specified task
        args: the overall argument
        writer: a file object. If not none, will write the prediction result and the generated
            attention to the file
    '''
    # get the current batch
    text, rat_freq, rationale, gold_att, _, text_len, label, raw, _ = batch

    # convert to variable and tensor
    text, text_len, rat_freq, rationale, gold_att = _to_tensor(
        [text, text_len, rat_freq, rationale, gold_att], args.cuda)
    text_mask = _get_mask(text_len, args.cuda)

    # Encoder
    hidden, _ = model['encoder'](text, text_len, False)
    invar_hidden = model['transform'](hidden)

    # run r2a to generate attention
    pred_att, log_pred_att = model['r2a'](invar_hidden, rationale, rat_freq,
                                          text_len, text_mask)

    normalized_rationale = rationale * text_mask
    normalized_rationale = normalized_rationale / torch.sum(
        normalized_rationale, dim=1, keepdim=True)
    uniform = text_mask / torch.sum(text_mask, dim=1, keepdim=True)

    loss_p2g = 1 - F.cosine_similarity(pred_att, gold_att)
    loss_p2g = _to_numpy(loss_p2g)

    loss_rationale = 1 - F.cosine_similarity(normalized_rationale, gold_att)
    loss_rationale = _to_numpy(loss_rationale)

    loss_uniform = 1 - F.cosine_similarity(uniform, gold_att)
    loss_uniform = _to_numpy(loss_uniform)

    gold_att, rationale, pred_att, rat_freq = _to_numpy(
        [gold_att, rationale, pred_att, rat_freq])

    # Write attention to a tsv file (provide data for R2A model)
    if writer:
        data_utils.write(writer, task, raw, label, gold_att, rationale,
                         pred_att, rat_freq)

    return {
        'loss_p2g': loss_p2g,
        'loss_uniform': loss_uniform,
        'loss_rationale': loss_rationale,
    }
示例#2
0
def evaluate_batch(model,
                   optimizer,
                   task,
                   batch,
                   src_batches,
                   tar_batches,
                   args,
                   writer=None):
    '''
        Evaluate the network on a batch of examples

        model: a dictionary of networks
        optimizer: the optimizer that updates the network weights
        task: the name of the task
        batch: a batch of examples for the specified task
        src_batches: an iterator that generates a batch of source examples (used for estimating the
            wasserstein distance)
        tar_batches: an iterator that generates a batch of source examples (used for estimating the
            wasserstein distance)
        args: the overall argument
        writer: a file object. If not none, will write the prediction result and the generated
            attention to the file
    '''
    # ------------------------------------------------------------------------
    # Step 1:  Training the critic network
    # ------------------------------------------------------------------------
    # set all network to eval mode except the critic.
    for key in model.keys():
        model[key].eval()
        if key in optimizer:
            optimizer[key].zero_grad()

    # train critic network for critic_steps
    if args.l_wd != 0 and (args.mode == 'train_r2a'
                           or args.mode == 'test_r2a'):
        model['critic'].train()
        i = 0
        while True:
            # get target and source input, text only, no labels
            tar_text, _, _, _, _, tar_text_len, _, _, _ = next(tar_batches)
            tar_text, tar_text_len = _to_tensor([tar_text, tar_text_len],
                                                args.cuda)
            src_text, _, _, _, _, src_text_len, _, _, _ = next(src_batches)
            src_text, src_text_len = _to_tensor([src_text, src_text_len],
                                                args.cuda)

            # run the encoder
            tar_hidden, _ = model['encoder'](tar_text, tar_text_len, False)
            src_hidden, _ = model['encoder'](src_text, src_text_len, False)

            # apply the transformation layer
            invar_tar_hidden = model['transform'](tar_hidden)
            invar_src_hidden = model['transform'](src_hidden)

            # run the critic network
            optimizer['critic'].zero_grad()
            loss_wd, grad_penalty = model['critic'](invar_src_hidden.detach(),
                                                    src_text_len,
                                                    invar_tar_hidden.detach(),
                                                    tar_text_len, False)
            loss = -loss_wd + args.l_grad_penalty * grad_penalty

            # backprop
            loss.backward()
            optimizer['critic'].step()

            # by definition, loss_wd should be non-negative. If it is negative, it means the critic
            # network is not good enough. Thus, we need to train it more.
            i += 1
            if i >= args.critic_steps and _to_number(loss_wd) > 0:
                break

    model['critic'].eval()

    # ------------------------------------------------------------------------
    # Step 2: Run all other networks
    # ------------------------------------------------------------------------

    # get the current batch
    text, rat_freq, rationale, gold_att, pred_att, text_len, label, raw, _ = batch

    # convert to variable and tensor
    text, text_len, rat_freq, rationale, gold_att, pred_att, label = _to_tensor(
        [text, text_len, rat_freq, rationale, gold_att, pred_att, label],
        args.cuda)
    text_mask = _get_mask(text_len, args.cuda)

    # Run the encoder on source
    hidden, loss_src_lm = model['encoder'](text, text_len, True)
    invar_hidden = model['transform'](hidden)
    loss_src_lm = np.ones(len(raw)) * _to_number(loss_src_lm)

    # Estimating l_wd
    loss_tar_lm = 0
    if args.l_wd != 0 and (args.mode == 'test_r2a'
                           or args.mode == 'train_r2a'):
        # Run the encoder on target
        tar_text, _, _, _, _, tar_text_len, _, _, _ = next(tar_batches)

        # truncate to match the src batch size
        tar_text_len = tar_text_len[:len(raw)]
        tar_text = tar_text[:len(raw), :max(tar_text_len)]

        # convert to tensor and variable
        tar_text, tar_text_len = _to_tensor([tar_text, tar_text_len],
                                            args.cuda)

        # run the encoder on target
        tar_hidden, loss_tar_lm = model['encoder'](tar_text, tar_text_len,
                                                   True)
        invar_tar_hidden = model['transform'](tar_hidden)

        loss_wd, _ = model['critic'](invar_hidden, text_len, invar_tar_hidden,
                                     tar_text_len, True)
        loss_wd = np.ones(len(raw)) * _to_number(loss_wd)

    else:
        loss_wd = np.zeros(len(raw))

    loss_tar_lm = np.ones(len(raw)) * _to_number(loss_tar_lm)

    # Classifier
    out, att, log_att = model[task](hidden, text_mask)

    if args.num_classes[task] == 1:
        loss_lbl = _to_numpy(
            F.mse_loss(torch.sigmoid(out.squeeze(1)), label, reduce=False))
        pred_lbl = _to_numpy(torch.sigmoid(out.squeeze(1)))
    else:
        loss_lbl = _to_numpy(F.cross_entropy(out, label, reduce=False))
        pred_lbl = np.argmax(_to_numpy(out), axis=1)

    true_lbl = _to_numpy(label)

    if _to_number(torch.min(torch.sum(rationale, dim=1))) < 0.5:
        # no words are annotated as rationale, add a small eps to avoid numerical error
        rationale = rationale + 1e-6

    # normalize the rationale score by the number of tokens in the document
    normalized_rationale = rationale * text_mask
    normalized_rationale = normalized_rationale / torch.sum(
        normalized_rationale, dim=1, keepdim=True)

    if args.mode == 'train_clf' or args.mode == 'test_clf':
        # in this case, pred_att is loaded from the generated file and it provide supervision
        # for the attention of the classifier
        if args.att_target == 'gold_att':
            target = gold_att
        elif args.att_target == 'rationale':
            target = normalized_rationale
        elif args.att_target == 'pred_att':
            target = pred_att
        else:
            raise ValueError('Invalid supervision type.')

        log_pred_att = torch.log(pred_att)

    elif args.mode == 'train_r2a':
        # in this case, att (which is derived from the source multitask learning module)
        # is the supervision target for pred_att
        pred_att, log_pred_att = model['r2a'](invar_hidden, rationale,
                                              rat_freq, text_len, text_mask)

    else:
        raise ValueError('Invalid mode')

    loss_a2r = 1 - F.cosine_similarity(att, normalized_rationale)
    loss_a2r = _to_numpy(loss_a2r)

    loss_r2a = 1 - F.cosine_similarity(att, pred_att)
    loss_r2a = _to_numpy(loss_r2a)

    # Write attention to a tsv file
    if writer:
        gold_att, rationale, pred_att, rat_freq = _to_numpy(
            [att, rationale, pred_att, rat_freq])
        data_utils.write(writer, task, raw, true_lbl, gold_att, rationale,
                         pred_att, rat_freq)

    return {
        'true_lbl': true_lbl,
        'pred_lbl': pred_lbl,
        'loss_r2a': loss_r2a,
        'loss_lbl': loss_lbl,
        'loss_wd': loss_wd,
        'loss_a2r': loss_a2r,
        'loss_src_lm': loss_src_lm,
        'loss_tar_lm': loss_tar_lm,
    }