Ejemplo n.º 1
0
    def _save(self, step, session):
        """ Saves checkpoints.

        Args:
            step: A python integer, running step.
            session: A TensorFlow Session.
        """
        """Saves the latest checkpoint."""
        self._saver.save(session, self._save_path, global_step=step)
        tf.logging.info("Saving checkpoints for {} into {}".format(
            step, self._save_path))
        if self._summary_writer is not None:
            self._summary_writer.add_session_log(
                SessionLog(status=SessionLog.CHECKPOINT,
                           checkpoint_path=self._save_path), step)
Ejemplo n.º 2
0
  def _save(self, step, session):
    """Saves the latest checkpoint."""
    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)

    for l in self._listeners:
      l.before_save(session, step)

    self._get_saver().save(session, self._save_path, global_step=step)
    self._summary_writer.add_session_log(
        SessionLog(
            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
        step)

    for l in self._listeners:
      l.after_save(session, step)
Ejemplo n.º 3
0
    def _save(self, episode, session):
        """Saves the latest checkpoint."""
        logging.info("Saving checkpoints for episode {} into {}.".format(
            episode, self._save_path))

        for l in self._listeners:
            l.before_save(session, episode)

        self._get_saver().save(session, self._save_path, global_step=episode)
        self._summary_writer.add_session_log(
            SessionLog(status=SessionLog.CHECKPOINT,
                       checkpoint_path=self._save_path), episode)

        for l in self._listeners:
            l.after_save(session, episode)
Ejemplo n.º 4
0
 def _save(self, step, session):
   """Saves the latest checkpoint."""
   if step == self._last_saved_step:
     return
   logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
   self._last_saved_time = time.time()
   self._last_saved_step = step
   if self._saver is None:
     self._scaffold.saver.save(session, self._save_path, global_step=step)
   else:
     self._saver.save(session, self._save_path, global_step=step)
   self._summary_writer.add_session_log(
       SessionLog(
           status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
       step)
Ejemplo n.º 5
0
    def _save_fn():
      """Run the saver process."""
      logging.info("Saving checkpoints for %d into %s.", step, self._save_path)

      start_time = time.time()
      for l in self._listeners:
        l.before_save(session, step)

      self._get_saver().save(session, self._save_path, global_step=step)
      self._summary_writer.add_session_log(
          SessionLog(
              status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
          step)
      end_time = time.time()
      logging.info("Checkpoint actual writing time: (%.3f sec)",
                   end_time - start_time)
      logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
Ejemplo n.º 6
0
    def after_run(self, run_context, run_values):
        if not self._summary_writer:
            return

        stale_global_step = run_values.results['global_step']
        global_step = stale_global_step + 1
        if self._next_step is None or self._request_summary:
            global_step = run_context.session.run(self._global_step_tensor)

        if self._next_step is None:
            self._summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_step)

        if 'summary' in run_values.results:
            self._timer.update_last_triggered_step(global_step)
            for summary in run_values.results['summary']:
                self._summary_writer.add_summary(summary, global_step)

        self._next_step = global_step + 1
Ejemplo n.º 7
0
    def after_run(self, run_context, run_values):
        _ = run_context
        if not self._summary_writer:
            return

        global_step = run_values.results["global_step"]

        if self._next_step is None:
            self._summary_writer.add_session_log(
                SessionLog(status=SessionLog.START), global_step)

        if self._request_summary:
            self._timer.update_last_triggered_step(global_step)
            if "summary" in run_values.results:
                self._summary_writer.add_summary(run_values.results["summary"],
                                                 global_step)

        self._next_step = global_step + 1
Ejemplo n.º 8
0
    def after_run(self, run_context, run_values):
        _ = run_context
        if not self._summary_writer:
            return

        global_episode = run_values.results["global_episode"]

        if self._next_episode is None:
            self._next_episode = global_episode + 1
            self._summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_episode)

        if self._request_summary and self._timer.should_trigger_for_episode(global_episode):
            self._timer.update_last_triggered_episode(global_episode)
            self._next_episode = global_episode + 1
            if "summary" in run_values.results:
                for summary in run_values.results["summary"]:
                    self._summary_writer.add_summary(summary, global_episode)

        self._current_episode = global_episode
Ejemplo n.º 9
0
    def after_run(self, run_context, run_values):
        _ = run_context
        if not self._summary_writer:
            return

        global_step = run_values.results["global_step"]

        if self._last_saved_step is None:
            self._summary_writer.add_session_log(
                SessionLog(status=SessionLog.START), global_step)

        if self._request_summary:
            self._last_saved_step = global_step
            if "summary" in run_values.results:
                self._summary_writer.add_summary(run_values.results["summary"],
                                                 global_step)

        self._request_summary = (global_step >=
                                 self._last_saved_step + self._save_steps - 1)
  def testSessionLogSummaries(self):
    data = [
        {'session_log': SessionLog(status=SessionLog.START), 'step': 0},
        {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 1},
        {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 2},
        {'session_log': SessionLog(status=SessionLog.CHECKPOINT), 'step': 3},
        {'session_log': SessionLog(status=SessionLog.STOP), 'step': 4},
        {'session_log': SessionLog(status=SessionLog.START), 'step': 5},
        {'session_log': SessionLog(status=SessionLog.STOP), 'step': 6},
    ]

    self._WriteScalarSummaries(data)
    units = efi.get_inspection_units(self.logdir)
    self.assertEqual(1, len(units))
    printable = efi.get_dict_to_print(units[0].field_to_obs)
    self.assertEqual(printable['sessionlog:start']['steps'], [0, 5])
    self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6])
    self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
