Esempio n. 1
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        # Initialize TPU System.
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    elif FLAGS.strategy_type == 'multi_worker_mirror':
        workers = ["localhost:2001", "localhost:2002"]
        task_index = int(sys.argv[1])
        os.environ['TF_CONFIG'] = json.dumps({
            'cluster': {
                # 'worker': ["b10g4.bigc.dbg.private:2001", "b10g5.bigc.dbg.private:2002"]
                'worker': workers
            },
            'task': {'type': 'worker', 'index': task_index}
        })
        strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
        Context.init_context(len(workers), task_index)
        logging.info(Context.get_is_init)
        logging.info(Context.get_num_task)
    else:
        raise ValueError('The distribution strategy type is not supported: %s' %
                         FLAGS.strategy_type)
    run_bert(strategy, input_meta_data)
Esempio n. 2
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if FLAGS.mode == 'export_only':
        export_squad(FLAGS.model_export_path, input_meta_data)
        return

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'multi_worker_mirror':
        strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        # Initialize TPU System.
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)
    if FLAGS.mode == 'train':
        train_squad(strategy, input_meta_data)
    if FLAGS.mode == 'predict':
        predict_squad(strategy, input_meta_data)
Esempio n. 3
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(
        data_utils.get_classification_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path)
    test_input_fn = functools.partial(data_utils.get_classification_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_steps)
    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    model_fn = functools.partial(get_classificationxlnet_model, model_config,
                                 run_config, FLAGS.n_class)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["n_class"] = FLAGS.n_class

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=get_metric_fn,
                         train_input_fn=train_input_fn,
                         test_input_fn=test_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Esempio n. 4
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)

    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = get_dataset_fn(FLAGS.train_data_path,
                                    max_seq_length,
                                    FLAGS.train_batch_size,
                                    is_training=True)
    eval_input_fn = get_dataset_fn(FLAGS.eval_data_path,
                                   max_seq_length,
                                   FLAGS.eval_batch_size,
                                   is_training=False)

    run_bert(strategy, input_meta_data, train_input_fn, eval_input_fn)
Esempio n. 5
0
    def _build_tpu_strategy(self):
        """Builds a TPUStrategy object."""

        tpu = self._strategy_config.tpu
        logging.info('Use TPU at %s', tpu if tpu is not None else '')
        cluster_resolver = tpu_lib.tpu_initialize(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)

        return strategy
Esempio n. 6
0
 def _init_strategy(self):
     """Initialize the distribution strategy (e.g. TPU/GPU/Mirrored)."""
     if self._strategy is None:
         if self._tpu is not None:
             resolver = tpu_lib.tpu_initialize(self._tpu)
             self._strategy = tf.distribute.experimental.TPUStrategy(
                 resolver)
         elif self._distribution_strategy is None or self._distribution_strategy == 'default':
             self._strategy = tf.distribute.get_strategy()
         elif self._distribution_strategy == 'cpu':
             self._strategy = tf.distribute.OneDeviceStrategy(
                 '/device:cpu:0')
         else:
             if self._distribution_strategy == 'mirrored':
                 self._strategy = tf.distribute.MirroredStrategy()
             else:
                 raise ValueError(
                     f'Invalid distribution strategy="{self._distribution_strategy}"'
                 )
Esempio n. 7
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)
    if strategy:
        print('***** Number of cores used : ', strategy.num_replicas_in_sync)

    run_bert_pretrain(strategy)
Esempio n. 8
0
def main(_):
  # Users should always run this script under TF 2.x
  assert tf.version.VERSION.startswith('2.')

  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

  strategy = None
  if FLAGS.strategy_type == 'mirror':
    strategy = tf.distribute.MirroredStrategy()
  elif FLAGS.strategy_type == 'tpu':
    # Initialize TPU System.
    cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
    strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
  else:
    raise ValueError('The distribution strategy type is not supported: %s' %
                     FLAGS.strategy_type)
  run_bert(strategy, input_meta_data)
