def calc_loss(self, translator, src, trg): search_outputs = translator.generate_search_output(src, self.search_strategy) sign = -1 if self.inv_eval else 1 total_loss = FactoredLossExpr() for search_output in search_outputs: self.eval_score = [] for trg_i, sample_i in zip(trg, search_output.word_ids): # Removing EOS sample_i = self.remove_eos(sample_i.tolist()) ref_i = trg_i.words[:trg_i.len_unpadded()] score = self.evaluation_metric.evaluate_one_sent(ref_i, sample_i) self.eval_score.append(sign * score) self.reward = dy.inputTensor(self.eval_score, batched=True) # Composing losses loss = FactoredLossExpr() if self.baseline is not None: baseline_loss = [] losses = [] for state, logsoft, mask in zip(search_output.state, search_output.logsoftmaxes, search_output.mask): bs_score = self.baseline.transform(state) baseline_loss.append(dy.squared_distance(self.reward, bs_score)) loss_i = dy.cmult(logsoft, self.reward - bs_score) valid = list(np.nonzero(mask)[0]) losses.append(dy.cmult(loss_i, dy.inputTensor(mask, batched=True))) loss.add_loss("reinforce", dy.sum_elems(dy.esum(losses))) loss.add_loss("reinf_baseline", dy.sum_elems(dy.esum(baseline_loss))) else: loss.add_loss("reinforce", dy.sum_elems(dy.cmult(self.true_score, dy.esum(logsofts)))) total_loss.add_factored_loss_expr(loss) return loss
def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']): total_loss = FactoredLossExpr() for loss, weight in zip(self.losses, self.loss_weight): total_loss.add_factored_loss_expr(loss.calc_loss(model, src, trg) * weight) return total_loss
def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batcher.Batch'], trg: Union[sent.Sentence, 'batcher.Batch']): loss_builder = FactoredLossExpr() for _ in range(self.repeat): standard_loss = self.child_loss.calc_loss(model, src, trg) additional_loss = event_trigger.calc_additional_loss(trg, model, standard_loss) loss_builder.add_factored_loss_expr(standard_loss) loss_builder.add_factored_loss_expr(additional_loss) return loss_builder