Ejemplo n.º 11
0
 def before_run(self, run_context):
     # For the first run, record a SessionLog.START at the pre-run global step.
     if self._current_step is None:
         self._current_step = run_context.session.run(
             self._global_step_tensor)
         with ops.default_session(run_context.session):
             self._summary_writer.add_session_log(
                 SessionLog(status=SessionLog.START), self._current_step)
     requests = {"global_step": self._global_step_tensor}
     self._request_summary = self._timer.should_trigger_for_step(
         self._current_step)
     if self._request_summary:
         self._timer.update_last_triggered_step(self._current_step)
         if self._get_summary_op() is not None:
             requests["summary"] = self._get_summary_op()
     feeds = {}
     if self._placeholder is not None and self._request_summary:
         feeds[self._placeholder] = self._request_summary
     args = SessionRunArgs(fetches=requests, feed_dict=feeds)
     return args
Ejemplo n.º 12
0
    def _save(self, session, step):
        """Saves the latest checkpoint, returns should_stop."""
        logging.info("Saving checkpoints for %d into %s.", step,
                     self._save_path)

        for l in self._listeners:
            l.before_save(session, step)

        self._get_saver().save(session, self._save_path, global_step=step)
        self._summary_writer.add_session_log(
            SessionLog(status=SessionLog.CHECKPOINT,
                       checkpoint_path=self._save_path), step)

        should_stop = False
        for l in self._listeners:
            if l.after_save(session, step):
                logging.info(
                    "A CheckpointSaverListener requested that training be stopped. "
                    "listener: {}".format(l))
                should_stop = True
        return should_stop
