Ejemplo n.º 1
0
 def get_specific_feed_dict(self, dataset, seq_idx):
     """
 :param Dataset.Dataset dataset:
 :param int seq_idx:
 :return: feed_dict for self.tf_session.run()
 :rtype: dict[str,numpy.ndarray]
 """
     # No Runner instance here but a very simplified version of Runner.run().
     # First we need a custom DataProvider with a custom BatchSetGenerator
     # which will yield only one single batch for the provided sequence idx.
     batch = Batch()
     batch.add_frames(seq_idx=seq_idx,
                      seq_start_frame=0,
                      length=dataset.get_seq_length(seq_idx))
     batch_generator = iter([batch])
     batches = BatchSetGenerator(dataset, generator=batch_generator)
     from TFDataPipeline import FeedDictDataProvider
     data_provider = FeedDictDataProvider(
         tf_session=self.tf_session,
         extern_data=self.network.extern_data,
         data_keys=self.network.used_data_keys,
         dataset=dataset,
         batches=batches)
     feed_dict = data_provider.get_feed_dict(single_threaded=True)
     return feed_dict
Ejemplo n.º 2
0
def test_Updater_simple_batch():
    with make_scope() as session:
        from TFNetwork import TFNetwork, ExternData
        from Config import Config
        from GeneratingDataset import Task12AXDataset
        dataset = Task12AXDataset()
        dataset.init_seq_order(epoch=1)
        extern_data = ExternData()
        extern_data.init_from_dataset(dataset)

        config = Config()
        network = TFNetwork(extern_data=extern_data, train_flag=True)
        network.construct_from_dict({
            "layer1": {
                "class": "linear",
                "activation": "tanh",
                "n_out": 13
            },
            "layer2": {
                "class": "linear",
                "activation": "tanh",
                "n_out": 13,
                "from": ["layer1"]
            },
            "output": {
                "class": "softmax",
                "loss": "ce",
                "target": "classes",
                "from": ["layer2"]
            }
        })
        network.initialize_params(session=session)

        updater = Updater(config=config, network=network)
        updater.set_learning_rate(1.0, session=session)
        updater.set_trainable_vars(network.get_trainable_params())
        updater.init_optimizer_vars(session=session)

        from TFDataPipeline import FeedDictDataProvider
        batches = dataset.generate_batches(
            recurrent_net=network.recurrent,
            batch_size=100,
            max_seqs=10,
            max_seq_length=sys.maxsize,
            used_data_keys=network.used_data_keys)
        data_provider = FeedDictDataProvider(tf_session=session,
                                             extern_data=extern_data,
                                             data_keys=network.used_data_keys,
                                             dataset=dataset,
                                             batches=batches)
        feed_dict, _ = data_provider.get_feed_dict(single_threaded=True)
        session.run(updater.get_optim_op(), feed_dict=feed_dict)
Ejemplo n.º 3
0
def test_Updater_multiple_optimizers_and_opts():
  with make_scope() as session:
    from TFNetwork import TFNetwork, ExternData
    from Config import Config
    from GeneratingDataset import Task12AXDataset
    dataset = Task12AXDataset()
    dataset.init_seq_order(epoch=1)
    extern_data = ExternData()
    extern_data.init_from_dataset(dataset)

    config = Config()
    network = TFNetwork(extern_data=extern_data, train_flag=True)
    network.construct_from_dict({
      "layer1": {"class": "linear", "activation": "tanh", "n_out": 13,
                 "updater_opts": {"optimizer": {"class": "Adam"}, "accum_grad_multiple_step": 2}},
      "layer2": {"class": "linear", "activation": "tanh", "n_out": 13, "from": ["layer1"],
                 "updater_opts": {
                   "optimizer": {"class": "Adagrad", "learning_rate_multiplier": 3}, "gradient_noise": 0.1}},
      "output": {"class": "softmax", "loss": "ce", "target": "classes", "from": ["layer2"]}
    })
    network.initialize_params(session=session)

    updater = Updater(config=config, network=network)
    updater.set_learning_rate(1.0, session=session)
    updater.set_trainable_vars(network.get_trainable_params())
    updater.init_optimizer_vars(session=session)

    optim_op = updater.get_optim_op()
    assert isinstance(updater.optimizer, WrapOptimizer)
    assert len(updater.optimizer.optimizers) == 3

    from TFDataPipeline import FeedDictDataProvider
    batches = dataset.generate_batches(
      recurrent_net=network.recurrent,
      batch_size=100,
      max_seqs=10,
      max_seq_length=sys.maxsize,
      used_data_keys=network.used_data_keys)
    data_provider = FeedDictDataProvider(
      tf_session=session, extern_data=extern_data,
      data_keys=network.used_data_keys,
      dataset=dataset, batches=batches)
    feed_dict, _ = data_provider.get_feed_dict(single_threaded=True)
    session.run(optim_op, feed_dict=feed_dict)
