def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj)) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["num_parallel_calls"] = ( flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["repeat_dataset"] = None
def setUp(self): temp_dir = self.get_temp_dir() if TransformerTaskTest.local_flags is None: misc.define_transformer_flags() # Loads flags, array cannot be blank. flags.FLAGS(['foo']) TransformerTaskTest.local_flags = flagsaver.save_flag_values() else: flagsaver.restore_flag_values(TransformerTaskTest.local_flags) FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.param_set = 'tiny' FLAGS.use_synthetic_data = True FLAGS.steps_between_evals = 1 FLAGS.train_steps = 2 FLAGS.validation_steps = 1 FLAGS.batch_size = 8 FLAGS.num_gpus = 1 FLAGS.distribution_strategy = 'off' FLAGS.dtype = 'fp32' self.model_dir = FLAGS.model_dir self.temp_dir = temp_dir self.vocab_file = os.path.join(temp_dir, 'vocab') self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size'] self.bleu_source = os.path.join(temp_dir, 'bleu_source') self.bleu_ref = os.path.join(temp_dir, 'bleu_ref') self.orig_policy = tf.keras.mixed_precision.experimental.global_policy( )
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["is_tpu_pod"] = flags_obj.is_tpu_pod params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["num_parallel_calls"] = (flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params[ "default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params[ "enable_metrics_in_training"] = flags_obj.enable_metrics_in_training if params["dtype"] == tf.float16: # TODO(reedwm): It's pretty ugly to set the global policy in a constructor # like this. What if multiple instances of TransformerTask are created? # We should have a better way in the tf.keras.mixed_precision API of doing # this. policy = tf.keras.mixed_precision.experimental.Policy( "infer_float32_vars") tf.keras.mixed_precision.experimental.set_policy(policy) self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, tpu_address=flags_obj.tpu or "") if self.use_tpu: if not params["static_batch"]: raise ValueError("TPU requires static batch for input data.") else: print("Running transformer with num_gpus =", num_gpus) if self.distribution_strategy: print("For training, using distribution strategy: ", self.distribution_strategy) else: print("Not using any distribution strategy.")
def setUp(self): temp_dir = self.get_temp_dir() FLAGS.model_dir = temp_dir FLAGS.init_logdir_timestamp = FIXED_TIMESTAMP FLAGS.param_set = param_set = "tiny" FLAGS.use_synthetic_data = True FLAGS.steps_per_epoch = 1 FLAGS.validation_steps = 1 FLAGS.train_epochs = 1 FLAGS.batch_size = 8 FLAGS.init_weight_path = None self.cur_log_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) self.vocab_file = os.path.join(self.cur_log_dir, "vocab") self.vocab_size = misc.get_model_params(param_set, 0)["vocab_size"] self.bleu_source = os.path.join(self.cur_log_dir, "bleu_source") self.bleu_ref = os.path.join(self.cur_log_dir, "bleu_ref") self.flags_file = os.path.join(self.cur_log_dir, "flags")
def setUp(self): temp_dir = self.get_temp_dir() FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.param_set = "tiny" FLAGS.use_synthetic_data = True FLAGS.steps_between_evals = 1 FLAGS.train_steps = 2 FLAGS.validation_steps = 1 FLAGS.batch_size = 8 FLAGS.num_gpus = 1 FLAGS.distribution_strategy = "off" self.model_dir = FLAGS.model_dir self.temp_dir = temp_dir self.vocab_file = os.path.join(temp_dir, "vocab") self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"] self.bleu_source = os.path.join(temp_dir, "bleu_source") self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
def setUp(self): temp_dir = self.get_temp_dir() FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.param_set = "tiny" FLAGS.use_synthetic_data = True FLAGS.steps_between_evals = 1 FLAGS.train_steps = 2 FLAGS.validation_steps = 1 FLAGS.batch_size = 8 FLAGS.num_gpus = 1 FLAGS.distribution_strategy = "off" FLAGS.dtype = "fp32" self.model_dir = FLAGS.model_dir self.temp_dir = temp_dir self.vocab_file = os.path.join(temp_dir, "vocab") self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"] self.bleu_source = os.path.join(temp_dir, "bleu_source") self.bleu_ref = os.path.join(temp_dir, "bleu_ref") self.orig_policy = tf.keras.mixed_precision.experimental.global_policy()
def __init__(self, flags_obj): """Init function of TransformerMain. Args: flags_obj: Object containing parsed flag values, i.e., FLAGS. Raises: ValueError: if not using static batch for input data on TPU. """ self.flags_obj = flags_obj self.predict_model = None # Add flag-defined parameters to params object num_gpus = flags_core.get_num_gpus(flags_obj) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) params["num_gpus"] = num_gpus params["use_ctl"] = flags_obj.use_ctl params["data_dir"] = flags_obj.data_dir params["model_dir"] = flags_obj.model_dir params["static_batch"] = flags_obj.static_batch params["max_length"] = flags_obj.max_length params["decode_batch_size"] = flags_obj.decode_batch_size params["decode_max_length"] = flags_obj.decode_max_length params["padded_decode"] = flags_obj.padded_decode params["num_parallel_calls"] = ( flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) params["use_synthetic_data"] = flags_obj.use_synthetic_data params["batch_size"] = flags_obj.batch_size or params["default_batch_size"] params["repeat_dataset"] = None params["dtype"] = flags_core.get_tf_dtype(flags_obj) params["enable_tensorboard"] = flags_obj.enable_tensorboard params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training params["steps_between_evals"] = flags_obj.steps_between_evals num_workers = distribution_utils.configure_cluster( flags_obj.worker_hosts, flags_obj.task_index) self.distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=num_gpus, num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu or "") if self.use_tpu: params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync if not params["static_batch"]: raise ValueError("TPU requires static batch for input data.") else: logging.info("Running transformer with num_gpus = %d", num_gpus) if self.distribution_strategy: logging.info("For training, using distribution strategy: %s", self.distribution_strategy) else: logging.info("Not using any distribution strategy.") if params["dtype"] == tf.float16: # TODO(reedwm): It's pretty ugly to set the global policy in a constructor # like this. What if multiple instances of TransformerTask are created? # We should have a better way in the tf.keras.mixed_precision API of doing # this. loss_scale = flags_core.get_loss_scale( flags_obj, default_for_fp16="dynamic") policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( "mixed_float16", loss_scale=loss_scale) tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) elif params["dtype"] == tf.bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( "mixed_bfloat16") tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)