Ejemplo n.º 13
0
  def stop(self,
           threads=None,
           close_summary_writer=True,
           ignore_live_threads=False):
    """Stop the services and the coordinator.

    This does not close the session.

    Args:
      threads: Optional list of threads to join with the coordinator.  If
        `None`, defaults to the threads running the standard services, the
        threads started for `QueueRunners`, and the threads started by the
        `loop()` method.  To wait on additional threads, pass the list in this
        parameter.
      close_summary_writer: Whether to close the `summary_writer`.  Defaults to
        `True` if the summary writer was created by the supervisor, `False`
        otherwise.
      ignore_live_threads: If `True` ignores threads that remain running after a
        grace period when joining threads via the coordinator, instead of
        raising a RuntimeError.
    """
    self._coord.request_stop()
    try:
      # coord.join() re-raises the first reported exception; the "finally"
      # block ensures that we clean up whether or not an exception was
      # reported.
      self._coord.join(
          threads,
          stop_grace_period_secs=self._stop_grace_secs,
          ignore_live_threads=ignore_live_threads)
    finally:
      # Close the writer last, in case one of the running threads was using it.
      if close_summary_writer and self._summary_writer:
        # Stop messages are not logged with event.step,
        # since the session may have already terminated.
        self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
        self._summary_writer.close()
        self._graph_added_to_summary = False
Ejemplo n.º 14
0
  def testSessionLogStartMessageDiscardsExpiredEvents(self):
    """Test that SessionLog.START message discards expired events.

    This discard logic is preferred over the out-of-order step discard logic,
    but this logic can only be used for event protos which have the SessionLog
    enum, which was introduced to event.proto for file_version >= brain.Event:2.
    """
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2'))

    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    gen.AddScalar('s1', wall_time=1, step=200, value=20)
    gen.AddScalar('s1', wall_time=1, step=300, value=20)
    gen.AddScalar('s1', wall_time=1, step=400, value=20)

    gen.AddScalar('s2', wall_time=1, step=202, value=20)
    gen.AddScalar('s2', wall_time=1, step=203, value=20)

    slog = SessionLog(status=SessionLog.START)
    gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog))
    acc.Reload()
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200])
    self.assertEqual([x.step for x in acc.Scalars('s2')], [])
Ejemplo n.º 15
0
def train(graph,
          output_dir,
          train_op,
          loss_op,
          global_step_tensor=None,
          init_op=None,
          log_every_steps=10,
          supervisor_is_chief=True,
          supervisor_master='',
          supervisor_save_model_secs=600,
          supervisor_save_summaries_secs=10,
          max_steps=None,
          fail_on_nan_loss=True):
  """Train a model.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save a checkpoint every
      `supervisor_save_model_secs` seconds when training.
    supervisor_save_summaries_secs: Save summaries every
      `supervisor_save_summaries_secs` seconds when training.
    max_steps: Train until `global_step_tensor` evaluates to this value.
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `global_step_tensor` is not provided. See
        `tf.contrib.framework.get_global_step` for how we look it up if not
        provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
        evaluates to `NaN`.
  """
  global_step_tensor = contrib_variables.assert_or_get_global_step(
      graph, global_step_tensor)
  if global_step_tensor is None:
    raise ValueError('No "global_step" was provided or found in the graph.')
  supervisor, session = _prepare_session(
      graph=graph,
      output_dir=output_dir,
      start_services=True,
      global_step_tensor=global_step_tensor,
      init_op=init_op,
      supervisor_is_chief=supervisor_is_chief,
      supervisor_master=supervisor_master,
      supervisor_save_model_secs=supervisor_save_model_secs,
      supervisor_save_summaries_secs=supervisor_save_summaries_secs)

  with session:
    get_current_step = lambda: session.run(global_step_tensor)

    start_step = get_current_step()
    last_step = start_step
    last_log_step = start_step
    loss_value = None
    logging.info('Training steps [%d,%s)', last_step, 'inf'
                 if max_steps is None else str(max_steps))
    try:
      try:
        while not supervisor.ShouldStop() and (
            (max_steps is None) or (last_step < max_steps)):
          start_time = time.time()
          _, loss_value = session.run([train_op, loss_op])
          if np.isnan(loss_value):
            failure_message = 'Model diverged with loss = NaN.'
            if fail_on_nan_loss:
              logging.error(failure_message)
              raise NanLossDuringTrainingError()
            else:
              logging.warning(failure_message)

          this_step = get_current_step()

          if this_step <= last_step:
            logging.error(
                'Global step was not incremented by train op at step %s'
                ': new step %d' % (last_step, this_step))

          last_step = this_step
          is_last_step = (max_steps is not None) and (last_step >= max_steps)
          if is_last_step or (last_step - last_log_step >= log_every_steps):
            logging.info(
                'training step %d, loss = %.5f (%.3f sec/batch).',
                last_step, loss_value, float(time.time() - start_time))
            last_log_step = last_step
      finally:
        # Call supervisor.Stop() from within a try block because it re-raises
        # exceptions thrown by the supervised threads.
        supervisor.Stop(close_summary_writer=False)

        # Save one last checkpoint and summaries
        # TODO(wicke): This should be handled by Supervisor

        # In case we encountered an exception in the try block before we updated
        # last_step, update it here (again).
        last_step = get_current_step()
        if supervisor_is_chief:
          ckpt_path = supervisor.save_path
          logging.info('Saving checkpoint for step %d to checkpoint: %s.' % (
              last_step, ckpt_path))
          supervisor.saver.save(session, ckpt_path, global_step=last_step)
          if supervisor.summary_op is not None:
            summary_strs = session.run(supervisor.summary_op)
            supervisor.summary_writer.add_summary(summary_strs, last_step)
            supervisor.summary_writer.add_session_log(
                SessionLog(status=SessionLog.STOP), last_step)
            supervisor.summary_writer.close()
    # catch OutOfRangeError which is thrown when queue is out of data (and for
    # other reasons as well).
    except errors.OutOfRangeError as e:
      logging.warn('Got exception during tf.learn training loop possibly '
                   'due to exhausted input queue %s.', e)
    return loss_value
