def test_policy(self): event_trigger.set_train(True) self.model.policy_learning = PolicyGradient(input_dim=3 * self.layer_dim) mle_loss = MLELoss() loss = mle_loss.calc_loss(self.model, self.src[0], self.trg[0]) event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
def test_reinforce_loss(self): fertility_loss = GlobalFertilityLoss() mle_loss = MLELoss() loss = CompositeLoss(pt_losses=[mle_loss, fertility_loss]).calc_loss( self.model, self.src[0], self.trg[0]) reinforce_loss = event_trigger.calc_additional_loss( self.trg[0], self.model, loss) pl = self.model.encoder.policy_learning # Ensure correct length src = self.src[0] mask = src.mask.np_arr outputs = self.segmenting_encoder.compose_output actions = self.segmenting_encoder.segment_actions # Ensure sample == outputs for i, sample_item in enumerate(actions): # The last segmentation is 1 self.assertEqual(sample_item[-1], src[i].len_unpadded()) self.assertTrue("mle" in loss.expr_factors) self.assertTrue("global_fertility" in loss.expr_factors) self.assertTrue("rl_reinf" in reinforce_loss.expr_factors) self.assertTrue("rl_baseline" in reinforce_loss.expr_factors) self.assertTrue("rl_confpen" in reinforce_loss.expr_factors) # Ensure we are sampling from the policy learning self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.POLICY)
def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batcher.Batch'], trg: Union[sent.Sentence, 'batcher.Batch']) -> losses.FactoredLossExpr: loss_builder = losses.FactoredLossExpr() for _ in range(self.repeat): standard_loss = self.child_loss.calc_loss(model, src, trg) additional_loss = event_trigger.calc_additional_loss(trg, model, standard_loss) loss_builder.add_factored_loss_expr(standard_loss) loss_builder.add_factored_loss_expr(additional_loss) return loss_builder
def calc_loss_single_batch(self): loss = MLELoss().calc_loss(self.model, self.src[0], self.trg[0]) reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss) return loss, reinforce_loss