Exemplo n.º 1
0
 def setUp(self):
     xnmt.events.clear()
     self.input_reader = PlainTextReader()
     list(self.input_reader.read_sents('examples/data/head.ja'))
     self.input_reader.freeze()
     self.context = ExpGlobal(
         dynet_param_collection=PersistentParamCollection(None, 0))
Exemplo n.º 2
0
    def setUp(self):
        xnmt.events.clear()
        self.exp_global = ExpGlobal(
            dynet_param_collection=PersistentParamCollection("some_file", 1))
        self.model = DefaultTranslator(
            src_reader=PlainTextReader(),
            trg_reader=PlainTextReader(),
            src_embedder=SimpleWordEmbedder(exp_global=self.exp_global,
                                            vocab_size=100),
            encoder=BiLSTMSeqTransducer(exp_global=self.exp_global),
            attender=MlpAttender(exp_global=self.exp_global),
            trg_embedder=SimpleWordEmbedder(exp_global=self.exp_global,
                                            vocab_size=100),
            decoder=MlpSoftmaxDecoder(exp_global=self.exp_global,
                                      vocab_size=100,
                                      bridge=CopyBridge(
                                          exp_global=self.exp_global,
                                          dec_layers=1)),
        )
        self.model.set_train(False)
        self.model.initialize_generator(beam=1)

        self.src_data = list(
            self.model.src_reader.read_sents("examples/data/head.ja"))
        self.trg_data = list(
            self.model.trg_reader.read_sents("examples/data/head.en"))
Exemplo n.º 3
0
    def setUp(self):
        xnmt.events.clear()
        self.exp_global = ExpGlobal(
            dynet_param_collection=NonPersistentParamCollection())

        self.src_reader = PlainTextReader()
        self.trg_reader = PlainTextReader()
        self.src_data = list(
            self.src_reader.read_sents("examples/data/head.ja"))
        self.trg_data = list(
            self.trg_reader.read_sents("examples/data/head.en"))
Exemplo n.º 4
0
 def test_overfitting(self):
     self.exp_global = ExpGlobal(
         dynet_param_collection=NonPersistentParamCollection(), dropout=0.0)
     self.exp_global.default_layer_dim = 16
     batcher = SrcBatcher(batch_size=10, break_ties_randomly=False)
     train_args = {}
     train_args['src_file'] = "examples/data/head.ja"
     train_args['trg_file'] = "examples/data/head.en"
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader(),
         src_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100),
         encoder=BiLSTMSeqTransducer(self.exp_global),
         attender=MlpAttender(self.exp_global),
         trg_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100),
         decoder=MlpSoftmaxDecoder(self.exp_global,
                                   vocab_size=100,
                                   bridge=CopyBridge(
                                       exp_global=self.exp_global,
                                       dec_layers=1)),
     )
     train_args['dev_tasks'] = [
         LossEvalTask(model=train_args['model'],
                      src_file="examples/data/head.ja",
                      ref_file="examples/data/head.en",
                      batcher=batcher)
     ]
     train_args['run_for_epochs'] = 1
     train_args['trainer'] = AdamTrainer(self.exp_global, alpha=0.1)
     train_args['batcher'] = batcher
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         exp_global=self.exp_global, **train_args)
     training_regimen.exp_global = self.exp_global
     for _ in range(50):
         training_regimen.run_training(save_fct=lambda: None,
                                       update_weights=True)
     self.assertAlmostEqual(0.0,
                            training_regimen.logger.epoch_loss.sum() /
                            training_regimen.logger.epoch_words,
                            places=2)