예제 #1
0
 def setUp(self):
   temp_dir = self.get_temp_dir()
   if TransformerTaskTest.local_flags is None:
     misc.define_transformer_flags()
     # Loads flags, array cannot be blank.
     flags.FLAGS(['foo'])
     TransformerTaskTest.local_flags = flagsaver.save_flag_values()
   else:
     flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
   FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
   FLAGS.param_set = 'tiny'
   FLAGS.use_synthetic_data = True
   FLAGS.steps_between_evals = 1
   FLAGS.train_steps = 2
   FLAGS.validation_steps = 1
   FLAGS.batch_size = 8
   FLAGS.max_length = 1
   FLAGS.num_gpus = 1
   FLAGS.distribution_strategy = 'off'
   FLAGS.dtype = 'fp32'
   self.model_dir = FLAGS.model_dir
   self.temp_dir = temp_dir
   self.vocab_file = os.path.join(temp_dir, 'vocab')
   self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
   self.bleu_source = os.path.join(temp_dir, 'bleu_source')
   self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
   self.orig_policy = (
       tf.compat.v2.keras.mixed_precision.experimental.global_policy())
  def __init__(self, flags_obj):
    """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
    self.flags_obj = flags_obj
    self.predict_model = None

    # Add flag-defined parameters to params object
    num_gpus = flags_core.get_num_gpus(flags_obj)
    self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)

    params["num_gpus"] = num_gpus
    params["use_ctl"] = flags_obj.use_ctl
    params["data_dir"] = flags_obj.data_dir
    params["model_dir"] = flags_obj.model_dir
    params["static_batch"] = flags_obj.static_batch
    params["max_length"] = flags_obj.max_length
    params["decode_batch_size"] = flags_obj.decode_batch_size
    params["decode_max_length"] = flags_obj.decode_max_length
    params["padded_decode"] = flags_obj.padded_decode
    params["num_parallel_calls"] = (
        flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)

    params["use_synthetic_data"] = flags_obj.use_synthetic_data
    params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
    params["repeat_dataset"] = None
    params["dtype"] = flags_core.get_tf_dtype(flags_obj)
    params["enable_tensorboard"] = flags_obj.enable_tensorboard
    params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
    params["steps_between_evals"] = flags_obj.steps_between_evals
    params["enable_checkpointing"] = flags_obj.enable_checkpointing

    self.distribution_strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=num_gpus,
        all_reduce_alg=flags_obj.all_reduce_alg,
        num_packs=flags_obj.num_packs,
        tpu_address=flags_obj.tpu or "")
    if self.use_tpu:
      params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
      if not params["static_batch"]:
        raise ValueError("TPU requires static batch for input data.")
    else:
      logging.info("Running transformer with num_gpus = %d", num_gpus)

    if self.distribution_strategy:
      logging.info("For training, using distribution strategy: %s",
                   self.distribution_strategy)
    else:
      logging.info("Not using any distribution strategy.")

    performance.set_mixed_precision_policy(
        params["dtype"],
        flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
예제 #3
0
    def __init__(self, flags_obj):
        """Init function of TransformerMain.

    Args:
      flags_obj: Object containing parsed flag values, i.e., FLAGS.

    Raises:
      ValueError: if not using static batch for input data on TPU.
    """
        self.flags_obj = flags_obj
        self.predict_model = None

        # Add flag-defined parameters to params object
        num_gpus = flags_core.get_num_gpus(flags_obj)
        self.params = params = misc.get_model_params(flags_obj.param_set,
                                                     num_gpus)

        params["num_gpus"] = num_gpus
        params["use_ctl"] = flags_obj.use_ctl
        params["data_dir"] = flags_obj.data_dir
        params["model_dir"] = flags_obj.model_dir
        params["static_batch"] = flags_obj.static_batch
        params["max_length"] = flags_obj.max_length
        params["decode_batch_size"] = flags_obj.decode_batch_size
        params["decode_max_length"] = flags_obj.decode_max_length
        params["padded_decode"] = flags_obj.padded_decode
        params["num_parallel_calls"] = (flags_obj.num_parallel_calls
                                        or tf.data.experimental.AUTOTUNE)

        params["use_synthetic_data"] = flags_obj.use_synthetic_data
        params["batch_size"] = flags_obj.batch_size or params[
            "default_batch_size"]
        params["repeat_dataset"] = None
        params["dtype"] = flags_core.get_tf_dtype(flags_obj)
        params["enable_tensorboard"] = flags_obj.enable_tensorboard
        params[
            "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
        params["steps_between_evals"] = flags_obj.steps_between_evals

        self.distribution_strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy=flags_obj.distribution_strategy,
            num_gpus=num_gpus,
            all_reduce_alg=flags_obj.all_reduce_alg,
            num_packs=flags_obj.num_packs,
            tpu_address=flags_obj.tpu or "")
        if self.use_tpu:
            params[
                "num_replicas"] = self.distribution_strategy.num_replicas_in_sync
            if not params["static_batch"]:
                raise ValueError("TPU requires static batch for input data.")
        else:
            logging.info("Running transformer with num_gpus = %d", num_gpus)

        if self.distribution_strategy:
            logging.info("For training, using distribution strategy: %s",
                         self.distribution_strategy)
        else:
            logging.info("Not using any distribution strategy.")

        if params["dtype"] == tf.float16:
            # TODO(reedwm): It's pretty ugly to set the global policy in a constructor
            # like this. What if multiple instances of TransformerTask are created?
            # We should have a better way in the tf.keras.mixed_precision API of doing
            # this.
            loss_scale = flags_core.get_loss_scale(flags_obj,
                                                   default_for_fp16="dynamic")
            policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
                "mixed_float16", loss_scale=loss_scale)
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

        elif params["dtype"] == tf.bfloat16:
            policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
                "mixed_bfloat16")
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)