def __init__(self): fk = ssd_constants.IMAGE_SIZE / np.array(ssd_constants.STEPS) self.default_boxes = [] # size of feature and number of feature for idx, feature_size in enumerate(ssd_constants.FEATURE_SIZES): sk1 = ssd_constants.SCALES[idx] / ssd_constants.IMAGE_SIZE sk2 = ssd_constants.SCALES[idx + 1] / ssd_constants.IMAGE_SIZE sk3 = math.sqrt(sk1 * sk2) all_sizes = [(sk1, sk1), (sk3, sk3)] for alpha in ssd_constants.ASPECT_RATIOS[idx]: w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) all_sizes.append((w, h)) all_sizes.append((h, w)) assert len(all_sizes) == ssd_constants.NUM_DEFAULTS[idx] for w, h in all_sizes: for i, j in it.product(range(feature_size), repeat=2): cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] box = tuple(np.clip(k, 0, 1) for k in (cy, cx, h, w)) self.default_boxes.append(box) assert len(self.default_boxes) == ssd_constants.NUM_SSD_BOXES mlp_log.mlperf_print('max_samples', ssd_constants.NUM_SSD_BOXES) def to_ltrb(cy, cx, h, w): return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 # For IoU calculation self.default_boxes_ltrb = tuple( to_ltrb(*i) for i in self.default_boxes)
def train_and_eval(self, train_steps, local_batch_size, # pylint: disable=unused-argument num_threads=2): # pylint: disable=unused-argument """Run the training loop on the TPU device.""" tf.logging.info("LowLevelRunner: train for %d steps in total.", train_steps) if train_steps % self.iterations != 0: tf.logging.warning( "train_steps %d is not divisible by iterations_per_loop %d", train_steps, self.iterations) train_steps = self.iterations * int( math.ceil(train_steps / self.iterations)) # Train and eval/predict thread. def train_eval_thread_fn(sess, train_eval_op, steps): sess.run([train_eval_op], feed_dict={self.num_epochs_tensor: steps, self.train_steps_tensor: self.iterations, self.eval_steps_tensor: self.eval_steps}) self.train_eval_thread = threading.Thread( target=train_eval_thread_fn, args=(self.sess, self.train_eval_op, train_steps // self.iterations)) self.train_eval_thread.start() # Infeed thread. def infeed_thread_fn(sess, eval_sess, enqueue_ops, eval_enqueue_ops, eval_dataset_initializer): """Build and infeed session.run calls in a background thread.""" for i in range(train_steps // self.iterations): mlp_log.mlperf_print( "block_start", None, metadata={ "first_epoch_num": i + 1, "epoch_count": 1 }) tf.logging.info( "Start to infeed %d batches for training of epoch %d.", self.iterations, i) sess.run([enqueue_ops]) eval_sess.run(eval_dataset_initializer) eval_sess.run([eval_enqueue_ops]) self.infeed_thread = threading.Thread( target=infeed_thread_fn, args=(self.input_sess, self.eval_input_sess, self.enqueue_ops, self.eval_enqueue_ops, self.eval_dataset_initializer)) time.sleep(240) mlp_log.mlperf_print(key="init_stop", value=None) mlp_log.mlperf_print(key="run_start", value=None) self.infeed_thread.start()
def run_main(flags, default_hparams, estimator_fn): """Run main.""" # Job jobid = flags.jobid utils.print_out("# Job id %d" % jobid) # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) tf.set_random_seed(random_seed) # Model output directory out_dir = flags.out_dir if out_dir and not tf.gfile.Exists(out_dir): utils.print_out("# Creating output directory %s ..." % out_dir) tf.gfile.MakeDirs(out_dir) # Load hparams. hparams = create_or_load_hparams(default_hparams, flags.hparams_path) # TODO(dehao) move init time closer to model construction if necessary. mlp_log.mlperf_print("init_start", None) # Train or Evaluation return estimator_fn(hparams)
def learning_rate_schedule(params, global_step): """Handles learning rate scaling, linear warmup, and learning rate decay. Args: params: A dictionary that defines hyperparameters of model. global_step: A tensor representing current global step. Returns: A tensor representing current learning rate. """ base_learning_rate = params['base_learning_rate'] lr_warmup_step = params['lr_warmup_step'] first_lr_drop_step = params['first_lr_drop_step'] second_lr_drop_step = params['second_lr_drop_step'] batch_size = (params['batch_size'] * params['num_shards'] if params['use_tpu'] else params['batch_size']) scaling_factor = batch_size / ssd_constants.DEFAULT_BATCH_SIZE mlp_log.mlperf_print('opt_learning_rate_warmup_factor', scaling_factor) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', lr_warmup_step) adjusted_learning_rate = base_learning_rate * scaling_factor learning_rate = (tf.cast(global_step, dtype=tf.float32) / lr_warmup_step) * adjusted_learning_rate lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step], [0.01, second_lr_drop_step]] for mult, start_global_step in lr_schedule: learning_rate = tf.where(global_step < start_global_step, learning_rate, adjusted_learning_rate * mult) return learning_rate
def learning_rate_schedule(current_epoch): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After 5 epochs we reach the base learning rate (scaled to account for batch size). After 30, 60 and 80 epochs the learning rate is divided by 10. After 90 epochs training stops and the LR is set to 0. This ensures that we train for exactly 90 epochs for reproducibility. Args: current_epoch: `Tensor` for current epoch. Returns: A scaled `Tensor` for current learning rate. """ mlp_log.mlperf_print('base_learning_rate', FLAGS.base_learning_rate) scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) decay_rate = (scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1]) for mult, start_epoch in LR_SCHEDULE: decay_rate = tf.where(current_epoch < start_epoch, decay_rate, scaled_lr * mult) return decay_rate
def train(self): """Run the train loops and write a summary to directory.""" mlp_log.mlperf_print(key='init_stop', value=None) mlp_log.mlperf_print(key='run_start', value=None) if self.use_tpu_estimator: self.runner.train(input_fn=self.input_fn, max_steps=self.params['total_steps']) else: self.runner.train()
def poly_rate_schedule(current_epoch, poly_rate=0.0): """Handles linear scaling rule, gradual warmup, and LR decay. The learning rate starts at 0, then it increases linearly per step. After FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account for batch size). The learning rate is then decayed using a polynomial rate decay schedule with power 2.0. Args: current_epoch: `Tensor` for current epoch. poly_rate: Polynomial decay rate. Returns: A scaled `Tensor` for current learning rate. """ batch_size = FLAGS.train_batch_size if batch_size < 16384: plr = 10.0 w_epochs = 5 elif batch_size < 32768: plr = 25.0 w_epochs = 5 else: plr = 33.0 w_epochs = 25 # Override default poly learning rate and warmup epochs if poly_rate > 0.0: plr = poly_rate if FLAGS.lars_base_learning_rate > 0.0: plr = FLAGS.lars_base_learning_rate if FLAGS.lars_warmup_epochs > 0: w_epochs = FLAGS.lars_warmup_epochs mlp_log.mlperf_print('opt_base_learning_rate', plr) mlp_log.mlperf_print('opt_learning_rate_warmup_epochs', w_epochs) mlp_log.mlperf_print('lars_opt_end_learning_rate', 0.0001) wrate = (plr * current_epoch / w_epochs) w_steps = (w_epochs * FLAGS.num_train_images // batch_size) min_step = tf.constant(1, dtype=tf.int64) global_step = tf.train.get_or_create_global_step() decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps)) mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps', FLAGS.train_steps - w_steps + 1) mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0) poly_rate = tf.train.polynomial_decay(plr, decay_steps, FLAGS.train_steps - w_steps + 1, power=2.0) decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate) return decay_rate
def init_lars_optimizer(current_epoch): """Initialize the LARS Optimizer.""" mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay) mlp_log.mlperf_print('lars_epsilon', 0.0) learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate) optimizer = tf.contrib.opt.LARSOptimizer( learning_rate, momentum=FLAGS.momentum, weight_decay=FLAGS.weight_decay, skip_list=['batch_normalization', 'bias']) return optimizer
def learning_rate_schedule(hparams): """Learning rate schedule based on hparams.""" mlp_log.mlperf_print(key="opt_base_learning_rate", value=hparams.learning_rate_constant) mlp_log.mlperf_print(key="opt_learning_rate_warmup_steps", value=hparams.learning_rate_warmup_steps) step_num = _global_step(hparams) schedule_string = hparams.learning_rate_schedule names = schedule_string.split("*") names = [name.strip() for name in names if name.strip()] ret = tf.constant(1.0) for name in names: ret *= learning_rate_factor(name, step_num, hparams) return ret
def init_lars_optimizer(current_epoch): """Initialize the LARS Optimizer.""" lars_epsilon = FLAGS.lars_epsilon mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay) mlp_log.mlperf_print('lars_epsilon', lars_epsilon) learning_rate = get_lars_lr(current_epoch) optimizer = tf.contrib.opt.LARSOptimizer( learning_rate, momentum=FLAGS.momentum, weight_decay=FLAGS.weight_decay, skip_list=['batch_normalization', 'bias'], epsilon=lars_epsilon) return optimizer, learning_rate
def infeed_thread_fn(sess, eval_sess, enqueue_ops, eval_enqueue_ops, eval_dataset_initializer): """Build and infeed session.run calls in a background thread.""" for i in range(train_steps // self.iterations): mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": i + 1, "epoch_count": 1 }) tf.logging.info( "Start to infeed %d batches for training of epoch %d.", self.iterations, i) sess.run([enqueue_ops]) eval_sess.run(eval_dataset_initializer) eval_sess.run([eval_enqueue_ops])
def log_eval_result_fn(results): """Log eval results.""" cur_step, eval_results = results if cur_step == _STOP: return epoch = cur_step // self.params["steps_per_epoch"] with tf.Graph().as_default(): summaries = [] for metric in eval_results: summaries.append( tf.Summary.Value(tag=metric, simple_value=eval_results[metric])) tf_summary = tf.Summary(value=list(summaries)) summary_writer.add_summary(tf_summary, cur_step) mlp_log.mlperf_print("eval_accuracy", eval_results["COCO/AP"], metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": epoch + 1}) if epoch in self.epoch_count: epoch_count = self.epoch_count[epoch] else: epoch_count = 1 mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": epoch - epoch_count + 1, "epoch_count": epoch_count }) self.log_epochs[epoch] = True if eval_results["COCO/AP"] >= ssd_constants.EVAL_TARGET: self.run_success = True if epoch < self.success_epoch: self.success_epoch = epoch log_run_final = self.run_success for epoch in self.log_epochs: if epoch < self.success_epoch and not self.log_epochs[ epoch]: log_run_final = False break # Log run_final when all the previous eval results are logged. if log_run_final and not self.log_run_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) self.log_run_success = True
def get_predict_results(self, cur_epoch): """Run the predict loop on the TPU device.""" for step in range(self.eval_steps): tf.logging.info( "TrainAndEvalLowLevelRunner: reading eval step %d results", step) predictions = {name: [] for name in self.outfeed_names} for outfeed_dict in self.train_eval_sess.run(self.dequeue_ops): for name, tensors in six.iteritems(outfeed_dict): predictions[name].extend(tensors) if step == self.eval_steps - 1: # all predictions is read from device, async eval post-process starts. # next train on device also starts. mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": cur_epoch, "epoch_count": 1 }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": cur_epoch}) tf.logging.info( "TrainAndEvalLowLevelRunner: start eval epoch %d.", cur_epoch) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": cur_epoch + 1, "epoch_count": 1 }) yield predictions
def after_run(self, run_context, run_values): # pylint: disable=unused-argument """Runs evaluator.""" step = np.asscalar(run_context.session.run(self._global_step_tensor)) if self._timer.should_trigger_for_step(step): logging.info('Starting eval.') eval_results = self._evaluate(run_context.session, step) mlp_log.mlperf_print('eval_accuracy', float(eval_results[_EVAL_METRIC]), metadata={ 'epoch_num': max(step // self._steps_per_epoch - 1, 0) }) # The ImageNet eval size is hard coded. if eval_results[_EVAL_METRIC] >= self._stop_threshold: self._run_success = True mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) mlp_log.mlperf_print('run_final', None) run_context.request_stop() if step // self._steps_per_epoch == self._eval_every_epoch_from: self._timer = training.SecondOrStepTimer( every_steps=self._steps_per_epoch) self._timer.reset()
def post_processing_thread_fn(): """Run post-processing on CPU for predictions.""" mlp_log.mlperf_print( "block_start", None, metadata={"first_epoch_num": 0, "epoch_count": 1}) for cur_epoch in range(self.total_epoch): eval_begin = time.time() # Enables multi-processing to accelerate post-processing. eval_multiprocess.eval_multiprocessing( self.eval_steps, self.get_predict_results(cur_epoch), self.eval_metric, self.eval_params["eval_worker_count"]) pred_end = time.time() tf.logging.info("prediction takes %d seconds.", pred_end - eval_begin) num_eval_samples, eval_results = self.eval_metric.evaluate() eval_end = time.time() tf.logging.info("COCO evaluates %d samples", num_eval_samples) if num_eval_samples != self.eval_params["eval_samples"]: tf.logging.info("COCO fails to evaluate all %d samples, exit!" % self.eval_params["eval_samples"]) self.run_success = False self.continue_train = False return tf.logging.info("one evaluation takes %d seconds", eval_end - eval_begin) self.write_eval_summary(self.eval_summary_writer, eval_results, cur_epoch * self.iterations_per_loop) tf.logging.info("AP: %s" % eval_results["AP"]) tf.logging.info("mask_AP: %s" % eval_results["mask_AP"]) # Eval epoch is 0-indexed (for MLPerf log parsing). mlp_log.mlperf_print( "eval_stop", None, metadata={"epoch_num": cur_epoch}) # TODO(b/127959551): use both metrics once the bug is resolved. mlp_log.mlperf_print( "eval_accuracy", (float(eval_results["AP"]), float(eval_results["mask_AP"])), metadata={"epoch_num": cur_epoch}) if (eval_results["AP"] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results["mask_AP"] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) self.run_success = True self.continue_train = False return
def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init): """Start the infeed.""" time.sleep(300) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) for i in range(self.hparams.max_train_epochs): tf.logging.info("Infeed for epoch: %d", i + 1) mlp_log.mlperf_print( "block_start", None, metadata={ "first_epoch_num": i + 1, "epoch_count": 1 }) mlp_log.mlperf_print("epoch_start", None, metadata={"epoch_num": i + 1}) sess.run(eval_init) sess.run([train_enqueue_ops]) sess.run([eval_enqueue_ops])
def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disable=super-init-not-called tf.logging.info("Using optimizer %s", optimizer_name) mlp_log.mlperf_print(key="opt_name", value=optimizer_name) mlp_log.mlperf_print(key="opt_adam_beta_1", value=hparams.optimizer_adam_beta1) mlp_log.mlperf_print(key="opt_adam_beta_2", value=hparams.optimizer_adam_beta2) mlp_log.mlperf_print(key="opt_adam_epsilon", value=hparams.optimizer_adam_epsilon) self._bfloat16_grads_all_reduce = hparams.bfloat16_grads_all_reduce if optimizer_name == "Adam": # We change the default epsilon for Adam. # Using LazyAdam as it's much faster for large vocabulary embeddings. self._opt = tf.contrib.opt.LazyAdamOptimizer( lr, beta1=hparams.optimizer_adam_beta1, beta2=hparams.optimizer_adam_beta2, epsilon=hparams.optimizer_adam_epsilon) elif optimizer_name == "Momentum": self._opt = tf.train.MomentumOptimizer( lr, momentum=hparams.optimizer_momentum_momentum, use_nesterov=hparams.optimizer_momentum_nesterov) elif optimizer_name == "TrueAdam": self._opt = tf.train.AdamOptimizer( lr, beta1=hparams.optimizer_adam_beta1, beta2=hparams.optimizer_adam_beta2, epsilon=hparams.optimizer_adam_epsilon) elif optimizer_name == "Adafactor": self._opt = adafactor.adafactor_optimizer_from_hparams(hparams, lr) # BEGIN GOOGLE-INTERNAL elif optimizer_name == "SM3": self._opt = sm3.SM3Optimizer( lr, momentum=hparams.optimizer_momentum_momentum) # END GOOGLE-INTERNAL else: self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name]( lr)
def get_learning_rate(self, params, global_step): """Sets up learning rate schedule.""" learning_rate = lr_policy.learning_rate_schedule( params['learning_rate'], params['lr_warmup_init'], params['lr_warmup_step'], params['first_lr_drop_step'], params['second_lr_drop_step'], global_step) mlp_log.mlperf_print(key='opt_base_learning_rate', value=params['learning_rate']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps', value=params['lr_warmup_step']) mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor', value=params['lr_warmup_init']/params['learning_rate']) return learning_rate
def end(self, session): # pylint: disable=unused-argument """Runs evaluator for final model.""" # Only runs eval at the end if highest accuracy so far # is less than self._stop_threshold. if not self._run_success: step = np.asscalar(session.run(self._global_step_tensor)) logging.info('Starting eval.') eval_results = self._evaluate(session, step) mlp_log.mlperf_print('eval_accuracy', float(eval_results[_EVAL_METRIC]), metadata={ 'epoch_num': max(step // self._steps_per_epoch - 1, 0) }) if eval_results[_EVAL_METRIC] >= self._stop_threshold: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'}) else: mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'}) mlp_log.mlperf_print('run_final', None)
def evaluate(self, ckpt): """Performs evaluation against `ckpt` and writes a summary to directory.""" current_step, num_epochs = self.get_step_and_epoch_number(ckpt) mlp_log.mlperf_print('eval_start', None, metadata={'epoch_num': num_epochs}) eval_begin = time.time() if self.use_tpu_estimator: input_fn = functools.partial(self.input_fn, num_examples=self.eval_steps * self.params['eval_batch_size']) predictor = self.runner.predict(input_fn=input_fn, checkpoint_path=ckpt, yield_single_examples=False) else: predictor = self.runner.predict(checkpoint_path=ckpt, eval_steps=self.eval_steps) # Enables multi-processing to accelerate post-processing. eval_multiprocess.eval_multiprocessing( self.eval_steps, predictor, self.eval_metric, self.params['eval_worker_count']) pred_end = time.time() tf.logging.info('prediction takes %d seconds.', pred_end - eval_begin) num_eval_samples, eval_results = self.eval_metric.evaluate() eval_end = time.time() tf.logging.info('COCO evaluates %d samples', num_eval_samples) assert num_eval_samples == self.params['eval_samples'] tf.logging.info('one evaluation takes %d seconds', eval_end - eval_begin) self.write_summary(eval_results, current_step) tf.logging.info('AP: %s' % eval_results['AP']) tf.logging.info('mask_AP: %s' % eval_results['mask_AP']) mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': num_epochs}) # TODO(b/127959551): use both metrics once the bug is resolved. mlp_log.mlperf_print( 'eval_accuracy', (float(eval_results['AP']), float(eval_results['mask_AP'])), metadata={'epoch_num': num_epochs}) return eval_results
def compute_bleu_summaries(hook_args): """Compute BLEU core summaries using the decoder output. Args: hook_args: DecodeHookArgs namedtuple Returns: A list of tf.Summary values if hook_args.hparams contains the reference file and the translated file. """ outputs, references = [], [] for output, reference in hook_args.predictions: outputs.append(output) references.append(reference) decode_hparams = hook_args.decode_hparams values = [] bleu = 100 * bleu_hook.bleu_wrapper(references, outputs) values.append(tf.Summary.Value(tag="BLEU", simple_value=bleu)) tf.logging.info("BLEU = %6.2f" % (bleu)) if hook_args.hparams.mlperf_mode: current_step = decode_hparams.mlperf_decode_step mlp_log.mlperf_print( "eval_stop", None, metadata={ "epoch_num": max(current_step // decode_hparams.iterations_per_loop, 1) }) mlp_log.mlperf_print( "eval_accuracy", bleu, metadata={ "epoch_num": max(current_step // decode_hparams.iterations_per_loop, 1) }) if bleu >= decode_hparams.mlperf_threshold: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) decode_hparams.set_hparam("mlperf_success", True) return values
def train(self, max_steps=None): """Train for max_steps.""" mlp_log.mlperf_print(key="init_stop", value=None) mlp_log.mlperf_print(key="run_start", value=None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 1, "epoch_count": 1 }) if self._hparams.train_with_low_level_api: self._trunner.train(self._hparams.train_steps, self._hparams.batch_size) self._trunner.shutdown() else: self._estimator.train(self._train_spec.input_fn, hooks=self._train_spec.hooks, max_steps=max_steps or self._train_spec.max_steps)
def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" # Starts the clock. time.sleep(60) mlp_log.mlperf_print(key="init_stop", value=None) mlp_log.mlperf_print(key="run_start", value=None) mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": 0, "epoch_count": 1 }) for cur_epoch in range(self.total_epoch): tf.logging.info("Start to infeed train batches for epoch %d", cur_epoch) self.input_sess.run([self.enqueue_ops]) tf.logging.info("Start to infeed eval batches for epoch %d", cur_epoch) self.input_sess.run([self.eval_enqueue_ops]) tf.logging.info("infeed thread exited.")
def continuous_decode_on_eval_data(self): """Decode from dataset on new checkpoint.""" if self._hparams.mlperf_mode: ckpt_generator = next_undecoded_checkpoint(self._hparams.model_dir) else: ckpt_generator = next_checkpoint(self._hparams.model_dir) for ckpt in ckpt_generator: current_step = int(os.path.basename(ckpt).split("-")[1]) tf.logging.info("Decoding step %d" % current_step) # Skip checkpoint 0. if current_step == 0: continue # Decode the latest checkpoint by default. checkpoint_path = None if self._hparams.mlperf_mode: self._decode_hparams.mlperf_decode_step = current_step checkpoint_path = ckpt mlp_log.mlperf_print( "eval_start", None, metadata={ "epoch_num": max( current_step // self._decode_hparams.iterations_per_loop, 1) }) self.decode(dataset_split=tf.estimator.ModeKeys.EVAL, checkpoint_path=checkpoint_path) if self._hparams.mlperf_mode and self._decode_hparams.mlperf_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) break if self._hparams.mlperf_mode and not self._decode_hparams.mlperf_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
def train_and_eval(self): """Performs distributed model eval and writes a summary to directory.""" self.run_success = False self.continue_train = True # queues for predictions post-processing. def post_processing_thread_fn(): """Run post-processing on CPU for predictions.""" for cur_epoch in range(self.total_epoch): eval_begin = time.time() # Enables multi-processing to accelerate post-processing. eval_multiprocess.eval_multiprocessing( self.eval_steps, self.get_predict_results(cur_epoch), self.eval_metric, self.eval_params["eval_worker_count"]) pred_end = time.time() tf.logging.info("prediction takes %d seconds.", pred_end - eval_begin) num_eval_samples, eval_results = self.eval_metric.evaluate() eval_end = time.time() tf.logging.info("COCO evaluates %d samples", num_eval_samples) if num_eval_samples != self.eval_params["eval_samples"]: tf.logging.info( "COCO fails to evaluate all %d samples, exit!" % self.eval_params["eval_samples"]) self.run_success = False self.continue_train = False return tf.logging.info("one evaluation takes %d seconds", eval_end - eval_begin) self.write_eval_summary(self.eval_summary_writer, eval_results, cur_epoch * self.iterations_per_loop) tf.logging.info("AP: %s" % eval_results["AP"]) tf.logging.info("mask_AP: %s" % eval_results["mask_AP"]) # Eval epoch is 0-indexed (for MLPerf log parsing). mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": cur_epoch}) # TODO(b/127959551): use both metrics once the bug is resolved. mlp_log.mlperf_print("eval_accuracy", (float( eval_results["AP"]), float(eval_results["mask_AP"])), metadata={"epoch_num": cur_epoch}) if (eval_results["AP"] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results["mask_AP"] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) self.run_success = True self.continue_train = False return # Run predict post processing thread on the background. post_processing_thread = threading.Thread( target=post_processing_thread_fn) post_processing_thread.start() if self.train_params["all_in_one_session"]: tf.logging.info( "TrainAndEvalLowLevelRunner: start train_eval sessions") self.train_eval_sess.run(self.train_eval_op) else: if self.train_params["train_and_eval_save_checkpoint"]: ckpt_saver = runner_utils.AsyncCheckpointSaver( _MAX_NUM_CHECKPOINT_THREADS, self.saver, self.model_dir, self.train_eval_sess) cur_epoch = 0 while cur_epoch < self.total_epoch and self.continue_train: tf.logging.info( "TrainAndEvalLowLevelRunner: start train epoch: %d", cur_epoch) start = time.time() self.train_eval_sess.run(self.train_eval_op) end = time.time() self.write_summary(summary_writer=self.summary_writer, graph=self.train_eval_graph, global_step=cur_epoch * self.iterations_per_loop, elapsed_time=end - start, elapsed_steps=self.iterations_per_loop, trained_examples=self. train_params["num_examples_per_epoch"]) if self.train_params["train_and_eval_save_checkpoint"]: ckpt_saver.checkpoint(cur_epoch * self.iterations_per_loop) if self.run_success or not self.continue_train: break cur_epoch += 1 post_processing_thread.join() if not self.run_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
def train_and_eval(self, output_summaries=False, enable_tracing=True): """Run the Train steps on the TPU device.""" if output_summaries: output_dir = os.path.join(FLAGS.model_dir, "eval", self.tpu_name) tf.gfile.MakeDirs(output_dir) # Summary writer writes out eval metrics. summary_writer = tf.summary.FileWriter(output_dir) if FLAGS.save_graphs: summary_writer.add_graph(self.graph) summary_writer.add_graph(self.input_graph) summary_writer.add_graph(self.eval_input_graph) summary_writer.add_graph(self.eval_output_graph) def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" # Build infeed sesssion # Run infeed session.run calls tf.logging.info("Start infeed thread") for _ in range(self.train_steps // self.iterations): self.input_sess.run([self.enqueue_ops]) self.eval_input_sess.run([self.eval_enqueue_ops]) if False: self.infeed_thread = threading.Thread(target=infeed_thread_fn) self.infeed_thread.start() # Gather trace for the first few steps. if enable_tracing: self.launch_profiler() self.cur_step = 0 success = False def enq(self, run=True): if self.infeed_thread is None: tf.logging.info("TrainAndEvalRunner: input_sess enqueue") self.input_sess.run([self.enqueue_ops]) self.eval_input_sess.run([self.eval_enqueue_ops]) tf.logging.info("TrainAndEvalRunner: enqueue (done)") if run: tf.logging.info("TrainAndEvalRunner: train_eval_op...") result = self.sess.run([self.train_eval_op]) tf.logging.info( "TrainAndEvalRunner: train_eval_op... (done)") return result def checkpoint_thread_fn(tpu_name, saver, sess, step): name = ''.join(['_' if not c.isalnum() else c for c in tpu_name]) if FLAGS.export_dir is None: tf.logging.info("Not model %d: %s (FLAGS.export_dir is unset)", step, name) else: name = FLAGS.export_dir + "/model-%s.ckpt-%d" % (name, step) tf.logging.info("Saving model %d: %s", step, name) saver.save(sess, name) @tflex.register_command def save(): checkpoint_thread_fn(self.tpu_name, self.saver, self.sess, self.cur_step) # take care of the first JIT enq(self, run=False) while self.cur_step < self.train_steps or True: tflex.check_commands() if tflex.should_quit(): import pdb pdb.set_trace() break self.start = time.time() tf.logging.info("TrainAndEvalRunner: start next %d steps", self.iterations) self.cur_step = self.coordinator.claim(self.iterations) self.sess.run(self.global_step_init, {self.global_step_in: self.cur_step}) epoch = self.cur_step // self.steps_per_epoch - 1 mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": epoch + 1, "epoch_count": 4 }) self.step_loss = enq(self) self.eval_results = self.eval(self.eval_steps) self.end = time.time() self.step_time = self.end - self.start self.examples_sec = self.iterations * self.cfg[ 'train_batch_size'] / self.step_time self.eval_results['examples_sec'] = self.examples_sec self.eval_results['step_time'] = self.step_time if self.step_loss is not None: self.eval_results['loss'] = self.step_loss[0] if 'global_step' in self.eval_results: self.eval_results[ 'global_step_sec'] = self.iterations / self.step_time tf.logging.info( "TrainAndEvalRunner ({}): step {} step time {} sec {} examples/sec" .format(self.tpu_name, self.cur_step, self.step_time, self.examples_sec)) # Run eval. # Write out summary to tensorboard. if output_summaries: with tf.Graph().as_default(): summaries = [] for metric in self.eval_results: summaries.append( tf.Summary.Value( tag=metric, simple_value=self.eval_results[metric])) tf_summary = tf.Summary(value=list(summaries)) summary_writer.add_summary(tf_summary, self.cur_step) def flush(i): tf.logging.info("Flushing summaries...") start = time.time() summary_writer.flush() end = time.time() tf.logging.info("Flushing summaries (done in %.2fs)", (end - start)) if self.flush_summaries_thread is not None and self.flush_summaries_thread.is_alive( ): start = time.time() self.flush_summaries_thread.join() end = time.time() tf.logging.info( "Flushing summaries [BLOCKED] (done in %.2fs)", (end - start)) self.flush_summaries_thread = dispatch([0], flush)[0] # MLPerf logging for eval results. mlp_log.mlperf_print("eval_accuracy", float(self.eval_results["top_1_accuracy"]), metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("block_stop", None, metadata={"first_epoch_num": epoch + 1}) tf.logging.info("Eval results at step %d: %s", self.cur_step, self.eval_results) if self.eval_results["top_1_accuracy"] >= FLAGS.stop_threshold: success = True if FLAGS.export_dir is not None: self.checkpoint_thread = threading.Thread( target=checkpoint_thread_fn, args=(self.tpu_name, self.saver, self.sess, self.cur_step)) self.checkpoint_thread.start() mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) import pdb pdb.set_trace() break if enable_tracing and self.cur_step > self.train_steps // 4: self.launch_profiler() enable_tracing = False if not success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"}) mlp_log.mlperf_print("run_final", None) if output_summaries: summary_writer.close()
def main(argv): del argv # Unused. # TODO(b/132208296): remove this workaround that uses control flow v2. control_flow_util.ENABLE_CONTROL_FLOW_V2 = True tpu = FLAGS.tpu or FLAGS.master tpu_cluster_resolver = runner_utils.create_tpu_cluster_resolver( FLAGS.use_tpu, tpu, FLAGS.tpu_zone, FLAGS.gcp_project) if tpu_cluster_resolver: tpu_grpc_url = tpu_cluster_resolver.get_master() tf.Session.reset(tpu_grpc_url) # Check data path run_train = FLAGS.mode in ('train', 'train_and_eval') if run_train and FLAGS.training_file_pattern is None: raise RuntimeError( 'You must specify --training_file_pattern for training.') run_eval = FLAGS.mode in ('eval', 'train_and_eval') or ( FLAGS.mode == 'train' and FLAGS.eval_after_training) if run_eval: if FLAGS.validation_file_pattern is None: raise RuntimeError('You must specify --validation_file_pattern ' 'for evaluation.') if FLAGS.val_json_file is None: raise RuntimeError( 'You must specify --val_json_file for evaluation.') # Parse hparams hparams = mask_rcnn_params.default_hparams() hparams.parse(FLAGS.hparams) # The following is for spatial partitioning. `features` has one tensor while # `labels` has 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input # partition is performed on `features` and all partitionable tensors of # `labels`, see the partition logic below. # Note: In the below code, TPUEstimator uses both `shard` and `replica` (with # the same meaning). # Note that spatial partition is part of the model-parallelism optimization. # See core_assignment_utils.py for more details about model parallelism. if FLAGS.input_partition_dims: labels_partition_dims = { 'gt_boxes': None, 'gt_classes': None, 'cropped_gt_masks': None, } for level in range(hparams.get('min_level'), hparams.get('max_level') + 1): labels_partition_dims['box_targets_%d' % level] = None labels_partition_dims['score_targets_%d' % level] = None num_cores_per_replica = int(np.prod(FLAGS.input_partition_dims)) image_partition_dims = [ FLAGS.input_partition_dims[i] for i in [1, 0, 2] ] if hparams.get('transpose_input') else FLAGS.input_partition_dims features_partition_dims = { 'images': image_partition_dims, 'source_ids': None, 'image_info': None, } input_partition_dims = [features_partition_dims, labels_partition_dims] num_shards = FLAGS.num_cores // num_cores_per_replica else: num_cores_per_replica = None input_partition_dims = None num_shards = FLAGS.num_cores params = dict(hparams.values(), num_shards=num_shards, num_cores_per_replica=num_cores_per_replica, use_tpu=FLAGS.use_tpu, resnet_checkpoint=FLAGS.resnet_checkpoint, val_json_file=FLAGS.val_json_file, model_dir=FLAGS.model_dir) tpu_config = tf.contrib.tpu.TPUConfig( params['iterations_per_loop'], num_shards=num_shards, num_cores_per_replica=params['num_cores_per_replica'], input_partition_dims=input_partition_dims, per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. PER_HOST_V2, tpu_job_name=FLAGS.tpu_job_name, ) run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=FLAGS.model_dir, log_step_count_steps=params['iterations_per_loop'], tpu_config=tpu_config, save_checkpoints_steps=params['iterations_per_loop'], ) train_replicas_per_worker = ( params['cores_per_worker'] // params['num_cores_per_replica'] ) if params['num_cores_per_replica'] else params['cores_per_worker'] train_params = dict( params, replicas_per_worker=train_replicas_per_worker, ) eval_params = dict( params, input_rand_hflip=False, resnet_checkpoint=None, is_training_bn=False, transpose_input=False, ) # MLPerf logging. mlp_log.mlperf_print(key='init_start', value=None) mlp_log.mlperf_print(key='global_batch_size', value=params['train_batch_size']) runner = None if run_train and run_eval: if params['train_use_tpu_estimator'] or params[ 'eval_use_tpu_estimator']: raise RuntimeError( 'train_and_eval runner does not support TPUEstimator.') dist_eval_params = dict( eval_params, replicas_per_worker=train_replicas_per_worker, ) runner = mask_rcnn_runner.TrainEvalRunner( model_fn=mask_rcnn_model.MaskRcnnModelFn(), input_fn=dataloader.InputReader(FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN, use_fake_data=FLAGS.use_fake_data), eval_input_fn=dataloader.InputReader( FLAGS.validation_file_pattern, mode=tf.estimator.ModeKeys.PREDICT, distributed_eval=True), eval_metric=coco_metric.EvaluationMetric(FLAGS.val_json_file, use_cpp_extension=True), train_params=train_params, eval_params=dist_eval_params, run_config=run_config) elif run_train: # Check low-level train runner compatibility. if not params['train_use_tpu_estimator']: if FLAGS.mode == 'train_and_eval': raise RuntimeError( 'Low level train runner does not support mode ' 'train_and_eval yet.') train_params = dict( params, replicas_per_worker=train_replicas_per_worker, ) runner = mask_rcnn_runner.TrainRunner( model_fn=mask_rcnn_model.MaskRcnnModelFn(), input_fn=dataloader.InputReader(FLAGS.training_file_pattern, mode=tf.estimator.ModeKeys.TRAIN, use_fake_data=FLAGS.use_fake_data), params=train_params, run_config=run_config, use_tpu_estimator=train_params['train_use_tpu_estimator']) else: sidecar_eval_params = dict( eval_params, # sidecar eval only uses one worker and does not use spatial partition. replicas_per_worker=FLAGS.num_cores, ) runner = mask_rcnn_runner.EvalRunner( mask_rcnn_model.MaskRcnnModelFn(), dataloader.InputReader(FLAGS.validation_file_pattern, mode=tf.estimator.ModeKeys.PREDICT), coco_metric.EvaluationMetric(FLAGS.val_json_file, use_cpp_extension=True), sidecar_eval_params, run_config, use_tpu_estimator=sidecar_eval_params['eval_use_tpu_estimator']) if FLAGS.mode == 'train': runner.train() elif FLAGS.mode == 'eval': def terminate_eval(): tf.logging.info( 'Terminating eval after %d seconds of no checkpoints' % FLAGS.eval_timeout) return True run_success = False # Run evaluation when there's a new checkpoint for ckpt in tf.contrib.training.checkpoints_iterator( params['model_dir'], min_interval_secs=FLAGS.min_eval_interval, timeout=FLAGS.eval_timeout, timeout_fn=terminate_eval): tf.logging.info('Starting to evaluate.') try: eval_results = runner.evaluate(ckpt) current_step, _ = runner.get_step_and_epoch_number(ckpt) if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET): mlp_log.mlperf_print(key='run_stop', metadata={'status': 'success'}) run_success = True break if int(current_step) >= params['total_steps']: tf.logging.info( 'Evaluation finished after training step %d' % current_step) break except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info( 'Checkpoint %s no longer exists, skipping checkpoint' % ckpt) if not run_success: mlp_log.mlperf_print(key='run_stop', metadata={'status': 'aborted'}) elif FLAGS.mode == 'train_and_eval': runner.train_and_eval() else: tf.logging.info('Mode not found.')
tf.logging.info( "Evaluation finished but failed to reach target score." ) break except tf.errors.NotFoundError: tf.logging.info( "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) nmt_parser = argparse.ArgumentParser() add_arguments(nmt_parser) FLAGS, unparsed = nmt_parser.parse_known_args() mlp_log.mlperf_print("global_batch_size", FLAGS.batch_size) mlp_log.mlperf_print("opt_learning_rate_alt_decay_func", "True") mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate) mlp_log.mlperf_print("opt_learning_rate_decay_interval", FLAGS.decay_interval) mlp_log.mlperf_print("opt_learning_rate_decay_factor", FLAGS.decay_factor) mlp_log.mlperf_print("opt_learning_rate_decay_steps", FLAGS.decay_steps) mlp_log.mlperf_print("opt_learning_rate_remain_steps", FLAGS.decay_start) mlp_log.mlperf_print("opt_learning_rate_alt_warmup_func", FLAGS.warmup_scheme) mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.warmup_steps) mlp_log.mlperf_print("max_sequence_length", FLAGS.src_max_len) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
def train_and_eval(self, train_steps): """Run the Train and Eval loop on the TPU device.""" output_dir = os.path.join(FLAGS.model_dir, "eval") tf.gfile.MakeDirs(output_dir) # Summary writer writes out eval metrics. summary_writer = tf.summary.FileWriter(output_dir) self.run_success = False def log_eval_result_fn(results): """Log eval results.""" cur_step, eval_results = results if cur_step == _STOP: return epoch = cur_step // self.params["steps_per_epoch"] with tf.Graph().as_default(): summaries = [] for metric in eval_results: summaries.append( tf.Summary.Value(tag=metric, simple_value=eval_results[metric])) tf_summary = tf.Summary(value=list(summaries)) summary_writer.add_summary(tf_summary, cur_step) mlp_log.mlperf_print("eval_accuracy", eval_results["COCO/AP"], metadata={"epoch_num": epoch + 1}) mlp_log.mlperf_print("eval_stop", None, metadata={"epoch_num": epoch + 1}) if epoch in self.epoch_count: epoch_count = self.epoch_count[epoch] else: epoch_count = 1 mlp_log.mlperf_print("block_stop", None, metadata={ "first_epoch_num": epoch - epoch_count + 1, "epoch_count": epoch_count }) self.log_epochs[epoch] = True if eval_results["COCO/AP"] >= ssd_constants.EVAL_TARGET: self.run_success = True if epoch < self.success_epoch: self.success_epoch = epoch log_run_final = self.run_success for epoch in self.log_epochs: if epoch < self.success_epoch and not self.log_epochs[ epoch]: log_run_final = False break # Log run_final when all the previous eval results are logged. if log_run_final and not self.log_run_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"}) self.log_run_success = True tf.logging.info( "TrainAndEvalLowLevelRunner: train for %d steps in total", train_steps) if train_steps % self.iterations != 0: tf.logging.warning( "train_steps %d is not divisible by iterations_per_loop %d", train_steps, self.iterations) train_steps = self.iterations * int( math.ceil(train_steps / self.iterations)) # Start train and eval op on the background. def train_eval_thread_fn(sess, train_eval_op): sess.run([train_eval_op]) train_eval_thread = threading.Thread(target=train_eval_thread_fn, args=(self.sess, self.train_eval_op)) train_eval_thread.start() # pylint: disable=line-too-long q_in = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE) q_out = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE) processes = [ multiprocessing.Process(target=predict_post_processing, args=(q_in, q_out)) for _ in range(self.num_multiprocessing_workers) ] # pylint: enable=line-too-long time.sleep(self.sleep_seconds) mlp_log.mlperf_print("init_stop", None) mlp_log.mlperf_print("run_start", None) for p in processes: p.start() self.infeed_thread.start() def log_eval_results_fn(): result = q_out.get() cur_step, _ = result while cur_step != _STOP: log_eval_result_fn(result) result = q_out.get() cur_step, _ = result log_eval_result_thread = threading.Thread(target=log_eval_results_fn) log_eval_result_thread.start() cur_step = 0 current_epoch = 0 # Train and eval loop. while cur_step < train_steps: if self.run_success: break tf.logging.info("TrainAndEvalLowLevelRunner: start train step:%d", cur_step) cur_step += self.iterations current_epoch = cur_step // self.params["steps_per_epoch"] if self.run_success: break if self.params[ "eval_every_checkpoint"] or current_epoch in self.eval_epochs: if current_epoch in self.epoch_count: epoch_count = self.epoch_count[current_epoch] else: epoch_count = 1 mlp_log.mlperf_print("block_start", None, metadata={ "first_epoch_num": current_epoch - epoch_count + 1, "epoch_count": epoch_count }) mlp_log.mlperf_print("eval_start", None, metadata={"epoch_num": current_epoch + 1}) # Run predict on device. start = time.time() predictions = list(self.predict()) end = time.time() tf.logging.info( "TrainAndEvalRunner: step {} step time {} sec".format( cur_step, end - start)) # Run predict post processing. q_in.put((cur_step, predictions)) train_eval_thread.join() # Turn off predict thread. for _ in processes: q_in.put((_STOP, None)) for p in processes: p.join(timeout=self.sleep_seconds) q_out.put((_STOP, None)) log_eval_result_thread.join() # Clear out all the queues to avoid deadlock. while not q_out.empty(): log_eval_result_fn(q_out.get()) while not q_in.empty(): q_in.get() summary_writer.close() if not self.run_success: mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
def initialize(self, input_fn, eval_input_fn, model_fn, params): """Build graph and do initialization for training.""" tf.logging.info("TrainAndEvalLowLevelRunner: initialize method") mlp_log.mlperf_print("init_start", None) self.params = params self.build_enqueue_ops(input_fn, params, host_id=0) def infeed_thread_fn(): """Build and infeed session.run calls in a background thread.""" # Initialize dataset variables for i in range(self.max_train_iterations): tf.logging.info( "TrainAndEvalRunner: start infeed for %d steps", self.iterations) self.input_sess.run([self.enqueue_ops]) if self.params[ "eval_every_checkpoint"] or i in self.eval_iterations: self.input_sess.run(self.eval_dataset_initializer) self.input_sess.run([self.eval_enqueue_ops]) def tpu_train_step(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.feature_structure, values) features = unflattened_inputs["features"] labels = unflattened_inputs["labels"] estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN, params) loss, train_op = estimator_spec.loss, estimator_spec.train_op self.scaffold_fn = estimator_spec.scaffold_fn with tf.control_dependencies([train_op]): return tf.identity(loss) def train_loop(): return training_loop.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS]) # Start the build of the train graph. self.train_loop = train_loop for i in range(1, self.num_hosts): self.build_enqueue_ops(input_fn, params, host_id=i) # Init for eval. self.initialize_eval(eval_input_fn, model_fn, params) with self.graph.as_default(): if self.scaffold_fn: self.scaffold_fn() global_initializer = tf.global_variables_initializer() local_initializer = tf.local_variables_initializer() graph_io.write_graph(self.graph.as_graph_def(add_shapes=True), FLAGS.model_dir, "graph.pbtxt") # Build tpu train model session and initialize graph self.sess = tf.Session(self.master, graph=self.graph, config=self.session_config) self.input_sess = tf.Session(self.master, graph=self.input_graph, config=self.session_config) self.sess.run(global_initializer) self.sess.run(local_initializer) self.input_sess.run(self.dataset_initializer) self.input_sess.run(self.eval_dataset_initializer) # Complete infeed graph generation. self.infeed_thread = threading.Thread(target=infeed_thread_fn) # Compile. self.sess.run([self.train_eval_compile_op])