def evaluate_r2a_batch(model, task, batch, args, writer=None): ''' Evaluate the network on a batch of examples for the target task model: a dictionary of networks task: the name of the task batch: a batch of examples for the specified task args: the overall argument writer: a file object. If not none, will write the prediction result and the generated attention to the file ''' # get the current batch text, rat_freq, rationale, gold_att, _, text_len, label, raw, _ = batch # convert to variable and tensor text, text_len, rat_freq, rationale, gold_att = _to_tensor( [text, text_len, rat_freq, rationale, gold_att], args.cuda) text_mask = _get_mask(text_len, args.cuda) # Encoder hidden, _ = model['encoder'](text, text_len, False) invar_hidden = model['transform'](hidden) # run r2a to generate attention pred_att, log_pred_att = model['r2a'](invar_hidden, rationale, rat_freq, text_len, text_mask) normalized_rationale = rationale * text_mask normalized_rationale = normalized_rationale / torch.sum( normalized_rationale, dim=1, keepdim=True) uniform = text_mask / torch.sum(text_mask, dim=1, keepdim=True) loss_p2g = 1 - F.cosine_similarity(pred_att, gold_att) loss_p2g = _to_numpy(loss_p2g) loss_rationale = 1 - F.cosine_similarity(normalized_rationale, gold_att) loss_rationale = _to_numpy(loss_rationale) loss_uniform = 1 - F.cosine_similarity(uniform, gold_att) loss_uniform = _to_numpy(loss_uniform) gold_att, rationale, pred_att, rat_freq = _to_numpy( [gold_att, rationale, pred_att, rat_freq]) # Write attention to a tsv file (provide data for R2A model) if writer: data_utils.write(writer, task, raw, label, gold_att, rationale, pred_att, rat_freq) return { 'loss_p2g': loss_p2g, 'loss_uniform': loss_uniform, 'loss_rationale': loss_rationale, }
def evaluate_batch(model, optimizer, task, batch, src_batches, tar_batches, args, writer=None): ''' Evaluate the network on a batch of examples model: a dictionary of networks optimizer: the optimizer that updates the network weights task: the name of the task batch: a batch of examples for the specified task src_batches: an iterator that generates a batch of source examples (used for estimating the wasserstein distance) tar_batches: an iterator that generates a batch of source examples (used for estimating the wasserstein distance) args: the overall argument writer: a file object. If not none, will write the prediction result and the generated attention to the file ''' # ------------------------------------------------------------------------ # Step 1: Training the critic network # ------------------------------------------------------------------------ # set all network to eval mode except the critic. for key in model.keys(): model[key].eval() if key in optimizer: optimizer[key].zero_grad() # train critic network for critic_steps if args.l_wd != 0 and (args.mode == 'train_r2a' or args.mode == 'test_r2a'): model['critic'].train() i = 0 while True: # get target and source input, text only, no labels tar_text, _, _, _, _, tar_text_len, _, _, _ = next(tar_batches) tar_text, tar_text_len = _to_tensor([tar_text, tar_text_len], args.cuda) src_text, _, _, _, _, src_text_len, _, _, _ = next(src_batches) src_text, src_text_len = _to_tensor([src_text, src_text_len], args.cuda) # run the encoder tar_hidden, _ = model['encoder'](tar_text, tar_text_len, False) src_hidden, _ = model['encoder'](src_text, src_text_len, False) # apply the transformation layer invar_tar_hidden = model['transform'](tar_hidden) invar_src_hidden = model['transform'](src_hidden) # run the critic network optimizer['critic'].zero_grad() loss_wd, grad_penalty = model['critic'](invar_src_hidden.detach(), src_text_len, invar_tar_hidden.detach(), tar_text_len, False) loss = -loss_wd + args.l_grad_penalty * grad_penalty # backprop loss.backward() optimizer['critic'].step() # by definition, loss_wd should be non-negative. If it is negative, it means the critic # network is not good enough. Thus, we need to train it more. i += 1 if i >= args.critic_steps and _to_number(loss_wd) > 0: break model['critic'].eval() # ------------------------------------------------------------------------ # Step 2: Run all other networks # ------------------------------------------------------------------------ # get the current batch text, rat_freq, rationale, gold_att, pred_att, text_len, label, raw, _ = batch # convert to variable and tensor text, text_len, rat_freq, rationale, gold_att, pred_att, label = _to_tensor( [text, text_len, rat_freq, rationale, gold_att, pred_att, label], args.cuda) text_mask = _get_mask(text_len, args.cuda) # Run the encoder on source hidden, loss_src_lm = model['encoder'](text, text_len, True) invar_hidden = model['transform'](hidden) loss_src_lm = np.ones(len(raw)) * _to_number(loss_src_lm) # Estimating l_wd loss_tar_lm = 0 if args.l_wd != 0 and (args.mode == 'test_r2a' or args.mode == 'train_r2a'): # Run the encoder on target tar_text, _, _, _, _, tar_text_len, _, _, _ = next(tar_batches) # truncate to match the src batch size tar_text_len = tar_text_len[:len(raw)] tar_text = tar_text[:len(raw), :max(tar_text_len)] # convert to tensor and variable tar_text, tar_text_len = _to_tensor([tar_text, tar_text_len], args.cuda) # run the encoder on target tar_hidden, loss_tar_lm = model['encoder'](tar_text, tar_text_len, True) invar_tar_hidden = model['transform'](tar_hidden) loss_wd, _ = model['critic'](invar_hidden, text_len, invar_tar_hidden, tar_text_len, True) loss_wd = np.ones(len(raw)) * _to_number(loss_wd) else: loss_wd = np.zeros(len(raw)) loss_tar_lm = np.ones(len(raw)) * _to_number(loss_tar_lm) # Classifier out, att, log_att = model[task](hidden, text_mask) if args.num_classes[task] == 1: loss_lbl = _to_numpy( F.mse_loss(torch.sigmoid(out.squeeze(1)), label, reduce=False)) pred_lbl = _to_numpy(torch.sigmoid(out.squeeze(1))) else: loss_lbl = _to_numpy(F.cross_entropy(out, label, reduce=False)) pred_lbl = np.argmax(_to_numpy(out), axis=1) true_lbl = _to_numpy(label) if _to_number(torch.min(torch.sum(rationale, dim=1))) < 0.5: # no words are annotated as rationale, add a small eps to avoid numerical error rationale = rationale + 1e-6 # normalize the rationale score by the number of tokens in the document normalized_rationale = rationale * text_mask normalized_rationale = normalized_rationale / torch.sum( normalized_rationale, dim=1, keepdim=True) if args.mode == 'train_clf' or args.mode == 'test_clf': # in this case, pred_att is loaded from the generated file and it provide supervision # for the attention of the classifier if args.att_target == 'gold_att': target = gold_att elif args.att_target == 'rationale': target = normalized_rationale elif args.att_target == 'pred_att': target = pred_att else: raise ValueError('Invalid supervision type.') log_pred_att = torch.log(pred_att) elif args.mode == 'train_r2a': # in this case, att (which is derived from the source multitask learning module) # is the supervision target for pred_att pred_att, log_pred_att = model['r2a'](invar_hidden, rationale, rat_freq, text_len, text_mask) else: raise ValueError('Invalid mode') loss_a2r = 1 - F.cosine_similarity(att, normalized_rationale) loss_a2r = _to_numpy(loss_a2r) loss_r2a = 1 - F.cosine_similarity(att, pred_att) loss_r2a = _to_numpy(loss_r2a) # Write attention to a tsv file if writer: gold_att, rationale, pred_att, rat_freq = _to_numpy( [att, rationale, pred_att, rat_freq]) data_utils.write(writer, task, raw, true_lbl, gold_att, rationale, pred_att, rat_freq) return { 'true_lbl': true_lbl, 'pred_lbl': pred_lbl, 'loss_r2a': loss_r2a, 'loss_lbl': loss_lbl, 'loss_wd': loss_wd, 'loss_a2r': loss_a2r, 'loss_src_lm': loss_src_lm, 'loss_tar_lm': loss_tar_lm, }