def setUp(self): temp_dir = self.get_temp_dir() if TransformerTaskTest.local_flags is None: misc.define_transformer_flags() # Loads flags, array cannot be blank. flags.FLAGS(['foo']) TransformerTaskTest.local_flags = flagsaver.save_flag_values() else: flagsaver.restore_flag_values(TransformerTaskTest.local_flags) FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.param_set = 'tiny' FLAGS.use_synthetic_data = True FLAGS.steps_between_evals = 1 FLAGS.train_steps = 2 FLAGS.validation_steps = 1 FLAGS.batch_size = 8 FLAGS.max_length = 1 FLAGS.num_gpus = 1 FLAGS.distribution_strategy = 'off' FLAGS.dtype = 'fp32' self.model_dir = FLAGS.model_dir self.temp_dir = temp_dir self.vocab_file = os.path.join(temp_dir, 'vocab') self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size'] self.bleu_source = os.path.join(temp_dir, 'bleu_source') self.bleu_ref = os.path.join(temp_dir, 'bleu_ref') self.orig_policy = ( tf.compat.v2.keras.mixed_precision.experimental.global_policy())
def main(_): flags_obj = flags.FLAGS if flags_obj.enable_mlir_bridge: tf.config.experimental.enable_mlir_bridge() task = TransformerTask(flags_obj) # Execute flag override logic for better model performance if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj.datasets_num_private_threads) if flags_obj.mode == "train": task.train() elif flags_obj.mode == "predict": task.predict() elif flags_obj.mode == "eval": task.eval() else: raise ValueError("Invalid mode {}".format(flags_obj.mode)) if __name__ == "__main__": logging.set_verbosity(logging.INFO) misc.define_transformer_flags() app.run(main)