def learning_rate_schedule(current_epoch): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After 5 epochs we reach the base learning rate (scaled to account for batch size). After 30, 60 and 80 epochs the learning rate is divided by 10. After 90 epochs training stops and the LR is set to 0. This ensures that we train for exactly 90 epochs for reproducibility. Args: current_epoch: `Tensor` for current epoch. Returns: A scaled `Tensor` for current learning rate. """ mlp_log.mlperf_print('lars_opt_base_learning_rate', FLAGS.base_learning_rate) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) decay_rate = (scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1]) for mult, start_epoch in LR_SCHEDULE: decay_rate = tf.where(current_epoch < start_epoch, decay_rate, scaled_lr * mult) return decay_rate
def _default_run_finish_fn(success_status): if not success_status: mlp_log.mlperf_print("run_stop", None, metadata={"status": "failure"}) tf.logging.info("Retrieving embedding vars and writing stats.") runner.retrieve_embedding_vars()
def eval_finish_fn(cur_step, eval_output, _): """Callback function that's executed after each eval.""" if eval_steps == 0: return False # Concat eval_output as eval_output is a list from each host. for key in eval_output: eval_output[key] = np.concatenate(eval_output[key], axis=0) steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) eval_multiprocess.eval_multiprocessing(eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT) mlp_log.mlperf_print( 'eval_start', None, metadata={'epoch_num': cur_epoch + 1}) _, eval_results = eval_metric.evaluate() mlp_log.mlperf_print( 'eval_accuracy', {'BBOX': float(eval_results['AP']), 'SEGM': float(eval_results['mask_AP'])}, metadata={'epoch_num': cur_epoch + 1}) mlp_log.mlperf_print( 'eval_stop', None, metadata={'epoch_num': cur_epoch + 1}) if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True return False
def learning_rate_schedule(params, global_step): """Handles learning rate scaling, linear warmup, and learning rate decay. Args: params: A dictionary that defines hyperparameters of model. global_step: A tensor representing current global step. Returns: A tensor representing current learning rate. """ base_learning_rate = params['base_learning_rate'] lr_warmup_step = params['lr_warmup_step'] first_lr_drop_step = params['first_lr_drop_step'] second_lr_drop_step = params['second_lr_drop_step'] batch_size = params['batch_size'] * params['num_shards'] scaling_factor = batch_size / ssd_constants.DEFAULT_BATCH_SIZE mlp_log.mlperf_print('opt_learning_rate_warmup_factor', scaling_factor) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', lr_warmup_step) adjusted_learning_rate = base_learning_rate * scaling_factor learning_rate = (tf.cast(global_step, dtype=tf.float32) / lr_warmup_step) * adjusted_learning_rate lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step], [0.01, second_lr_drop_step]] for mult, start_global_step in lr_schedule: learning_rate = tf.where(global_step < start_global_step, learning_rate, adjusted_learning_rate * mult) return learning_rate
def poly_rate_schedule(current_epoch, poly_rate=0.0): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account for batch size). The learning rate is then decayed using a polynomial rate decay schedule with power 2.0. Args: current_epoch: `Tensor` for current epoch. poly_rate: Polynomial decay rate. Returns: A scaled `Tensor` for current learning rate. """ batch_size = FLAGS.train_batch_size if batch_size < 16384: plr = 10.0 w_epochs = 5 elif batch_size < 32768: plr = 25.0 w_epochs = 5 else: plr = 31.2 w_epochs = 25 # Override default poly learning rate and warmup epochs if poly_rate > 0.0: plr = poly_rate if FLAGS.lars_base_learning_rate > 0.0: plr = FLAGS.lars_base_learning_rate if FLAGS.lars_warmup_epochs > 0: w_epochs = FLAGS.lars_warmup_epochs mlp_log.mlperf_print('lars_opt_base_learning_rate', plr) mlp_log.mlperf_print('lars_opt_learning_rate_warmup_epochs', w_epochs) end_lr = 0.0001 mlp_log.mlperf_print('lars_opt_end_learning_rate', end_lr) wrate = (plr * current_epoch / w_epochs) w_steps = (w_epochs * FLAGS.num_train_images // batch_size) min_step = tf.constant(1, dtype=tf.int64) global_step = tf.train.get_or_create_global_step() decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps)) mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps', FLAGS.train_steps - w_steps + 1) mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0) poly_rate = tf.train.polynomial_decay(plr, decay_steps, FLAGS.train_steps - w_steps + 1, end_lr, power=2.0) decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate) return decay_rate
def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 })
def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 })
def eval_init_fn(cur_step): """Executed before every eval.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print( "block_start", None, metadata={ "first_epoch_num": cur_step + FLAGS.iterations_per_loop, "epoch_count": FLAGS.iterations_per_loop })
def eval_finish_fn(cur_step, eval_output, _): steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': FLAGS.iterations_per_loop // steps_per_epoch }) if FLAGS.run_cocoeval: q_in.put((cur_step, eval_output['detections']))
def init_lars_optimizer(current_epoch): """Initialize the LARS Optimizer.""" lars_epsilon = FLAGS.lars_epsilon mlp_log.mlperf_print('lars_epsilon', lars_epsilon) learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate) optimizer = contrib_opt.LARSOptimizer( learning_rate, momentum=FLAGS.momentum, weight_decay=FLAGS.weight_decay, skip_list=['batch_normalization', 'bias'], epsilon=lars_epsilon) return optimizer
def log_eval_results_fn(): """Print out MLPerf log.""" result = q_out.get() success = False while result[0] != _STOP: if not success: steps_per_epoch = (FLAGS.num_examples_per_epoch // FLAGS.train_batch_size) epoch = (result[0] + FLAGS.iterations_per_loop) // steps_per_epoch mlp_log.mlperf_print('eval_accuracy', result[1]['COCO/AP'], metadata={'epoch_num': epoch}) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': epoch}) if result[1]['COCO/AP'] > ssd_constants.EVAL_TARGET: success = True mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) result = q_out.get() if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None): eval_num = cur_step // FLAGS.steps_between_evals mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": eval_num + 1}) mlp_log.mlperf_print("block_stop", None, metadata={"first_epoch_num": eval_num + 1}) tf.logging.info( "== Eval finished (step {}). Computing metric..".format(cur_step)) results_np = np.array(eval_output["results"]) results_np = np.reshape(results_np, (-1, 2)) predictions_np = results_np[:, 0].astype(np.float32) targets_np = results_np[:, 1].astype(np.int32) roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np) roc_auc = roc_obj.ComputeRocAuc() tf.logging.info("== Eval shape: {}. AUC = {:.4f}".format( predictions_np.shape, roc_auc)) success = roc_auc >= _ACCURACY_THRESH mlp_log.mlperf_print("eval_accuracy", roc_auc, metadata={"epoch_num": eval_num + 1}) if success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) if summary_writer: summary_writer.add_summary( utils.create_scalar_summary("auc", roc_auc), global_step=cur_step + FLAGS.steps_between_evals) eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc)) return success
def _default_eval_init_fn(cur_step): """Logging statements executed before every eval.""" eval_num = cur_step // FLAGS.steps_between_evals tf.logging.info("== Block {}. Step {} of {}".format( eval_num + 1, cur_step, FLAGS.train_steps)) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": eval_num + 1, "epoch_count": 1 }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": eval_num + 1})
def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print('eval_start', None, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch })
def _write_metrics(eval_metrics, train_metrics, host_step, total_training_steps, host_id): """Logs the accuracy metrics.""" del host_id global RUN_STOP global TOTAL_STEPS if RUN_STOP: return eval_metrics = jax.tree_map(jax.device_get, eval_metrics) train_metrics = jax.tree_map(jax.device_get, train_metrics) masked_lm_accuracy = ( np.sum(eval_metrics['masked_lm_weighted_correct']) / np.sum(eval_metrics['masked_lm_weighted_count'])) total_loss = np.mean(train_metrics['total_loss']) lm_loss = np.mean(train_metrics['lm_loss']) sentence_loss = np.mean(train_metrics['sentence_loss']) mlp_log.mlperf_print('eval_accuracy', float(masked_lm_accuracy), metadata={'epoch_num': host_step}) logging.info('(Step %s / %s), masked_lm_accuracy: %s', host_step, total_training_steps, masked_lm_accuracy) logging.info( '(----Step %s / %s) Total loss: %s | LM loss: %s | Sentence loss: %s', host_step, total_training_steps, total_loss, lm_loss, sentence_loss) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': host_step}) if masked_lm_accuracy >= FLAGS.target_accuracy: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) RUN_STOP = time.time() TOTAL_STEPS = host_step
def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch eval_accuracy = float(np.sum( eval_output['total_correct'])) / FLAGS.num_eval_images if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy) ]), cur_step) mlp_log.mlperf_print('eval_accuracy', eval_accuracy, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) if eval_accuracy >= FLAGS.stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True else: return False
def get_learning_rate(self, global_step): """Sets up learning rate schedule.""" learning_rate = lr_policy.learning_rate_schedule( self.params['learning_rate'], self.params['lr_warmup_init'], self.params['lr_warmup_step'], self.params['first_lr_drop_step'], self.params['second_lr_drop_step'], global_step) mlp_log.mlperf_print(key='opt_base_learning_rate', value=self.params['learning_rate']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps', value=self.params['lr_warmup_step']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor', value=self.params['learning_rate'] / self.params['lr_warmup_step']) return learning_rate
def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init): """Start the infeed.""" time.sleep(150) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 1, "epoch_count": 1 }) for i in range(self.hparams.max_train_epochs): tf.logging.info("Infeed for epoch: %d", i + 1) sess.run(eval_init) sess.run([train_enqueue_ops]) sess.run([eval_enqueue_ops])
def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" global run_steps global masked_lm_accuracy cur_step_corrected = cur_step + FLAGS.iterations_per_loop run_steps = cur_step_corrected masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"] masked_lm_weighted_count = eval_output["masked_lm_weighted_count"] masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum( masked_lm_weighted_count) # the eval_output may mix up the order of the two arrays # swap the order if it did got mix up if masked_lm_accuracy > 1: masked_lm_accuracy = 1 / masked_lm_accuracy if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="masked_lm_accuracy", simple_value=masked_lm_accuracy) ]), cur_step_corrected) mlp_log.mlperf_print( "block_stop", None, metadata={ "first_epoch_num": cur_step_corrected, }) # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print( "eval_accuracy", float(masked_lm_accuracy), metadata={"epoch_num": cur_step_corrected}) if (masked_lm_accuracy >= FLAGS.stop_threshold and cur_step_corrected >= FLAGS.iterations_per_loop * 6): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) return True else: return False
def train_and_predict(self): """Run the predict loop on the TPU device.""" self.sess.run([self.compile_op]) # Train and eval thread. def train_eval_thread_fn(sess, train_eval_op): tf.logging.info("train_eval_op start") sess.run([train_eval_op]) train_eval_thread = threading.Thread(target=train_eval_thread_fn, args=(self.sess, self.train_eval_op)) train_eval_thread.start() # Infeed thread. def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init): """Start the infeed.""" time.sleep(150) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 1, "epoch_count": 1 }) for i in range(self.hparams.max_train_epochs): tf.logging.info("Infeed for epoch: %d", i + 1) sess.run(eval_init) sess.run([train_enqueue_ops]) sess.run([eval_enqueue_ops]) infeed_thread = threading.Thread(target=infeed_thread_fn, args=(self.sess, self.enqueue_ops, self.eval_enqueue_ops, self.eval_dataset_initializer)) infeed_thread.start() if self.eval_steps > 0: eval_state = {"run_success": False, "score": 0.0} for epoch in range(self.hparams.max_train_epochs): predictions = list(self.predict()) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": epoch + 1}) current_step = epoch * self.iterations eval_state["score"] = metric.get_metric( self.hparams, predictions, current_step) tf.logging.info("Score after epoch %d: %f", epoch, eval_state["score"]) mlp_log.mlperf_print("eval_accuracy", eval_state["score"] / 100.0, metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": epoch + 1, "epoch_count": 1 }) if eval_state["score"] >= self.hparams.target_bleu: eval_state["run_success"] = True mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) break mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": epoch + 2, "epoch_count": 1 }) if not eval_state["run_success"]: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) infeed_thread.join() train_eval_thread.join() if self.eval_steps > 0: return eval_state["score"], current_step else: return None, None
def main(argv): del argv # Unused. # TODO(b/132208296): remove this workaround that uses control flow v2. control_flow_util.ENABLE_CONTROL_FLOW_V2 = True # Parse hparams hparams = mask_rcnn_params.default_hparams() hparams.parse(FLAGS.hparams) params = dict( hparams.values(), transpose_input=False if FLAGS.input_partition_dims is not None else True, resnet_checkpoint=FLAGS.resnet_checkpoint, val_json_file=FLAGS.val_json_file, num_cores_per_replica=int(np.prod(FLAGS.input_partition_dims)) if FLAGS.input_partition_dims else 1, replicas_per_host=FLAGS.replicas_per_host) # MLPerf logging. mlp_log.mlperf_print(key='cache_clear', value=True) mlp_log.mlperf_print(key='init_start', value=None) mlp_log.mlperf_print(key='global_batch_size', value=FLAGS.train_batch_size) mlp_log.mlperf_print(key='train_samples', value=FLAGS.num_examples_per_epoch) mlp_log.mlperf_print(key='eval_samples', value=FLAGS.eval_samples) mlp_log.mlperf_print( key='min_image_size', value=params['short_side_image_size']) mlp_log.mlperf_print( key='max_image_size', value=params['long_side_max_image_size']) mlp_log.mlperf_print(key='num_image_candidates', value=params['rpn_post_nms_topn']) train_steps = ( FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size) eval_steps = int(math.ceil(float(FLAGS.eval_samples) / FLAGS.eval_batch_size)) if eval_steps > 0: # The eval dataset is not evenly divided. Adding step by one will make sure # all eval samples are covered. # TODO(b/151732586): regenerate the eval dataset to make all hosts get the # same amount of work. eval_steps += 1 runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.num_examples_per_epoch // FLAGS.train_batch_size, train_steps, eval_steps, FLAGS.num_shards) train_input_fn = dataloader.InputReader( FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN, use_fake_data=FLAGS.use_fake_data) eval_input_fn = functools.partial( dataloader.InputReader( FLAGS.validation_file_pattern, mode=tf.estimator.ModeKeys.PREDICT, distributed_eval=True), num_examples=eval_steps * FLAGS.eval_batch_size) eval_metric = coco_metric.EvaluationMetric( FLAGS.val_json_file, use_cpp_extension=True) def init_fn(): if FLAGS.resnet_checkpoint: tf.train.init_from_checkpoint(FLAGS.resnet_checkpoint, {'resnet/': 'resnet50/'}) runner.initialize(train_input_fn, eval_input_fn, mask_rcnn_model.MaskRcnnModelFn(params), FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.input_partition_dims, init_fn, params=params) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) def eval_finish_fn(cur_step, eval_output, _): """Callback function that's executed after each eval.""" if eval_steps == 0: return False # Concat eval_output as eval_output is a list from each host. for key in eval_output: eval_output[key] = np.concatenate(eval_output[key], axis=0) steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch mlp_log.mlperf_print( 'block_stop', None, metadata={ 'first_epoch_num': cur_epoch, 'epoch_count': 1 }) eval_multiprocess.eval_multiprocessing(eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT) mlp_log.mlperf_print( 'eval_start', None, metadata={'epoch_num': cur_epoch + 1}) _, eval_results = eval_metric.evaluate() mlp_log.mlperf_print( 'eval_accuracy', {'BBOX': float(eval_results['AP']), 'SEGM': float(eval_results['mask_AP'])}, metadata={'epoch_num': cur_epoch + 1}) mlp_log.mlperf_print( 'eval_stop', None, metadata={'epoch_num': cur_epoch + 1}) if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True return False def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
def resnet_model_fn(features, labels, is_training): """The model_fn for ResNet to be used with TPU. Args: features: `Tensor` of batched images. labels: `Tensor` of labels for the data samples is_training: whether this is training Returns: train_op, logits """ if isinstance(features, dict): features = features['feature'] if FLAGS.use_space_to_depth: if FLAGS.train_batch_size // FLAGS.num_replicas > 8: features = tf.reshape( features, [FLAGS.image_size // 2, FLAGS.image_size // 2, 12, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC else: features = tf.reshape( features, [FLAGS.image_size // 2, FLAGS.image_size // 2, -1, 12]) features = tf.transpose(features, [2, 0, 1, 3]) # HWNC to NHWC else: if FLAGS.train_batch_size // FLAGS.num_replicas > 8: features = tf.reshape(features, [FLAGS.image_size, FLAGS.image_size, 3, -1]) features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC else: features = tf.reshape(features, [FLAGS.image_size, FLAGS.image_size, -1, 3]) features = tf.transpose(features, [2, 0, 1, 3]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. if FLAGS.use_space_to_depth: features -= tf.constant(MEAN_RGB, shape=[1, 1, 12], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 12], dtype=features.dtype) else: features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) # This nested function allows us to avoid duplicating the logic which # builds the network, for different values of --precision. def build_network(): with tf.variable_scope('resnet', reuse=tf.AUTO_REUSE): network = resnet_model.resnet_v1( resnet_depth=FLAGS.resnet_depth, num_classes=FLAGS.num_label_classes, use_space_to_depth=FLAGS.use_space_to_depth, num_replicas=FLAGS.num_replicas, distributed_group_size=FLAGS.distributed_group_size) return network(inputs=features, is_training=is_training) if FLAGS.precision == 'bfloat16': with tf.tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if not is_training: total_correct = tf.reduce_sum( tf.cast( tf.equal(tf.cast(tf.argmax(logits, axis=1), labels.dtype), labels), tf.int32)) return None, {'total_correct': tf.reshape(total_correct, [-1])} # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=FLAGS.label_smoothing) # Add weight decay to the loss for non-batch-normalization variables. if FLAGS.enable_lars: loss = cross_entropy else: loss = cross_entropy + FLAGS.weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_or_create_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) mlp_log.mlperf_print( 'model_bn_span', FLAGS.distributed_group_size * (FLAGS.train_batch_size // FLAGS.num_replicas)) if FLAGS.enable_lars: learning_rate = 0.0 mlp_log.mlperf_print('opt_name', 'lars') optimizer = lars_util.init_lars_optimizer(current_epoch) else: mlp_log.mlperf_print('opt_name', 'sgd') learning_rate = learning_rate_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) optimizer = tf.tpu.CrossShardOptimizer(optimizer) # Batch normalization requires UPDATE_OPS to be added as a dependency to # the train operation. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss, global_step) return train_op, None
def run_model(params, eval_init_fn=None, eval_finish_fn=None, run_finish_fn=None): """Run the DLRM model, using a pre-defined configuration. Args: params: HPTuner object that provides new params for the trial. eval_init_fn: Lambda to run at start of eval. None means use the default. eval_finish_fn: Lambda for end of eval. None means use the default. run_finish_fn: Lambda for end of execution. None means use the default. Returns: A list of tuples, each entry describing the eval metric for one eval. Each tuple entry is (global_step, metric_value). """ mlp_log.mlperf_print(key="cache_clear", value=True) mlp_log.mlperf_print(key="init_start", value=None) mlp_log.mlperf_print("global_batch_size", params["batch_size"]) mlp_log.mlperf_print("train_samples", _NUM_TRAIN_EXAMPLES) mlp_log.mlperf_print("eval_samples", _NUM_EVAL_EXAMPLES) adjusted_lr = params["learning_rate"] * (params["batch_size"] / 2048.0) mlp_log.mlperf_print("opt_base_learning_rate", adjusted_lr) mlp_log.mlperf_print("sgd_opt_base_learning_rate", adjusted_lr) mlp_log.mlperf_print("sgd_opt_learning_rate_decay_poly_power", 2) mlp_log.mlperf_print("sgd_opt_learning_rate_decay_steps", params["decay_steps"]) mlp_log.mlperf_print("lr_decay_start_steps", params["decay_start_step"]) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", params["lr_warmup_steps"]) # Used for vizier. List of tuples. Each entry is (global_step, auc_metric). eval_metrics = [(0, 0.0)] feature_config = fc.FeatureConfig(params) (feature_to_config_dict, table_to_config_dict) = feature_config.get_feature_tbl_config() opt_params = { "sgd": tpu_embedding.StochasticGradientDescentParameters( learning_rate=params["learning_rate"]), "adagrad": tpu_embedding.AdagradParameters( learning_rate=params["learning_rate"], initial_accumulator=params["adagrad_init_accum"]) } embedding = tpu_embedding.TPUEmbedding( table_to_config_dict, feature_to_config_dict, params["batch_size"], mode=tpu_embedding.TRAINING, optimization_parameters=opt_params[params["optimizer"]], partition_strategy="mod", pipeline_execution_with_tensor_core=FLAGS.pipeline_execution, master=FLAGS.master) runner = dlrm_embedding_runner.DLRMEmbeddingRunner( iterations_per_loop=FLAGS.steps_between_evals, train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, num_replicas=FLAGS.num_tpu_shards, sparse_features_key="cat-features", embedding=embedding) train_input_fn, eval_input_fn = get_input_fns(params, feature_config) runner.initialize(train_input_fn, eval_input_fn, functools.partial(dlrm.dlrm_llr_model_fn, params, feature_config), params["batch_size"], params["eval_batch_size"], train_has_labels=False, eval_has_labels=False) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) def _default_eval_init_fn(cur_step): """Logging statements executed before every eval.""" eval_num = cur_step // FLAGS.steps_between_evals tf.logging.info("== Block {}. Step {} of {}".format( eval_num + 1, cur_step, FLAGS.train_steps)) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": eval_num + 1, "epoch_count": 1 }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": eval_num + 1}) def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None): eval_num = cur_step // FLAGS.steps_between_evals mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": eval_num + 1}) mlp_log.mlperf_print("block_stop", None, metadata={"first_epoch_num": eval_num + 1}) tf.logging.info( "== Eval finished (step {}). Computing metric..".format(cur_step)) results_np = np.array(eval_output["results"]) results_np = np.reshape(results_np, (-1, 2)) predictions_np = results_np[:, 0].astype(np.float32) targets_np = results_np[:, 1].astype(np.int32) roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np) roc_auc = roc_obj.ComputeRocAuc() tf.logging.info("== Eval shape: {}. AUC = {:.4f}".format( predictions_np.shape, roc_auc)) success = roc_auc >= _ACCURACY_THRESH mlp_log.mlperf_print("eval_accuracy", roc_auc, metadata={"epoch_num": eval_num + 1}) if success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) if summary_writer: summary_writer.add_summary( utils.create_scalar_summary("auc", roc_auc), global_step=cur_step + FLAGS.steps_between_evals) eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc)) return success def _default_run_finish_fn(success_status): if not success_status: mlp_log.mlperf_print("run_stop", None, metadata={"status": "failure"}) tf.logging.info("Retrieving embedding vars and writing stats.") runner.retrieve_embedding_vars() runner.train_and_eval(eval_init_fn=eval_init_fn or _default_eval_init_fn, eval_finish_fn=eval_finish_fn or _default_eval_finish_fn, run_finish_fn=run_finish_fn or _default_run_finish_fn) return eval_metrics
def run_pretrain(optimizer): """Run bert pretraining. Args: optimizer: BERT model with pretraining layer Returns: optimizer: trained model """ result_stats = {} def get_input_context(): class InputContext(): def __init__(self): self.input_pipeline_id = jax.host_id() self.num_input_pipelines = jax.host_count() return InputContext() summary_thread = thread.ThreadPoolExecutor(1, 'summary') host_id = jax.host_id() # Get input dataset input_files = [] for input_pattern in FLAGS.input_files.split(','): input_files.extend(tf.io.gfile.glob(input_pattern)) logging.info('*** Input Files ***') for input_file in input_files: logging.info(' %s', input_file) eval_input_files = [] for input_pattern in FLAGS.eval_input_files.split(','): eval_input_files.extend(tf.io.gfile.glob(input_pattern)) logging.info('*** Eval Input Files ***') for input_file in eval_input_files: logging.info(' %s', input_file) train_input_fn = input_pipeline.input_fn_builder( input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=True, num_cpu_threads=8) host_train_batch_size = FLAGS.train_batch_size // jax.host_count() host_eval_batch_size = FLAGS.eval_batch_size // jax.host_count() params = {'batch_size': host_train_batch_size} input_context = get_input_context() train_dataset = train_input_fn(params, input_context) train_iterator = iter(train_dataset) eval_input_fn = input_pipeline.input_fn_builder( input_files=eval_input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=False, num_cpu_threads=8, global_input_size=FLAGS.eval_sample_size) eval_params = {'batch_size': host_eval_batch_size} eval_dataset = eval_input_fn(eval_params, input_context) eval_iterator = iter(eval_dataset) # train step total_training_steps = FLAGS.total_training_steps learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, total_training_steps=FLAGS.total_training_steps, poly_power=FLAGS.poly_power, start_warmup_step=FLAGS.start_warmup_step) # Device training loop cond. def device_train_loop_cond(args): _, _, _, _, _, _, step, epoch, num_steps_per_epoch = args return step // num_steps_per_epoch == epoch # Device training loop body. def device_train_loop_body(args): """Device training loop body.""" (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch) = args device_batch_size = FLAGS.train_batch_size // jax.device_count() input_shape = [device_batch_size, FLAGS.max_seq_length] input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq] (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels), token = lax.infeed( token, shape=(jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.float32), jax.ShapedArray([device_batch_size, 1], jnp.int32))) inputs = [input_ids, input_mask, segment_ids, masked_lm_positions] labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels] optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step( optimizer, inputs, labels, learning_rate_fn, dropout_rng=new_dropout_rng) step += 1 return (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch) # Device training loop. def device_train_loop(optimizer, dropout_rng, total_loss, lm_loss, sentence_loss, step, epoch, num_steps_per_epoch): """Device training loop.""" token = lax.create_token(step) (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, _, step, epoch, num_steps_per_epoch) = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, token, step, epoch, num_steps_per_epoch)) return optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, step if FLAGS.infeed: pmap_fn = jax.pmap if FLAGS.enable_buffer_donation: pmap_fn = functools.partial(pmap_fn, donate_argnums=(0, 1)) if FLAGS.enable_wus: pmap_fn = functools.partial( pmap_fn, in_axes=(None, 0, None, None, None, None, None, None)) p_train_epoch = pmap_fn(device_train_loop, axis_name='batch') else: # without infeed. p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') pmap_fn = jax.pmap # Weight update sharding is not implemented yet for host train loop. # Enable wus on eval only if device loop is used. if FLAGS.enable_wus and FLAGS.infeed: pmap_fn = functools.partial(pmap_fn, in_axes=(None, 0, 0)) p_eval_step = pmap_fn(eval_step, axis_name='batch') rng = random.PRNGKey(0) device_count = jax.local_device_count() dropout_rngs = random.split(rng, device_count) num_steps_per_epoch = np.int32(FLAGS.num_steps_per_epoch) if FLAGS.precompile: if FLAGS.infeed: if FLAGS.enable_wus: total_loss = np.float32(0.0) lm_loss = np.float32(0.0) sentence_loss = np.float32(0.0) host_step = 0 host_epoch = 1 optimizer = unbroadcast(optimizer) # the device training loop condition will immediately be false optimizer, total_loss, lm_loss, sentence_loss, _, _ = p_train_epoch( optimizer, dropout_rngs, total_loss, lm_loss, sentence_loss, host_step, host_epoch, num_steps_per_epoch) else: total_loss = jax_utils.replicate(np.float32(0.0)) lm_loss = jax_utils.replicate(np.float32(0.0)) sentence_loss = jax_utils.replicate(np.float32(0.0)) device_step = jax_utils.replicate(0) device_epoch = jax_utils.replicate(1) # the device training loop condition will immediately be false optimizer, total_loss, lm_loss, sentence_loss, _, _ = p_train_epoch( optimizer, dropout_rngs, total_loss, lm_loss, sentence_loss, device_step, device_epoch, jax_utils.replicate(num_steps_per_epoch)) else: train_input_shape = (host_train_batch_size, FLAGS.max_seq_length) train_input_shape_pred = (host_train_batch_size, FLAGS.max_predictions_per_seq) word_id_data = jax.random.randint(rng, train_input_shape, 0, 10) mask_data = jax.random.randint(rng, train_input_shape, 0, 1) type_id_data = jax.random.randint(rng, train_input_shape, 0, 3) lm_mask = jax.random.randint(rng, train_input_shape_pred, 0, 5) masked_lm_ids = jax.random.randint(rng, train_input_shape_pred, 0, 2) masked_lm_weights = jax.random.randint(rng, train_input_shape_pred, 1, 1).astype(np.float32) next_sentence_labels = jax.random.randint(rng, (host_train_batch_size, 1), 0, 1) labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels] train_inputs = [word_id_data, mask_data, type_id_data, lm_mask] train_inputs = common_utils.shard(train_inputs) labels = common_utils.shard(labels) p_train_step(optimizer, train_inputs, labels, dropout_rng=dropout_rngs) eval_input_shape = (host_eval_batch_size, FLAGS.max_seq_length) eval_input_shape_pred = (host_eval_batch_size, FLAGS.max_predictions_per_seq) word_id_data = jax.random.randint(rng, eval_input_shape, 0, 10) mask_data = jax.random.randint(rng, eval_input_shape, 0, 1) type_id_data = jax.random.randint(rng, eval_input_shape, 0, 3) lm_mask = jax.random.randint(rng, eval_input_shape_pred, 0, 5) masked_lm_ids = jax.random.randint(rng, eval_input_shape_pred, 0, 2) masked_lm_weights = jax.random.randint( rng, eval_input_shape_pred, 1, 1).astype(np.float32) next_sentence_labels = jax.random.randint(rng, (host_eval_batch_size, 1), 0, 1) eval_inputs = { 'input_ids': word_id_data, 'input_mask': mask_data, 'segment_ids': type_id_data, 'masked_lm_positions': lm_mask, 'masked_lm_ids': masked_lm_ids, 'masked_lm_weights': masked_lm_weights, 'next_sentence_labels': next_sentence_labels } eval_inputs = common_utils.shard(eval_inputs) metrics = empty_metrics() optimizer_target = optimizer.target # Weight update sharding is not implemented yet for host train loop. # Enable wus on eval only if device loop is used. if FLAGS.enable_wus and FLAGS.infeed: optimizer_target = unbroadcast(optimizer_target) metrics = p_eval_step(optimizer_target, eval_inputs, metrics) metrics = allreduce_metrics(metrics) metrics = empty_metrics() time.sleep(FLAGS.init_sleep) allreduce_metrics(metrics)['masked_lm_weighted_correct'].block_until_ready() mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) # To make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print('train_samples', FLAGS.total_training_steps * FLAGS.train_batch_size) mlp_log.mlperf_print('eval_samples', FLAGS.eval_sample_size) xprof = None run_start = time.time() global RUN_STOP global TOTAL_STEPS RUN_STOP = False TOTAL_STEPS = False if host_id == 0: if FLAGS.end_to_end_profile: xprof = xprof_session.XprofSession() xprof.start_session(device_name='REDACTED', enable_python_tracer=True, host_trace_level=2) elif FLAGS.profile: profile_with_xprof_on_background(start_after_sec=FLAGS.profile_latency, profile_time_sec=FLAGS.profile_duration) if FLAGS.infeed: h_total_loss = np.float32(0.0) h_lm_loss = np.float32(0.0) h_sentence_loss = np.float32(0.0) d_total_loss = jax_utils.replicate(np.float32(0.0)) d_lm_loss = jax_utils.replicate(np.float32(0.0)) d_sentence_loss = jax_utils.replicate(np.float32(0.0)) host_step, device_step = 0, jax_utils.replicate(0) device_epoch = jax_utils.replicate(0) num_train_epochs = FLAGS.total_training_steps // FLAGS.num_steps_per_epoch steps_per_epoch = num_steps_per_epoch if num_train_epochs >= 6: # Merge the first 6 epochs, as we do not have to do eval. steps_per_epoch = np.int32(num_steps_per_epoch * 6) for host_epoch in range(num_train_epochs): block_step = host_step # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print( 'block_start', None, metadata={ 'first_epoch_num': block_step, 'epoch_count': FLAGS.num_steps_per_epoch }) if not (num_train_epochs >= 6 and host_epoch in (1, 2, 3, 4, 5)) and FLAGS.infeed: if FLAGS.enable_wus: optimizer = unbroadcast(optimizer) (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs, device_step) = p_train_epoch(optimizer, dropout_rngs, h_total_loss, h_lm_loss, h_sentence_loss, host_step, host_epoch, steps_per_epoch) else: device_epoch = jax_utils.replicate(host_epoch) device_steps_per_epoch = jax_utils.replicate(steps_per_epoch) (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs, device_step) = p_train_epoch(optimizer, dropout_rngs, d_total_loss, d_lm_loss, d_sentence_loss, device_step, device_epoch, device_steps_per_epoch) # After first epoch, reduce the steps per epoch back to normal number. steps_per_epoch = num_steps_per_epoch # Training for one epoch. while int(host_step // FLAGS.num_steps_per_epoch) == host_epoch: input_data = next(train_iterator) input_data = jax.tree_map(lambda x: x.numpy(), input_data) input_data = jax.tree_map(common_utils.shard, input_data) input_ids = input_data['input_ids'] input_mask = input_data['input_mask'] segment_ids = input_data['segment_ids'] masked_lm_positions = input_data['masked_lm_positions'] masked_lm_ids = input_data['masked_lm_ids'] masked_lm_weights = input_data['masked_lm_weights'] next_sentence_labels = input_data['next_sentence_labels'] # Infeed data to infeed queue. if FLAGS.infeed: for i, device in enumerate(jax.local_devices()): infeed_pool.submit( partial(device.transfer_to_infeed, (input_ids[i], input_mask[i], segment_ids[i], masked_lm_positions[i], masked_lm_ids[i], masked_lm_weights[i], next_sentence_labels[i]))) else: inputs = [input_ids, input_mask, segment_ids, masked_lm_positions] labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels] (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs ) = p_train_step(optimizer, inputs, labels, dropout_rng=dropout_rngs) host_step += 1 mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': block_step, 'epoch_count': FLAGS.num_steps_per_epoch }) # No need to do eval in the first 5 epochs as it has to traverse min 3M # samples. if host_epoch < 5: continue if host_step % FLAGS.num_steps_per_epoch == 0: mlp_log.mlperf_print( 'eval_start', None, metadata={'epoch_num': host_step}) optimizer_target = optimizer.target if FLAGS.enable_wus and FLAGS.infeed: optimizer_target = unbroadcast(optimizer_target) metrics = empty_metrics() for _ in range(FLAGS.max_eval_steps): inputs = jax.tree_map(lambda x: x.numpy(), next(eval_iterator)) inputs = jax.tree_map(common_utils.shard, inputs) # Weight update sharding is not implemented yet for host train loop. # Enable wus on eval only if device loop is used. metrics = p_eval_step(optimizer_target, inputs, metrics) metrics = allreduce_metrics(metrics) train_metrics = {'total_loss': total_loss, 'lm_loss': lm_loss, 'sentence_loss': sentence_loss} # masked_lm_accuracy = get_masked_lm_accuracy(metrics) summary_thread.submit(partial( _write_metrics, metrics, train_metrics, host_step, total_training_steps, host_id)) if host_step % FLAGS.num_steps_per_epoch == 0 and FLAGS.save_checkpoint: if host_id == 0: checkpoints.save_checkpoint( FLAGS.model_dir, optimizer, host_step, prefix='checkpoint', keep=1) allreduce_metrics(metrics)['masked_lm_weighted_correct'].block_until_ready() summary_thread.shutdown() if not RUN_STOP: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) mlp_log.mlperf_print('run_final', None) if host_id == 0: if FLAGS.end_to_end_profile: xprof_url = xprof.end_session_and_get_url(tag='') logging.info('Xprof profile is at %s', xprof_url) if RUN_STOP: result_stats['total_time'] = RUN_STOP - run_start result_stats['total_steps'] = TOTAL_STEPS return optimizer, result_stats
def run_exp(): mlp_log.mlperf_print('cache_clear', None) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', FLAGS.warmup_steps) mlp_log.mlperf_print('num_warmup_steps', FLAGS.warmup_steps) mlp_log.mlperf_print('start_warmup_step', FLAGS.start_warmup_step) mlp_log.mlperf_print('opt_lamb_weight_decay_rate', FLAGS.lamb_weight_decay) mlp_log.mlperf_print('max_sequence_length', FLAGS.max_seq_length) mlp_log.mlperf_print('opt_base_learning_rate', FLAGS.learning_rate) mlp_log.mlperf_print('opt_lamb_beta_1', FLAGS.lamb_beta_1) mlp_log.mlperf_print('opt_lamb_beta_2', FLAGS.lamb_beta_2) mlp_log.mlperf_print('opt_lamb_learning_rate_decay_poly_power', 1) mlp_log.mlperf_print('opt_gradient_accumulation_steps', 0) mlp_log.mlperf_print('max_predictions_per_seq', FLAGS.max_predictions_per_seq) mlp_log.mlperf_print('opt_epsilon', 10**FLAGS.log_epsilon) mlp_log.mlperf_print('opt_learning_rate_training_steps', FLAGS.total_training_steps) mlp_log.mlperf_print('submission_benchmark', 'bert') mlp_log.mlperf_print('submission_division', 'closed') mlp_log.mlperf_print('submission_org', 'google') mlp_log.mlperf_print('submission_platform', 'tpu-v3-%d' % jax.device_count()) mlp_log.mlperf_print('submission_status', 'research') jax_model, model_kwargs = get_pretrain_model() optimizer = create_optimizer(jax_model, model_kwargs, learning_rate=None) _, result_stats = run_pretrain(optimizer) return result_stats
def main(argv): del argv # Unused. params = construct_run_config(FLAGS.iterations_per_loop) mlp_log.mlperf_print(key='cache_clear', value=True) mlp_log.mlperf_print(key='init_start', value=None) mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size) mlp_log.mlperf_print('opt_base_learning_rate', params['base_learning_rate']) mlp_log.mlperf_print( 'opt_learning_rate_decay_boundary_epochs', [params['first_lr_drop_epoch'], params['second_lr_drop_epoch']]) mlp_log.mlperf_print('opt_weight_decay', params['weight_decay']) mlp_log.mlperf_print( 'model_bn_span', FLAGS.train_batch_size // FLAGS.num_shards * params['distributed_group_size']) mlp_log.mlperf_print('max_samples', ssd_constants.NUM_CROP_PASSES) mlp_log.mlperf_print('train_samples', FLAGS.num_examples_per_epoch) mlp_log.mlperf_print('eval_samples', FLAGS.eval_samples) params['batch_size'] = FLAGS.train_batch_size // FLAGS.num_shards input_partition_dims = FLAGS.input_partition_dims train_steps = FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size eval_steps = int(math.ceil(FLAGS.eval_samples / FLAGS.eval_batch_size)) runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.iterations_per_loop, train_steps, eval_steps, FLAGS.num_shards) train_input_fn = dataloader.SSDInputReader( FLAGS.training_file_pattern, params['transpose_input'], is_training=True, use_fake_data=FLAGS.use_fake_data, params=params) eval_input_fn = dataloader.SSDInputReader( FLAGS.validation_file_pattern, is_training=False, use_fake_data=FLAGS.use_fake_data, distributed_eval=True, count=eval_steps * FLAGS.eval_batch_size, params=params) def init_fn(): tf.train.init_from_checkpoint( params['resnet_checkpoint'], { 'resnet/': 'resnet%s/' % ssd_constants.RESNET_DEPTH, }) runner.initialize(train_input_fn, eval_input_fn, functools.partial(ssd_model.ssd_model_fn, params), FLAGS.train_batch_size, FLAGS.eval_batch_size, input_partition_dims, init_fn) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) if FLAGS.run_cocoeval: # copybara:strip_begin q_in, q_out = REDACTEDprocess.get_user_data() processes = [ REDACTEDprocess.Process(target=REDACTED_predict_post_processing) for _ in range(4) ] # copybara:strip_end_and_replace_begin # q_in = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE) # q_out = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE) # processes = [ # multiprocessing.Process( # target=predict_post_processing, args=(q_in, q_out)) # for _ in range(self.num_multiprocessing_workers) # ] # copybara:replace_end for p in processes: p.start() def log_eval_results_fn(): """Print out MLPerf log.""" result = q_out.get() success = False while result[0] != _STOP: if not success: steps_per_epoch = (FLAGS.num_examples_per_epoch // FLAGS.train_batch_size) epoch = (result[0] + FLAGS.iterations_per_loop) // steps_per_epoch mlp_log.mlperf_print('eval_accuracy', result[1]['COCO/AP'], metadata={'epoch_num': epoch}) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': epoch}) if result[1]['COCO/AP'] > ssd_constants.EVAL_TARGET: success = True mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) result = q_out.get() if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) log_eval_result_thread = threading.Thread(target=log_eval_results_fn) log_eval_result_thread.start() def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print('eval_start', None, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch }) def eval_finish_fn(cur_step, eval_output, _): steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': FLAGS.iterations_per_loop // steps_per_epoch }) if FLAGS.run_cocoeval: q_in.put((cur_step, eval_output['detections'])) runner.train_and_eval(eval_init_fn, eval_finish_fn) if FLAGS.run_cocoeval: for _ in processes: q_in.put((_STOP, None)) for p in processes: try: p.join(timeout=10) except Exception: # pylint: disable=broad-except pass q_out.put((_STOP, None)) log_eval_result_thread.join() # Clear out all the queues to avoid deadlock. while not q_out.empty(): q_out.get() while not q_in.empty(): q_in.get()
def run_finish_fn(success): if not success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) mlp_log.mlperf_print("run_final", None)
def main(unused_argv): def eval_init_fn(cur_step): """Executed before every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size epoch = cur_step // steps_per_epoch eval_accuracy = float(np.sum( eval_output['total_correct'])) / FLAGS.num_eval_images if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy) ]), cur_step) mlp_log.mlperf_print('eval_accuracy', eval_accuracy, metadata={ 'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch }) mlp_log.mlperf_print('block_stop', None, metadata={ 'first_epoch_num': epoch, 'epoch_count': 4 }) if eval_accuracy >= FLAGS.stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) return True else: return False def run_finish_fn(success): if not success: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) mlp_log.mlperf_print('run_final', None) low_level_runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.iterations_per_loop, FLAGS.train_steps, int(math.ceil(FLAGS.num_eval_images / FLAGS.eval_batch_size)), FLAGS.num_replicas) mlp_log.mlperf_print('cache_clear', True) mlp_log.mlperf_print('init_start', None) mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size) mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay) mlp_log.mlperf_print('lars_opt_momentum', FLAGS.momentum) mlp_log.mlperf_print('submission_benchmark', 'resnet') mlp_log.mlperf_print('submission_division', 'closed') mlp_log.mlperf_print('submission_org', 'google') mlp_log.mlperf_print('submission_platform', 'tpu-v3-%d' % FLAGS.num_replicas) mlp_log.mlperf_print('submission_status', 'research') assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', ( 'Invalid value for --precision flag; must be bfloat16 or float32.') input_dtype = tf.bfloat16 if FLAGS.precision == 'bfloat16' else tf.float32 cache_decoded_image = True if FLAGS.num_replicas > 2048 else False imagenet_train, imagenet_eval = [ imagenet_input.get_input_fn( # pylint: disable=g-complex-comprehension FLAGS.data_dir, is_training, input_dtype, FLAGS.image_size, FLAGS.input_partition_dims is None, cache_decoded_image=cache_decoded_image) for is_training in [True, False] ] low_level_runner.initialize(imagenet_train, imagenet_eval, resnet_model_fn, FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.input_partition_dims) mlp_log.mlperf_print('train_samples', FLAGS.num_train_images) mlp_log.mlperf_print('eval_samples', FLAGS.num_eval_images) mlp_log.mlperf_print('init_stop', None) mlp_log.mlperf_print('run_start', None) low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
def run_pretraining(hparams): """Run pretraining with given hyperparameters.""" global masked_lm_accuracy global run_steps masked_lm_accuracy = 0 run_steps = 0 def eval_init_fn(cur_step): """Executed before every eval.""" # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print( "block_start", None, metadata={ "first_epoch_num": cur_step + FLAGS.iterations_per_loop, "epoch_count": FLAGS.iterations_per_loop }) def eval_finish_fn(cur_step, eval_output, summary_writer): """Executed after every eval.""" global run_steps global masked_lm_accuracy cur_step_corrected = cur_step + FLAGS.iterations_per_loop run_steps = cur_step_corrected masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"] masked_lm_weighted_count = eval_output["masked_lm_weighted_count"] masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum( masked_lm_weighted_count) # the eval_output may mix up the order of the two arrays # swap the order if it did got mix up if masked_lm_accuracy > 1: masked_lm_accuracy = 1 / masked_lm_accuracy if summary_writer: with tf.Graph().as_default(): summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="masked_lm_accuracy", simple_value=masked_lm_accuracy) ]), cur_step_corrected) mlp_log.mlperf_print( "block_stop", None, metadata={ "first_epoch_num": cur_step_corrected, }) # While BERT pretraining does not have epochs, # to make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print( "eval_accuracy", float(masked_lm_accuracy), metadata={"epoch_num": cur_step_corrected}) if (masked_lm_accuracy >= FLAGS.stop_threshold and cur_step_corrected >= FLAGS.iterations_per_loop * 6): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) return True else: return False def run_finish_fn(success): if not success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) mlp_log.mlperf_print("run_final", None) def init_fn(): if FLAGS.init_checkpoint: tf.train.init_from_checkpoint(FLAGS.init_checkpoint, { "bert/": "bert/", "cls/": "cls/", }) # Passing the hyperparameters if "learning_rate" in hparams: FLAGS.learning_rate = hparams.learning_rate if "lamb_weight_decay_rate" in hparams: FLAGS.lamb_weight_decay_rate = hparams.lamb_weight_decay_rate if "lamb_beta_1" in hparams: FLAGS.lamb_beta_1 = hparams.lamb_beta_1 if "lamb_beta_2" in hparams: FLAGS.lamb_beta_2 = hparams.lamb_beta_2 if "epsilon" in hparams: FLAGS.epsilon = hparams.epsilon if "num_warmup_steps" in hparams: FLAGS.num_warmup_steps = hparams.num_warmup_steps if "num_train_steps" in hparams: FLAGS.num_train_steps = hparams.num_train_steps # Input handling tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.repeatable: tf.set_random_seed(123) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") train_input_files = [] for input_pattern in FLAGS.input_file.split(","): train_input_files.extend(tf.gfile.Glob(input_pattern)) eval_input_file = "/REDACTED/je-d/home/staging-REDACTED-gpu-dedicated/bert/eval_original_dataset/part-*" eval_input_files = [] for input_pattern in eval_input_file.split(","): eval_input_files.extend(tf.gfile.Glob(input_pattern)) tf.logging.info("*** Input Files ***") tf.logging.info("%s Files." % len(train_input_files)) dataset_train = dataset_input.input_fn_builder( input_files=train_input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=True, num_cpu_threads=8) dataset_eval = dataset_input.input_fn_builder( input_files=eval_input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=False, num_cpu_threads=8, num_eval_samples=FLAGS.num_eval_samples) # Create the low level runner low_level_runner = train_and_eval_runner.TrainAndEvalRunner( FLAGS.iterations_per_loop, FLAGS.stop_steps + 1, FLAGS.max_eval_steps, FLAGS.num_tpu_cores // FLAGS.num_partitions) mlp_log.mlperf_print("cache_clear", True) mlp_log.mlperf_print("init_start", None) mlp_log.mlperf_print("global_batch_size", FLAGS.train_batch_size) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.num_warmup_steps) mlp_log.mlperf_print("num_warmup_steps", FLAGS.num_warmup_steps) mlp_log.mlperf_print("start_warmup_step", FLAGS.start_warmup_step) mlp_log.mlperf_print("max_sequence_length", FLAGS.max_seq_length) mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate) mlp_log.mlperf_print("opt_lamb_beta_1", FLAGS.lamb_beta_1) mlp_log.mlperf_print("opt_lamb_beta_2", FLAGS.lamb_beta_2) mlp_log.mlperf_print("opt_epsilon", 10 ** FLAGS.log_epsilon) mlp_log.mlperf_print("opt_learning_rate_training_steps", FLAGS.num_train_steps) mlp_log.mlperf_print("opt_lamb_weight_decay_rate", FLAGS.lamb_weight_decay_rate) mlp_log.mlperf_print("opt_lamb_learning_rate_decay_poly_power", 1) mlp_log.mlperf_print("opt_gradient_accumulation_steps", 0) mlp_log.mlperf_print("max_predictions_per_seq", FLAGS.max_predictions_per_seq) low_level_runner.initialize( dataset_train, dataset_eval, bert_model_fn, FLAGS.train_batch_size, FLAGS.eval_batch_size, input_partition_dims=None, init_fn=init_fn, train_has_labels=False, eval_has_labels=False, num_partitions=FLAGS.num_partitions) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) # To make the logging consistent with other mlperf models, # in all the mlp_log, epochs are steps, and examples are sequences. mlp_log.mlperf_print("train_samples", FLAGS.num_train_steps * FLAGS.train_batch_size) mlp_log.mlperf_print("eval_samples", FLAGS.num_eval_samples) low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn) return masked_lm_accuracy, run_steps