示例#1
0
def main(argv):
    config = epl.Config({"cluster.colocate_split_and_replicate": True})
    epl.init(config)
    FLAGS.worker_id = epl.Env.get().cluster.worker_index
    FLAGS.worker_gpu = epl.Env.get().cluster.total_gpu_num
    epl.set_default_strategy(epl.replicate(FLAGS.worker_gpu))

    # Create HParams.
    if argv:
        set_hparams_from_args(argv[1:])
    if FLAGS.schedule != "run_std_server":
        hparams = create_hparams()

    if FLAGS.schedule == "train":
        mlperf_log.transformer_print(key=mlperf_log.RUN_START)
    else:
        raise RuntimeError(
            "Support training tasks only for now, you can define tasks in other modes."
        )
    trainer_lib.set_random_seed(FLAGS.random_seed)

    hparams.add_hparam("data_dir", FLAGS.data_dir)
    hparams.add_hparam("schedule", FLAGS.schedule)
    hparams.add_hparam("train_steps", FLAGS.train_steps)
    hparams.add_hparam("warm_start_from", None)
    trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

    # Dataset generation.
    if FLAGS.generate_data:
        generate_data()

    def model_fn_replicate(features, labels, mode):
        model_fn = t2t_model.T2TModel.make_estimator_model_fn(
            FLAGS.model, hparams)
        return model_fn(features, labels, mode)

    if is_chief():
        save_metadata(hparams)

    estimator = tf.estimator.Estimator(model_fn=model_fn_replicate,
                                       config=create_run_config())
    hooks = []
    hooks.append(
        tf.train.StepCounterHook(every_n_steps=FLAGS.log_step_count_steps))

    optimize.log_variable_sizes(verbose=True)

    problem = hparams.problem
    train_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.TRAIN, hparams)

    estimator.train(train_input_fn, max_steps=hparams.train_steps, hooks=hooks)
示例#2
0
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""EPL bert pipeline example."""
import tensorflow as tf
from run_squad import FLAGS, main  # pylint: disable=unused-import
import epl

if __name__ == "__main__":
    config_json = {}
    if FLAGS.gc:
        config_json["gradient_checkpoint.type"] = "auto"
    if FLAGS.amp:
        config_json["amp.level"] = "o1"
        config_json["amp.loss_scale"] = 128
        config_json["amp.debug_log"] = True
    config_json["pipeline.num_micro_batch"] = FLAGS.num_micro_batch
    epl.init(epl.Config(config_json))
    tf.app.run()
示例#3
0
                                          num_classes=None,
                                          is_training=True)[0]
        features = tf.squeeze(features, [1, 2])

    with epl.split(total_gpu_num):
        logits = tf.layers.dense(features, class_num)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                      logits=logits)

    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(learning_rate=0.9)
    train_op = optimizer.minimize(loss, global_step=global_step)

    hooks = [tf.train.StopAtStepHook(last_step=20)]
    with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
        while not sess.should_stop():
            starttime = time.time()
            _, _, step = sess.run([loss, train_op, global_step])
            endtime = time.time()
            tf.logging.info("[Iteration {} ], Time: {:.4} .".format(
                step, endtime - starttime))
    tf.logging.info("[Finished]")


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    config = epl.Config({"cluster.colocate_split_and_replicate": True})
    epl.init(config)
    _total_gpu_num = epl.Env.get().cluster.total_gpu_num
    run_model(_total_gpu_num)