Ejemplo n.º 4
0
  def __init__(self, engine, dataset, batches, train, eval=True, extra_fetches=None, extra_fetches_callback=None):
    """
    :param Engine engine:
    :param Dataset.Dataset dataset:
    :param BatchSetGenerator batches:
    :param bool train: whether to do updates on the model
    :param bool eval: whether to evaluate (i.e. calculate loss/error)
    :param dict[str,tf.Tensor|TFUtil.Data|TFNetworkLayer.LayerBase]|None extra_fetches: additional fetches per step.
      `extra_fetches_callback` will be called with these. In case of Data/LayerBase, it will return a list,
      where each item corresponds to the batch-seq.
      It might also be useful to add `network.get_extern_data("seq_idx")` and `network.get_extern_data("seq_tag")`.
    :param (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None extra_fetches_callback: called if extra_fetches
    """
    from TFDataPipeline import FeedDictDataProvider, DataProviderBase
    self.engine = engine
    self.data_provider = FeedDictDataProvider(
      tf_session=engine.tf_session, extern_data=engine.network.extern_data,
      data_keys=engine.network.used_data_keys,
      dataset=dataset, batches=batches)
    assert isinstance(self.data_provider, DataProviderBase)
    self._should_train = train
    self._should_eval = eval
    self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0)
    self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0)
    self.finalized = False
    self.num_steps = None
    self.device_crash_batch = None  # type: int|None
    self.start_time = None
    self.elapsed = None
    self._results_accumulated = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.results = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.score = {}  # type: dict[str,float]  # entries like "cost:output"
    self.error = {}  # type: dict[str,float]  # entries like "error:output"
    self.stats = {}  # type: dict[str,float]  # entries like "stats:..."
    self.extra_fetches = extra_fetches
    if extra_fetches is not None:
      assert extra_fetches_callback
    self.extra_fetches_callback = extra_fetches_callback

    from Util import terminal_size
    terminal_width, _ = terminal_size()
    self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)
Ejemplo n.º 5
0
    def __init__(self, engine, dataset, batches, train, eval=True):
        """
    :param Engine engine:
    :param Dataset.Dataset dataset:
    :param BatchSetGenerator batches:
    :param bool train: whether to do updates on the model
    :param bool eval: whether to evaluate (i.e. calculate loss/error)
    """
        from TFDataPipeline import FeedDictDataProvider, DataProviderBase
        self.engine = engine
        self.data_provider = FeedDictDataProvider(
            tf_session=engine.tf_session,
            extern_data=engine.network.extern_data,
            data_keys=engine.network.used_data_keys,
            dataset=dataset,
            batches=batches)
        assert isinstance(self.data_provider, DataProviderBase)
        self._should_train = train
        self._should_eval = eval
        self.store_metadata_mod_step = engine.config.int(
            "store_metadata_mod_step", 0)
        self.reset_updater_vars_mod_step = engine.config.int(
            "reset_updater_vars_mod_step", 0)
        self.finalized = False
        self.num_steps = None
        self.device_crash_batch = None  # type: int|None
        self.start_time = None
        self.elapsed = None
        self._results_accumulated = {
        }  # type: dict[str,float]  # entries like "cost:output" or "loss"
        self.results = {
        }  # type: dict[str,float]  # entries like "cost:output" or "loss"
        self.score = {}  # type: dict[str,float]  # entries like "cost:output"
        self.error = {}  # type: dict[str,float]  # entries like "error:output"
        self.stats = {}  # type: dict[str,float]  # entries like "stats:..."

        from Util import terminal_size
        terminal_width, _ = terminal_size()
        self._show_interactive_process_bar = (log.verbose[3]
                                              and (not log.verbose[5])
                                              and terminal_width >= 0)
