Ejemplo n.º 1
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if FLAGS.mode == 'export_only':
        export_squad(FLAGS.model_export_path, input_meta_data)
        return

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'multi_worker_mirror':
        strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)
    if FLAGS.mode in ('train', 'train_and_predict'):
        train_squad(strategy, input_meta_data)
    if FLAGS.mode in ('predict', 'train_and_predict'):
        predict_squad(strategy, input_meta_data)
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    # with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    #     input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    #
    # Configuration stuff
    #
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = FLAGS.num_train_epochs
    # train_data_size = 11778
    train_data_size = 5000
    steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)  # 368
    warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
    initial_lr = FLAGS.learning_rate

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError('The distribution strategy type is not supported: %s' %
                         FLAGS.strategy_type)

    #
    # Modeling and training
    #

    # the model
    def _get_supervised_model():
        supervised_model, core_model = (
            bert_models.classifier_model(
                bert_config,
                float_type=tf.float32,
                num_labels=1
                max_seq_length=128))
Ejemplo n.º 3
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)
    if strategy:
        print('***** Number of cores used : ', strategy.num_replicas_in_sync)

    run_bert_pretrain(strategy)
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError('The distribution strategy type is not supported: %s' %
                         FLAGS.strategy_type)
    run_bert(strategy, input_meta_data)
Ejemplo n.º 5
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.1')
    tf.random.set_seed(FLAGS.seed)

    # with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    #     input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    #
    # Configuration stuff
    #
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = FLAGS.num_train_epochs
    # train_data_size = 11778
    train_data_size = 5000
    steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)  # 368
    warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
    initial_lr = FLAGS.learning_rate

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)

    #
    # Modeling and training
    #

    # the model
    def _get_dragon_model(do_masking):
        if not FLAGS.fixed_feature_baseline:
            dragon_model, core_model = (bert_models.dragon_model(
                bert_config,
                max_seq_length=FLAGS.max_seq_length,
                binary_outcome=True,
                use_unsup=do_masking,
                max_predictions_per_seq=20,
                unsup_scale=1.))
        else:
            dragon_model, core_model = bert_models.derpy_dragon_baseline(
                bert_config,
                max_seq_length=FLAGS.max_seq_length,
                binary_outcome=True)

        # WARNING: the original optimizer causes a bug where loss increases after first epoch
        # dragon_model.optimizer = optimization.create_optimizer(
        #     FLAGS.train_batch_size * initial_lr, steps_per_epoch * epochs, warmup_steps)
        dragon_model.optimizer = tf.keras.optimizers.SGD(
            learning_rate=FLAGS.train_batch_size * initial_lr)
        return dragon_model, core_model

    if FLAGS.mode == 'train_and_predict':
        # training. strategy.scope context allows use of multiple devices
        with strategy.scope():
            keras_train_data = make_dataset(is_training=True,
                                            do_masking=FLAGS.do_masking)

            dragon_model, core_model = _get_dragon_model(FLAGS.do_masking)
            optimizer = dragon_model.optimizer

            if FLAGS.init_checkpoint:
                checkpoint = tf.train.Checkpoint(model=core_model)
                checkpoint.restore(
                    FLAGS.init_checkpoint).assert_existing_objects_matched()

            latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            if latest_checkpoint:
                dragon_model.load_weights(latest_checkpoint)

            dragon_model.compile(optimizer=optimizer,
                                 loss={
                                     'g': 'binary_crossentropy',
                                     'q0': 'binary_crossentropy',
                                     'q1': 'binary_crossentropy'
                                 },
                                 loss_weights={
                                     'g': FLAGS.treatment_loss_weight,
                                     'q0': 0.1,
                                     'q1': 0.1
                                 },
                                 weighted_metrics=make_dragonnet_metrics())

            summary_callback = tf.keras.callbacks.TensorBoard(FLAGS.model_dir,
                                                              update_freq=128)
            checkpoint_dir = os.path.join(FLAGS.model_dir,
                                          'model_checkpoint.{epoch:02d}')
            checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
                checkpoint_dir, save_weights_only=True, period=10)

            callbacks = [summary_callback, checkpoint_callback]

            dragon_model.fit(
                x=keras_train_data,
                # validation_data=evaluation_dataset,
                steps_per_epoch=steps_per_epoch,
                epochs=epochs,
                # vailidation_steps=eval_steps,
                callbacks=callbacks)

        # save a final model checkpoint (so we can restore weights into model w/o training idiosyncracies)
        if FLAGS.model_export_path:
            model_export_path = FLAGS.model_export_path
        else:
            model_export_path = os.path.join(FLAGS.model_dir,
                                             'trained/dragon.ckpt')

        checkpoint = tf.train.Checkpoint(model=dragon_model)
        saved_path = checkpoint.save(model_export_path)
    else:
        saved_path = FLAGS.saved_path

    # make predictions and write to file

    # create data and model w/o masking
    eval_data = make_dataset(is_training=False, do_masking=False)
    dragon_model, core_model = _get_dragon_model(do_masking=False)
    # reload the model weights (necessary because we've obliterated the masking)
    checkpoint = tf.train.Checkpoint(model=dragon_model)
    checkpoint.restore(saved_path).assert_existing_objects_matched()
    # loss added as simple hack to bizzarre keras bug that requires compile for predict, and a loss for compile
    dragon_model.add_loss(lambda: 0)
    dragon_model.compile()

    outputs = dragon_model.predict(x=eval_data)

    out_dict = {}
    out_dict['g'] = outputs[0].squeeze()
    out_dict['q0'] = outputs[1].squeeze()
    out_dict['q1'] = outputs[2].squeeze()

    predictions = pd.DataFrame(out_dict)

    label_dataset = eval_data.map(lambda f, l: l)
    data_df = dataset_to_pandas_df(label_dataset)

    outs = data_df.join(predictions)
    with tf.io.gfile.GFile(FLAGS.prediction_file, "w") as writer:
        writer.write(outs.to_csv(sep="\t"))
