def main(argv): config = epl.Config({"cluster.colocate_split_and_replicate": True}) epl.init(config) FLAGS.worker_id = epl.Env.get().cluster.worker_index FLAGS.worker_gpu = epl.Env.get().cluster.total_gpu_num epl.set_default_strategy(epl.replicate(FLAGS.worker_gpu)) # Create HParams. if argv: set_hparams_from_args(argv[1:]) if FLAGS.schedule != "run_std_server": hparams = create_hparams() if FLAGS.schedule == "train": mlperf_log.transformer_print(key=mlperf_log.RUN_START) else: raise RuntimeError( "Support training tasks only for now, you can define tasks in other modes." ) trainer_lib.set_random_seed(FLAGS.random_seed) hparams.add_hparam("data_dir", FLAGS.data_dir) hparams.add_hparam("schedule", FLAGS.schedule) hparams.add_hparam("train_steps", FLAGS.train_steps) hparams.add_hparam("warm_start_from", None) trainer_lib.add_problem_hparams(hparams, FLAGS.problem) # Dataset generation. if FLAGS.generate_data: generate_data() def model_fn_replicate(features, labels, mode): model_fn = t2t_model.T2TModel.make_estimator_model_fn( FLAGS.model, hparams) return model_fn(features, labels, mode) if is_chief(): save_metadata(hparams) estimator = tf.estimator.Estimator(model_fn=model_fn_replicate, config=create_run_config()) hooks = [] hooks.append( tf.train.StepCounterHook(every_n_steps=FLAGS.log_step_count_steps)) optimize.log_variable_sizes(verbose=True) problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams) estimator.train(train_input_fn, max_steps=hparams.train_steps, hooks=hooks)
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """EPL bert pipeline example.""" import tensorflow as tf from run_squad import FLAGS, main # pylint: disable=unused-import import epl if __name__ == "__main__": config_json = {} if FLAGS.gc: config_json["gradient_checkpoint.type"] = "auto" if FLAGS.amp: config_json["amp.level"] = "o1" config_json["amp.loss_scale"] = 128 config_json["amp.debug_log"] = True config_json["pipeline.num_micro_batch"] = FLAGS.num_micro_batch epl.init(epl.Config(config_json)) tf.app.run()
num_classes=None, is_training=True)[0] features = tf.squeeze(features, [1, 2]) with epl.split(total_gpu_num): logits = tf.layers.dense(features, class_num) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=0.9) train_op = optimizer.minimize(loss, global_step=global_step) hooks = [tf.train.StopAtStepHook(last_step=20)] with tf.train.MonitoredTrainingSession(hooks=hooks) as sess: while not sess.should_stop(): starttime = time.time() _, _, step = sess.run([loss, train_op, global_step]) endtime = time.time() tf.logging.info("[Iteration {} ], Time: {:.4} .".format( step, endtime - starttime)) tf.logging.info("[Finished]") if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) config = epl.Config({"cluster.colocate_split_and_replicate": True}) epl.init(config) _total_gpu_num = epl.Env.get().cluster.total_gpu_num run_model(_total_gpu_num)