def _save(self, step, session): """ Saves checkpoints. Args: step: A python integer, running step. session: A TensorFlow Session. """ """Saves the latest checkpoint.""" self._saver.save(session, self._save_path, global_step=step) tf.logging.info("Saving checkpoints for {} into {}".format( step, self._save_path)) if self._summary_writer is not None: self._summary_writer.add_session_log( SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step)
def _save(self, step, session): """Saves the latest checkpoint.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: l.before_save(session, step) self._get_saver().save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) for l in self._listeners: l.after_save(session, step)
def _save(self, episode, session): """Saves the latest checkpoint.""" logging.info("Saving checkpoints for episode {} into {}.".format( episode, self._save_path)) for l in self._listeners: l.before_save(session, episode) self._get_saver().save(session, self._save_path, global_step=episode) self._summary_writer.add_session_log( SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), episode) for l in self._listeners: l.after_save(session, episode)
def _save(self, step, session): """Saves the latest checkpoint.""" if step == self._last_saved_step: return logging.info("Saving checkpoints for %d into %s.", step, self._save_path) self._last_saved_time = time.time() self._last_saved_step = step if self._saver is None: self._scaffold.saver.save(session, self._save_path, global_step=step) else: self._saver.save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step)
def _save_fn(): """Run the saver process.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) start_time = time.time() for l in self._listeners: l.before_save(session, step) self._get_saver().save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) end_time = time.time() logging.info("Checkpoint actual writing time: (%.3f sec)", end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
def after_run(self, run_context, run_values): if not self._summary_writer: return stale_global_step = run_values.results['global_step'] global_step = stale_global_step + 1 if self._next_step is None or self._request_summary: global_step = run_context.session.run(self._global_step_tensor) if self._next_step is None: self._summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_step) if 'summary' in run_values.results: self._timer.update_last_triggered_step(global_step) for summary in run_values.results['summary']: self._summary_writer.add_summary(summary, global_step) self._next_step = global_step + 1
def after_run(self, run_context, run_values): _ = run_context if not self._summary_writer: return global_step = run_values.results["global_step"] if self._next_step is None: self._summary_writer.add_session_log( SessionLog(status=SessionLog.START), global_step) if self._request_summary: self._timer.update_last_triggered_step(global_step) if "summary" in run_values.results: self._summary_writer.add_summary(run_values.results["summary"], global_step) self._next_step = global_step + 1
def after_run(self, run_context, run_values): _ = run_context if not self._summary_writer: return global_episode = run_values.results["global_episode"] if self._next_episode is None: self._next_episode = global_episode + 1 self._summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_episode) if self._request_summary and self._timer.should_trigger_for_episode(global_episode): self._timer.update_last_triggered_episode(global_episode) self._next_episode = global_episode + 1 if "summary" in run_values.results: for summary in run_values.results["summary"]: self._summary_writer.add_summary(summary, global_episode) self._current_episode = global_episode
def after_run(self, run_context, run_values): _ = run_context if not self._summary_writer: return global_step = run_values.results["global_step"] if self._last_saved_step is None: self._summary_writer.add_session_log( SessionLog(status=SessionLog.START), global_step) if self._request_summary: self._last_saved_step = global_step if "summary" in run_values.results: self._summary_writer.add_summary(run_values.results["summary"], global_step) self._request_summary = (global_step >= self._last_saved_step + self._save_steps - 1)
def testSessionLogSummaries(self): data = [ {'session_log': SessionLog(status=SessionLog.START), 'step': 0}, {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 1}, {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 2}, {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 3}, {'session_log': SessionLog(status=SessionLog.STOP), 'step': 4}, {'session_log': SessionLog(status=SessionLog.START), 'step': 5}, {'session_log': SessionLog(status=SessionLog.STOP), 'step': 6}, ] self._WriteScalarSummaries(data) units = efi.get_inspection_units(self.logdir) self.assertEqual(1, len(units)) printable = efi.get_dict_to_print(units[0].field_to_obs) self.assertEqual(printable['sessionlog:start']['steps'], [0, 5]) self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6]) self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
def before_run(self, run_context): # For the first run, record a SessionLog.START at the pre-run global step. if self._current_step is None: self._current_step = run_context.session.run( self._global_step_tensor) with ops.default_session(run_context.session): self._summary_writer.add_session_log( SessionLog(status=SessionLog.START), self._current_step) requests = {"global_step": self._global_step_tensor} self._request_summary = self._timer.should_trigger_for_step( self._current_step) if self._request_summary: self._timer.update_last_triggered_step(self._current_step) if self._get_summary_op() is not None: requests["summary"] = self._get_summary_op() feeds = {} if self._placeholder is not None and self._request_summary: feeds[self._placeholder] = self._request_summary args = SessionRunArgs(fetches=requests, feed_dict=feeds) return args
def _save(self, session, step): """Saves the latest checkpoint, returns should_stop.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: l.before_save(session, step) self._get_saver().save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) should_stop = False for l in self._listeners: if l.after_save(session, step): logging.info( "A CheckpointSaverListener requested that training be stopped. " "listener: {}".format(l)) should_stop = True return should_stop
def stop(self, threads=None, close_summary_writer=True, ignore_live_threads=False): """Stop the services and the coordinator. This does not close the session. Args: threads: Optional list of threads to join with the coordinator. If `None`, defaults to the threads running the standard services, the threads started for `QueueRunners`, and the threads started by the `loop()` method. To wait on additional threads, pass the list in this parameter. close_summary_writer: Whether to close the `summary_writer`. Defaults to `True` if the summary writer was created by the supervisor, `False` otherwise. ignore_live_threads: If `True` ignores threads that remain running after a grace period when joining threads via the coordinator, instead of raising a RuntimeError. """ self._coord.request_stop() try: # coord.join() re-raises the first reported exception; the "finally" # block ensures that we clean up whether or not an exception was # reported. self._coord.join( threads, stop_grace_period_secs=self._stop_grace_secs, ignore_live_threads=ignore_live_threads) finally: # Close the writer last, in case one of the running threads was using it. if close_summary_writer and self._summary_writer: # Stop messages are not logged with event.step, # since the session may have already terminated. self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP)) self._summary_writer.close() self._graph_added_to_summary = False
def testSessionLogStartMessageDiscardsExpiredEvents(self): """Test that SessionLog.START message discards expired events. This discard logic is preferred over the out-of-order step discard logic, but this logic can only be used for event protos which have the SessionLog enum, which was introduced to event.proto for file_version >= brain.Event:2. """ gen = _EventGenerator() acc = ea.EventAccumulator(gen) gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2')) gen.AddScalar('s1', wall_time=1, step=100, value=20) gen.AddScalar('s1', wall_time=1, step=200, value=20) gen.AddScalar('s1', wall_time=1, step=300, value=20) gen.AddScalar('s1', wall_time=1, step=400, value=20) gen.AddScalar('s2', wall_time=1, step=202, value=20) gen.AddScalar('s2', wall_time=1, step=203, value=20) slog = SessionLog(status=SessionLog.START) gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog)) acc.Reload() self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200]) self.assertEqual([x.step for x in acc.Scalars('s2')], [])
def train(graph, output_dir, train_op, loss_op, global_step_tensor=None, init_op=None, log_every_steps=10, supervisor_is_chief=True, supervisor_master='', supervisor_save_model_secs=600, supervisor_save_summaries_secs=10, max_steps=None, fail_on_nan_loss=True): """Train a model. Given `graph`, a directory to write outputs to (`output_dir`), and some ops, run a training loop. The given `train_op` performs one step of training on the model. The `loss_op` represents the objective function of the training. It is expected to increment the `global_step_tensor`, a scalar integer tensor counting training steps. This function uses `Supervisor` to initialize the graph (from a checkpoint if one is available in `output_dir`), write summaries defined in the graph, and write regular checkpoints as defined by `supervisor_save_model_secs`. Training continues until `global_step_tensor` evaluates to `max_steps`, or, if `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the program is terminated with exit code 1. Args: graph: A graph to train. It is expected that this graph is not in use elsewhere. output_dir: A directory to write outputs to. train_op: An op that performs one training step when run. loss_op: A scalar loss tensor. global_step_tensor: A tensor representing the global step. If none is given, one is extracted from the graph using the same logic as in `Supervisor`. init_op: An op that initializes the graph. If `None`, use `Supervisor`'s default. log_every_steps: Output logs regularly. The logs contain timing data and the current loss. supervisor_is_chief: Whether the current process is the chief supervisor in charge of restoring the model and running standard services. supervisor_master: The master string to use when preparing the session. supervisor_save_model_secs: Save a checkpoint every `supervisor_save_model_secs` seconds when training. supervisor_save_summaries_secs: Save summaries every `supervisor_save_summaries_secs` seconds when training. max_steps: Train until `global_step_tensor` evaluates to this value. fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op` evaluates to `NaN`. If false, continue training as if nothing happened. Returns: The final loss value. Raises: ValueError: If `global_step_tensor` is not provided. See `tf.contrib.framework.get_global_step` for how we look it up if not provided explicitly. NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever evaluates to `NaN`. """ global_step_tensor = contrib_variables.assert_or_get_global_step( graph, global_step_tensor) if global_step_tensor is None: raise ValueError('No "global_step" was provided or found in the graph.') supervisor, session = _prepare_session( graph=graph, output_dir=output_dir, start_services=True, global_step_tensor=global_step_tensor, init_op=init_op, supervisor_is_chief=supervisor_is_chief, supervisor_master=supervisor_master, supervisor_save_model_secs=supervisor_save_model_secs, supervisor_save_summaries_secs=supervisor_save_summaries_secs) with session: get_current_step = lambda: session.run(global_step_tensor) start_step = get_current_step() last_step = start_step last_log_step = start_step loss_value = None logging.info('Training steps [%d,%s)', last_step, 'inf' if max_steps is None else str(max_steps)) try: try: while not supervisor.ShouldStop() and ( (max_steps is None) or (last_step < max_steps)): start_time = time.time() _, loss_value = session.run([train_op, loss_op]) if np.isnan(loss_value): failure_message = 'Model diverged with loss = NaN.' if fail_on_nan_loss: logging.error(failure_message) raise NanLossDuringTrainingError() else: logging.warning(failure_message) this_step = get_current_step() if this_step <= last_step: logging.error( 'Global step was not incremented by train op at step %s' ': new step %d' % (last_step, this_step)) last_step = this_step is_last_step = (max_steps is not None) and (last_step >= max_steps) if is_last_step or (last_step - last_log_step >= log_every_steps): logging.info( 'training step %d, loss = %.5f (%.3f sec/batch).', last_step, loss_value, float(time.time() - start_time)) last_log_step = last_step finally: # Call supervisor.Stop() from within a try block because it re-raises # exceptions thrown by the supervised threads. supervisor.Stop(close_summary_writer=False) # Save one last checkpoint and summaries # TODO(wicke): This should be handled by Supervisor # In case we encountered an exception in the try block before we updated # last_step, update it here (again). last_step = get_current_step() if supervisor_is_chief: ckpt_path = supervisor.save_path logging.info('Saving checkpoint for step %d to checkpoint: %s.' % ( last_step, ckpt_path)) supervisor.saver.save(session, ckpt_path, global_step=last_step) if supervisor.summary_op is not None: summary_strs = session.run(supervisor.summary_op) supervisor.summary_writer.add_summary(summary_strs, last_step) supervisor.summary_writer.add_session_log( SessionLog(status=SessionLog.STOP), last_step) supervisor.summary_writer.close() # catch OutOfRangeError which is thrown when queue is out of data (and for # other reasons as well). except errors.OutOfRangeError as e: logging.warn('Got exception during tf.learn training loop possibly ' 'due to exhausted input queue %s.', e) return loss_value
def start_training(config): if config.IS_DISTRIBUTION: import horovod.tensorflow as hvd # initialize Horovod. hvd.init() num_worker = hvd.size() rank = hvd.rank() # verify that MPI multi-threading is supported. assert hvd.mpi_threads_supported() # make sure MPI is not re-initialized. import mpi4py.rc mpi4py.rc.initialize = False # import mpi4py from mpi4py import MPI comm = MPI.COMM_WORLD # check size and rank are syncronized assert num_worker == comm.Get_size() assert rank == comm.Get_rank() else: num_worker = 1 rank = 0 ModelClass = config.NETWORK_CLASS network_kwargs = dict( (key.lower(), val) for key, val in config.NETWORK.items()) if "train_validation_saving_size".upper() in config.DATASET.keys(): use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0 else: use_train_validation_saving = False if use_train_validation_saving: top_train_validation_saving_set_accuracy = 0 train_dataset = setup_dataset(config, "train", rank) print("train dataset num:", train_dataset.num_per_epoch) if use_train_validation_saving: train_validation_saving_dataset = setup_dataset( config, "train_validation_saving", rank) print("train_validation_saving dataset num:", train_validation_saving_dataset.num_per_epoch) validation_dataset = setup_dataset(config, "validation", rank) print("validation dataset num:", validation_dataset.num_per_epoch) graph = tf.Graph() with graph.as_default(): if ModelClass.__module__.startswith("lmnet.networks.object_detection"): model = ModelClass( classes=train_dataset.classes, num_max_boxes=train_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) elif ModelClass.__module__.startswith("lmnet.networks.segmentation"): model = ModelClass( classes=train_dataset.classes, label_colors=train_dataset.label_colors, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=train_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) global_step = tf.Variable(0, name="global_step", trainable=False) is_training_placeholder = tf.placeholder( tf.bool, name="is_training_placeholder") images_placeholder, labels_placeholder = model.placeholderes() output = model.inference(images_placeholder, is_training_placeholder) if ModelClass.__module__.startswith("lmnet.networks.object_detection"): loss = model.loss(output, labels_placeholder, is_training_placeholder) else: loss = model.loss(output, labels_placeholder) opt = model.optimizer(global_step) if config.IS_DISTRIBUTION: # add Horovod Distributed Optimizer opt = hvd.DistributedOptimizer(opt) train_op = model.train(loss, opt, global_step) metrics_ops_dict, metrics_update_op = model.metrics( output, labels_placeholder) # TODO(wakisaka): Deal with many networks. model.summary(output, labels_placeholder) summary_op = tf.summary.merge_all() metrics_summary_op, metrics_placeholders = executor.prepare_metrics( metrics_ops_dict) init_op = tf.global_variables_initializer() reset_metrics_op = tf.local_variables_initializer() if config.IS_DISTRIBUTION: # add Horovod broadcasting variables from rank 0 to all bcast_global_variables_op = hvd.broadcast_global_variables(0) if use_train_validation_saving: saver = tf.train.Saver(max_to_keep=1) else: saver = tf.train.Saver(max_to_keep=None) if config.IS_PRETRAIN: all_vars = tf.global_variables() pretrain_var_list = [ var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS)) ] print("pretrain_vars", [var.name for var in pretrain_var_list]) pretrain_saver = tf.train.Saver(pretrain_var_list, name="pretrain_saver") if config.IS_DISTRIBUTION: # For distributed training session_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, visible_device_list=str(hvd.local_rank()))) else: # TODO(wakisaka): For debug. # session_config = tf.ConfigProto( # gpu_options=tf.GPUOptions( # allow_growth=True, # per_process_gpu_memory_fraction=0.1 # ) # ) session_config = tf.ConfigProto( ) # tf.ConfigProto(log_device_placement=True) # TODO(wakisaka): XLA JIT # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 sess = tf.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) if rank == 0: train_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train", sess.graph) if use_train_validation_saving: train_val_saving_writer = tf.summary.FileWriter( environment.TENSORBOARD_DIR + "/train_validation_saving") val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation") if config.IS_PRETRAIN: print("------- Load pretrain data ----------") pretrain_saver.restore( sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE)) sess.run(tf.assign(global_step, 0)) last_step = 0 # for recovery ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR) if ckpt and ckpt.model_checkpoint_path: print("--------- Restore last checkpoint -------------") saver.restore(sess, ckpt.model_checkpoint_path) # saver.recover_last_checkpoints(ckpt.model_checkpoint_path) last_step = sess.run(global_step) # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard. # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072 train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) print("recovered. last step", last_step) if config.IS_DISTRIBUTION: # broadcast variables from rank 0 to all other processes sess.run(bcast_global_variables_op) # calculate step per epoch for each nodes train_num_per_epoch = train_dataset.num_per_epoch num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker step_per_epoch = num_per_nodes // config.BATCH_SIZE begin_index = (train_num_per_epoch * rank) // num_worker end_index = begin_index + num_per_nodes last_step = sess.run(global_step) # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS. if "MAX_EPOCHS" in config: max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS) else: max_steps = config.MAX_STEPS print("max_steps: {}".format(max_steps)) for step in range(last_step, max_steps): print("step", step) if config.IS_DISTRIBUTION: # scatter dataset if step % step_per_epoch == 0: indices = train_dataset.get_shuffle_index( ) if rank == 0 else None # broadcast shuffled indices indices = comm.bcast(indices, 0) feed_indices = indices[begin_index:end_index] # update each dataset by splited indices train_dataset.update_dataset(feed_indices) images, labels = train_dataset.feed() feed_dict = { is_training_placeholder: True, images_placeholder: images, labels_placeholder: labels, } if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0: # Runtime statistics for develop. # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() sess.run(reset_metrics_op) _, summary, _ = sess.run( [train_op, summary_op, metrics_update_op], feed_dict=feed_dict, # options=run_options, # run_metadata=run_metadata, ) # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1)) train_writer.add_summary(summary, step + 1) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_writer.add_summary(metrics_summary, step + 1) else: sess.run([train_op], feed_dict=feed_dict) to_be_saved = step == 0 or ( step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0 if to_be_saved and rank == 0: if use_train_validation_saving: sess.run(reset_metrics_op) train_validation_saving_step_size = int( math.ceil(train_validation_saving_dataset.num_per_epoch / config.BATCH_SIZE)) print("train_validation_saving_step_size", train_validation_saving_step_size) current_train_validation_saving_set_accuracy = 0 for train_validation_saving_step in range( train_validation_saving_step_size): print("train_validation_saving_step", train_validation_saving_step) images, labels = train_validation_saving_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if train_validation_saving_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) train_val_saving_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) train_val_saving_writer.add_summary(metrics_summary, step + 1) current_train_validation_saving_set_accuracy = sess.run( metrics_ops_dict["accuracy"]) if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy: top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy print("New top train_validation_saving accuracy is: ", top_train_validation_saving_set_accuracy) _save_checkpoint(saver, sess, global_step, step) else: _save_checkpoint(saver, sess, global_step, step) if step == 0: # check create pb on only first step. minimal_graph = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), ["output"], ) pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1) pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step + 1) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pb_name, as_text=False) tf.train.write_graph(minimal_graph, environment.CHECKPOINTS_DIR, pbtxt_name, as_text=True) if step == 0 or (step + 1) % config.TEST_STEPS == 0: # init metrics values sess.run(reset_metrics_op) test_step_size = int( math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) print("test_step_size", test_step_size) for test_step in range(test_step_size): print("test_step", test_step) images, labels = validation_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if test_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) if rank == 0: val_writer.add_summary(summary, step + 1) else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_values = sess.run(list(metrics_ops_dict.values())) metrics_feed_dict = { placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values) } metrics_summary, = sess.run( [metrics_summary_op], feed_dict=metrics_feed_dict, ) if rank == 0: val_writer.add_summary(metrics_summary, step + 1) # training loop end. print("reach max step")
def start_training(config): use_horovod = horovod_util.is_enabled() print("use_horovod:", use_horovod) if use_horovod: hvd = horovod_util.setup() rank = hvd.rank() local_rank = hvd.local_rank() else: rank = 0 local_rank = -1 ModelClass = config.NETWORK_CLASS network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()} train_dataset = setup_dataset(config, "train", rank, local_rank) print("train dataset num:", train_dataset.num_per_epoch) validation_dataset = setup_dataset(config, "validation", rank, local_rank) print("validation dataset num:", validation_dataset.num_per_epoch) graph = tf.Graph() with graph.as_default(): if config.TASK == Tasks.OBJECT_DETECTION: model = ModelClass( classes=train_dataset.classes, num_max_boxes=train_dataset.num_max_boxes, is_debug=config.IS_DEBUG, **network_kwargs, ) else: model = ModelClass( classes=train_dataset.classes, is_debug=config.IS_DEBUG, **network_kwargs, ) is_training_placeholder = tf.compat.v1.placeholder(tf.bool, name="is_training_placeholder") images_placeholder, labels_placeholder = model.placeholders() output = model.inference(images_placeholder, is_training_placeholder) loss = model.loss(output, labels_placeholder) opt = model.optimizer() if use_horovod: # add Horovod Distributed Optimizer opt = hvd.DistributedOptimizer(opt) train_op = model.train(loss, opt) metrics_ops_dict, metrics_update_op = model.metrics(output, labels_placeholder) # TODO(wakisaka): Deal with many networks. model.summary(output, labels_placeholder) summary_op = tf.compat.v1.summary.merge_all() metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict) init_op = tf.compat.v1.global_variables_initializer() reset_metrics_op = tf.compat.v1.local_variables_initializer() if use_horovod: # add Horovod broadcasting variables from rank 0 to all bcast_global_variables_op = hvd.broadcast_global_variables(0) saver = tf.compat.v1.train.Saver(max_to_keep=config.KEEP_CHECKPOINT_MAX) if config.IS_PRETRAIN: all_vars = tf.compat.v1.global_variables() pretrain_var_list = [ var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS)) ] print("pretrain_vars", [ var.name for var in pretrain_var_list ]) pretrain_saver = tf.compat.v1.train.Saver(pretrain_var_list, name="pretrain_saver") if use_horovod: # For distributed training session_config = tf.compat.v1.ConfigProto( gpu_options=tf.compat.v1.GPUOptions( allow_growth=True, visible_device_list=str(hvd.local_rank()) ) ) else: # TODO(wakisaka): For debug. # session_config = tf.ConfigProto( # gpu_options=tf.GPUOptions( # allow_growth=True, # per_process_gpu_memory_fraction=0.1 # ) # ) session_config = tf.compat.v1.ConfigProto() # tf.ConfigProto(log_device_placement=True) # TODO(wakisaka): XLA JIT # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 sess = tf.compat.v1.Session(graph=graph, config=session_config) sess.run([init_op, reset_metrics_op]) executor.save_pb_file(sess, environment.CHECKPOINTS_DIR) if rank == 0: train_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/train", sess.graph) val_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation") if config.IS_PRETRAIN: print("------- Load pretrain data ----------") pretrain_saver.restore(sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE)) # for recovery ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR) if ckpt and ckpt.model_checkpoint_path: print("--------- Restore last checkpoint -------------") saver.restore(sess, ckpt.model_checkpoint_path) # saver.recover_last_checkpoints(ckpt.model_checkpoint_path) last_step = sess.run(model.global_step) # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard. # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072 train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1) print("recovered. last step", last_step) if use_horovod: # broadcast variables from rank 0 to all other processes sess.run(bcast_global_variables_op) last_step = sess.run(model.global_step) # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS. if "MAX_EPOCHS" in config: max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS) else: max_steps = config.MAX_STEPS progbar = Progbar(max_steps) if rank == 0: progbar.update(last_step) for step in range(last_step, max_steps): images, labels = train_dataset.feed() feed_dict = { is_training_placeholder: True, images_placeholder: images, labels_placeholder: labels, } if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0: # Runtime statistics for develop. # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # run_metadata = tf.RunMetadata() sess.run(reset_metrics_op) _, summary, _ = sess.run( [train_op, summary_op, metrics_update_op], feed_dict=feed_dict, # options=run_options, # run_metadata=run_metadata, ) # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1)) train_writer.add_summary(summary, step + 1) metrics_summary = sess.run(metrics_summary_op) train_writer.add_summary(metrics_summary, step + 1) train_writer.flush() else: sess.run([train_op], feed_dict=feed_dict) to_be_saved = step == 0 or (step + 1) == max_steps or (step + 1) % config.SAVE_CHECKPOINT_STEPS == 0 if to_be_saved and rank == 0: _save_checkpoint(saver, sess, model.global_step) if step == 0 or (step + 1) % config.TEST_STEPS == 0: # init metrics values sess.run(reset_metrics_op) test_step_size = int(math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE)) for test_step in range(test_step_size): images, labels = validation_dataset.feed() feed_dict = { is_training_placeholder: False, images_placeholder: images, labels_placeholder: labels, } if test_step % config.SUMMARISE_STEPS == 0: summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict) if rank == 0: val_writer.add_summary(summary, step + 1) val_writer.flush() else: sess.run([metrics_update_op], feed_dict=feed_dict) metrics_summary = sess.run(metrics_summary_op) if rank == 0: val_writer.add_summary(metrics_summary, step + 1) val_writer.flush() if rank == 0: progbar.update(step + 1) # training loop end. train_dataset.close() validation_dataset.close() print("Done")