def testTrainLanguageModel(self): src = test_util.make_data_file( os.path.join(self.get_temp_dir(), "src.txt"), ["1 2 3 4", "5 6 7 8 9", "3 2"]) vocab = test_util.make_vocab( os.path.join(self.get_temp_dir(), "vocab.txt"), list(map(str, range(10)))) config = { "data": { "train_features_file": src, "vocabulary": vocab, }, "params": { "learning_rate": 0.0005, "optimizer": "Adam" }, "train": { "batch_size": 10, "max_step": 2, }, } model = models.LanguageModel(decoders.SelfAttentionDecoder( 2, num_units=32, ffn_inner_dim=32), embedding_size=16, reuse_embedding=False) runner = Runner(model, config) runner.train()
def testUpdateVocab(self): config = { "params": { "learning_rate": 0.0005, "optimizer": "Adam" } } runner = self._getTransliterationRunner(config) # Reverse order of non special tokens. new_en_vocab = os.path.join(self.get_temp_dir(), "en.vocab.new") with open(os.path.join(runner._config["model_dir"], "en.vocab")) as en_vocab, \ open(new_en_vocab, "w") as new_vocab: tokens = en_vocab.readlines() for token in tokens[:3]: new_vocab.write(token) for token in reversed(tokens[3:]): new_vocab.write(token) output_dir = os.path.join(self.get_temp_dir(), "updated_vocab") self.assertEqual(runner.update_vocab(output_dir, tgt_vocab=new_en_vocab), output_dir) # Check that the translation is unchanged. new_config = copy.deepcopy(runner._config) new_config["model_dir"] = output_dir new_config["data"]["target_vocabulary"] = new_en_vocab runner = Runner(runner._model, new_config) ar_file, _ = self._makeTransliterationData() en_file = os.path.join(self.get_temp_dir(), "output.txt") runner.infer(ar_file, predictions_file=en_file) with open(en_file) as f: self.assertEqual(next(f).strip(), "a t z m o n")
def _getTransliterationRunner(self, base_config=None, model_version="v2"): model_dir = os.path.join(self.get_temp_dir(), "model") shutil.copytree(os.path.join(test_data, "transliteration-aren-v2", model_version), model_dir) config = {} config["model_dir"] = model_dir config["data"] = { "source_vocabulary": os.path.join(model_dir, "ar.vocab"), "target_vocabulary": os.path.join(model_dir, "en.vocab"), } if base_config is not None: config = misc.merge_dict(config, base_config) model = load_model(model_dir) runner = Runner(model, config) return runner