Esempio n. 9
0
def get_distribution_strategy(distribution_strategy="default",
                              num_gpus=0,
                              num_workers=1,
                              all_reduce_alg=None,
                              num_packs=1,
                              tpu_address=None):
    """Return a DistributionStrategy for running the model.

  Args:
    distribution_strategy: a string specifying which distribution strategy to
      use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
      'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
      'off' means not to use Distribution Strategy; 'default' means to choose from
      `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
      according to the number of GPUs and number of workers. 'tpu' means to use
      TPUStrategy using `tpu_address`.
    num_gpus: Number of GPUs to run this model.
    num_workers: Number of workers to run this model.
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
      all-reduce. For `MirroredStrategy`, valid values are "nccl" and
      "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
      "ring" and "nccl".  If None, DistributionStrategy will choose based on
      device topology.
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
    tpu_address: Optional. String that represents TPU to connect to. Must not
      be None if `distribution_strategy` is set to `tpu`.
  Returns:
    tf.distribute.DistibutionStrategy object.
  Raises:
    ValueError: if `distribution_strategy` is 'off' or 'one_device' and
      `num_gpus` is larger than 1; or `num_gpus` is negative or if
      `distribution_strategy` is `tpu` but `tpu_address` is not specified.
  """
    if num_gpus < 0:
        raise ValueError("`num_gpus` can not be negative.")

    distribution_strategy = distribution_strategy.lower()
    if distribution_strategy == "off":
        if num_gpus > 1:
            raise ValueError(
                "When {} GPUs and  {} workers are specified, distribution_strategy "
                "flag cannot be set to 'off'.".format(num_gpus, num_workers))
        return None

    if distribution_strategy == "tpu":
        # When tpu_address is an empty string, we communicate with local TPUs.
        # Initialize TPU System.
        cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
        return tf.distribute.experimental.TPUStrategy(cluster_resolver)

    if distribution_strategy == "multi_worker_mirrored":
        return tf.distribute.experimental.MultiWorkerMirroredStrategy(
            communication=_collective_communication(all_reduce_alg))

    if (distribution_strategy == "one_device"
            or (distribution_strategy == "default" and num_gpus <= 1)):
        if num_gpus == 0:
            return tf.distribute.OneDeviceStrategy("device:CPU:0")
        else:
            if num_gpus > 1:
                raise ValueError(
                    "`OneDeviceStrategy` can not be used for more than "
                    "one device.")
            return tf.distribute.OneDeviceStrategy("device:GPU:0")

    if distribution_strategy in ("mirrored", "default"):
        if num_gpus == 0:
            assert distribution_strategy == "mirrored"
            devices = ["device:CPU:0"]
        else:
            devices = ["device:GPU:%d" % i for i in range(num_gpus)]
        return tf.distribute.MirroredStrategy(
            devices=devices,
            cross_device_ops=_mirrored_cross_device_ops(
                all_reduce_alg, num_packs))

    if distribution_strategy == "parameter_server":
        return tf.distribute.experimental.ParameterServerStrategy()

    raise ValueError("Unrecognized Distribution Strategy: %r" %
                     distribution_strategy)
Esempio n. 10
0
def main(unused_argv):
    del unused_argv
    num_hosts = 1
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
        topology = FLAGS.tpu_topology.split("x")
        total_num_core = 2 * int(topology[0]) * int(topology[1])
        num_hosts = total_num_core // FLAGS.num_core_per_host
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
        logging.info("***** Number of hosts used : %d", num_hosts)
    train_input_fn = functools.partial(
        data_utils.get_pretrain_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len,
        FLAGS.perm_size, FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict,
        FLAGS.bi_data, FLAGS.uncased, num_hosts)

    total_training_steps = FLAGS.train_steps
    steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
    steps_per_loop = FLAGS.iterations

    optimizer, learning_rate_fn = optimization.create_optimizer(
        init_lr=FLAGS.learning_rate,
        num_train_steps=total_training_steps,
        num_warmup_steps=FLAGS.warmup_steps,
        min_lr_ratio=FLAGS.min_lr_ratio,
        adam_epsilon=FLAGS.adam_epsilon,
        weight_decay_rate=FLAGS.weight_decay_rate)

    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                                 run_config)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=None,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         test_input_fn=None,
                         init_checkpoint=FLAGS.init_checkpoint,
                         total_training_steps=total_training_steps,
                         steps_per_epoch=steps_per_epoch,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Esempio n. 11