Ejemplo n.º 6
0
def test_DataProvider():
    """
  :param Dataset.Dataset dataset:
  :param int seq_idx:
  :param str|None output_layer_name: e.g. "output". if not set, will read from config "forward_output_layer"
  :return: numpy array, output in time major format (time,batch,dim)
  :rtype: numpy.ndarray
  """
    from GeneratingDataset import DummyDataset
    seq_len = 5
    n_data_dim = 2
    n_classes_dim = 3
    dataset = DummyDataset(input_dim=n_data_dim,
                           output_dim=n_classes_dim,
                           num_seqs=2,
                           seq_len=seq_len)
    dataset.init_seq_order(epoch=1)

    extern_data = ExternData()
    extern_data.init_from_dataset(dataset)

    # No Runner instance here but a very simplified version of Runner.run().
    # First we need a custom DataProvider with a custom BatchSetGenerator
    # which will yield only one single batch for the provided sequence idx.
    seq_idx = 0
    n_batch = 1
    batch = Batch()
    batch.add_frames(seq_idx=seq_idx,
                     seq_start_frame=0,
                     length=dataset.get_seq_length(seq_idx))
    batch_generator = iter([batch])
    batches = BatchSetGenerator(dataset, generator=batch_generator)
    from TFDataPipeline import FeedDictDataProvider
    data_provider = FeedDictDataProvider(tf_session=session,
                                         extern_data=extern_data,
                                         data_keys=["data", "classes"],
                                         dataset=dataset,
                                         batches=batches)

    feed_dict = data_provider.get_feed_dict(single_threaded=True)
    print(feed_dict)
    assert_is_instance(feed_dict, dict)
    assert extern_data.data["data"].placeholder in feed_dict
    assert extern_data.data["data"].size_placeholder[0] in feed_dict
    assert extern_data.data["classes"].placeholder in feed_dict
    assert extern_data.data["classes"].size_placeholder[0] in feed_dict
    data = feed_dict[extern_data.data["data"].placeholder]
    data_size = feed_dict[extern_data.data["data"].size_placeholder[0]]
    classes = feed_dict[extern_data.data["classes"].placeholder]
    classes_size = feed_dict[extern_data.data["classes"].size_placeholder[0]]
    assert_is_instance(data, numpy.ndarray)
    assert_is_instance(data_size, numpy.ndarray)
    assert_is_instance(classes, numpy.ndarray)
    assert_is_instance(classes_size, numpy.ndarray)
    assert_equal(data.shape, (n_batch, seq_len, n_data_dim))
    assert_equal(data_size.shape, (n_batch, ))
    assert_equal(classes.shape, (n_batch, seq_len))
    assert_equal(classes_size.shape, (n_batch, ))
    assert_equal(list(data_size), [seq_len])
    assert_equal(list(classes_size), [seq_len])
    numpy.testing.assert_almost_equal(list(data[0, 0]), [-0.5, -0.4])
    numpy.testing.assert_almost_equal(list(data[0, -1]), [0.3, 0.4])
    assert_equal(classes.tolist(), [[1, 2, 0, 1, 2]])
