def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = None if FLAGS.strategy_type == 'mirror': strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == 'tpu': # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) elif FLAGS.strategy_type == 'multi_worker_mirror': workers = ["localhost:2001", "localhost:2002"] task_index = int(sys.argv[1]) os.environ['TF_CONFIG'] = json.dumps({ 'cluster': { # 'worker': ["b10g4.bigc.dbg.private:2001", "b10g5.bigc.dbg.private:2002"] 'worker': workers }, 'task': {'type': 'worker', 'index': task_index} }) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() Context.init_context(len(workers), task_index) logging.info(Context.get_is_init) logging.info(Context.get_num_task) else: raise ValueError('The distribution strategy type is not supported: %s' % FLAGS.strategy_type) run_bert(strategy, input_meta_data)
def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if FLAGS.mode == 'export_only': export_squad(FLAGS.model_export_path, input_meta_data) return strategy = None if FLAGS.strategy_type == 'mirror': strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == 'multi_worker_mirror': strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() elif FLAGS.strategy_type == 'tpu': # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( 'The distribution strategy type is not supported: %s' % FLAGS.strategy_type) if FLAGS.mode == 'train': train_squad(strategy, input_meta_data) if FLAGS.mode == 'predict': predict_squad(strategy, input_meta_data)
def main(unused_argv): del unused_argv if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( "The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial( data_utils.get_classification_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_classification_input_data, FLAGS.test_batch_size, FLAGS.seq_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_steps) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) model_fn = functools.partial(get_classificationxlnet_model, model_config, run_config, FLAGS.n_class) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["n_class"] = FLAGS.n_class training_utils.train(strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=get_metric_fn, train_input_fn=train_input_fn, test_input_fn=test_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)
def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = None if FLAGS.strategy_type == 'mirror': strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == 'tpu': cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( 'The distribution strategy type is not supported: %s' % FLAGS.strategy_type) max_seq_length = input_meta_data['max_seq_length'] train_input_fn = get_dataset_fn(FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) eval_input_fn = get_dataset_fn(FLAGS.eval_data_path, max_seq_length, FLAGS.eval_batch_size, is_training=False) run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
def _build_tpu_strategy(self): """Builds a TPUStrategy object.""" tpu = self._strategy_config.tpu logging.info('Use TPU at %s', tpu if tpu is not None else '') cluster_resolver = tpu_lib.tpu_initialize(tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) return strategy
def _init_strategy(self): """Initialize the distribution strategy (e.g. TPU/GPU/Mirrored).""" if self._strategy is None: if self._tpu is not None: resolver = tpu_lib.tpu_initialize(self._tpu) self._strategy = tf.distribute.experimental.TPUStrategy( resolver) elif self._distribution_strategy is None or self._distribution_strategy == 'default': self._strategy = tf.distribute.get_strategy() elif self._distribution_strategy == 'cpu': self._strategy = tf.distribute.OneDeviceStrategy( '/device:cpu:0') else: if self._distribution_strategy == 'mirrored': self._strategy = tf.distribute.MirroredStrategy() else: raise ValueError( f'Invalid distribution strategy="{self._distribution_strategy}"' )
def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = None if FLAGS.strategy_type == 'mirror': strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == 'tpu': cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( 'The distribution strategy type is not supported: %s' % FLAGS.strategy_type) if strategy: print('***** Number of cores used : ', strategy.num_replicas_in_sync) run_bert_pretrain(strategy)
def main(_): # Users should always run this script under TF 2.x assert tf.version.VERSION.startswith('2.') with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = None if FLAGS.strategy_type == 'mirror': strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == 'tpu': # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError('The distribution strategy type is not supported: %s' % FLAGS.strategy_type) run_bert(strategy, input_meta_data)
def get_distribution_strategy(distribution_strategy="default", num_gpus=0, num_workers=1, all_reduce_alg=None, num_packs=1, tpu_address=None): """Return a DistributionStrategy for running the model. Args: distribution_strategy: a string specifying which distribution strategy to use. Accepted values are 'off', 'default', 'one_device', 'mirrored', 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive. 'off' means not to use Distribution Strategy; 'default' means to choose from `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy` according to the number of GPUs and number of workers. 'tpu' means to use TPUStrategy using `tpu_address`. num_gpus: Number of GPUs to run this model. num_workers: Number of workers to run this model. all_reduce_alg: Optional. Specifies which algorithm to use when performing all-reduce. For `MirroredStrategy`, valid values are "nccl" and "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are "ring" and "nccl". If None, DistributionStrategy will choose based on device topology. num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce` or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`. tpu_address: Optional. String that represents TPU to connect to. Must not be None if `distribution_strategy` is set to `tpu`. Returns: tf.distribute.DistibutionStrategy object. Raises: ValueError: if `distribution_strategy` is 'off' or 'one_device' and `num_gpus` is larger than 1; or `num_gpus` is negative or if `distribution_strategy` is `tpu` but `tpu_address` is not specified. """ if num_gpus < 0: raise ValueError("`num_gpus` can not be negative.") distribution_strategy = distribution_strategy.lower() if distribution_strategy == "off": if num_gpus > 1: raise ValueError( "When {} GPUs and {} workers are specified, distribution_strategy " "flag cannot be set to 'off'.".format(num_gpus, num_workers)) return None if distribution_strategy == "tpu": # When tpu_address is an empty string, we communicate with local TPUs. # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(tpu_address) return tf.distribute.experimental.TPUStrategy(cluster_resolver) if distribution_strategy == "multi_worker_mirrored": return tf.distribute.experimental.MultiWorkerMirroredStrategy( communication=_collective_communication(all_reduce_alg)) if (distribution_strategy == "one_device" or (distribution_strategy == "default" and num_gpus <= 1)): if num_gpus == 0: return tf.distribute.OneDeviceStrategy("device:CPU:0") else: if num_gpus > 1: raise ValueError( "`OneDeviceStrategy` can not be used for more than " "one device.") return tf.distribute.OneDeviceStrategy("device:GPU:0") if distribution_strategy in ("mirrored", "default"): if num_gpus == 0: assert distribution_strategy == "mirrored" devices = ["device:CPU:0"] else: devices = ["device:GPU:%d" % i for i in range(num_gpus)] return tf.distribute.MirroredStrategy( devices=devices, cross_device_ops=_mirrored_cross_device_ops( all_reduce_alg, num_packs)) if distribution_strategy == "parameter_server": return tf.distribute.experimental.ParameterServerStrategy() raise ValueError("Unrecognized Distribution Strategy: %r" % distribution_strategy)
def main(unused_argv): del unused_argv num_hosts = 1 if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) topology = FLAGS.tpu_topology.split("x") total_num_core = 2 * int(topology[0]) * int(topology[1]) num_hosts = total_num_core // FLAGS.num_core_per_host else: raise ValueError( "The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) logging.info("***** Number of hosts used : %d", num_hosts) train_input_fn = functools.partial( data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict, FLAGS.bi_data, FLAGS.uncased, num_hosts) total_training_steps = FLAGS.train_steps steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size) steps_per_loop = FLAGS.iterations optimizer, learning_rate_fn = optimization.create_optimizer( init_lr=FLAGS.learning_rate, num_train_steps=total_training_steps, num_warmup_steps=FLAGS.warmup_steps, min_lr_ratio=FLAGS.min_lr_ratio, adam_epsilon=FLAGS.adam_epsilon, weight_decay_rate=FLAGS.weight_decay_rate) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate model_fn = functools.partial(get_pretrainxlnet_model, model_config, run_config) training_utils.train(strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=None, metric_fn=None, train_input_fn=train_input_fn, test_input_fn=None, init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)
def main(unused_argv): del unused_argv use_remote_tpu = False if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) use_remote_tpu = True else: raise ValueError( "The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.train_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.test_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size) steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["start_n_top"] = FLAGS.start_n_top input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["predict_dir"] = FLAGS.predict_dir input_meta_data["predict_file"] = FLAGS.predict_file input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["max_answer_length"] = FLAGS.max_answer_length input_meta_data["test_feature_path"] = FLAGS.test_feature_path input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["mem_len"] = FLAGS.mem_len model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, FLAGS.start_n_top, FLAGS.end_n_top) logging.info("start reading pickle file...") with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f: eval_features = pickle.load(f) logging.info("finishing reading pickle file...") input_meta_data["eval_features"] = eval_features eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_steps, input_meta_data) with tf.device(get_primary_cpu_task(use_remote_tpu)): training_utils.train(strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=None, train_input_fn=train_input_fn, test_input_fn=test_input_fn, init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir)
def main(unused_argv): del unused_argv if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( "The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.train_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.test_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["start_n_top"] = FLAGS.start_n_top input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["predict_dir"] = FLAGS.predict_dir input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["max_answer_length"] = FLAGS.max_answer_length input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["mem_len"] = FLAGS.mem_len model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, FLAGS.start_n_top, FLAGS.end_n_top) eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file, is_training=False) if FLAGS.test_feature_path: logging.info("start reading pickle file...") with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f: eval_features = pickle.load(f) logging.info("finishing reading pickle file...") else: sp_model = spm.SentencePieceProcessor() sp_model.LoadFromSerializedProto( tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read()) spm_basename = os.path.basename(FLAGS.spiece_model_file) eval_features = squad_utils.create_eval_data( spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased) with tf.io.gfile.GFile(FLAGS.predict_file) as f: original_data = json.load(f)["data"] eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_examples, eval_features, original_data, eval_steps, input_meta_data) training_utils.train(strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=None, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)
def main(unused_argv): del unused_argv use_remote_tpu = False if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": # Initialize TPU System. cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) use_remote_tpu = True else: raise ValueError("The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_classification_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_classification_input_data, FLAGS.test_batch_size, FLAGS.seq_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size) steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_steps) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) model_fn = functools.partial(get_classificationxlnet_model, model_config, run_config, FLAGS.n_class) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["n_class"] = FLAGS.n_class print("DEBUG: ", str(input_meta_data)) def logits_init_fn(): return tf.zeros( shape=(input_meta_data["batch_size_per_core"], input_meta_data["n_class"]), dtype=tf.float32) with tf.device(get_primary_cpu_task(use_remote_tpu)): training_utils.train( strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=get_metric_fn, logits_init_fn=logits_init_fn, train_input_fn=train_input_fn, test_input_fn=test_input_fn, init_checkpoint=FLAGS.init_checkpoint, total_training_steps=total_training_steps, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir)
def main(unused_argv): del unused_argv num_hosts = 1 if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) topology = FLAGS.tpu_topology.split("x") total_num_core = 2 * int(topology[0]) * int(topology[1]) num_hosts = total_num_core // FLAGS.num_core_per_host else: raise ValueError("The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) logging.info("***** Number of hosts used : %d", num_hosts) online_masking_config = data_utils.OnlineMaskingConfig( sample_strategy=FLAGS.sample_strategy, max_num_tokens=FLAGS.max_num_tokens, min_num_tokens=FLAGS.min_num_tokens, max_num_words=FLAGS.max_num_words, min_num_words=FLAGS.min_num_words) train_input_fn = functools.partial( data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config, num_hosts) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations optimizer, learning_rate_fn = optimization.create_optimizer( init_lr=FLAGS.learning_rate, num_train_steps=total_training_steps, num_warmup_steps=FLAGS.warmup_steps, min_lr_ratio=FLAGS.min_lr_ratio, adam_epsilon=FLAGS.adam_epsilon, weight_decay_rate=FLAGS.weight_decay_rate) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["d_model"] = FLAGS.d_model input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate model_fn = functools.partial(get_pretrainxlnet_model, model_config, run_config) model = training_utils.train( strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=None, metric_fn=None, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps) # Export transformer-xl model checkpoint to be used in finetuning. checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model) saved_path = checkpoint.save( os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt")) logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s", saved_path)