0
def main(unused_argv):
    del unused_argv
    use_remote_tpu = False
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        # Initialize TPU System.
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
        use_remote_tpu = True
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["predict_file"] = FLAGS.predict_file
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_feature_path"] = FLAGS.test_feature_path
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)

    logging.info("start reading pickle file...")
    with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
        eval_features = pickle.load(f)

    logging.info("finishing reading pickle file...")
    input_meta_data["eval_features"] = eval_features
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_steps, input_meta_data)

    with tf.device(get_primary_cpu_task(use_remote_tpu)):
        training_utils.train(strategy=strategy,
                             model_fn=model_fn,
                             input_meta_data=input_meta_data,
                             eval_fn=eval_fn,
                             metric_fn=None,
                             train_input_fn=train_input_fn,
                             test_input_fn=test_input_fn,
                             init_checkpoint=FLAGS.init_checkpoint,
                             total_training_steps=total_training_steps,
                             steps_per_epoch=steps_per_epoch,
                             steps_per_loop=steps_per_loop,
                             optimizer=optimizer,
                             learning_rate_fn=learning_rate_fn,
                             model_dir=FLAGS.model_dir)
Esempio n. 12
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)
    eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file,
                                                    is_training=False)
    if FLAGS.test_feature_path:
        logging.info("start reading pickle file...")
        with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
            eval_features = pickle.load(f)
        logging.info("finishing reading pickle file...")
    else:
        sp_model = spm.SentencePieceProcessor()
        sp_model.LoadFromSerializedProto(
            tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
        spm_basename = os.path.basename(FLAGS.spiece_model_file)
        eval_features = squad_utils.create_eval_data(
            spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
            FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)

    with tf.io.gfile.GFile(FLAGS.predict_file) as f:
        original_data = json.load(f)["data"]
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_examples, eval_features, original_data,
                                eval_steps, input_meta_data)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
def main(unused_argv):
  del unused_argv
  use_remote_tpu = False
  if FLAGS.strategy_type == "mirror":
    strategy = tf.distribute.MirroredStrategy()
  elif FLAGS.strategy_type == "tpu":
    # Initialize TPU System.
    cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
    strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    use_remote_tpu = True
  else:
    raise ValueError("The distribution strategy type is not supported: %s" %
                     FLAGS.strategy_type)
  if strategy:
    logging.info("***** Number of cores used : %d",
                 strategy.num_replicas_in_sync)
  train_input_fn = functools.partial(data_utils.get_classification_input_data,
                                     FLAGS.train_batch_size, FLAGS.seq_len,
                                     strategy, True, FLAGS.train_tfrecord_path)
  test_input_fn = functools.partial(data_utils.get_classification_input_data,
                                    FLAGS.test_batch_size, FLAGS.seq_len,
                                    strategy, False, FLAGS.test_tfrecord_path)

  total_training_steps = FLAGS.train_steps
  steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
  steps_per_loop = FLAGS.iterations
  eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
  eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                              eval_steps)
  optimizer, learning_rate_fn = optimization.create_optimizer(
      FLAGS.learning_rate,
      total_training_steps,
      FLAGS.warmup_steps,
      adam_epsilon=FLAGS.adam_epsilon)
  model_config = xlnet_config.XLNetConfig(FLAGS)
  run_config = xlnet_config.create_run_config(True, False, FLAGS)
  model_fn = functools.partial(get_classificationxlnet_model, model_config,
                               run_config, FLAGS.n_class)
  input_meta_data = {}
  input_meta_data["d_model"] = FLAGS.d_model
  input_meta_data["mem_len"] = FLAGS.mem_len
  input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                               strategy.num_replicas_in_sync)
  input_meta_data["n_layer"] = FLAGS.n_layer
  input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
  input_meta_data["n_class"] = FLAGS.n_class
  print("DEBUG: ", str(input_meta_data))

  def logits_init_fn():
    return tf.zeros(
        shape=(input_meta_data["batch_size_per_core"],
               input_meta_data["n_class"]),
        dtype=tf.float32)

  with tf.device(get_primary_cpu_task(use_remote_tpu)):
    training_utils.train(
        strategy=strategy,
        model_fn=model_fn,
        input_meta_data=input_meta_data,
        eval_fn=eval_fn,
        metric_fn=get_metric_fn,
        logits_init_fn=logits_init_fn,
        train_input_fn=train_input_fn,
        test_input_fn=test_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        total_training_steps=total_training_steps,
        steps_per_epoch=steps_per_epoch,
        steps_per_loop=steps_per_loop,
        optimizer=optimizer,
        learning_rate_fn=learning_rate_fn,
        model_dir=FLAGS.model_dir)
