示例#1
0
class TestBLEU(unittest.TestCase):
    def setUp(self):
        events.clear()
        self.hyp = ["the taro met the hanako".split()]
        self.ref = ["taro met hanako".split()]

        vocab = Vocab(i2w=["the", "taro", "met", "hanako"])
        self.hyp_id = list(map(vocab.convert, self.hyp[0]))
        self.ref_id = list(map(vocab.convert, self.ref[0]))

    def test_bleu_1gram(self):
        bleu = metrics.BLEUEvaluator(ngram=1)
        exp_bleu = 3.0 / 5.0
        act_bleu = bleu.evaluate(self.ref, self.hyp).value()
        self.assertEqual(act_bleu, exp_bleu)

    @unittest.skipUnless(has_cython(), "requires cython to run")
    def test_bleu_4gram_fast(self):
        bleu = metrics.FastBLEUEvaluator(ngram=4, smooth=1)
        exp_bleu = math.exp(
            math.log(
                (3.0 / 5.0) * (2.0 / 5.0) * (1.0 / 4.0) * (1.0 / 3.0)) / 4.0)
        act_bleu = bleu.evaluate_one_sent(self.ref_id, self.hyp_id)
        self.assertEqual(act_bleu, exp_bleu)
示例#2
0
class TestSegmentingEncoder(unittest.TestCase):
  
  def setUp(self):
    # Seeding
    numpy.random.seed(2)
    random.seed(2)
    layer_dim = 64
    xnmt.events.clear()
    ParamManager.init_param_col()
    self.segment_encoder_bilstm = BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim)
    self.segment_composer = SumComposer()

    self.src_reader = CharFromWordTextReader(vocab=Vocab(vocab_file="examples/data/head.ja.charvocab"))
    self.trg_reader = PlainTextReader(vocab=Vocab(vocab_file="examples/data/head.en.vocab"))
    self.loss_calculator = FeedbackLoss(child_loss=MLELoss(), repeat=5)

    baseline = Linear(input_dim=layer_dim, output_dim=1)
    policy_network = Linear(input_dim=layer_dim, output_dim=2)
    self.poisson_prior = PoissonPrior(mu=3.3)
    self.eps_greedy = EpsilonGreedy(eps_prob=0.0, prior=self.poisson_prior)
    self.conf_penalty = ConfidencePenalty()
    self.policy_gradient = PolicyGradient(input_dim=layer_dim,
                                          output_dim=2,
                                          baseline=baseline,
                                          policy_network=policy_network,
                                          z_normalization=True,
                                          conf_penalty=self.conf_penalty)
    self.length_prior = PoissonLengthPrior(lmbd=3.3, weight=1)
    self.segmenting_encoder = SegmentingSeqTransducer(
      embed_encoder = self.segment_encoder_bilstm,
      segment_composer =  self.segment_composer,
      final_transducer = BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim),
      policy_learning = self.policy_gradient,
      eps_greedy = self.eps_greedy,
      length_prior = self.length_prior,
    )

    self.model = DefaultTranslator(
      src_reader=self.src_reader,
      trg_reader=self.trg_reader,
      src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
      encoder=self.segmenting_encoder,
      attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
      trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
      decoder=AutoRegressiveDecoder(input_dim=layer_dim,
                                    rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim,
                                                             decoder_input_dim=layer_dim, yaml_path="decoder"),
                                    transform=AuxNonLinear(input_dim=layer_dim, output_dim=layer_dim,
                                                           aux_input_dim=layer_dim),
                                    scorer=Softmax(vocab_size=100, input_dim=layer_dim),
                                    trg_embed_dim=layer_dim,
                                    bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
    )
    event_trigger.set_train(True)

    self.layer_dim = layer_dim
    self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
    self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
    my_batcher = batchers.TrgBatcher(batch_size=3)
    self.src, self.trg = my_batcher.pack(self.src_data, self.trg_data)
    dy.renew_cg(immediate_compute=True, check_validity=True)

  def test_reinforce_loss(self):
    fertility_loss = GlobalFertilityLoss()
    mle_loss = MLELoss()
    loss = CompositeLoss(losses=[mle_loss, fertility_loss]).calc_loss(self.model, self.src[0], self.trg[0])
    reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
    pl = self.model.encoder.policy_learning
    # Ensure correct length
    src = self.src[0]
    mask = src.mask.np_arr
    outputs = self.segmenting_encoder.compose_output
    actions = self.segmenting_encoder.segment_actions
    # Ensure sample == outputs
    for i, sample_item in enumerate(actions):
      # The last segmentation is 1
      self.assertEqual(sample_item[-1], src[i].len_unpadded())
    self.assertTrue("mle" in loss.expr_factors)
    self.assertTrue("global_fertility" in loss.expr_factors)
    self.assertTrue("rl_reinf" in reinforce_loss.expr_factors)
    self.assertTrue("rl_baseline" in reinforce_loss.expr_factors)
    self.assertTrue("rl_confpen" in reinforce_loss.expr_factors)
    # Ensure we are sampling from the policy learning
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.POLICY)

  def calc_loss_single_batch(self):
    loss = MLELoss().calc_loss(self.model, self.src[0], self.trg[0])
    reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
    return loss, reinforce_loss

  def test_gold_input(self):
    self.model.encoder.policy_learning = None
    self.model.encoder.eps_greedy = None
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.GOLD)

  @unittest.skipUnless(has_cython(), "requires cython to run")
  def test_sample_input(self):
    self.model.encoder.eps_greedy.eps_prob= 1.0
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.POLICY_SAMPLE)
    self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.PREDEFINED)
  
  def test_policy_train_test(self):
    event_trigger.set_train(True)
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.POLICY_CLP)
    event_trigger.set_train(False)
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.POLICY_AMAX)

  def test_no_policy_train_test(self):
    self.model.encoder.policy_learning = None
    event_trigger.set_train(True)
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.PURE_SAMPLE)
    event_trigger.set_train(False)
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.PURE_SAMPLE)

  def test_sample_during_search(self):
    event_trigger.set_train(False)
    self.model.encoder.sample_during_search = True
    self.calc_loss_single_batch()
    self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.POLICY)

  @unittest.skipUnless(has_cython(), "requires cython to run")
  def test_policy_gold(self):
    self.model.encoder.eps_greedy.prior = GoldInputPrior("segment")
    self.model.encoder.eps_greedy.eps_prob = 1.0
    self.calc_loss_single_batch()

  def test_reporter(self):
    self.model.encoder.reporter = SegmentingReporter("test/tmp/seg-report.log", self.model.src_reader.vocab)
    self.calc_loss_single_batch()
