Esempio n. 1
0
 def setUp(self):
   temp_dir = self.get_temp_dir()
   if TransformerTaskTest.local_flags is None:
     misc.define_transformer_flags()
     # Loads flags, array cannot be blank.
     flags.FLAGS(['foo'])
     TransformerTaskTest.local_flags = flagsaver.save_flag_values()
   else:
     flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
   FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
   FLAGS.param_set = 'tiny'
   FLAGS.use_synthetic_data = True
   FLAGS.steps_between_evals = 1
   FLAGS.train_steps = 2
   FLAGS.validation_steps = 1
   FLAGS.batch_size = 8
   FLAGS.max_length = 1
   FLAGS.num_gpus = 1
   FLAGS.distribution_strategy = 'off'
   FLAGS.dtype = 'fp32'
   self.model_dir = FLAGS.model_dir
   self.temp_dir = temp_dir
   self.vocab_file = os.path.join(temp_dir, 'vocab')
   self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
   self.bleu_source = os.path.join(temp_dir, 'bleu_source')
   self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
   self.orig_policy = (
       tf.compat.v2.keras.mixed_precision.experimental.global_policy())
Esempio n. 2
0

def main(_):
  flags_obj = flags.FLAGS
  if flags_obj.enable_mlir_bridge:
    tf.config.experimental.enable_mlir_bridge()
  task = TransformerTask(flags_obj)

  # Execute flag override logic for better model performance
  if flags_obj.tf_gpu_thread_mode:
    keras_utils.set_gpu_thread_mode_and_count(
        per_gpu_thread_count=flags_obj.per_gpu_thread_count,
        gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
        num_gpus=flags_obj.num_gpus,
        datasets_num_private_threads=flags_obj.datasets_num_private_threads)

  if flags_obj.mode == "train":
    task.train()
  elif flags_obj.mode == "predict":
    task.predict()
  elif flags_obj.mode == "eval":
    task.eval()
  else:
    raise ValueError("Invalid mode {}".format(flags_obj.mode))


if __name__ == "__main__":
  logging.set_verbosity(logging.INFO)
  misc.define_transformer_flags()
  app.run(main)