def main(unused_argv):
  del unused_argv
  num_hosts = 1
  if FLAGS.strategy_type == "mirror":
    strategy = tf.distribute.MirroredStrategy()
  elif FLAGS.strategy_type == "tpu":
    cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
    strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    topology = FLAGS.tpu_topology.split("x")
    total_num_core = 2 * int(topology[0]) * int(topology[1])
    num_hosts = total_num_core // FLAGS.num_core_per_host
  else:
    raise ValueError("The distribution strategy type is not supported: %s" %
                     FLAGS.strategy_type)
  if strategy:
    logging.info("***** Number of cores used : %d",
                 strategy.num_replicas_in_sync)
    logging.info("***** Number of hosts used : %d", num_hosts)
  online_masking_config = data_utils.OnlineMaskingConfig(
      sample_strategy=FLAGS.sample_strategy,
      max_num_tokens=FLAGS.max_num_tokens,
      min_num_tokens=FLAGS.min_num_tokens,
      max_num_words=FLAGS.max_num_words,
      min_num_words=FLAGS.min_num_words)

  train_input_fn = functools.partial(
      data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
      strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
      FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
      num_hosts)

  total_training_steps = FLAGS.train_steps

  steps_per_loop = FLAGS.iterations

  optimizer, learning_rate_fn = optimization.create_optimizer(
      init_lr=FLAGS.learning_rate,
      num_train_steps=total_training_steps,
      num_warmup_steps=FLAGS.warmup_steps,
      min_lr_ratio=FLAGS.min_lr_ratio,
      adam_epsilon=FLAGS.adam_epsilon,
      weight_decay_rate=FLAGS.weight_decay_rate)

  model_config = xlnet_config.XLNetConfig(FLAGS)
  run_config = xlnet_config.create_run_config(True, False, FLAGS)
  input_meta_data = {}
  input_meta_data["d_model"] = FLAGS.d_model
  input_meta_data["mem_len"] = FLAGS.mem_len
  input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                               strategy.num_replicas_in_sync)
  input_meta_data["n_layer"] = FLAGS.n_layer
  input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
  model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                               run_config)

  model = training_utils.train(
      strategy=strategy,
      model_fn=model_fn,
      input_meta_data=input_meta_data,
      eval_fn=None,
      metric_fn=None,
      train_input_fn=train_input_fn,
      init_checkpoint=FLAGS.init_checkpoint,
      init_from_transformerxl=FLAGS.init_from_transformerxl,
      total_training_steps=total_training_steps,
      steps_per_loop=steps_per_loop,
      optimizer=optimizer,
      learning_rate_fn=learning_rate_fn,
      model_dir=FLAGS.model_dir,
      save_steps=FLAGS.save_steps)

  # Export transformer-xl model checkpoint to be used in finetuning.
  checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
  saved_path = checkpoint.save(
      os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
  logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s",
               saved_path)