示例#3
0
class TestRunningConfig(unittest.TestCase):
    def setUp(self):
        xnmt.events.clear()

    def test_assemble(self):
        run.main(["test/config/assemble.yaml"])

    def test_cascade(self):
        run.main(["test/config/cascade.yaml"])

    def test_classifier(self):
        run.main(["test/config/classifier.yaml"])

    def test_component_sharing(self):
        run.main(["test/config/component_sharing.yaml"])

    def test_encoders(self):
        run.main(["test/config/encoders.yaml"])

    def test_ensembling(self):
        run.main(["test/config/ensembling.yaml"])

    def test_forced(self):
        run.main(["test/config/forced.yaml"])

    def test_lm(self):
        run.main(["test/config/lm.yaml"])

    def test_load_model(self):
        run.main(["test/config/load_model.yaml"])

    def test_multi_task(self):
        run.main(["test/config/multi_task.yaml"])

    def test_multi_task_speech(self):
        run.main(["test/config/multi_task_speech.yaml"])

    def test_preproc(self):
        run.main(["test/config/preproc.yaml"])

    def test_pretrained_emb(self):
        run.main(["test/config/pretrained_embeddings.yaml"])

    def test_random_search_test_params(self):
        run.main(["test/config/random_search_test_params.yaml"])

    def test_random_search_train_params(self):
        run.main(["test/config/random_search_train_params.yaml"])

    def test_reload(self):
        run.main(["test/config/reload.yaml"])

    def test_segmenting(self):
        run.main(["test/config/seg_report.yaml"])

    def test_reload_exception(self):
        with self.assertRaises(ValueError) as context:
            run.main(["test/config/reload_exception.yaml"])
        self.assertEqual(
            str(context.exception),
            'VanillaLSTMGates: x_t has inconsistent dimension 20, expecting 40'
        )

    def test_report(self):
        run.main(["test/config/report.yaml"])

    @unittest.expectedFailure  # TODO: these tests need to be fixed
    def test_retrieval(self):
        run.main(["test/config/retrieval.yaml"])

    def test_score(self):
        run.main(["test/config/score.yaml"])

    def test_self_attentional_am(self):
        run.main(["test/config/self_attentional_am.yaml"])

    def test_seq_labeler(self):
        run.main(["test/config/seq_labeler.yaml"])

    def test_speech(self):
        run.main(["test/config/speech.yaml"])

    @unittest.expectedFailure  # TODO: these tests need to be fixed
    def test_speech_retrieval(self):
        run.main(["test/config/speech_retrieval.yaml"])

    def test_standard(self):
        run.main(["test/config/standard.yaml"])

    @unittest.expectedFailure  # TODO: these tests need to be fixed
    def test_transformer(self):
        run.main(["test/config/transformer.yaml"])

    @unittest.skipUnless(has_cython(), "requires cython to run")
    def test_search_strategy_reinforce(self):
        run.main(["test/config/reinforce.yaml"])

    @unittest.skipUnless(has_cython(), "requires cython to run")
    def test_search_strategy_minrisk(self):
        run.main(["test/config/minrisk.yaml"])

    def tearDown(self):
        try:
            if os.path.isdir("test/tmp"):
                shutil.rmtree("test/tmp")
        except:
            pass
