Exemplo n.º 1
0
 def testTrainLanguageModel(self):
     src = test_util.make_data_file(
         os.path.join(self.get_temp_dir(), "src.txt"),
         ["1 2 3 4", "5 6 7 8 9", "3 2"])
     vocab = test_util.make_vocab(
         os.path.join(self.get_temp_dir(), "vocab.txt"),
         list(map(str, range(10))))
     config = {
         "data": {
             "train_features_file": src,
             "vocabulary": vocab,
         },
         "params": {
             "learning_rate": 0.0005,
             "optimizer": "Adam"
         },
         "train": {
             "batch_size": 10,
             "max_step": 2,
         },
     }
     model = models.LanguageModel(decoders.SelfAttentionDecoder(
         2, num_units=32, ffn_inner_dim=32),
                                  embedding_size=16,
                                  reuse_embedding=False)
     runner = Runner(model, config)
     runner.train()
Exemplo n.º 2
0
  def testUpdateVocab(self):
    config = {
        "params": {
            "learning_rate": 0.0005,
            "optimizer": "Adam"
        }
    }
    runner = self._getTransliterationRunner(config)

    # Reverse order of non special tokens.
    new_en_vocab = os.path.join(self.get_temp_dir(), "en.vocab.new")
    with open(os.path.join(runner._config["model_dir"], "en.vocab")) as en_vocab, \
         open(new_en_vocab, "w") as new_vocab:
      tokens = en_vocab.readlines()
      for token in tokens[:3]:
        new_vocab.write(token)
      for token in reversed(tokens[3:]):
        new_vocab.write(token)

    output_dir = os.path.join(self.get_temp_dir(), "updated_vocab")
    self.assertEqual(runner.update_vocab(output_dir, tgt_vocab=new_en_vocab), output_dir)

    # Check that the translation is unchanged.
    new_config = copy.deepcopy(runner._config)
    new_config["model_dir"] = output_dir
    new_config["data"]["target_vocabulary"] = new_en_vocab
    runner = Runner(runner._model, new_config)
    ar_file, _ = self._makeTransliterationData()
    en_file = os.path.join(self.get_temp_dir(), "output.txt")
    runner.infer(ar_file, predictions_file=en_file)
    with open(en_file) as f:
      self.assertEqual(next(f).strip(), "a t z m o n")
Exemplo n.º 3
0
 def _getTransliterationRunner(self, base_config=None, model_version="v2"):
   model_dir = os.path.join(self.get_temp_dir(), "model")
   shutil.copytree(os.path.join(test_data, "transliteration-aren-v2", model_version), model_dir)
   config = {}
   config["model_dir"] = model_dir
   config["data"] = {
       "source_vocabulary": os.path.join(model_dir, "ar.vocab"),
       "target_vocabulary": os.path.join(model_dir, "en.vocab"),
   }
   if base_config is not None:
     config = misc.merge_dict(config, base_config)
   model = load_model(model_dir)
   runner = Runner(model, config)
   return runner