Esempio n. 1
0
  def __init__(self, flags_obj):
    """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.
    """
    self.flags_obj = flags_obj
    self.predict_model = None

    # Add flag-defined parameters to params object
    num_gpus = flags_core.get_num_gpus(flags_obj)
    self.distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_core.get_num_gpus(flags_obj))

    self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)

    params["num_gpus"] = num_gpus
    params["data_dir"] = flags_obj.data_dir
    params["model_dir"] = flags_obj.model_dir
    params["static_batch"] = flags_obj.static_batch
    params["max_length"] = flags_obj.max_length
    params["num_parallel_calls"] = (
        flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)

    params["use_synthetic_data"] = flags_obj.use_synthetic_data
    params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
    params["repeat_dataset"] = None
Esempio n. 2
0
 def setUp(self):
     temp_dir = self.get_temp_dir()
     if TransformerTaskTest.local_flags is None:
         misc.define_transformer_flags()
         # Loads flags, array cannot be blank.
         flags.FLAGS(['foo'])
         TransformerTaskTest.local_flags = flagsaver.save_flag_values()
     else:
         flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
     FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
     FLAGS.param_set = 'tiny'
     FLAGS.use_synthetic_data = True
     FLAGS.steps_between_evals = 1
     FLAGS.train_steps = 2
     FLAGS.validation_steps = 1
     FLAGS.batch_size = 8
     FLAGS.num_gpus = 1
     FLAGS.distribution_strategy = 'off'
     FLAGS.dtype = 'fp32'
     self.model_dir = FLAGS.model_dir
     self.temp_dir = temp_dir
     self.vocab_file = os.path.join(temp_dir, 'vocab')
     self.vocab_size = misc.get_model_params(FLAGS.param_set,
                                             0)['vocab_size']
     self.bleu_source = os.path.join(temp_dir, 'bleu_source')
     self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
     self.orig_policy = tf.keras.mixed_precision.experimental.global_policy(
     )
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)

        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["is_tpu_pod"] = flags_obj.is_tpu_pod
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["num_parallel_calls"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params[
            "default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training

        if params["dtype"] == tf.float16:
            # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
            # like this. What if multiple instances of TransformerTask are created?
            # We should have a better way in the tf.keras.mixed_precision API of doing
            # this.
            policy = tf.keras.mixed_precision.experimental.Policy(
                "infer_float32_vars")
            tf.keras.mixed_precision.experimental.set_policy(policy)

        self.distribution_strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            if not params["static_batch"]:
                raise ValueError("TPU requires static batch for input data.")
        else:
            print("Running transformer with num_gpus =", num_gpus)

        if self.distribution_strategy:
            print("For training, using distribution strategy: ",
                  self.distribution_strategy)
        else:
            print("Not using any distribution strategy.")
 def setUp(self):
     temp_dir = self.get_temp_dir()
     FLAGS.model_dir = temp_dir
     FLAGS.init_logdir_timestamp = FIXED_TIMESTAMP
     FLAGS.param_set = param_set = "tiny"
     FLAGS.use_synthetic_data = True
     FLAGS.steps_per_epoch = 1
     FLAGS.validation_steps = 1
     FLAGS.train_epochs = 1
     FLAGS.batch_size = 8
     FLAGS.init_weight_path = None
     self.cur_log_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
     self.vocab_file = os.path.join(self.cur_log_dir, "vocab")
     self.vocab_size = misc.get_model_params(param_set, 0)["vocab_size"]
     self.bleu_source = os.path.join(self.cur_log_dir, "bleu_source")
     self.bleu_ref = os.path.join(self.cur_log_dir, "bleu_ref")
     self.flags_file = os.path.join(self.cur_log_dir, "flags")
Esempio n. 5
0
 def setUp(self):
   temp_dir = self.get_temp_dir()
   FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
   FLAGS.param_set = "tiny"
   FLAGS.use_synthetic_data = True
   FLAGS.steps_between_evals = 1
   FLAGS.train_steps = 2
   FLAGS.validation_steps = 1
   FLAGS.batch_size = 8
   FLAGS.num_gpus = 1
   FLAGS.distribution_strategy = "off"
   self.model_dir = FLAGS.model_dir
   self.temp_dir = temp_dir
   self.vocab_file = os.path.join(temp_dir, "vocab")
   self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
   self.bleu_source = os.path.join(temp_dir, "bleu_source")
   self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
 def setUp(self):
   temp_dir = self.get_temp_dir()
   FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
   FLAGS.param_set = "tiny"
   FLAGS.use_synthetic_data = True
   FLAGS.steps_between_evals = 1
   FLAGS.train_steps = 2
   FLAGS.validation_steps = 1
   FLAGS.batch_size = 8
   FLAGS.num_gpus = 1
   FLAGS.distribution_strategy = "off"
   FLAGS.dtype = "fp32"
   self.model_dir = FLAGS.model_dir
   self.temp_dir = temp_dir
   self.vocab_file = os.path.join(temp_dir, "vocab")
   self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
   self.bleu_source = os.path.join(temp_dir, "bleu_source")
   self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
   self.orig_policy = tf.keras.mixed_precision.experimental.global_policy()
Esempio n. 7
0
  def __init__(self, flags_obj):
    """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
    self.flags_obj = flags_obj
    self.predict_model = None

    # Add flag-defined parameters to params object
    num_gpus = flags_core.get_num_gpus(flags_obj)
    self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)

    params["num_gpus"] = num_gpus
    params["use_ctl"] = flags_obj.use_ctl
    params["data_dir"] = flags_obj.data_dir
    params["model_dir"] = flags_obj.model_dir
    params["static_batch"] = flags_obj.static_batch
    params["max_length"] = flags_obj.max_length
    params["decode_batch_size"] = flags_obj.decode_batch_size
    params["decode_max_length"] = flags_obj.decode_max_length
    params["padded_decode"] = flags_obj.padded_decode
    params["num_parallel_calls"] = (
        flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)

    params["use_synthetic_data"] = flags_obj.use_synthetic_data
    params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
    params["repeat_dataset"] = None
    params["dtype"] = flags_core.get_tf_dtype(flags_obj)
    params["enable_tensorboard"] = flags_obj.enable_tensorboard
    params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
    params["steps_between_evals"] = flags_obj.steps_between_evals

    num_workers = distribution_utils.configure_cluster(
        flags_obj.worker_hosts, flags_obj.task_index)

    self.distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=num_gpus,
        num_workers=num_workers,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu or "")
    if self.use_tpu:
      params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
      if not params["static_batch"]:
        raise ValueError("TPU requires static batch for input data.")
    else:
      logging.info("Running transformer with num_gpus = %d", num_gpus)

    if self.distribution_strategy:
      logging.info("For training, using distribution strategy: %s",
                   self.distribution_strategy)
    else:
      logging.info("Not using any distribution strategy.")

    if params["dtype"] == tf.float16:
      # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
      # like this. What if multiple instances of TransformerTask are created?
      # We should have a better way in the tf.keras.mixed_precision API of doing
      # this.
      loss_scale = flags_core.get_loss_scale(
          flags_obj, default_for_fp16="dynamic")
      policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
          "mixed_float16", loss_scale=loss_scale)
      tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    elif params["dtype"] == tf.bfloat16:
      policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
          "mixed_bfloat16")
      tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)