示例#1
0
  def __init__(self, engine, dataset, batches, train, eval=True):
    """
    :param Engine engine:
    :param Dataset.Dataset dataset:
    :param BatchSetGenerator batches:
    :param bool train: whether to do updates on the model
    :param bool eval: whether to evaluate (i.e. calculate loss/error)
    """
    self.engine = engine
    self.data_provider = DataProvider(
      tf_session=engine.tf_session, extern_data=engine.network.extern_data,
      data_keys=engine.network.used_data_keys,
      dataset=dataset, batches=batches)
    self._should_train = train
    self._should_eval = eval
    self.store_metadata_mod_step = False  # 500
    self.finalized = False
    self.num_steps = None
    self.device_crash_batch = None
    self.start_time = None
    self.elapsed = None
    self._results_accumulated = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.results = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.score = {}  # type: dict[str,float]  # entries like "cost:output"
    self.error = {}  # type: dict[str,float]  # entries like "error:output"
    self.stats = {}  # type: dict[str,float]  # entries like "stats:..."

    from Util import terminal_size
    terminal_width, _ = terminal_size()
    self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)
示例#2
0
  def __init__(self, engine, dataset, batches, train, eval=True, extra_fetches=None, extra_fetches_callback=None):
    """
    :param Engine engine:
    :param Dataset.Dataset dataset:
    :param BatchSetGenerator batches:
    :param bool train: whether to do updates on the model
    :param bool eval: whether to evaluate (i.e. calculate loss/error)
    :param dict[str,tf.Tensor|TFUtil.Data|TFNetworkLayer.LayerBase]|None extra_fetches: additional fetches per step.
      `extra_fetches_callback` will be called with these. In case of Data/LayerBase, it will return a list,
      where each item corresponds to the batch-seq.
      It might also be useful to add `network.get_extern_data("seq_idx")` and `network.get_extern_data("seq_tag")`.
    :param (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None extra_fetches_callback: called if extra_fetches
    """
    from TFDataPipeline import FeedDictDataProvider, DataProviderBase
    self.engine = engine
    self.data_provider = FeedDictDataProvider(
      tf_session=engine.tf_session, extern_data=engine.network.extern_data,
      data_keys=engine.network.used_data_keys,
      dataset=dataset, batches=batches)
    assert isinstance(self.data_provider, DataProviderBase)
    self._should_train = train
    self._should_eval = eval
    self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0)
    self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0)
    self.finalized = False
    self.num_steps = None
    self.device_crash_batch = None  # type: int|None
    self.start_time = None
    self.elapsed = None
    self._results_accumulated = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.results = {}  # type: dict[str,float]  # entries like "cost:output" or "loss"
    self.score = {}  # type: dict[str,float]  # entries like "cost:output"
    self.error = {}  # type: dict[str,float]  # entries like "error:output"
    self.stats = {}  # type: dict[str,float]  # entries like "stats:..."
    self.extra_fetches = extra_fetches
    if extra_fetches is not None:
      assert extra_fetches_callback
    self.extra_fetches_callback = extra_fetches_callback

    from Util import terminal_size
    terminal_width, _ = terminal_size()
    self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)
