def on_calc_additional_loss(self, reward): if not self.learn_segmentation: return None ret = LossBuilder() if self.length_prior_alpha > 0: reward += self.segment_length_prior * self.length_prior_alpha reward = dy.cdiv(reward - dy.mean_batches(reward), dy.std_batches(reward)) # Baseline Loss if self.use_baseline: baseline_loss = [] for i, baseline in enumerate(self.bs): baseline_loss.append(dy.squared_distance(reward, baseline)) ret.add_loss("Baseline", dy.esum(baseline_loss)) # Reinforce Loss lmbd = self.lmbd.get_value(self.warmup_counter) if lmbd > 0.0: reinforce_loss = [] # Calculating the loss of the baseline and reinforce for i in range(len(self.segment_decisions)): ll = dy.pick_batch(self.segment_logsoftmaxes[i], self.segment_decisions[i]) if self.use_baseline: r_i = reward - self.bs[i] else: r_i = reward reinforce_loss.append(dy.logistic(r_i) * ll) ret.add_loss("Reinforce", -dy.esum(reinforce_loss) * lmbd) # Total Loss return ret
def on_calc_additional_loss(self, translator_loss): if not self.learn_segmentation or self.segment_decisions is None: return None reward = -translator_loss["mle"] if not self.log_reward: reward = dy.exp(reward) reward = dy.nobackprop(reward) # Make sure that reward is not scalar, but rather based on the each batch item assert reward.dim()[1] == len(self.src_sent) # Mask enc_mask = self.enc_mask.get_active_one_mask().transpose() if self.enc_mask is not None else None # Compose the lose ret = LossBuilder() ## Length prior alpha = self.length_prior_alpha.value() if self.length_prior_alpha is not None else 0 if alpha > 0: reward += self.segment_length_prior * alpha # reward z-score normalization if self.z_normalization: reward = dy.cdiv(reward-dy.mean_batches(reward), dy.std_batches(reward) + EPS) ## Baseline Loss if self.use_baseline: baseline_loss = [] for i, baseline in enumerate(self.bs): loss = dy.squared_distance(reward, baseline) if enc_mask is not None: loss = dy.cmult(dy.inputTensor(enc_mask[i], batched=True), loss) baseline_loss.append(loss) ret.add_loss("Baseline", dy.esum(baseline_loss)) if self.print_sample: print(dy.exp(self.segment_logsoftmaxes[i]).npvalue().transpose()[0]) ## Reinforce Loss lmbd = self.lmbd.value() if lmbd > 0.0: reinforce_loss = [] # Calculating the loss of the baseline and reinforce for i in range(len(self.segment_decisions)): ll = dy.pick_batch(self.segment_logsoftmaxes[i], self.segment_decisions[i]) if self.use_baseline: r_i = reward - dy.nobackprop(self.bs[i]) else: r_i = reward if enc_mask is not None: ll = dy.cmult(dy.inputTensor(enc_mask[i], batched=True), ll) reinforce_loss.append(r_i * -ll) loss = dy.esum(reinforce_loss) * lmbd ret.add_loss("Reinforce", loss) if self.confidence_penalty: ls_loss = self.confidence_penalty(self.segment_logsoftmaxes, enc_mask) ret.add_loss("Confidence Penalty", ls_loss) # Total Loss return ret
def calc_loss(self, rewards): loss = FactoredLossExpr() ## Z-Normalization if self.z_normalization: reward_batches = dy.concatenate_to_batch(rewards) mean_batches = dy.mean_batches(reward_batches) std_batches = dy.std_batches(reward_batches) rewards = [ dy.cdiv(reward - mean_batches, std_batches) for reward in rewards ] ## Calculate baseline if self.baseline is not None: pred_reward, baseline_loss = self.calc_baseline_loss(rewards) loss.add_loss("rl_baseline", baseline_loss) ## Calculate Confidence Penalty if self.confidence_penalty: loss.add_loss("rl_confpen", self.confidence_penalty.calc_loss(self.policy_lls)) ## Calculate Reinforce Loss reinf_loss = [] # Loop through all action in one sequence for i, (policy, action_sample) in enumerate(zip(self.policy_lls, self.actions)): # Discount the reward if we use baseline if self.baseline is not None: rewards = [reward - pred_reward[i] for reward in rewards] # Main Reinforce calculation sample_loss = [] for action, reward in zip(action_sample, rewards): ll = dy.pick_batch(policy, action) if self.valid_pos is not None: ll = dy.pick_batch_elems(ll, self.valid_pos[i]) reward = dy.pick_batch_elems(reward, self.valid_pos[i]) sample_loss.append(dy.sum_batches(ll * reward)) # Take the average of the losses accross multiple samples reinf_loss.append(dy.esum(sample_loss) / len(sample_loss)) loss.add_loss("rl_reinf", self.weight * -dy.esum(reinf_loss)) ## the composed losses return loss