def setUp(self): layer_dim = 512 xnmt.events.clear() ParamManager.init_param_col() self.model = DefaultTranslator( src_reader=PlainTextReader(), trg_reader=PlainTextReader(), src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), decoder=AutoRegressiveDecoder( input_dim=layer_dim, trg_embed_dim=layer_dim, rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"), transform=NonLinear(input_dim=layer_dim * 2, output_dim=layer_dim), scorer=Softmax(input_dim=layer_dim, vocab_size=100), bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)), ) self.model.set_train(False) self.src_data = list( self.model.src_reader.read_sents("examples/data/head.ja")) self.trg_data = list( self.model.trg_reader.read_sents("examples/data/head.en")) self.search = GreedySearch()
def test_single(self): dy.renew_cg() train_loss = self.model.calc_loss( src=self.src_data[0], trg=self.trg_data[0], loss_calculator=AutoRegressiveMLELoss()).value() dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) output_score = outputs[0].score self.assertAlmostEqual(-output_score, train_loss, places=5)
def test_single(self): dy.renew_cg() train_loss = self.model.calc_loss(src=self.src_data[0], trg=self.trg_data[0], loss_calculator=MLELoss()).value() dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, GreedySearch(), forced_trg_ids=self.trg_data[0]) output_score = outputs[0].score self.assertAlmostEqual(-output_score, train_loss, places=5)
def initialize_generator(self, **kwargs): if kwargs.get("len_norm_type", None) is None: len_norm = xnmt.length_normalization.NoNormalization() else: len_norm = xnmt.serializer.YamlSerializer().initialize_if_needed(kwargs["len_norm_type"]) search_args = {} if kwargs.get("max_len", None) is not None: search_args["max_len"] = kwargs["max_len"] if kwargs.get("beam", None) is None: self.search_strategy = GreedySearch(**search_args) else: search_args["beam_size"] = kwargs.get("beam", 1) search_args["len_norm"] = len_norm self.search_strategy = BeamSearch(**search_args) self.report_path = kwargs.get("report_path", None) self.report_type = kwargs.get("report_type", None)
def test_greedy_vs_beam(self): dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(beam_size=1), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) output_score1 = outputs[0].score dy.renew_cg() outputs = self.model.generate( xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(), forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]])) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)
def initialize_generator(self, **kwargs): if kwargs.get("len_norm_type", None) is None: len_norm = xnmt.length_normalization.NoNormalization() else: len_norm = initialize_object(kwargs["len_norm_type"]) search_args = {} if kwargs.get("max_len", None) is not None: search_args["max_len"] = kwargs["max_len"] self.max_len = kwargs.get("max_len", 50) if kwargs.get("beam", None) is None: self.search_strategy = GreedySearch(**search_args) else: search_args["beam_size"] = kwargs.get("beam", 1) search_args["len_norm"] = len_norm # self.search_strategy = TransformerBeamSearch(**search_args) self.report_path = kwargs.get("report_path", None) self.report_type = kwargs.get("report_type", None)
def test_greedy_vs_beam(self): dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, BeamSearch(beam_size=1), forced_trg_ids=self.trg_data[0]) output_score1 = outputs[0].score dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output(self.src_data[0], 0, GreedySearch(), forced_trg_ids=self.trg_data[0]) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)