def _default_run_finish_fn(success_status): if not success_status: mlp_log.mlperf_print("run_stop", None, metadata={"status": "failure"}) tf.logging.info("Retrieving embedding vars and writing stats.") runner.retrieve_embedding_vars()
def learning_rate_schedule(current_epoch): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After 5 epochs we reach the base learning rate (scaled to account for batch size). After 30, 60 and 80 epochs the learning rate is divided by 10. After 90 epochs training stops and the LR is set to 0. This ensures that we train for exactly 90 epochs for reproducibility. Args: current_epoch: `Tensor` for current epoch. Returns: A scaled `Tensor` for current learning rate. """ mlp_log.mlperf_print('lars_opt_base_learning_rate', FLAGS.base_learning_rate) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) decay_rate = (scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1]) for mult, start_epoch in LR_SCHEDULE: decay_rate = tf.where(current_epoch < start_epoch, decay_rate, scaled_lr * mult) return decay_rate
def poly_rate_schedule(current_epoch, poly_rate=0.0): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account for batch size). The learning rate is then decayed using a polynomial rate decay schedule with power 2.0. Args: current_epoch: `Tensor` for current epoch. poly_rate: Polynomial decay rate. Returns: A scaled `Tensor` for current learning rate. """ batch_size = FLAGS.train_batch_size if batch_size < 16384: plr = 10.0 w_epochs = 5 elif batch_size < 32768: plr = 25.0 w_epochs = 5 else: plr = 31.2 w_epochs = 25 # Override default poly learning rate and warmup epochs if poly_rate > 0.0: plr = poly_rate if FLAGS.lars_base_learning_rate > 0.0: plr = FLAGS.lars_base_learning_rate if FLAGS.lars_warmup_epochs > 0: w_epochs = FLAGS.lars_warmup_epochs mlp_log.mlperf_print('lars_opt_base_learning_rate', plr) mlp_log.mlperf_print('lars_opt_learning_rate_warmup_epochs', w_epochs) end_lr = 0.0001 mlp_log.mlperf_print('lars_opt_end_learning_rate', end_lr) wrate = (plr * current_epoch / w_epochs) w_steps = (w_epochs * FLAGS.num_train_images // batch_size) min_step = tf.constant(1, dtype=tf.int64) global_step = tf.train.get_or_create_global_step() decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps)) mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps', FLAGS.train_steps - w_steps + 1) mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0) poly_rate = tf.train.polynomial_decay(plr, decay_steps, FLAGS.train_steps - w_steps + 1, end_lr, power=2.0) decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate) return decay_rate
def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 })
def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 })
def eval_init_fn(cur_step): """Executed before every eval.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": cur_step + FLAGS.iterations_per_loop, "epoch_count": FLAGS.iterations_per_loop })
def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None): eval_num = cur_step // FLAGS.steps_between_evals mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": eval_num + 1}) mlp_log.mlperf_print("block_stop", None, metadata={"first_epoch_num": eval_num + 1}) tf.logging.info( "== Eval finished (step {}). Computing metric..".format(cur_step)) results_np = np.array(eval_output["results"]) results_np = np.reshape(results_np, (-1, 2)) predictions_np = results_np[:, 0].astype(np.float32) targets_np = results_np[:, 1].astype(np.int32) roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np) roc_auc = roc_obj.ComputeRocAuc() tf.logging.info("== Eval shape: {}. AUC = {:.4f}".format( predictions_np.shape, roc_auc)) success = roc_auc >= _ACCURACY_THRESH mlp_log.mlperf_print("eval_accuracy", roc_auc, metadata={"epoch_num": eval_num + 1}) if success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) if summary_writer: summary_writer.add_summary( utils.create_scalar_summary("auc", roc_auc), global_step=cur_step + FLAGS.steps_between_evals) eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc)) return success
def _default_eval_init_fn(cur_step): """Logging statements executed before every eval.""" eval_num = cur_step // FLAGS.steps_between_evals tf.logging.info("== Block {}. Step {} of {}".format( eval_num + 1, cur_step, FLAGS.train_steps)) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": eval_num + 1, "epoch_count": 1 }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": eval_num + 1})
def init_lars_optimizer(current_epoch): """Initialize the LARS Optimizer.""" lars_epsilon = FLAGS.lars_epsilon mlp_log.mlperf_print('lars_epsilon', lars_epsilon) learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate) optimizer = contrib_opt.LARSOptimizer( learning_rate, momentum=FLAGS.momentum, weight_decay=FLAGS.weight_decay, skip_list=['batch_normalization', 'bias'], epsilon=lars_epsilon) return optimizer
def learning_rate_schedule(peak_learning_rate, lr_warmup_init, lr_warmup_step, first_lr_drop_step, second_lr_drop_step, global_step): """Handles linear scaling rule, gradual warmup, and LR decay.""" # lr_warmup_init is the starting learning rate; the learning rate is linearly # scaled up to the full learning rate after `lr_warmup_step` before decaying. mlp_log.mlperf_print(key='opt_learning_rate_decay_factor', value=0.1) mlp_log.mlperf_print('opt_learning_rate_decay_steps', (first_lr_drop_step, second_lr_drop_step)) linear_warmup = (lr_warmup_init + (tf.cast(global_step, dtype=tf.float32) / lr_warmup_step * (peak_learning_rate - lr_warmup_init))) learning_rate = tf.where(global_step < lr_warmup_step, linear_warmup, peak_learning_rate) lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step], [0.01, second_lr_drop_step]] for mult, start_global_step in lr_schedule: learning_rate = tf.where(global_step < start_global_step, learning_rate, peak_learning_rate * mult) return learning_rate
def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch eval_accuracy = float(np.sum( eval_output['total_correct'])) / FLAGS.num_eval_images if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy) ]), cur_step) mlp_log.mlperf_print( 'eval_accuracy', eval_accuracy, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) if eval_accuracy >= FLAGS.stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True else: return False
def eval_finish_fn(cur_step, eval_output, _): """Callback function that's executed after each eval.""" if eval_steps == 0: return False # Concat eval_output as eval_output is a list from each host. for key in eval_output: eval_output[key] = np.concatenate(eval_output[key], axis=0) steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) eval_multiprocess.eval_multiprocessing( eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT) mlp_log.mlperf_print('eval_start', None, metadata={'epoch_num': cur_epoch + 1}) _, eval_results = eval_metric.evaluate() mlp_log.mlperf_print('eval_accuracy', { 'BBOX': float(eval_results['AP']), 'SEGM': float(eval_results['mask_AP']) }, metadata={'epoch_num': cur_epoch + 1}) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': cur_epoch + 1}) if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True return False
def get_learning_rate(self, global_step): """Sets up learning rate schedule.""" learning_rate = lr_policy.learning_rate_schedule( self.params['learning_rate'], self.params['lr_warmup_init'], self.params['lr_warmup_step'], self.params['first_lr_drop_step'], self.params['second_lr_drop_step'], global_step) mlp_log.mlperf_print(key='opt_base_learning_rate', value=self.params['learning_rate']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps', value=self.params['lr_warmup_step']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor', value=self.params['learning_rate'] / self.params['lr_warmup_step']) return learning_rate
def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init): """Start the infeed.""" time.sleep(150) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 1, "epoch_count": 1 }) for i in range(self.hparams.max_train_epochs): tf.logging.info("Infeed for epoch: %d", i + 1) sess.run(eval_init) sess.run([train_enqueue_ops]) sess.run([eval_enqueue_ops])
def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" global run_steps global masked_lm_accuracy cur_step_corrected = cur_step + FLAGS.iterations_per_loop run_steps = cur_step_corrected masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"] masked_lm_weighted_count = eval_output["masked_lm_weighted_count"] masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum( masked_lm_weighted_count) # the eval_output may mix up the order of the two arrays # swap the order if it did got mix up if masked_lm_accuracy > 1: masked_lm_accuracy = 1 / masked_lm_accuracy if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="masked_lm_accuracy", simple_value=masked_lm_accuracy) ]), cur_step_corrected) mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": cur_step_corrected, }) # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("eval_accuracy", float(masked_lm_accuracy), metadata={"epoch_num": cur_step_corrected}) if (masked_lm_accuracy >= FLAGS.stop_threshold and cur_step_corrected >= FLAGS.iterations_per_loop * 6): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) return True else: return False
def main(unused_argv): def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch eval_accuracy = float(np.sum( eval_output['total_correct'])) / FLAGS.num_eval_images if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy) ]), cur_step) mlp_log.mlperf_print( 'eval_accuracy', eval_accuracy, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) if eval_accuracy >= FLAGS.stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True else: return False def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) mlp_log.mlperf_print('run_final', None) low_level_runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.iterations_per_loop, FLAGS.train_steps, int(math.ceil(FLAGS.num_eval_images / FLAGS.eval_batch_size)), FLAGS.num_replicas) mlp_log.mlperf_print('cache_clear', True) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size) mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay) mlp_log.mlperf_print('lars_opt_momentum', FLAGS.momentum) mlp_log.mlperf_print('submission_benchmark', 'resnet') mlp_log.mlperf_print('submission_division', 'closed') mlp_log.mlperf_print('submission_org', 'google') mlp_log.mlperf_print('submission_platform', 'tpu-v3-%d' % FLAGS.num_replicas) mlp_log.mlperf_print('submission_status', 'research') assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', ( 'Invalid value for --precision flag; must be bfloat16 or float32.') input_dtype = tf.bfloat16 if FLAGS.precision == 'bfloat16' else tf.float32 cache_decoded_image = True if FLAGS.num_replicas > 2048 else False imagenet_train, imagenet_eval = [ imagenet_input.get_input_fn( # pylint: disable=g-complex-comprehension FLAGS.data_dir, is_training, input_dtype, FLAGS.image_size, FLAGS.input_partition_dims is None, cache_decoded_image=cache_decoded_image) for is_training in [True, False] ] low_level_runner.initialize(imagenet_train, imagenet_eval, resnet_model_fn, FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.input_partition_dims) mlp_log.mlperf_print('train_samples', FLAGS.num_train_images) mlp_log.mlperf_print('eval_samples', FLAGS.num_eval_images) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
def train_and_predict(self): """Run the predict loop on the TPU device.""" self.sess.run([self.compile_op]) # Train and eval thread. def train_eval_thread_fn(sess, train_eval_op): tf.logging.info("train_eval_op start") sess.run([train_eval_op]) train_eval_thread = threading.Thread(target=train_eval_thread_fn, args=(self.sess, self.train_eval_op)) train_eval_thread.start() # Infeed thread. def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init): """Start the infeed.""" time.sleep(150) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 1, "epoch_count": 1 }) for i in range(self.hparams.max_train_epochs): tf.logging.info("Infeed for epoch: %d", i + 1) sess.run(eval_init) sess.run([train_enqueue_ops]) sess.run([eval_enqueue_ops]) infeed_thread = threading.Thread(target=infeed_thread_fn, args=(self.sess, self.enqueue_ops, self.eval_enqueue_ops, self.eval_dataset_initializer)) infeed_thread.start() if self.eval_steps > 0: eval_state = {"run_success": False, "score": 0.0} for epoch in range(self.hparams.max_train_epochs): predictions = list(self.predict()) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": epoch + 1}) current_step = epoch * self.iterations eval_state["score"] = metric.get_metric( self.hparams, predictions, current_step) tf.logging.info("Score after epoch %d: %f", epoch, eval_state["score"]) mlp_log.mlperf_print("eval_accuracy", eval_state["score"] / 100.0, metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": epoch + 1, "epoch_count": 1 }) if eval_state["score"] >= self.hparams.target_bleu: eval_state["run_success"] = True mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) break mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": epoch + 2, "epoch_count": 1 }) if not eval_state["run_success"]: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) infeed_thread.join() train_eval_thread.join() if self.eval_steps > 0: return eval_state["score"], current_step else: return None, None
def main(argv): del argv # Unused. # TODO(b/132208296): remove this workaround that uses control flow v2. control_flow_util.ENABLE_CONTROL_FLOW_V2 = True # Parse hparams hparams = mask_rcnn_params.default_hparams() hparams.parse(FLAGS.hparams) params = dict( hparams.values(), transpose_input=False if FLAGS.input_partition_dims is not None else True, resnet_checkpoint=FLAGS.resnet_checkpoint, val_json_file=FLAGS.val_json_file, num_cores_per_replica=int(np.prod(FLAGS.input_partition_dims)) if FLAGS.input_partition_dims else 1, replicas_per_host=FLAGS.replicas_per_host) # MLPerf logging. mlp_log.mlperf_print(key='cache_clear', value=True) mlp_log.mlperf_print(key='init_start', value=None) mlp_log.mlperf_print(key='global_batch_size', value=FLAGS.train_batch_size) mlp_log.mlperf_print(key='train_samples', value=FLAGS.num_examples_per_epoch) mlp_log.mlperf_print(key='eval_samples', value=FLAGS.eval_samples) mlp_log.mlperf_print(key='min_image_size', value=params['short_side_image_size']) mlp_log.mlperf_print(key='max_image_size', value=params['long_side_max_image_size']) mlp_log.mlperf_print(key='num_image_candidates', value=params['rpn_post_nms_topn']) train_steps = (FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size) eval_steps = int( math.ceil(float(FLAGS.eval_samples) / FLAGS.eval_batch_size)) if eval_steps > 0: # The eval dataset is not evenly divided. Adding step by one will make sure # all eval samples are covered. # TODO(b/151732586): regenerate the eval dataset to make all hosts get the # same amount of work. eval_steps += 1 runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.num_examples_per_epoch // FLAGS.train_batch_size, train_steps, eval_steps, FLAGS.num_shards) train_input_fn = dataloader.InputReader(FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN, use_fake_data=FLAGS.use_fake_data) eval_input_fn = functools.partial( dataloader.InputReader(FLAGS.validation_file_pattern, mode=tf.estimator.ModeKeys.PREDICT, distributed_eval=True), num_examples=eval_steps * FLAGS.eval_batch_size) eval_metric = coco_metric.EvaluationMetric(FLAGS.val_json_file, use_cpp_extension=True) def init_fn(): if FLAGS.resnet_checkpoint: tf.train.init_from_checkpoint(FLAGS.resnet_checkpoint, {'resnet/': 'resnet50/'}) runner.initialize(train_input_fn, eval_input_fn, mask_rcnn_model.MaskRcnnModelFn(params), FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.input_partition_dims, init_fn, params=params) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) def eval_finish_fn(cur_step, eval_output, _): """Callback function that's executed after each eval.""" if eval_steps == 0: return False # Concat eval_output as eval_output is a list from each host. for key in eval_output: eval_output[key] = np.concatenate(eval_output[key], axis=0) steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) eval_multiprocess.eval_multiprocessing( eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT) mlp_log.mlperf_print('eval_start', None, metadata={'epoch_num': cur_epoch + 1}) _, eval_results = eval_metric.evaluate() mlp_log.mlperf_print('eval_accuracy', { 'BBOX': float(eval_results['AP']), 'SEGM': float(eval_results['mask_AP']) }, metadata={'epoch_num': cur_epoch + 1}) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': cur_epoch + 1}) if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True return False def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
def run_main(flags, default_hparams, estimator_fn): """Run main.""" # Job jobid = flags.jobid utils.print_out("# Job id %d" % jobid) # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) tf.set_random_seed(random_seed) mlp_log.mlperf_print("cache_clear", True) mlp_log.mlperf_print("init_start", None) mlp_log.mlperf_print("submission_benchmark", "resnet") mlp_log.mlperf_print("submission_division", "closed") mlp_log.mlperf_print("submission_org", "google") mlp_log.mlperf_print("submission_platform", "tpu-v3-%d" % FLAGS.num_shards) mlp_log.mlperf_print("submission_status", "research") mlp_log.mlperf_print("global_batch_size", FLAGS.batch_size) mlp_log.mlperf_print("opt_learning_rate_alt_decay_func", "True") mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate) mlp_log.mlperf_print("opt_learning_rate_decay_interval", FLAGS.decay_interval) mlp_log.mlperf_print("opt_learning_rate_decay_factor", FLAGS.decay_factor) mlp_log.mlperf_print("opt_learning_rate_decay_steps", FLAGS.decay_steps) mlp_log.mlperf_print("opt_learning_rate_remain_steps", FLAGS.decay_start) mlp_log.mlperf_print("opt_learning_rate_alt_warmup_func", FLAGS.warmup_scheme) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.warmup_steps) mlp_log.mlperf_print( "max_sequence_length", FLAGS.src_max_len, metadata={"method": "discard"}) mlp_log.mlperf_print("train_samples", FLAGS.num_examples_per_epoch) mlp_log.mlperf_print("eval_samples", FLAGS.examples_to_infer) # Model output directory out_dir = flags.out_dir if out_dir and not tf.gfile.Exists(out_dir): utils.print_out("# Creating output directory %s ..." % out_dir) tf.gfile.MakeDirs(out_dir) # Load hparams. hparams = create_or_load_hparams(default_hparams, flags.hparams_path) # Train or Evaluation return estimator_fn(hparams)
def main(unused_argv): """Run the reinforcement learning loop.""" logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('[%(asctime)s] %(message)s', '%Y-%m-%d %H:%M:%S') # ML Perf Logging. mlp_log.mlperf_print('cache_clear', True) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print(key='train_batch_size', value=FLAGS.training_batch_size) mlp_log.mlperf_print(key='filter_amount', value=FLAGS.filter_amount) mlp_log.mlperf_print(key='window_size', value=FLAGS.window_size) mlp_log.mlperf_print(key='lr_boundaries', value=str(FLAGS.lr_boundaries).strip('[]')) mlp_log.mlperf_print(key='lr_rates', value=str(FLAGS.lr_rates).strip('[]')) mlp_log.mlperf_print(key='opt_weight_decay', value=FLAGS.l2_strength) mlp_log.mlperf_print(key='min_selfplay_games_per_generation', value=FLAGS.mlperf_num_games) mlp_log.mlperf_print(key='train_samples', value=FLAGS.mlperf_num_games) mlp_log.mlperf_print(key='eval_samples', value=FLAGS.mlperf_num_games) mlp_log.mlperf_print(key='num_readouts', value=FLAGS.mlperf_num_readouts) mlp_log.mlperf_print(key='value_init_penalty', value=FLAGS.mlperf_value_init_penalty) mlp_log.mlperf_print(key='holdout_pct', value=FLAGS.mlperf_holdout_pct) mlp_log.mlperf_print(key='disable_resign_pct', value=FLAGS.mlperf_disable_resign_pct) mlp_log.mlperf_print(key='resign_threshold', value=(sum(FLAGS.mlperf_resign_threshold) / len(FLAGS.mlperf_resign_threshold))) mlp_log.mlperf_print(key='parallel_games', value=FLAGS.mlperf_parallel_games) mlp_log.mlperf_print(key='virtual_losses', value=FLAGS.mlperf_virtual_losses) mlp_log.mlperf_print(key='gating_win_rate', value=FLAGS.mlperf_gating_win_rate) mlp_log.mlperf_print(key='eval_games', value=FLAGS.mlperf_eval_games) for handler in logger.handlers: handler.setFormatter(formatter) # The training loop must be bootstrapped; either by running bootstrap.sh # to generate training data from random games, or by running # copy_checkpoint.sh to copy an already generated checkpoint. model_dirs = list_selfplay_dirs(FLAGS.selfplay_dir) iteration_model_names = [] if not model_dirs: raise RuntimeError( 'Couldn\'t find any selfplay games under %s. Either bootstrap.sh ' 'or init_from_checkpoint.sh must be run before the train loop is ' 'started') model_num = int(os.path.basename(model_dirs[0])) tpu_name = FLAGS.tpu_name.split(':')[0] session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) timeout_run_options = tf.RunOptions( timeout_in_ms=FLAGS.worker_reset_timeout_ms) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) with minigo_utils.logged_timer('Total time'): state = State(model_num) while state.iter_num < FLAGS.iterations: state.iter_num += 1 iteration_model_names.append(state.train_model_name) mlp_log.mlperf_print(key='epoch_start', value=None, metadata={'epoch_num': state.iter_num}) train_once(state) mlp_log.mlperf_print(key='epoch_stop', value=None, metadata={'epoch_num': state.iter_num}) mlp_log.mlperf_print(key='save_model', value='{iteration_num: ' + str(state.iter_num) + ' }') # In the case where iterations are fast, TPUEstimator can deadlock # between iterations on TPU Init. We attempt to manually make sure # the worker can Init with deadlines so we don't get stuck. while True: try: tf.logging.info('Attempting to shutdown worker.') gc.collect() with tf.Graph().as_default(): with tf.Session(tpu_name, config=session_config) as sess: sess.run(tf.tpu.shutdown_system(job='tpu_worker'), options=timeout_run_options) tf.logging.info('Attempting to initialize worker.') with tf.Graph().as_default(): with tf.Session(tpu_name, config=session_config) as sess: init_result = sess.run( tf.tpu.initialize_system(job='tpu_worker'), options=timeout_run_options) if init_result: tf.logging.info('Worker reset.') break except tf.errors.DeadlineExceededError: pass with tf.gfile.GFile(FLAGS.abort_file_path, 'w') as f: f.write('abort') total_file_count = 0 for iteration_model_name in iteration_model_names: total_file_count = total_file_count + len( tf.io.gfile.glob(FLAGS.selfplay_dir + '/' + iteration_model_name + '/*/*/*')) mlp_log.mlperf_print(key='actual_selfplay_games_per_generation', value=int(total_file_count / len(iteration_model_names)))
def run_finish_fn(success): if not success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) mlp_log.mlperf_print("run_final", None)
def run_model(params, eval_init_fn=None, eval_finish_fn=None, run_finish_fn=None): """Run the DLRM model, using a pre-defined configuration. Args: params: HPTuner object that provides new params for the trial. eval_init_fn: Lambda to run at start of eval. None means use the default. eval_finish_fn: Lambda for end of eval. None means use the default. run_finish_fn: Lambda for end of execution. None means use the default. Returns: A list of tuples, each entry describing the eval metric for one eval. Each tuple entry is (global_step, metric_value). """ mlp_log.mlperf_print(key="cache_clear", value=True) mlp_log.mlperf_print(key="init_start", value=None) mlp_log.mlperf_print("global_batch_size", params["batch_size"]) mlp_log.mlperf_print("train_samples", _NUM_TRAIN_EXAMPLES) mlp_log.mlperf_print("eval_samples", _NUM_EVAL_EXAMPLES) mlp_log.mlperf_print("opt_base_learning_rate", params["learning_rate"]) mlp_log.mlperf_print("sgd_opt_base_learning_rate", params["learning_rate"]) mlp_log.mlperf_print("sgd_opt_learning_rate_decay_poly_power", 2) mlp_log.mlperf_print("sgd_opt_learning_rate_decay_steps", params["decay_steps"]) mlp_log.mlperf_print("lr_decay_start_steps", params["decay_start_step"]) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", params["lr_warmup_steps"]) # Used for vizier. List of tuples. Each entry is (global_step, auc_metric). eval_metrics = [(0, 0.0)] feature_config = fc.FeatureConfig(params) (feature_to_config_dict, table_to_config_dict) = feature_config.get_feature_tbl_config() opt_params = { "sgd": tpu_embedding.StochasticGradientDescentParameters( learning_rate=params["learning_rate"]), "adagrad": tpu_embedding.AdagradParameters( learning_rate=params["learning_rate"], initial_accumulator=params["adagrad_init_accum"]) } embedding = tpu_embedding.TPUEmbedding( table_to_config_dict, feature_to_config_dict, params["batch_size"], mode=tpu_embedding.TRAINING, optimization_parameters=opt_params[params["optimizer"]], partition_strategy="mod", pipeline_execution_with_tensor_core=FLAGS.pipeline_execution, master=FLAGS.master) runner = dlrm_embedding_runner.DLRMEmbeddingRunner( iterations_per_loop=FLAGS.steps_between_evals, train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, num_replicas=FLAGS.num_tpu_shards, sparse_features_key="cat-features", embedding=embedding) train_input_fn, eval_input_fn = get_input_fns(params, feature_config) runner.initialize(train_input_fn, eval_input_fn, functools.partial(dlrm.dlrm_llr_model_fn, params, feature_config), params["batch_size"], params["eval_batch_size"], train_has_labels=False, eval_has_labels=False) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) def _default_eval_init_fn(cur_step): """Logging statements executed before every eval.""" eval_num = cur_step // FLAGS.steps_between_evals tf.logging.info("== Block {}. Step {} of {}".format( eval_num + 1, cur_step, FLAGS.train_steps)) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": eval_num + 1, "epoch_count": 1 }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": eval_num + 1}) def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None): eval_num = cur_step // FLAGS.steps_between_evals mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": eval_num + 1}) mlp_log.mlperf_print("block_stop", None, metadata={"first_epoch_num": eval_num + 1}) tf.logging.info( "== Eval finished (step {}). Computing metric..".format(cur_step)) results_np = np.array(eval_output["results"]) results_np = np.reshape(results_np, (-1, 2)) predictions_np = results_np[:, 0].astype(np.float32) targets_np = results_np[:, 1].astype(np.int32) roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np) roc_auc = roc_obj.ComputeRocAuc() tf.logging.info("== Eval shape: {}. AUC = {:.4f}".format( predictions_np.shape, roc_auc)) success = roc_auc >= _ACCURACY_THRESH mlp_log.mlperf_print("eval_accuracy", roc_auc, metadata={"epoch_num": eval_num + 1}) if success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) if summary_writer: summary_writer.add_summary( utils.create_scalar_summary("auc", roc_auc), global_step=cur_step + FLAGS.steps_between_evals) eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc)) return success def _default_run_finish_fn(success_status): if not success_status: mlp_log.mlperf_print("run_stop", None, metadata={"status": "failure"}) tf.logging.info("Retrieving embedding vars and writing stats.") runner.retrieve_embedding_vars() runner.train_and_eval(eval_init_fn=eval_init_fn or _default_eval_init_fn, eval_finish_fn=eval_finish_fn or _default_eval_finish_fn, run_finish_fn=run_finish_fn or _default_run_finish_fn) return eval_metrics
def run_pretraining(hparams): """Run pretraining with given hyperparameters.""" global masked_lm_accuracy global run_steps masked_lm_accuracy = 0 run_steps = 0 def eval_init_fn(cur_step): """Executed before every eval.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": cur_step + FLAGS.iterations_per_loop, "epoch_count": FLAGS.iterations_per_loop }) def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" global run_steps global masked_lm_accuracy cur_step_corrected = cur_step + FLAGS.iterations_per_loop run_steps = cur_step_corrected masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"] masked_lm_weighted_count = eval_output["masked_lm_weighted_count"] masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum( masked_lm_weighted_count) # the eval_output may mix up the order of the two arrays # swap the order if it did got mix up if masked_lm_accuracy > 1: masked_lm_accuracy = 1 / masked_lm_accuracy if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="masked_lm_accuracy", simple_value=masked_lm_accuracy) ]), cur_step_corrected) mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": cur_step_corrected, }) # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("eval_accuracy", float(masked_lm_accuracy), metadata={"epoch_num": cur_step_corrected}) if (masked_lm_accuracy >= FLAGS.stop_threshold and cur_step_corrected >= FLAGS.iterations_per_loop * 6): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) return True else: return False def run_finish_fn(success): if not success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) mlp_log.mlperf_print("run_final", None) def init_fn(): if FLAGS.init_checkpoint: tf.train.init_from_checkpoint(FLAGS.init_checkpoint, { "bert/": "bert/", "cls/": "cls/", }) # Passing the hyperparameters if "learning_rate" in hparams: FLAGS.learning_rate = hparams.learning_rate if "lamb_weight_decay_rate" in hparams: FLAGS.lamb_weight_decay_rate = hparams.lamb_weight_decay_rate if "lamb_beta_1" in hparams: FLAGS.lamb_beta_1 = hparams.lamb_beta_1 if "lamb_beta_2" in hparams: FLAGS.lamb_beta_2 = hparams.lamb_beta_2 if "epsilon" in hparams: FLAGS.epsilon = hparams.epsilon if "num_warmup_steps" in hparams: FLAGS.num_warmup_steps = hparams.num_warmup_steps if "num_train_steps" in hparams: FLAGS.num_train_steps = hparams.num_train_steps # Input handling tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.repeatable: tf.set_random_seed(123) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.gfile.Glob(input_pattern)) tf.logging.info("*** Input Files ***") tf.logging.info("%s Files." % len(input_files)) dataset_train = dataset_input.input_fn_builder( input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=True) dataset_eval = dataset_input.input_fn_builder( input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=False, num_eval_samples=FLAGS.num_eval_samples) # Create the low level runner low_level_runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.iterations_per_loop, FLAGS.stop_steps + 1, FLAGS.max_eval_steps, FLAGS.num_tpu_cores // FLAGS.num_partitions) mlp_log.mlperf_print("cache_clear", True) mlp_log.mlperf_print("init_start", None) mlp_log.mlperf_print("global_batch_size", FLAGS.train_batch_size) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.num_warmup_steps) mlp_log.mlperf_print("num_warmup_steps", FLAGS.num_warmup_steps) mlp_log.mlperf_print("start_warmup_step", FLAGS.start_warmup_step) mlp_log.mlperf_print("max_sequence_length", FLAGS.max_seq_length) mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate) mlp_log.mlperf_print("opt_lamb_beta_1", FLAGS.lamb_beta_1) mlp_log.mlperf_print("opt_lamb_beta_2", FLAGS.lamb_beta_2) mlp_log.mlperf_print("opt_epsilon", 10**FLAGS.log_epsilon) mlp_log.mlperf_print("opt_learning_rate_training_steps", FLAGS.num_train_steps) mlp_log.mlperf_print("opt_lamb_weight_decay_rate", FLAGS.lamb_weight_decay_rate) mlp_log.mlperf_print("opt_lamb_learning_rate_decay_poly_power", 1) mlp_log.mlperf_print("opt_gradient_accumulation_steps", 0) mlp_log.mlperf_print("max_predictions_per_seq", FLAGS.max_predictions_per_seq) low_level_runner.initialize(dataset_train, dataset_eval, bert_model_fn, FLAGS.train_batch_size, FLAGS.eval_batch_size, input_partition_dims=None, init_fn=init_fn, train_has_labels=False, eval_has_labels=False, num_partitions=FLAGS.num_partitions) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) # To make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("train_samples", FLAGS.num_train_steps * FLAGS.train_batch_size) mlp_log.mlperf_print("eval_samples", FLAGS.max_eval_steps * FLAGS.eval_batch_size) low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn) return masked_lm_accuracy, run_steps
def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) mlp_log.mlperf_print('run_final', None)
def resnet_model_fn(features, labels, is_training): """The model_fn for ResNet to be used with TPU. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples is_training: whether this is training Returns: train_op, logits """ if isinstance(features, dict): features = features['feature'] if FLAGS.use_space_to_depth: if FLAGS.train_batch_size // FLAGS.num_replicas > 8: features = tf.reshape( features, [FLAGS.image_size // 2, FLAGS.image_size // 2, 12, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC else: features = tf.reshape( features, [FLAGS.image_size // 2, FLAGS.image_size // 2, -1, 12]) features = tf.transpose(features, [2, 0, 1, 3]) # HWNC to NHWC else: if FLAGS.train_batch_size // FLAGS.num_replicas > 8: features = tf.reshape(features, [FLAGS.image_size, FLAGS.image_size, 3, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC else: features = tf.reshape(features, [FLAGS.image_size, FLAGS.image_size, -1, 3]) features = tf.transpose(features, [2, 0, 1, 3]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. if FLAGS.use_space_to_depth: features -= tf.constant(MEAN_RGB, shape=[1, 1, 12], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 12], dtype=features.dtype) else: features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): with tf.variable_scope('resnet', reuse=tf.AUTO_REUSE): network = resnet_model.resnet_v1( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, use_space_to_depth=FLAGS.use_space_to_depth, num_replicas=FLAGS.num_replicas, distributed_group_size=FLAGS.distributed_group_size) return network(inputs=features, is_training=is_training) if FLAGS.precision == 'bfloat16': with tf.tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if not is_training: total_correct = tf.reduce_sum( tf.cast( tf.equal(tf.cast(tf.argmax(logits, axis=1), labels.dtype), labels), tf.int32)) return None, {'total_correct': tf.reshape(total_correct, [-1])} # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. if FLAGS.enable_lars: loss = cross_entropy else: loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_or_create_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) mlp_log.mlperf_print( 'model_bn_span', FLAGS.distributed_group_size * (FLAGS.train_batch_size // FLAGS.num_replicas)) if FLAGS.enable_lars: learning_rate = 0.0 mlp_log.mlperf_print('opt_name', 'lars') optimizer = lars_util.init_lars_optimizer(current_epoch) else: mlp_log.mlperf_print('opt_name', 'sgd') learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss, global_step) return train_op, None