def __init__(self, engine, dataset, batches, train, eval=True): """ :param Engine engine: :param Dataset.Dataset dataset: :param BatchSetGenerator batches: :param bool train: whether to do updates on the model :param bool eval: whether to evaluate (i.e. calculate loss/error) """ self.engine = engine self.data_provider = DataProvider( tf_session=engine.tf_session, extern_data=engine.network.extern_data, data_keys=engine.network.used_data_keys, dataset=dataset, batches=batches) self._should_train = train self._should_eval = eval self.store_metadata_mod_step = False # 500 self.finalized = False self.num_steps = None self.device_crash_batch = None self.start_time = None self.elapsed = None self._results_accumulated = {} # type: dict[str,float] # entries like "cost:output" or "loss" self.results = {} # type: dict[str,float] # entries like "cost:output" or "loss" self.score = {} # type: dict[str,float] # entries like "cost:output" self.error = {} # type: dict[str,float] # entries like "error:output" self.stats = {} # type: dict[str,float] # entries like "stats:..." from Util import terminal_size terminal_width, _ = terminal_size() self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)
def __init__(self, engine, dataset, batches, train, eval=True, extra_fetches=None, extra_fetches_callback=None): """ :param Engine engine: :param Dataset.Dataset dataset: :param BatchSetGenerator batches: :param bool train: whether to do updates on the model :param bool eval: whether to evaluate (i.e. calculate loss/error) :param dict[str,tf.Tensor|TFUtil.Data|TFNetworkLayer.LayerBase]|None extra_fetches: additional fetches per step. `extra_fetches_callback` will be called with these. In case of Data/LayerBase, it will return a list, where each item corresponds to the batch-seq. It might also be useful to add `network.get_extern_data("seq_idx")` and `network.get_extern_data("seq_tag")`. :param (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None extra_fetches_callback: called if extra_fetches """ from TFDataPipeline import FeedDictDataProvider, DataProviderBase self.engine = engine self.data_provider = FeedDictDataProvider( tf_session=engine.tf_session, extern_data=engine.network.extern_data, data_keys=engine.network.used_data_keys, dataset=dataset, batches=batches) assert isinstance(self.data_provider, DataProviderBase) self._should_train = train self._should_eval = eval self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0) self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0) self.finalized = False self.num_steps = None self.device_crash_batch = None # type: int|None self.start_time = None self.elapsed = None self._results_accumulated = {} # type: dict[str,float] # entries like "cost:output" or "loss" self.results = {} # type: dict[str,float] # entries like "cost:output" or "loss" self.score = {} # type: dict[str,float] # entries like "cost:output" self.error = {} # type: dict[str,float] # entries like "error:output" self.stats = {} # type: dict[str,float] # entries like "stats:..." self.extra_fetches = extra_fetches if extra_fetches is not None: assert extra_fetches_callback self.extra_fetches_callback = extra_fetches_callback from Util import terminal_size terminal_width, _ = terminal_size() self._show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0)
def run_inner(self): self.start_time = time.time() for device in self.devices: device.prepare(epoch=self.epoch, **self.get_device_prepare_args()) self.initialize() terminal_width, _ = terminal_size() self.interactive = (log.v[3] and terminal_width >= 0) print("starting task", self.task, file=log.v5) for device in self.devices: device.eval_batch_idx = -1 device.start_epoch_stats() device.num_frames = 0 device.num_updates = 0 device.tot = 0 num_device_runs = 1 if self.share_batches else len(self.devices) deviceRuns = [ self.DeviceBatchRun( self, [self.devices[i]] if not self.share_batches else self.devices) for i in range(num_device_runs) ] results = {'batchess': [], 'results': [], 'num_frames': NumbersDict(0)} run_frames = NumbersDict(0) cost_result_format = -1 crashed = False assert num_device_runs > 0 while True: if getattr(sys, "exited", False): # This happens when we exit Python. # Without this check, this thread would keep running until all exit handlers of Python are done. print("%s stopped" % self, file=log.v5) crashed = True break for i in range(num_device_runs): if deviceRuns[i].crashed or not deviceRuns[i].is_alive(): crashed = True break if deviceRuns[i].finished: results['batchess'] += deviceRuns[i].result['batchess'][:] results['results'] += deviceRuns[i].result['results'][:] results['result_format'] = deviceRuns[i].result[ 'result_format'] deviceRuns[i].finished = False if crashed: break if cost_result_format < 0 and deviceRuns[i].result['result_format']: for idx, fmt in enumerate( deviceRuns[i].result['result_format']): if fmt and fmt.startswith('cost:'): cost_result_format = idx total_cost = 0 if results['results'] and cost_result_format >= 0: total_cost = numpy.asarray( results['results'])[:, cost_result_format].sum() if total_cost >= self.eval_batch_size or not self.batches.has_more( ): if all(not (dev.finished or dev.allocated or dev.processing) for dev in deviceRuns): results['num_frames'] = run_frames self.num_frames += run_frames if self.share_batches: run_frames *= len(self.devices) self.reduce(run_frames) self.eval_batch_idx += 1 run_frames = NumbersDict(0) results['batchess'] = [] results['results'] = [] for device in self.devices: device.num_frames = 0 device.num_updates = 0 if not self.batches.has_more(): break else: time.sleep(0.01) match = True while self.batches.has_more( ) and total_cost < self.eval_batch_size and match: self.batch_idx = self.batches.get_current_batch_idx() if self.batch_idx < self.start_batch: self.batches.advance(1) break match = False for i in range(num_device_runs): if not deviceRuns[i].allocated: deviceRuns[i].allocate() run_frames += deviceRuns[i].run_frames match = True break if not match: time.sleep(0.01) for run in deviceRuns: run.stop() if crashed: return for device in self.devices: device.finish_epoch_stats() self.finalize() if self.interactive: progress_bar() self.elapsed = (time.time() - self.start_time)
def run_inner(self): self.start_time = time.time() for device in self.devices: device.prepare(epoch=self.epoch, **self.get_device_prepare_args()) self.initialize() terminal_width, _ = terminal_size() self.interactive = (log.v[3] and terminal_width >= 0) print >> log.v5, "starting task", self.task for device in self.devices: device.eval_batch_idx = -1 device.start_epoch_stats() device.num_frames = 0 device.num_updates = 0 device.tot = 0 num_device_runs = 1 if self.share_batches else len(self.devices) deviceRuns = [ self.DeviceBatchRun(self, [self.devices[i]] if not self.share_batches else self.devices) for i in range(num_device_runs) ] results = { 'batchess': [], 'results': [], 'num_frames' : NumbersDict(0) } run_frames = NumbersDict(0) crashed = False while True: if getattr(sys, "exited", False): # This happens when we exit Python. # Without this check, this thread would keep running until all exit handlers of Python are done. print >> log.v5, "%s stopped" % self crashed = True break for i in range(num_device_runs): if deviceRuns[i].crashed: crashed = True break if deviceRuns[i].finished: results['batchess'] += deviceRuns[i].result['batchess'][:] results['results'] += deviceRuns[i].result['results'][:] results['result_format'] = deviceRuns[i].result['result_format'] deviceRuns[i].finished = False if crashed: break if run_frames.max_value() >= self.eval_batch_size or not self.batches.has_more(): if all(not (dev.finished or dev.allocated or dev.processing) for dev in deviceRuns): results['num_frames'] = run_frames self.num_frames += run_frames if self.share_batches: run_frames *= len(self.devices) self.reduce(run_frames) self.eval_batch_idx += 1 run_frames = NumbersDict(0) results['batchess'] = [] results['results'] = [] for device in self.devices: device.num_frames = 0 device.num_updates = 0 if not self.batches.has_more(): break else: time.sleep(0.01) match = True while self.batches.has_more() and run_frames.max_value() < self.eval_batch_size and match: self.batch_idx = self.batches.get_current_batch_idx() if self.batch_idx < self.start_batch: self.batches.advance(1) break match = False for i in range(num_device_runs): if not deviceRuns[i].allocated: deviceRuns[i].allocate() run_frames += deviceRuns[i].run_frames match = True break if not match: time.sleep(0.01) for run in deviceRuns: run.stop() if crashed: return for device in self.devices: device.finish_epoch_stats() self.finalize() if self.interactive: progress_bar() self.elapsed = (time.time() - self.start_time)