def transformer_params_base(data_dir, vocab_src_name, vocab_tgt_name): """A set of basic hyperparameters.""" hparams = transformer_params() hparams.vocabulary = { "inputs": text_reader.TokenTextEncoder(vocab_filename=os.path.join(data_dir, vocab_src_name)), "targets": text_reader.TokenTextEncoder(vocab_filename=os.path.join(data_dir, vocab_tgt_name))} hparams.hidden_size = 512 hparams.filter_size = 2048 hparams.num_heads = 8 hparams.batching_mantissa_bits = 2 return hparams
def translation_token_generator(data_dir, tmp_dir, train_src_name, train_tgt_name, vocab_src_name, vocab_tgt_name): train_src_path = os.path.join(tmp_dir, train_src_name) train_tgt_path = os.path.join(tmp_dir, train_tgt_name) token_vocab_src_dir = os.path.join(data_dir, vocab_src_name) token_vocab_tgt_dir = os.path.join(data_dir, vocab_tgt_name) if not tf.gfile.Exists(token_vocab_src_dir): tf.gfile.Copy(os.path.join(tmp_dir, vocab_src_name), token_vocab_src_dir) if not tf.gfile.Exists(token_vocab_tgt_dir): tf.gfile.Copy(os.path.join(tmp_dir, vocab_tgt_name), token_vocab_tgt_dir) token_vocab_src = text_reader.TokenTextEncoder( vocab_filename=token_vocab_src_dir) token_vocab_tgt = text_reader.TokenTextEncoder( vocab_filename=token_vocab_tgt_dir) return token_generator_bi(train_src_path, train_tgt_path, token_vocab_src, token_vocab_tgt, 1)