Ejemplo n.º 16
0
def start_training(config):
    if config.IS_DISTRIBUTION:
        import horovod.tensorflow as hvd
        # initialize Horovod.
        hvd.init()
        num_worker = hvd.size()
        rank = hvd.rank()
        # verify that MPI multi-threading is supported.
        assert hvd.mpi_threads_supported()
        # make sure MPI is not re-initialized.
        import mpi4py.rc
        mpi4py.rc.initialize = False
        # import mpi4py
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        # check size and rank are syncronized
        assert num_worker == comm.Get_size()
        assert rank == comm.Get_rank()
    else:
        num_worker = 1
        rank = 0

    ModelClass = config.NETWORK_CLASS
    network_kwargs = dict(
        (key.lower(), val) for key, val in config.NETWORK.items())
    if "train_validation_saving_size".upper() in config.DATASET.keys():
        use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0
    else:
        use_train_validation_saving = False

    if use_train_validation_saving:
        top_train_validation_saving_set_accuracy = 0

    train_dataset = setup_dataset(config, "train", rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    if use_train_validation_saving:
        train_validation_saving_dataset = setup_dataset(
            config, "train_validation_saving", rank)
        print("train_validation_saving dataset num:",
              train_validation_saving_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        elif ModelClass.__module__.startswith("lmnet.networks.segmentation"):
            model = ModelClass(
                classes=train_dataset.classes,
                label_colors=train_dataset.label_colors,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        global_step = tf.Variable(0, name="global_step", trainable=False)
        is_training_placeholder = tf.placeholder(
            tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholderes()

        output = model.inference(images_placeholder, is_training_placeholder)
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            loss = model.loss(output, labels_placeholder,
                              is_training_placeholder)
        else:
            loss = model.loss(output, labels_placeholder)
        opt = model.optimizer(global_step)
        if config.IS_DISTRIBUTION:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt, global_step)
        metrics_ops_dict, metrics_update_op = model.metrics(
            output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.summary.merge_all()

        metrics_summary_op, metrics_placeholders = executor.prepare_metrics(
            metrics_ops_dict)

        init_op = tf.global_variables_initializer()
        reset_metrics_op = tf.local_variables_initializer()
        if config.IS_DISTRIBUTION:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        if use_train_validation_saving:
            saver = tf.train.Saver(max_to_keep=1)
        else:
            saver = tf.train.Saver(max_to_keep=None)

        if config.IS_PRETRAIN:
            all_vars = tf.global_variables()
            pretrain_var_list = [
                var for var in all_vars
                if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [var.name for var in pretrain_var_list])
            pretrain_saver = tf.train.Saver(pretrain_var_list,
                                            name="pretrain_saver")

    if config.IS_DISTRIBUTION:
        # For distributed training
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, visible_device_list=str(hvd.local_rank())))
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.ConfigProto(
        )  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    if rank == 0:
        train_writer = tf.summary.FileWriter(
            environment.TENSORBOARD_DIR + "/train", sess.graph)
        if use_train_validation_saving:
            train_val_saving_writer = tf.summary.FileWriter(
                environment.TENSORBOARD_DIR + "/train_validation_saving")
        val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR +
                                           "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(
                sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))
            sess.run(tf.assign(global_step, 0))

        last_step = 0

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START),
                                         global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START),
                                       global_step=last_step + 1)
            print("recovered. last step", last_step)

    if config.IS_DISTRIBUTION:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)
        # calculate step per epoch for each nodes
        train_num_per_epoch = train_dataset.num_per_epoch
        num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker
        step_per_epoch = num_per_nodes // config.BATCH_SIZE
        begin_index = (train_num_per_epoch * rank) // num_worker
        end_index = begin_index + num_per_nodes

    last_step = sess.run(global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE *
                        config.MAX_EPOCHS)
    else:
        max_steps = config.MAX_STEPS
    print("max_steps: {}".format(max_steps))

    for step in range(last_step, max_steps):
        print("step", step)

        if config.IS_DISTRIBUTION:
            # scatter dataset
            if step % step_per_epoch == 0:
                indices = train_dataset.get_shuffle_index(
                ) if rank == 0 else None
                # broadcast shuffled indices
                indices = comm.bcast(indices, 0)
                feed_indices = indices[begin_index:end_index]
                # update each dataset by splited indices
                train_dataset.update_dataset(feed_indices)

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            # Runtime statistics for develop.
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()

            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op],
                feed_dict=feed_dict,
                # options=run_options,
                # run_metadata=run_metadata,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }

            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            train_writer.add_summary(metrics_summary, step + 1)
        else:
            sess.run([train_op], feed_dict=feed_dict)

        to_be_saved = step == 0 or (
            step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0

        if to_be_saved and rank == 0:
            if use_train_validation_saving:

                sess.run(reset_metrics_op)
                train_validation_saving_step_size = int(
                    math.ceil(train_validation_saving_dataset.num_per_epoch /
                              config.BATCH_SIZE))
                print("train_validation_saving_step_size",
                      train_validation_saving_step_size)

                current_train_validation_saving_set_accuracy = 0

                for train_validation_saving_step in range(
                        train_validation_saving_step_size):
                    print("train_validation_saving_step",
                          train_validation_saving_step)

                    images, labels = train_validation_saving_dataset.feed()
                    feed_dict = {
                        is_training_placeholder: False,
                        images_placeholder: images,
                        labels_placeholder: labels,
                    }

                    if train_validation_saving_step % config.SUMMARISE_STEPS == 0:
                        summary, _ = sess.run([summary_op, metrics_update_op],
                                              feed_dict=feed_dict)
                        train_val_saving_writer.add_summary(summary, step + 1)
                    else:
                        sess.run([metrics_update_op], feed_dict=feed_dict)

                metrics_values = sess.run(list(metrics_ops_dict.values()))
                metrics_feed_dict = {
                    placeholder: value
                    for placeholder, value in zip(metrics_placeholders,
                                                  metrics_values)
                }
                metrics_summary, = sess.run(
                    [metrics_summary_op],
                    feed_dict=metrics_feed_dict,
                )
                train_val_saving_writer.add_summary(metrics_summary, step + 1)

                current_train_validation_saving_set_accuracy = sess.run(
                    metrics_ops_dict["accuracy"])

                if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy:
                    top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy
                    print("New top train_validation_saving accuracy is: ",
                          top_train_validation_saving_set_accuracy)

                    _save_checkpoint(saver, sess, global_step, step)

            else:
                _save_checkpoint(saver, sess, global_step, step)

            if step == 0:
                # check create pb on only first step.
                minimal_graph = tf.graph_util.convert_variables_to_constants(
                    sess,
                    sess.graph.as_graph_def(add_shapes=True),
                    ["output"],
                )
                pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1)
                pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step +
                                                                        1)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pb_name,
                                     as_text=False)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pbtxt_name,
                                     as_text=True)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(
                math.ceil(validation_dataset.num_per_epoch /
                          config.BATCH_SIZE))
            print("test_step_size", test_step_size)

            for test_step in range(test_step_size):
                print("test_step", test_step)

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op],
                                          feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }
            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)

    # training loop end.
    print("reach max step")