Ejemplo n.º 6
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["predict_file"] = FLAGS.predict_file
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_feature_path"] = FLAGS.test_feature_path
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)

    logging.info("start reading pickle file...")
    with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
        eval_features = pickle.load(f)

    logging.info("finishing reading pickle file...")
    input_meta_data["eval_features"] = eval_features
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_steps, input_meta_data)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         test_input_fn=test_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         total_training_steps=total_training_steps,
                         steps_per_epoch=steps_per_epoch,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir)
def main(unused_argv):
    del unused_argv
    num_hosts = 1
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
        topology = FLAGS.tpu_topology.split("x")
        total_num_core = 2 * int(topology[0]) * int(topology[1])
        num_hosts = total_num_core // FLAGS.num_core_per_host
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
        logging.info("***** Number of hosts used : %d", num_hosts)
    train_input_fn = functools.partial(
        data_utils.get_pretrain_input_data, FLAGS.train_batch_size,
        FLAGS.seq_len, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len,
        FLAGS.perm_size, FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict,
        FLAGS.bi_data, FLAGS.uncased, num_hosts)

    total_training_steps = FLAGS.train_steps
    steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
    steps_per_loop = FLAGS.iterations

    optimizer, learning_rate_fn = optimization.create_optimizer(
        init_lr=FLAGS.learning_rate,
        num_train_steps=total_training_steps,
        num_warmup_steps=FLAGS.warmup_steps,
        min_lr_ratio=FLAGS.min_lr_ratio,
        adam_epsilon=FLAGS.adam_epsilon,
        weight_decay_rate=FLAGS.weight_decay_rate)

    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["d_model"] = FLAGS.d_model
    input_meta_data["mem_len"] = FLAGS.mem_len
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["n_layer"] = FLAGS.n_layer
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    model_fn = functools.partial(get_pretrainxlnet_model, model_config,
                                 run_config)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=None,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         test_input_fn=None,
                         init_checkpoint=FLAGS.init_checkpoint,
                         total_training_steps=total_training_steps,
                         steps_per_epoch=steps_per_epoch,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Ejemplo n.º 8
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.1')

    # with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    #     input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'
    #
    # Configuration stuff
    #
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = FLAGS.num_train_epochs
    train_data_size = 11778  # todo: fix hardcording
    steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
    initial_lr = FLAGS.learning_rate

    strategy = None
    if FLAGS.strategy_type == 'mirror':
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == 'tpu':
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy_type)

    #
    # Modeling and training
    #

    num_treatments = FLAGS.num_treatments
    missing_outcomes = FLAGS.missing_outcomes

    # the model
    def _get_hydra_model(do_masking):
        hydra_model, core_model = (bert_models.hydra_model(
            bert_config,
            max_seq_length=FLAGS.max_seq_length,
            binary_outcome=True,
            num_treatments=num_treatments,
            missing_outcomes=missing_outcomes,
            use_unsup=do_masking,
            max_predictions_per_seq=20,
            unsup_scale=1.))

        # WARNING: the original optimizer causes a bug where loss increases after first epoch
        # hydra_model.optimizer = optimization.create_optimizer(
        #     FLAGS.train_batch_size * initial_lr, steps_per_epoch * epochs, warmup_steps)
        hydra_model.optimizer = tf.keras.optimizers.SGD(
            learning_rate=FLAGS.train_batch_size * initial_lr)

        return hydra_model, core_model

    # training. strategy.scope context allows use of multiple devices
    with strategy.scope():
        train_data = make_dataset(tf_record_files=FLAGS.input_files,
                                  is_training=True,
                                  num_treatments=num_treatments,
                                  missing_outcomes=missing_outcomes,
                                  do_masking=FLAGS.do_masking)
        eval_data = make_dataset(tf_record_files=FLAGS.input_files,
                                 is_training=True,
                                 is_eval=True,
                                 num_treatments=num_treatments,
                                 missing_outcomes=missing_outcomes,
                                 do_masking=FLAGS.do_masking)

        hydra_model, core_model = _get_hydra_model(FLAGS.do_masking)
        optimizer = hydra_model.optimizer
        print(hydra_model.summary())

        if FLAGS.init_checkpoint:
            checkpoint = tf.train.Checkpoint(model=core_model)
            checkpoint.restore(
                FLAGS.init_checkpoint).assert_existing_objects_matched()

        print("loss construction reached")
        t0 = time.time()
        if not missing_outcomes:
            losses = {'g': tf.keras.losses.SparseCategoricalCrossentropy()}
            loss_weights = {'g': 1.0}
        else:
            losses = {
                'g0': tf.keras.losses.SparseCategoricalCrossentropy(),
                'g1': tf.keras.losses.SparseCategoricalCrossentropy(),
                'y_is_obs': tf.keras.losses.BinaryCrossentropy()
            }
            loss_weights = {'g0': 1.0, 'g1': 1.0, 'y_is_obs': 1.0}

        for treat in range(num_treatments):
            losses[f"q{treat}"] = tf.keras.losses.BinaryCrossentropy()
            loss_weights[f"q{treat}"] = 0.1

        t1 = time.time()
        print(f"Loss construction completed: {t1-t0}")

        print("make metrics reached")
        t0 = time.time()
        hydra_metrics = make_hydra_metrics(num_treatments, missing_outcomes)
        t1 = time.time()
        print(f"metrics construction completed: {t1-t0}")

        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_checkpoint:
            hydra_model.load_weights(latest_checkpoint)

        print("Compile reached")
        t0 = time.time()
        hydra_model.compile(optimizer=optimizer,
                            loss=losses,
                            loss_weights=loss_weights,
                            weighted_metrics=hydra_metrics)
        t1 = time.time()
        print(f"Compile completed: {t1-t0}")
        summary_callback = tf.keras.callbacks.TensorBoard(FLAGS.model_dir,
                                                          update_freq=128)
        checkpoint_dir = os.path.join(FLAGS.model_dir,
                                      'model_checkpoint.{epoch:02d}')
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            checkpoint_dir, save_weights_only=True)

        callbacks = [summary_callback, checkpoint_callback]

        hydra_model.fit(x=train_data,
                        validation_data=eval_data,
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        validation_steps=256,
                        callbacks=callbacks)

    # save a final model checkpoint (so we can restore weights into model w/o training idiosyncracies)
    hydra_model.optimizer = None
    if FLAGS.model_export_path:
        model_export_path = FLAGS.model_export_path
    else:
        model_export_path = os.path.join(FLAGS.model_dir, 'trained/hydra.ckpt')

    checkpoint = tf.train.Checkpoint(model=hydra_model)
    saved_path = checkpoint.save(model_export_path)

    # make predictions and write to file
    # NOTE: theory suggests we should make predictions on heldout data ("cross fitting" or "sample splitting")
    # but our experiments showed best results by just reusing the data
    # You can accommodate sample splitting by using the splitting arguments for the dataset_ creation

    eval_data = make_dataset(FLAGS.input_files,
                             is_training=False,
                             do_masking=False,
                             num_treatments=num_treatments,
                             missing_outcomes=missing_outcomes)

    hydra_model, core_model = _get_hydra_model(do_masking=False)
    checkpoint = tf.train.Checkpoint(model=hydra_model)
    checkpoint.restore(saved_path).assert_existing_objects_matched()
    hydra_model.compile(
    )  # seems to erratically cause bugs to omit this? very puzzling

    outputs = hydra_model.predict(x=eval_data)

    out_dict = {}
    if missing_outcomes:

        for t, g0 in enumerate(tf.unstack(outputs[0], axis=-1)):
            out_dict['g0_' + str(t)] = g0.numpy()

        for t, g1 in enumerate(tf.unstack(outputs[1], axis=-1)):
            out_dict['g1_' + str(t)] = g1.numpy()

        out_dict['prob_y_obs'] = np.squeeze(outputs[2])

        for out, q in enumerate(outputs[3:]):
            out_dict['q' + str(out)] = np.squeeze(q)

    else:

        for t, g in enumerate(tf.unstack(outputs[0], axis=-1)):
            out_dict['g_' + str(t)] = g.numpy()

        for out, q in enumerate(outputs[1:]):
            out_dict['q' + str(out)] = np.squeeze(q)

    predictions = pd.DataFrame(out_dict)

    label_dataset = eval_data.map(lambda f, l: l)
    data_df = dataset_to_pandas_df(label_dataset)

    outs = data_df.join(predictions)
    with tf.io.gfile.GFile(FLAGS.prediction_file, "w") as writer:
        writer.write(outs.to_csv(sep="\t"))