示例#4
0
文件: test_run.py 项目: yzhen-li/xnmt
class TestRunningConfig(unittest.TestCase):

  def setUp(self):
    xnmt.events.clear()

  def tearDown(self):
    try:
      if os.path.isdir("test/tmp"):
        shutil.rmtree("test/tmp")
    except:
      pass

  def test_assemble(self):
    run.main(["test/config/assemble.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_autobatch_fail(self):
    with self.assertRaises(ValueError) as context:
      run.main(["test/config/autobatch-fail.yaml"])
    self.assertEqual(str(context.exception), 'AutobatchTrainingRegimen forces the batcher to have batch_size 1. Use update_every to set the actual batch size in this regimen.')

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_autobatch(self):
    run.main(["test/config/autobatch.yaml"])

  def test_cascade(self):
    run.main(["test/config/cascade.yaml"])

  def test_classifier(self):
    run.main(["test/config/classifier.yaml"])

  def test_component_sharing(self):
    run.main(["test/config/component_sharing.yaml"])

  @unittest.skipUnless(xnmt.backend_torch, "requires PyTorch backend")
  def test_cudnn_lstm(self):
    run.main(["test/config/cudnn-lstm.yaml"])

  def test_encoders(self):
    run.main(["test/config/encoders.yaml"])

  def test_ensembling(self):
    run.main(["test/config/ensembling.yaml"])

  def test_forced(self):
    run.main(["test/config/forced.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_lattice(self):
    run.main(["test/config/lattice.yaml"])

  def test_lm(self):
    run.main(["test/config/lm.yaml"])

  def test_load_model(self):
    run.main(["test/config/load_model.yaml"])

  def test_multi_task(self):
    run.main(["test/config/multi_task.yaml"])

  def test_multi_task_speech(self):
    run.main(["test/config/multi_task_speech.yaml"])

  def test_preproc(self):
    run.main(["test/config/preproc.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_pretrained_emb(self):
    run.main(["test/config/pretrained_embeddings.yaml"])

  def test_random_search_test_params(self):
    run.main(["test/config/random_search_test_params.yaml"])

  def test_random_search_train_params(self):
    run.main(["test/config/random_search_train_params.yaml"])

  def test_reload(self):
    run.main(["test/config/reload.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_segmenting(self):
    run.main(["test/config/seg_report.yaml"])

  def test_reload_exception(self):
    if xnmt.backend_dynet:
      with self.assertRaises(ValueError) as context:
        run.main(["test/config/reload_exception.yaml"])
        self.assertEqual(str(context.exception), 'VanillaLSTMGates: x_t has inconsistent dimension 20, expecting 40')
    else:
      with self.assertRaises(RuntimeError) as context:
        run.main(["test/config/reload_exception.yaml"])
        self.assertIn("20", str(context.exception))
        self.assertIn("40", str(context.exception))

  def test_report(self):
    run.main(["test/config/report.yaml"])

  @unittest.expectedFailure # TODO: these tests need to be fixed
  def test_retrieval(self):
    run.main(["test/config/retrieval.yaml"])

  def test_score(self):
    run.main(["test/config/score.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  def test_self_attentional_am(self):
    run.main(["test/config/self_attentional_am.yaml"])

  def test_seq_labeler(self):
    run.main(["test/config/seq_labeler.yaml"])

  def test_speech(self):
    run.main(["test/config/speech.yaml"])

  @unittest.expectedFailure # TODO: these tests need to be fixed
  def test_speech_retrieval(self):
    run.main(["test/config/speech_retrieval.yaml"])

  def test_standard(self):
    run.main(["test/config/standard.yaml"])

  @unittest.expectedFailure # TODO: these tests need to be fixed
  def test_transformer(self):
    run.main(["test/config/transformer.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  @unittest.skipUnless(has_cython(), "requires cython to run")
  def test_search_strategy_reinforce(self):
    run.main(["test/config/reinforce.yaml"])

  @unittest.skipUnless(xnmt.backend_dynet, "requires DyNet backend")
  @unittest.skipUnless(has_cython(), "requires cython to run")
  def test_search_strategy_minrisk(self):
    run.main(["test/config/minrisk.yaml"])