Ejemplo n.º 17
0
def start_training(config):
    use_horovod = horovod_util.is_enabled()
    print("use_horovod:", use_horovod)
    if use_horovod:
        hvd = horovod_util.setup()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        rank = 0
        local_rank = -1

    ModelClass = config.NETWORK_CLASS
    network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()}

    train_dataset = setup_dataset(config, "train", rank, local_rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank, local_rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if config.TASK == Tasks.OBJECT_DETECTION:
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        is_training_placeholder = tf.compat.v1.placeholder(tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholders()

        output = model.inference(images_placeholder, is_training_placeholder)
        loss = model.loss(output, labels_placeholder)
        opt = model.optimizer()
        if use_horovod:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt)
        metrics_ops_dict, metrics_update_op = model.metrics(output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.compat.v1.summary.merge_all()
        metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict)

        init_op = tf.compat.v1.global_variables_initializer()
        reset_metrics_op = tf.compat.v1.local_variables_initializer()
        if use_horovod:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        saver = tf.compat.v1.train.Saver(max_to_keep=config.KEEP_CHECKPOINT_MAX)

        if config.IS_PRETRAIN:
            all_vars = tf.compat.v1.global_variables()
            pretrain_var_list = [
                var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [
                var.name for var in pretrain_var_list
            ])
            pretrain_saver = tf.compat.v1.train.Saver(pretrain_var_list, name="pretrain_saver")

    if use_horovod:
        # For distributed training
        session_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(
                allow_growth=True,
                visible_device_list=str(hvd.local_rank())
            )
        )
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.compat.v1.ConfigProto()  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.compat.v1.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])
    executor.save_pb_file(sess, environment.CHECKPOINTS_DIR)

    if rank == 0:
        train_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/train", sess.graph)
        val_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(model.global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1)
            print("recovered. last step", last_step)

    if use_horovod:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)

    last_step = sess.run(model.global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS)
    else:
        max_steps = config.MAX_STEPS

    progbar = Progbar(max_steps)
    if rank == 0:
        progbar.update(last_step)
    for step in range(last_step, max_steps):

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            # Runtime statistics for develop.
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()

            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op], feed_dict=feed_dict,
                # options=run_options,
                # run_metadata=run_metadata,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_summary = sess.run(metrics_summary_op)
            train_writer.add_summary(metrics_summary, step + 1)
            train_writer.flush()
        else:
            sess.run([train_op], feed_dict=feed_dict)

        to_be_saved = step == 0 or (step + 1) == max_steps or (step + 1) % config.SAVE_CHECKPOINT_STEPS == 0

        if to_be_saved and rank == 0:
            _save_checkpoint(saver, sess, model.global_step)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE))

            for test_step in range(test_step_size):

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                        val_writer.flush()
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_summary = sess.run(metrics_summary_op)
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)
                val_writer.flush()

        if rank == 0:
            progbar.update(step + 1)
    # training loop end.
    train_dataset.close()
    validation_dataset.close()
    print("Done")