예제 #1
0
    def setUp(self):
        xnmt.events.clear()
        self.model_context = ModelContext()
        self.model_context.dynet_param_collection = PersistentParamCollection(
            "some_file", 1)
        self.model = DefaultTranslator(
            src_embedder=SimpleWordEmbedder(self.model_context,
                                            vocab_size=100),
            encoder=BiLSTMSeqTransducer(self.model_context),
            attender=MlpAttender(self.model_context),
            trg_embedder=SimpleWordEmbedder(self.model_context,
                                            vocab_size=100),
            decoder=MlpSoftmaxDecoder(self.model_context,
                                      vocab_size=100,
                                      bridge=CopyBridge(self.model_context,
                                                        dec_layers=1)),
        )
        self.model.initialize_training_strategy(TrainingStrategy())
        self.model.set_train(False)
        self.model.initialize_generator()

        self.training_corpus = BilingualTrainingCorpus(
            train_src="examples/data/head.ja",
            train_trg="examples/data/head.en",
            dev_src="examples/data/head.ja",
            dev_trg="examples/data/head.en")
        self.corpus_parser = BilingualCorpusParser(
            src_reader=PlainTextReader(),
            trg_reader=PlainTextReader(),
            training_corpus=self.training_corpus)
예제 #2
0
파일: test_encoder.py 프로젝트: nvog/xnmt
 def setUp(self):
     xnmt.events.clear()
     self.model_context = ModelContext()
     self.model_context.dynet_param_collection = PersistentParamCollection(
         "some_file", 1)
     self.training_corpus = BilingualTrainingCorpus(
         train_src="examples/data/head.ja",
         train_trg="examples/data/head.en",
         dev_src="examples/data/head.ja",
         dev_trg="examples/data/head.en")
     self.corpus_parser = BilingualCorpusParser(
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader(),
         training_corpus=self.training_corpus)
예제 #3
0
 def test_overfitting(self):
     self.model_context = ModelContext()
     self.model_context.dynet_param_collection = PersistentParamCollection(
         "some_file", 1)
     self.model_context.default_layer_dim = 16
     train_args = {}
     training_corpus = BilingualTrainingCorpus(
         train_src="examples/data/head.ja",
         train_trg="examples/data/head.en",
         dev_src="examples/data/head.ja",
         dev_trg="examples/data/head.en")
     train_args['corpus_parser'] = BilingualCorpusParser(
         training_corpus=training_corpus,
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader())
     train_args['training_strategy'] = TrainingStrategy()
     train_args['model'] = DefaultTranslator(
         src_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         encoder=BiLSTMSeqTransducer(self.model_context),
         attender=MlpAttender(self.model_context),
         trg_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         decoder=MlpSoftmaxDecoder(self.model_context, vocab_size=100),
     )
     train_args['model_file'] = None
     train_args['save_num_checkpoints'] = 0
     train_args['trainer'] = AdamTrainer(self.model_context, alpha=0.1)
     train_args['batcher'] = SrcBatcher(batch_size=10,
                                        break_ties_randomly=False)
     training_regimen = xnmt.train.TrainingRegimen(
         yaml_context=self.model_context, **train_args)
     training_regimen.model_context = self.model_context
     for _ in range(50):
         training_regimen.one_epoch(update_weights=True)
     self.assertAlmostEqual(
         0.0,
         training_regimen.logger.epoch_loss.loss_values['loss'] /
         training_regimen.logger.epoch_words,
         places=2)
예제 #4
0
 def test_train_dev_loss_equal(self):
     self.model_context = ModelContext()
     self.model_context.dynet_param_collection = NonPersistentParamCollection(
     )
     train_args = {}
     training_corpus = BilingualTrainingCorpus(
         train_src="examples/data/head.ja",
         train_trg="examples/data/head.en",
         dev_src="examples/data/head.ja",
         dev_trg="examples/data/head.en")
     train_args['corpus_parser'] = BilingualCorpusParser(
         training_corpus=training_corpus,
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader())
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         encoder=BiLSTMSeqTransducer(self.model_context),
         attender=MlpAttender(self.model_context),
         trg_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         decoder=MlpSoftmaxDecoder(self.model_context, vocab_size=100),
     )
     train_args['trainer'] = None
     train_args['batcher'] = SrcBatcher(batch_size=5,
                                        break_ties_randomly=False)
     train_args['run_for_epochs'] = 1
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         yaml_context=self.model_context, **train_args)
     training_regimen.model_context = self.model_context
     training_regimen.run_training(update_weights=False)
     self.assertAlmostEqual(
         training_regimen.logger.epoch_loss.loss_values['loss'] /
         training_regimen.logger.epoch_words,
         training_regimen.logger.dev_score.loss)