def _warmup(self, eager: bool = True) -> None: """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later. Traces are not executed in the warmup since they are likely to contain state variables which could become corrupted by running extra steps. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"}) sort_traces(all_traces) # This ensures that the traces can sort properly for on_begin and on_end monitor_names = self.monitor_names for mode in self.pipeline.get_modes() - {"test"}: scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items( mode) + self.get_scheduled_items(mode) signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode) epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode) for epoch in signature_epochs: if epoch not in epochs_with_data: continue network_output_keys = self.network.get_all_output_keys(mode, epoch) network_input_keys = self.network.get_effective_input_keys(mode, epoch) trace_input_keys = set() trace_output_keys = {"*"} traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch) for idx, trace in enumerate(traces): if idx > 0: # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking trace_input_keys.update(trace.inputs) trace_output_keys.update(trace.outputs) # key checking loader = self._configure_loader( self.pipeline.get_loader(mode, epoch, output_keys=trace_input_keys - network_output_keys | network_input_keys)) with Suppressor(): if isinstance(loader, tf.data.Dataset): batch = list(loader.take(1))[0] else: batch = next(iter(loader)) batch = self._configure_tensor(loader, batch) assert isinstance(batch, dict), "please make sure data output format is dictionary" pipeline_output_keys = to_set(batch.keys()) monitor_names = monitor_names - (pipeline_output_keys | network_output_keys) unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys) assert not unmet_requirements, \ "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements) sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys) trace_input_keys.update(traces[0].inputs) self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager) self.network.run_step(batch) self.network.unload_epoch() assert not monitor_names, "found missing key(s): {}".format(monitor_names)
def _start(self, run_modes: Set[str]) -> None: """The outer training loop. This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes the trace on_end method. Args: run_modes: The current execution modes. """ all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)) try: self._run_traces_on_begin(traces=all_traces) if "train" in run_modes or "eval" in run_modes: for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1): if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "train" self._run_epoch() if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch() else: self._run_epoch() except EarlyStop: pass # On early stopping we still want to run the final traces and return results self._run_traces_on_end(traces=all_traces)
def _prepare_traces(self, run_modes: Set[str]) -> None: """Prepare information about the traces for training. Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected. Args: run_modes: The current execution modes. """ self.traces_in_use = [trace for trace in self.traces] if self.system.log_steps is not None: self.traces_in_use.append(Logger()) # Look for any monitor names which should be automagically added. trace_outputs = set() extra_monitor_keys = set() for trace in sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)): trace_outputs.update(trace.outputs) extra_monitor_keys.update(trace.fe_monitor_names - trace_outputs) # Add the essential traces if "train" in run_modes: self.traces_in_use.insert(0, TrainEssential(monitor_names=self.monitor_names.union(extra_monitor_keys))) no_save_warning = True for trace in get_current_items(self.traces_in_use, run_modes=run_modes): if isinstance(trace, (ModelSaver, BestModelSaver)): no_save_warning = False if no_save_warning: print("FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.") if "eval" in run_modes and "eval" in self.pipeline.get_modes(): self.traces_in_use.insert(1, EvalEssential(monitor_names=self.monitor_names.union(extra_monitor_keys))) # insert system instance to trace for trace in get_current_items(self.traces_in_use, run_modes=run_modes): trace.system = self.system
def _start(self, run_modes: Set[str], eager: bool) -> None: """The outer training loop. This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes the trace on_end method. Args: run_modes: The current execution modes. eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)) try: self._run_traces_on_begin(traces=all_traces) if "train" in run_modes or "eval" in run_modes: # If the training is re-starting from a restore wizard, it should re-run the last eval epoch if self.system.epoch_idx > 0 and "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch(eager=eager) for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1): if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "train" self._run_epoch(eager=eager) if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch(eager=eager) else: self._run_epoch(eager=eager) except EarlyStop: pass # On early stopping we still want to run the final traces and return results self._run_traces_on_end(traces=all_traces)
def _run_epoch(self, eager: bool) -> None: """A method to perform an epoch of activity. This method requires that the current mode and epoch already be specified within the self.system object. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ traces = get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx) trace_input_keys = set() for trace in traces: trace_input_keys.update(trace.inputs) network_input_keys = self.network.get_effective_input_keys(self.system.mode, self.system.epoch_idx) network_output_keys = self.network.get_all_output_keys(self.system.mode, self.system.epoch_idx) loader = self._configure_loader( self.pipeline.get_loader(self.system.mode, self.system.epoch_idx, output_keys=trace_input_keys - network_output_keys | network_input_keys)) iterator = iter(loader) self.network.load_epoch(mode=self.system.mode, epoch=self.system.epoch_idx, output_keys=trace_input_keys, eager=eager) self.system.batch_idx = None with Suppressor(): batch = next(iterator) traces = sort_traces(traces, available_outputs=to_set(batch.keys()) | network_output_keys) self._run_traces_on_epoch_begin(traces=traces) while True: try: if self.system.mode == "train": self.system.update_global_step() self.system.update_batch_idx() batch = self._configure_tensor(loader, batch) self._run_traces_on_batch_begin(batch, traces=traces) batch, prediction = self.network.run_step(batch) self._run_traces_on_batch_end(batch, prediction, traces=traces) if isinstance(loader, DataLoader) and ( (self.system.batch_idx == self.system.max_train_steps_per_epoch and self.system.mode == "train") or (self.system.batch_idx == self.system.max_eval_steps_per_epoch and self.system.mode == "eval")): raise StopIteration with Suppressor(): batch = next(iterator) except StopIteration: break self._run_traces_on_epoch_end(traces=traces) self.network.unload_epoch()
def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. Returns: A pydot digraph representing the execution flow. """ ds = self.system.pipeline.data[mode] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch) if isinstance(ds, Dataset) else [] net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch) traces = sort_traces( get_current_items(self.system.traces, run_modes=mode, epoch=epoch)) diagram = pydot.Dot( compound='true' ) # Compound lets you draw edges which terminate at sub-graphs diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='box') # Make the dataset the first of the pipeline ops pipe_ops.insert(0, ds) label_last_seen = defaultdict( lambda: str(id(ds))) # Where was this key last generated batch_size = "" if isinstance(ds, Dataset) and not isinstance(ds, BatchDataset): batch_size = self.system.pipeline.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(epoch) if batch_size is not None: batch_size = f" (Batch Size: {batch_size})" self._draw_subgraph(diagram, diagram, label_last_seen, f'Pipeline{batch_size}', pipe_ops) self._draw_subgraph(diagram, diagram, label_last_seen, 'Network', net_ops) self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces', traces) return diagram
def _run_epoch(self, eager: bool) -> None: """A method to perform an epoch of activity. This method requires that the current mode and epoch already be specified within the self.system object. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx, self.system.mode) epoch_traces = sort_traces(get_current_items( self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx), ds_ids=ds_ids) self._run_traces_on_epoch_begin(traces=epoch_traces) self.system.batch_idx = None end_epoch_data = Data( ) # We will aggregate data over on_ds_end and put it into on_epoch_end for printing # run for each dataset for self.system.ds_id in ds_ids: ds_traces = get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) trace_input_keys = set() for ds_trace in ds_traces: trace_input_keys.update(ds_trace.inputs) network_input_keys = self.network.get_effective_input_keys( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) network_output_keys = self.network.get_all_output_keys( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) self.network.load_epoch(mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id, output_keys=trace_input_keys, eager=eager) with self.pipeline( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id, steps_per_epoch=self.system.steps_per_epoch, output_keys=trace_input_keys - network_output_keys | network_input_keys) as loader: loader = self._configure_loader(loader) iterator = iter(loader) with Suppressor(): batch = next(iterator) ds_traces = sort_traces(ds_traces, available_outputs=to_set(batch.keys()) | network_output_keys, ds_ids=ds_ids) per_ds_traces = [ trace for trace in ds_traces if isinstance(trace, PerDSTrace) ] self._run_traces_on_ds_begin(traces=per_ds_traces) while True: try: if self.system.mode == "train": self.system.update_global_step() self.system.update_batch_idx() batch = self._configure_tensor(loader, batch) self._run_traces_on_batch_begin(batch, traces=ds_traces) batch, prediction = self.network.run_step(batch) self._run_traces_on_batch_end(batch, prediction, traces=ds_traces) if isinstance(loader, DataLoader) and ( (self.system.batch_idx == self.system.train_steps_per_epoch and self.system.mode == "train") or (self.system.batch_idx == self.system.eval_steps_per_epoch and self.system.mode == "eval")): raise StopIteration with Suppressor(): batch = next(iterator) except StopIteration: break self._run_traces_on_ds_end(traces=per_ds_traces, data=end_epoch_data) self.network.unload_epoch() self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)