def get_distribution_strategy(distribution_strategy="default",
                              num_gpus=0,
                              num_workers=1,
                              all_reduce_alg=None,
                              num_packs=1,
                              tpu_address=None):
    """Return a DistributionStrategy for running the model.

  Args:
    distribution_strategy: a string specifying which distribution strategy to
      use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
      'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
      'off' means not to use Distribution Strategy; 'default' means to choose from
      `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
      according to the number of GPUs and number of workers. 'tpu' means to use
      TPUStrategy using `tpu_address`.
    num_gpus: Number of GPUs to run this model.
    num_workers: Number of workers to run this model.
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
      all-reduce. For `MirroredStrategy`, valid values are "nccl" and
      "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
      "ring" and "nccl".  If None, DistributionStrategy will choose based on
      device topology.
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
    tpu_address: Optional. String that represents TPU to connect to. Must not
      be None if `distribution_strategy` is set to `tpu`.
  Returns:
    tf.distribute.DistibutionStrategy object.
  Raises:
    ValueError: if `distribution_strategy` is 'off' or 'one_device' and
      `num_gpus` is larger than 1; or `num_gpus` is negative or if
      `distribution_strategy` is `tpu` but `tpu_address` is not specified.
  """
    if num_gpus < 0:
        raise ValueError("`num_gpus` can not be negative.")

    distribution_strategy = distribution_strategy.lower()
    if distribution_strategy == "off":
        if num_gpus > 1:
            raise ValueError(
                "When {} GPUs and  {} workers are specified, distribution_strategy "
                "flag cannot be set to 'off'.".format(num_gpus, num_workers))
        return None

    if distribution_strategy == "tpu":
        # When tpu_address is an empty string, we communicate with local TPUs.
        cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
        return tf.distribute.experimental.TPUStrategy(cluster_resolver)

    if distribution_strategy == "multi_worker_mirrored":
        return tf.distribute.experimental.MultiWorkerMirroredStrategy(
            communication=_collective_communication(all_reduce_alg))

    if (distribution_strategy == "one_device"
            or (distribution_strategy == "default" and num_gpus <= 1)):
        if num_gpus == 0:
            return tf.distribute.OneDeviceStrategy("device:CPU:0")
        else:
            if num_gpus > 1:
                raise ValueError(
                    "`OneDeviceStrategy` can not be used for more than "
                    "one device.")
            return tf.distribute.OneDeviceStrategy("device:GPU:0")

    if distribution_strategy in ("mirrored", "default"):
        if num_gpus == 0:
            assert distribution_strategy == "mirrored"
            devices = ["device:CPU:0"]
        else:
            devices = ["device:GPU:%d" % i for i in range(num_gpus)]
        return tf.distribute.MirroredStrategy(
            devices=devices,
            cross_device_ops=_mirrored_cross_device_ops(
                all_reduce_alg, num_packs))

    if distribution_strategy == "parameter_server":
        return tf.distribute.experimental.ParameterServerStrategy()

    raise ValueError("Unrecognized Distribution Strategy: %r" %
                     distribution_strategy)