Ejemplo n.º 7
0
class Runner(object):
  def __init__(self, engine, dataset, batches, train, eval=True, extra_fetches=None, extra_fetches_callback=None):
    """
    :param Engine engine:
    :param Dataset.Dataset dataset:
    :param BatchSetGenerator batches:
    :param bool train: whether to do updates on the model
    :param bool eval: whether to evaluate (i.e. calculate loss/error)
    :param dict[str,tf.Tensor|TFUtil.Data|TFNetworkLayer.LayerBase]|None extra_fetches: additional fetches per step.
      `extra_fetches_callback` will be called with these. In case of Data/LayerBase, it will return a list,
      where each item corresponds to the batch-seq.
      It might also be useful to add `network.get_extern_data("seq_idx")` and `network.get_extern_data("seq_tag")`.
    :param (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None extra_fetches_callback: called if extra_fetches
    """
    from TFDataPipeline import FeedDictDataProvider, DataProviderBase
    self.engine = engine
    self.data_provider = FeedDictDataProvider(
      tf_session=engine.tf_session, extern_data=engine.network.extern_data,
      data_keys=engine.network.used_data_keys,
      dataset=dataset, batches=batches)
    assert isinstance(self.data_provider, DataProviderBase)
    self._should_train = train
    self._should_eval = eval
    self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0)
    self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0)
    self.finalized = False
    self.num_steps = None
    self.device_crash_batch = None  # type: int|None
    self.start_time = None
    self.elapsed = None
    self._results_accumulated = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.results = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.score = {}  # type: dict[str,float]  # entries like "cost:output"
    self.error = {}  # type: dict[str,float]  # entries like "error:output"
    self.stats = {}  # type: dict[str,float]  # entries like "stats:..."
    self.extra_fetches = extra_fetches
    if extra_fetches is not None:
      assert extra_fetches_callback
    self.extra_fetches_callback = extra_fetches_callback

    from Util import terminal_size
    terminal_width, _ = terminal_size()
    self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)

  def _get_fetches_dict(self):
    """
    :return: values and actions which should be calculated and executed in self.run() by the TF session for each step
    :rtype: dict[str,tf.Tensor|tf.Operation]
    """
    # Note that it is important that we do not recreate graph nodes for every call to this function.
    # Thus everything which we access here should be cached.
    d = {}
    for key in self.data_provider.data_keys:
      data = self.data_provider.extern_data.get_data(key)
      for dim, v in data.size_placeholder.items():
        d["size:%s:%i" % (key, dim)] = v
    if self._should_train or self._should_eval:
      # These values are cached internally and the graph nodes are created on the first call.
      loss = self.engine.network.get_objective()
      if loss is 0:
        loss = self.engine.get_const_tensor(key="zero_loss", value=0.0)
      d["loss"] = loss
      for layer_name, loss in self.engine.network.loss_by_layer.items():
        d["cost:%s" % layer_name] = loss
      for layer_name, error in self.engine.network.error_by_layer.items():
        d["error:%s" % layer_name] = error
    for layer in self.engine.network.layers.values():
      for k, v in layer.stats.items():
        d["stats:%s:%s" % (layer.name, k)] = v
    if self._should_train:
      assert self.engine.updater
      def callback_on_new():
        # Force a new check.
        self.engine._checked_uninitialized_vars = False
      d["optim_op"] = self.engine.updater.get_optim_op(callback_on_new=callback_on_new)
    if self.extra_fetches is not None:
      from TFNetworkLayer import LayerBase
      from TFUtil import Data
      for k, v in self.extra_fetches.items():
        if isinstance(v, tf.Tensor):
          d["extra:%s" % k] = v
          continue
        if isinstance(v, LayerBase):
          v = v.output
        assert isinstance(v, Data)
        d["extra:%s" % k] = v.placeholder
        for i, s in v.size_placeholder.items():
          d["extra:%s:size_%i" % (k, i)] = s
    d["summary"] = self.engine.get_all_merged_summaries()
    return d

  def _print_process(self, report_prefix, step, step_duration, eval_info):
    if not self._show_interactive_process_bar and not log.v[5]:
      return
    start_elapsed = time.time() - self.start_time
    complete = self.data_provider.get_complete_frac()
    assert complete > 0
    total_time_estimated = start_elapsed / complete
    remaining_estimated = total_time_estimated - start_elapsed
    if log.verbose[5]:
      info = [
        report_prefix,
        "step %i" % step]
      if eval_info:  # Such as score.
        info += ["%s %s" % item for item in sorted(eval_info.items())]
      info += [
        "%.3f sec/step" % step_duration,
        "elapsed %s" % hms(start_elapsed),
        "exp. remaining %s" % hms(remaining_estimated),
        "complete %.02f%%" % (complete * 100)]
      print(", ".join(filter(None, info)), file=log.v5)
    elif self._show_interactive_process_bar:
      from Util import progress_bar
      progress_bar(complete, hms(remaining_estimated))

  def _print_finish_process(self):
    if self._show_interactive_process_bar:
      from Util import progress_bar
      progress_bar()

  def _get_target_for_key(self, key):
    """
    :param str key: e.g. "cost:output" where the last part is the layer name. or "loss"
    :return: target name which is the data-key in the dataset, e.g. "classes"
    :rtype: str
    """
    if ":" in key:
      layer = self.engine.network.layers[key.split(':')[-1]]
      if layer.target:
        return layer.target
    return self.engine.network.get_default_target()  # e.g. "classes"

  def _epoch_norm_factor_for_result(self, key):
    """
    :param str key: e.g. "cost:output"
    :return: factor to multiply with such accumulated values for the final epoch stats
    :rtype: float
    """
    target = self._get_target_for_key(key)
    # Default: Normalize by number of frames.
    return 1.0 / float(self.data_provider.get_num_frames()[target])

  def _finalize(self, num_steps):
    """
    Called at the end of an epoch.
    :param int num_steps: number of steps we did for this epoch
    """
    assert not self.data_provider.have_more_data(session=self.engine.tf_session)
    assert self.data_provider.get_num_frames()["data"] > 0
    results = {key: value * self._epoch_norm_factor_for_result(key)
               for (key, value) in self._results_accumulated.items()}
    self.results = results
    self.score = dict([(key,value) for (key, value) in results.items() if key.startswith("cost:")])
    self.error = dict([(key,value) for (key, value) in results.items() if key.startswith("error:")])
    self.num_steps = num_steps
    self.finalized = True

  def _step_seq_len(self, fetches_results, data_key):
    """
    :param dict[str,numpy.ndarray|None] fetches_results: results of calculations, see self._get_fetches_dict()
    :param str data_key: e.g. "classes"
    :return: the seq length of this batch
    :rtype: int
    """
    num_frames = numpy.sum(fetches_results["size:%s:0" % data_key])
    return num_frames

  def _collect_eval_info(self, fetches_results):
    """
    :param dict[str,numpy.ndarray|None] fetches_results: results of calculations, see self._get_fetches_dict()
    :return: dict for printing the step stats, see self._print_process(), e.g. {"cost:output": 2.3}
    :rtype: dict[str,float]
    """
    # See see self._get_fetches_dict() for the keys.
    keys = [k for k in fetches_results.keys() if k.startswith("cost:") or k.startswith("error:") or k == "loss"]

    # Accumulate for epoch stats.
    for key in keys:
      value = fetches_results[key]
      if key not in self._results_accumulated:
        self._results_accumulated[key] = value
      else:
        self._results_accumulated[key] += value

    # Prepare eval info stats for this batch run.
    eval_info = {}
    for key in keys:
      value = fetches_results[key]
      target = self._get_target_for_key(key)
      if value:
        value /= float(self._step_seq_len(fetches_results=fetches_results, data_key=target))
      eval_info[key] = value

    # Add raw stats.
    for k, v in fetches_results.items():
      if k.startswith("stats:"):
        if v.ndim == 1:
          v = list(v)  # looks nicer in logs
        eval_info[k] = v
        self.stats[k] = v  # Always just store latest value.

    return eval_info

  def _maybe_handle_extra_fetches(self, fetches_results):
    """
    :param dict[str,numpy.ndarray|str] fetches_results: results of calculations, see self._get_fetches_dict()
    """
    if self.extra_fetches is None:
      return
    d = {}
    from TFNetworkLayer import LayerBase
    from TFUtil import Data
    for k, v in self.extra_fetches.items():
      r = fetches_results["extra:%s" % k]
      if isinstance(v, tf.Tensor):
        d[k] = r
        continue
      if isinstance(v, LayerBase):
        v = v.output
      assert isinstance(v, Data)
      if v.batch_dim_axis != 0:
        r = numpy.moveaxis(r, v.batch_dim_axis, 0)
      if v.have_time_axis():
        assert v.time_dim_axis_excluding_batch == 0
        assert list(v.size_placeholder.keys()) == [0]
        seq_lens = fetches_results["extra:%s:size_0" % k]  # shape: (batch,)
        assert seq_lens.shape == (r.shape[0],)
        d[k] = [r[i, :seq_lens[i]] for i in range(seq_lens.shape[0])]
      else:
        d[k] = list(r)
    self.extra_fetches_callback(**d)

  def run(self, report_prefix):
    """
    :param str report_prefix: prefix for logging
    """
    sess = self.engine.tf_session
    if self.engine.config.has("tf_log_dir"):
      logdir = self.engine.config.value("tf_log_dir", None)
    else:
      logdir = os.path.dirname(self.engine.model_filename) or os.getcwd()
    if logdir:
      logdir += "/%s" % self.data_provider.get_dataset_name()
      if not self._should_train:  # like eval
        logdir += "-%i" % self.engine.epoch
      if self.engine.use_search_flag:
        logdir += "-search"
      writer = tf.summary.FileWriter(logdir)
      writer.add_graph(sess.graph)
    else:
      writer = None
    run_metadata = tf.RunMetadata()
    debug_shell_in_runner = self.engine.config.bool("debug_shell_in_runner", False)
    debug_shell_in_runner_step = self.engine.config.int("debug_shell_in_runner_step", 1)

    # Not sure if this is the best thing to do for an evaluation but it's ok for now.
    # We could also set it to 0 for non train epochs.
    step_offset = self.engine.network.get_global_train_step(session=sess)

    coord = self.data_provider.coord

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    self.data_provider.start_threads()
    self.start_time = time.time()
    step = None
    try:
      # step is like mini-batch in our usual terminology
      step = 0
      fetches_dict = self._get_fetches_dict()
      # After get_fetches_dict, maybe some new uninitialized vars. Last check.
      self.engine.check_uninitialized_vars()
      while self.data_provider.have_more_data(session=sess):
        feed_dict = self.data_provider.get_feed_dict()
        if isinstance(self.engine.network.train_flag, tf.Tensor):
          feed_dict[self.engine.network.train_flag] = self._should_train
        start_time = time.time()
        if self._should_train and self.reset_updater_vars_mod_step and step % self.reset_updater_vars_mod_step == 0:
          print("Reset updater vars in step %i." % step, file=log.v5)
          self.engine.updater.init_optimizer_vars()

        if debug_shell_in_runner and debug_shell_in_runner_step == step:
          print("debug_shell_in_runner, step %i" % step, file=log.v1)
          import Debug
          Debug.debug_shell(user_ns=locals(), user_global_ns=globals(), exit_afterwards=False)

        # Now do one calculation step. Optionally with metadata.
        if self.store_metadata_mod_step and step % self.store_metadata_mod_step == 0:
          # Slow run that stores extra information for debugging.
          print('Storing metadata', file=log.v5)
          run_options = tf.RunOptions(
            trace_level=tf.RunOptions.FULL_TRACE)
          # We could use tfdbg.add_debug_tensor_watch here.
          fetches_results = sess.run(
            fetches_dict,
            feed_dict=feed_dict,
            options=run_options,
            run_metadata=run_metadata)  # type: dict[str,numpy.ndarray|str]
          writer.add_summary(fetches_results["summary"], step + step_offset)
          writer.add_run_metadata(run_metadata, 'step_{:04d}'.format(step + step_offset))
          tl = timeline.Timeline(run_metadata.step_stats)
          timeline_path = os.path.join(logdir, 'timeline.trace')
          with open(timeline_path, 'w') as f:
            f.write(tl.generate_chrome_trace_format(show_memory=True))
        else:
          fetches_results = sess.run(fetches_dict, feed_dict=feed_dict)  # type: dict[str,numpy.ndarray|str]
          if writer:
            writer.add_summary(fetches_results["summary"], step + step_offset)

        eval_info = self._collect_eval_info(fetches_results=fetches_results)
        self._maybe_handle_extra_fetches(fetches_results)
        duration = time.time() - start_time
        self._print_process(report_prefix=report_prefix, step=step, step_duration=duration,
                            eval_info=eval_info)
        step += 1

      self._print_finish_process()

      if not self.data_provider.have_reached_end():
        raise Exception("Did not successfully reached the end of the dataset.")

      if self._should_train:
        final_global_train_step = self.engine.network.get_global_train_step(session=sess)
        assert step + step_offset == final_global_train_step

      self._finalize(num_steps=step)

    except KeyboardInterrupt:
      print("KeyboardInterrupt in step %r." % step)

    except BaseException as exc:
      print("Exception %r in step %r." % (exc, step), file=log.v1)
      sys.excepthook(*sys.exc_info())
      self.device_crash_batch = step

    finally:
      from Util import try_and_ignore_exception
      from TFUtil import stop_event_writer_thread
      if writer:
        try_and_ignore_exception(writer.close)
        try_and_ignore_exception(lambda: stop_event_writer_thread(writer.event_writer))
      try_and_ignore_exception(coord.request_stop)
      try_and_ignore_exception(lambda: coord.join(threads))
      try_and_ignore_exception(self.data_provider.stop_threads)
      self.elapsed = time.time() - self.start_time