def setUp(self): temp_dir = self.get_temp_dir() if TransformerTaskTest.local_flags is None: misc.define_transformer_flags() # Loads flags, array cannot be blank. flags.FLAGS(['foo']) TransformerTaskTest.local_flags = flagsaver.save_flag_values() else: flagsaver.restore_flag_values(TransformerTaskTest.local_flags) FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.param_set = 'tiny' FLAGS.use_synthetic_data = True FLAGS.steps_between_evals = 1 FLAGS.train_steps = 2 FLAGS.validation_steps = 1 FLAGS.batch_size = 8 FLAGS.num_gpus = 1 FLAGS.distribution_strategy = 'off' FLAGS.dtype = 'fp32' self.model_dir = FLAGS.model_dir self.temp_dir = temp_dir self.vocab_file = os.path.join(temp_dir, 'vocab') self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size'] self.bleu_source = os.path.join(temp_dir, 'bleu_source') self.bleu_ref = os.path.join(temp_dir, 'bleu_ref') self.orig_policy = tf.keras.mixed_precision.experimental.global_policy( )
return opt def _ensure_dir(log_dir): """Makes log dir if not existed.""" if not tf.io.gfile.exists(log_dir): tf.io.gfile.makedirs(log_dir) def main(_): flags_obj = flags.FLAGS with logger.benchmark_context(flags_obj): task = TransformerTask(flags_obj) if flags_obj.mode == "train": task.train() elif flags_obj.mode == "predict": task.predict() elif flags_obj.mode == "eval": task.eval() else: raise ValueError("Invalid mode {}".format(flags_obj.mode)) if __name__ == "__main__": tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) misc.define_transformer_flags() app.run(main)