示例#3
0
    def run_inner(self):
        self.start_time = time.time()
        for device in self.devices:
            device.prepare(epoch=self.epoch, **self.get_device_prepare_args())
        self.initialize()
        terminal_width, _ = terminal_size()
        self.interactive = (log.v[3] and terminal_width >= 0)
        print("starting task", self.task, file=log.v5)

        for device in self.devices:
            device.eval_batch_idx = -1
            device.start_epoch_stats()
            device.num_frames = 0
            device.num_updates = 0
            device.tot = 0

        num_device_runs = 1 if self.share_batches else len(self.devices)
        deviceRuns = [
            self.DeviceBatchRun(
                self,
                [self.devices[i]] if not self.share_batches else self.devices)
            for i in range(num_device_runs)
        ]

        results = {'batchess': [], 'results': [], 'num_frames': NumbersDict(0)}
        run_frames = NumbersDict(0)
        cost_result_format = -1

        crashed = False
        assert num_device_runs > 0

        while True:
            if getattr(sys, "exited", False):
                # This happens when we exit Python.
                # Without this check, this thread would keep running until all exit handlers of Python are done.
                print("%s stopped" % self, file=log.v5)
                crashed = True
                break

            for i in range(num_device_runs):
                if deviceRuns[i].crashed or not deviceRuns[i].is_alive():
                    crashed = True
                    break
                if deviceRuns[i].finished:
                    results['batchess'] += deviceRuns[i].result['batchess'][:]
                    results['results'] += deviceRuns[i].result['results'][:]
                    results['result_format'] = deviceRuns[i].result[
                        'result_format']
                    deviceRuns[i].finished = False
            if crashed:
                break

            if cost_result_format < 0 and deviceRuns[i].result['result_format']:
                for idx, fmt in enumerate(
                        deviceRuns[i].result['result_format']):
                    if fmt and fmt.startswith('cost:'):
                        cost_result_format = idx
            total_cost = 0
            if results['results'] and cost_result_format >= 0:
                total_cost = numpy.asarray(
                    results['results'])[:, cost_result_format].sum()
            if total_cost >= self.eval_batch_size or not self.batches.has_more(
            ):
                if all(not (dev.finished or dev.allocated or dev.processing)
                       for dev in deviceRuns):
                    results['num_frames'] = run_frames
                    self.num_frames += run_frames
                    if self.share_batches: run_frames *= len(self.devices)
                    self.reduce(run_frames)
                    self.eval_batch_idx += 1
                    run_frames = NumbersDict(0)
                    results['batchess'] = []
                    results['results'] = []
                    for device in self.devices:
                        device.num_frames = 0
                        device.num_updates = 0
                    if not self.batches.has_more():
                        break
                else:
                    time.sleep(0.01)

            match = True
            while self.batches.has_more(
            ) and total_cost < self.eval_batch_size and match:
                self.batch_idx = self.batches.get_current_batch_idx()
                if self.batch_idx < self.start_batch:
                    self.batches.advance(1)
                    break
                match = False
                for i in range(num_device_runs):
                    if not deviceRuns[i].allocated:
                        deviceRuns[i].allocate()
                        run_frames += deviceRuns[i].run_frames
                        match = True
                        break
            if not match:
                time.sleep(0.01)

        for run in deviceRuns:
            run.stop()
        if crashed: return
        for device in self.devices:
            device.finish_epoch_stats()
        self.finalize()
        if self.interactive: progress_bar()
        self.elapsed = (time.time() - self.start_time)
示例#4
0
    def run_inner(self):
      self.start_time = time.time()
      for device in self.devices:
        device.prepare(epoch=self.epoch, **self.get_device_prepare_args())
      self.initialize()
      terminal_width, _ = terminal_size()
      self.interactive = (log.v[3] and terminal_width >= 0)
      print >> log.v5, "starting task", self.task

      for device in self.devices:
        device.eval_batch_idx = -1
        device.start_epoch_stats()
        device.num_frames = 0
        device.num_updates = 0
        device.tot = 0

      num_device_runs = 1 if self.share_batches else len(self.devices)
      deviceRuns = [ self.DeviceBatchRun(self, [self.devices[i]] if not self.share_batches else self.devices) for i in range(num_device_runs) ]

      results = { 'batchess': [], 'results': [], 'num_frames' : NumbersDict(0) }
      run_frames = NumbersDict(0)

      crashed = False

      while True:
        if getattr(sys, "exited", False):
          # This happens when we exit Python.
          # Without this check, this thread would keep running until all exit handlers of Python are done.
          print >> log.v5, "%s stopped" % self
          crashed = True
          break

        for i in range(num_device_runs):
          if deviceRuns[i].crashed:
            crashed = True
            break
          if deviceRuns[i].finished:
            results['batchess'] += deviceRuns[i].result['batchess'][:]
            results['results'] += deviceRuns[i].result['results'][:]
            results['result_format'] = deviceRuns[i].result['result_format']
            deviceRuns[i].finished = False
        if crashed:
          break

        if run_frames.max_value() >= self.eval_batch_size or not self.batches.has_more():
          if all(not (dev.finished or dev.allocated or dev.processing) for dev in deviceRuns):
            results['num_frames'] = run_frames
            self.num_frames += run_frames
            if self.share_batches: run_frames *= len(self.devices)
            self.reduce(run_frames)
            self.eval_batch_idx += 1
            run_frames = NumbersDict(0)
            results['batchess'] = []
            results['results'] = []
            for device in self.devices:
              device.num_frames = 0
              device.num_updates = 0
            if not self.batches.has_more():
              break
          else:
            time.sleep(0.01)

        match = True
        while self.batches.has_more() and run_frames.max_value() < self.eval_batch_size and match:
          self.batch_idx = self.batches.get_current_batch_idx()
          if self.batch_idx < self.start_batch:
            self.batches.advance(1)
            break
          match = False
          for i in range(num_device_runs):
            if not deviceRuns[i].allocated:
              deviceRuns[i].allocate()
              run_frames += deviceRuns[i].run_frames
              match = True
              break
        if not match:
          time.sleep(0.01)

      for run in deviceRuns:
        run.stop()
      if crashed: return
      for device in self.devices:
        device.finish_epoch_stats()
      self.finalize()
      if self.interactive: progress_bar()
      self.elapsed = (time.time() - self.start_time)