def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(
        data_utils.get_classification_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path)
    test_input_fn = functools.partial(data_utils.get_classification_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_steps)
    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    model_fn = functools.partial(get_classificationxlnet_model, model_config,
                                 run_config, FLAGS.n_class, FLAGS.summary_type)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["n_class"] = FLAGS.n_class

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=get_metric_fn,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Example #2
0
def main(unused_argv):
    del unused_argv
    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(
        data_utils.get_classification_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, True, FLAGS.train_tfrecord_path)
    test_input_fn = functools.partial(data_utils.get_classification_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_steps)
    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    model_fn = functools.partial(modeling.classification_model, model_config,
                                 run_config, FLAGS.n_class, FLAGS.summary_type)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["n_class"] = FLAGS.n_class

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=get_metric_fn,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Example #3
0
def main(unused_argv):
    del unused_argv
    num_hosts = 1
    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.strategy_type, tpu_address=FLAGS.tpu)
    if FLAGS.strategy_type == "tpu":
        num_hosts = strategy.extended.num_hosts
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
        logging.info("***** Number of hosts used : %d", num_hosts)
    online_masking_config = data_utils.OnlineMaskingConfig(
        sample_strategy=FLAGS.sample_strategy,
        max_num_tokens=FLAGS.max_num_tokens,
        min_num_tokens=FLAGS.min_num_tokens,
        max_num_words=FLAGS.max_num_words,
        min_num_words=FLAGS.min_num_words)

    train_input_fn = functools.partial(
        data_utils.get_pretrain_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len,
        FLAGS.perm_size, FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased,
        online_masking_config, num_hosts)

    total_training_steps = FLAGS.train_steps

    steps_per_loop = FLAGS.iterations

    optimizer, learning_rate_fn = optimization.create_optimizer(
        init_lr=FLAGS.learning_rate,
        num_train_steps=total_training_steps,
        num_warmup_steps=FLAGS.warmup_steps,
        min_lr_ratio=FLAGS.min_lr_ratio,
        adam_epsilon=FLAGS.adam_epsilon,
        weight_decay_rate=FLAGS.weight_decay_rate)

    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                                 run_config)

    model = training_utils.train(
        strategy=strategy,
        model_fn=model_fn,
        input_meta_data=input_meta_data,
        eval_fn=None,
        metric_fn=None,
        train_input_fn=train_input_fn,
        init_checkpoint=FLAGS.init_checkpoint,
        init_from_transformerxl=FLAGS.init_from_transformerxl,
        total_training_steps=total_training_steps,
        steps_per_loop=steps_per_loop,
        optimizer=optimizer,
        learning_rate_fn=learning_rate_fn,
        model_dir=FLAGS.model_dir,
        save_steps=FLAGS.save_steps)

    # Export transformer-xl model checkpoint to be used in finetuning.
    checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
    saved_path = checkpoint.save(
        os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
    logging.info(
        "Exporting the transformer-xl model as a new TF checkpoint: %s",
        saved_path)
Example #4
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)
    eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file,
                                                    is_training=False)
    if FLAGS.test_feature_path:
        logging.info("start reading pickle file...")
        with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
            eval_features = pickle.load(f)
        logging.info("finishing reading pickle file...")
    else:
        sp_model = spm.SentencePieceProcessor()
        sp_model.LoadFromSerializedProto(
            tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
        spm_basename = os.path.basename(FLAGS.spiece_model_file)
        eval_features = squad_utils.create_eval_data(
            spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
            FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)

    with tf.io.gfile.GFile(FLAGS.predict_file) as f:
        original_data = json.load(f)["data"]
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_examples, eval_features, original_data,
                                eval_steps, input_meta_data)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)