コード例 #1
0
ファイル: test_simultaneous.py プロジェクト: yzhen-li/xnmt
 def test_policy(self):
     event_trigger.set_train(True)
     self.model.policy_learning = PolicyGradient(input_dim=3 *
                                                 self.layer_dim)
     mle_loss = MLELoss()
     loss = mle_loss.calc_loss(self.model, self.src[0], self.trg[0])
     event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
コード例 #2
0
 def test_reinforce_loss(self):
     fertility_loss = GlobalFertilityLoss()
     mle_loss = MLELoss()
     loss = CompositeLoss(pt_losses=[mle_loss, fertility_loss]).calc_loss(
         self.model, self.src[0], self.trg[0])
     reinforce_loss = event_trigger.calc_additional_loss(
         self.trg[0], self.model, loss)
     pl = self.model.encoder.policy_learning
     # Ensure correct length
     src = self.src[0]
     mask = src.mask.np_arr
     outputs = self.segmenting_encoder.compose_output
     actions = self.segmenting_encoder.segment_actions
     # Ensure sample == outputs
     for i, sample_item in enumerate(actions):
         # The last segmentation is 1
         self.assertEqual(sample_item[-1], src[i].len_unpadded())
     self.assertTrue("mle" in loss.expr_factors)
     self.assertTrue("global_fertility" in loss.expr_factors)
     self.assertTrue("rl_reinf" in reinforce_loss.expr_factors)
     self.assertTrue("rl_baseline" in reinforce_loss.expr_factors)
     self.assertTrue("rl_confpen" in reinforce_loss.expr_factors)
     # Ensure we are sampling from the policy learning
     self.assertEqual(self.model.encoder.segmenting_action,
                      SegmentingSeqTransducer.SegmentingAction.POLICY)
コード例 #3
0
ファイル: loss_calculators.py プロジェクト: ustcmike/xnmt
 def calc_loss(self,
               model: 'model_base.ConditionedModel',
               src: Union[sent.Sentence, 'batcher.Batch'],
               trg: Union[sent.Sentence, 'batcher.Batch']) -> losses.FactoredLossExpr:
   loss_builder = losses.FactoredLossExpr()
   for _ in range(self.repeat):
     standard_loss = self.child_loss.calc_loss(model, src, trg)
     additional_loss = event_trigger.calc_additional_loss(trg, model, standard_loss)
     loss_builder.add_factored_loss_expr(standard_loss)
     loss_builder.add_factored_loss_expr(additional_loss)
   return loss_builder
コード例 #4
0
ファイル: test_segmenting.py プロジェクト: seeledu/xnmt-devel
 def calc_loss_single_batch(self):
   loss = MLELoss().calc_loss(self.model, self.src[0], self.trg[0])
   reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
   return loss, reinforce_loss