def get_translator(self, model): """Get lazy singleton translator instance.""" if self.translator is None: args_clone = copy.copy(self.args) if self.args.loss_beam: # Override beam size if necessary args_clone.beam = self.args.loss_beam self.translator = generate.build_sequence_generator(args_clone, [model]) return self.translator
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ src_tokens = sample["net_input"]["src_tokens"] beam_size = self.args.rl_num_trajectory bsz, srclen = src_tokens.size() encoder_input = { "src_tokens": sample["net_input"]["src_tokens"], "src_lengths": sample["net_input"]["src_lengths"], } # 1) Generate hypos translator = generate.build_sequence_generator(self.args, self.task, [model]) with torch.no_grad(): seq_hypos = translator.generate( encoder_input, beam_size, maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b), ) word_hypos = [[] for j in range(bsz)] for k in range(bsz): word_hypos[k] = [{"tokens": sample["target"][k]}] ## Mix sequence, word-level hypos hypos = [seq_hypos[j] + word_hypos[j] for j in range(bsz)] hypos = [hypo for _ in hypos for hypo in _] hypos_len = ( torch.tensor([len(hypo["tokens"]) for hypo in hypos]) .type_as(src_tokens) .float() ) # mask index for word-level hypos, e.g., target sentence mask_index = torch.arange(beam_size, (beam_size + 1) * bsz, beam_size + 1).view( -1 ) # 2) Compute (log)-probs via forward models self.self_rescorer.model = model self.self_rescorer.task = self.task model.train() assert self.self_rescorer.model.training, "model should be in training phase" hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs( src_tokens, hypos ) hypo_logprobs, hypo_encoder_outs, forward_logprobs = self.self_rescorer.score_tokens( hypo_encoder_inputs, hypo_tokens ) hypo_logprobs /= hypos_len ** self.args.rescore_length_penalty # 3) Sequence level seq_loss = torch.zeros(1).type_as(hypo_logprobs) if self.args.rl_weight > 0.0: ## 3.1) Compute seq-level rewards with torch.no_grad(): rescorer = Rescorer(self.args, self.task, self.rescore_models) scores = rescorer.score(src_tokens, hypos) rewards = self.combine_score(src_tokens, hypos, hypos_len, scores) assert not rewards.requires_grad, "no grads flow back to generation" ## 3.2) Compute Policy Gradient loss rewards = rewards.type_as(hypo_logprobs) seq_mask = hypo_logprobs.new_ones(hypo_logprobs.size()) seq_mask[mask_index] = 0.0 seq_loss = -1.0 * (seq_mask * hypo_logprobs * rewards).sum() # 4) Word-level word_loss = torch.zeros(1).type_as(hypo_logprobs) if self.args.word_weight > 0.0: ## 4.1) Compute word-level rewards from a left-right rescoring model with torch.no_grad(): teacher_model = self.rescore_models[self.args.word_model] teacher = SimpleModelScorer(self.args, None, teacher_model, self.task) _, _, teacher_logprobs = teacher.score_tokens( hypo_encoder_inputs, hypo_tokens ) ## 4.2) Compute word-level loss f_logprob, f_index = forward_logprobs.topk(self.args.topk_words) word_mask = f_logprob.new_zeros(f_logprob.size()) word_mask[mask_index, :, :] = 1.0 ## KL(p_s || p_t) = \sum p_s log p_s - \sum p_s log p_t, aka RL + maxEnt word_loss = ( word_mask * f_logprob.exp() * (f_logprob - 1.0 * teacher_logprobs.gather(-1, f_index)) ).sum() # 5) Compute Cross-entropy loss eos = self.task.target_dictionary.eos() target_tokens = torch.cat( ( torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]), sample["target"], ), dim=1, ) target_encoder_inputs = ( encoder_input["src_tokens"], [encoder_input["src_lengths"][0].item()], ) target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens( target_encoder_inputs, target_tokens ) nll_loss = -1.0 * target_logprobs.sum() # 6) Gather losses loss = ( self.args.rl_weight * seq_loss + self.args.word_weight * word_loss + nll_loss ) # Logging sample_size = ( sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] ) logging_output = { "loss": utils.item(loss.data) if reduce else loss.data, "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, } return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ src_tokens = sample["net_input"]["src_tokens"] beam_size = self.args.rl_num_trajectory bsz, srclen = src_tokens.size() encoder_input = { "src_tokens": sample["net_input"]["src_tokens"], "src_lengths": sample["net_input"]["src_lengths"], } # 1) Generate hypos translator = generate.build_sequence_generator(self.args, self.task, [model]) with torch.no_grad(): hypos = translator.generate( encoder_input, beam_size, maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b), ) ## flatten nested list hypos = [hypo for _ in hypos for hypo in _] # with length of bsz * beam_size hypos_len = (torch.tensor([len(hypo["tokens"]) for hypo in hypos ]).type_as(src_tokens).float()) # 2) Compute (log)-probs via forward models self.self_rescorer.model = model self.self_rescorer.task = self.task model.train() assert self.self_rescorer.model.training, "model should be in training phase" hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs( src_tokens, hypos) hypo_logprobs, hypo_encoder_outs, _ = self.self_rescorer.score_tokens( hypo_encoder_inputs, hypo_tokens) hypo_logprobs /= hypos_len**self.args.rescore_length_penalty # 3) Compute rewards from rescoring models with torch.no_grad(): rescorer = Rescorer(self.args, self.task, self.rescore_models) scores = rescorer.score(src_tokens, hypos) rewards = self.combine_score(src_tokens, hypos, hypos_len, scores) assert not rewards.requires_grad, "no grads flow back to generation" # 4) Compute Policy Gradient loss rewards = rewards.type_as(hypo_logprobs) rl_loss = -1.0 * (hypo_logprobs * rewards).sum() # 5) Compute Cross-entropy loss eos = self.task.target_dictionary.eos() target_tokens = torch.cat( ( torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]), sample["target"], ), dim=1, ) target_encoder_inputs = ( encoder_input["src_tokens"], [encoder_input["src_lengths"][0].item()], ) target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens( target_encoder_inputs, target_tokens) nll_loss = -1.0 * target_logprobs.sum() # 6) Gather losses loss = self.args.rl_weight * rl_loss + nll_loss # Logging sample_size = (sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]) logging_output = { "loss": utils.item(loss.data) if reduce else loss.data, "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, } return loss, sample_size, logging_output