def _document_fe_graph(self) -> None: """Add FE execution graphs into the traceability document. """ with self.doc.create(Section("FastEstimator Architecture")): for mode in self.system.pipeline.data.keys(): scheduled_items = self.system.pipeline.get_scheduled_items( mode) + self.system.network.get_scheduled_items(mode) + self.system.traces signature_epochs = get_signature_epochs(scheduled_items, total_epochs=self.system.epoch_idx, mode=mode) epochs_with_data = self.system.pipeline.get_epochs_with_data(total_epochs=self.system.epoch_idx, mode=mode) if set(signature_epochs) & epochs_with_data: self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create(Subsection(mode.capitalize())): for epoch in signature_epochs: if epoch not in epochs_with_data: continue self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create( Subsubsection(f"Epoch {epoch}", label=Label(Marker(name=f"{mode}{epoch}", prefix="ssubsec")))): diagram = self._draw_diagram(mode, epoch) ltx = d2t.dot2tex(diagram.to_string(), figonly=True) args = Arguments(**{'max width': r'\textwidth, max height=0.9\textheight'}) args.escape = False with self.doc.create(Center()): with self.doc.create(AdjustBox(arguments=args)) as box: box.append(NoEscape(ltx))
def _warmup(self, eager: bool = True) -> None: """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later. Traces are not executed in the warmup since they are likely to contain state variables which could become corrupted by running extra steps. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"}) sort_traces(all_traces) # This ensures that the traces can sort properly for on_begin and on_end monitor_names = self.monitor_names for mode in self.pipeline.get_modes() - {"test"}: scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items( mode) + self.get_scheduled_items(mode) signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode) epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode) for epoch in signature_epochs: if epoch not in epochs_with_data: continue network_output_keys = self.network.get_all_output_keys(mode, epoch) network_input_keys = self.network.get_effective_input_keys(mode, epoch) trace_input_keys = set() trace_output_keys = {"*"} traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch) for idx, trace in enumerate(traces): if idx > 0: # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking trace_input_keys.update(trace.inputs) trace_output_keys.update(trace.outputs) # key checking loader = self._configure_loader( self.pipeline.get_loader(mode, epoch, output_keys=trace_input_keys - network_output_keys | network_input_keys)) with Suppressor(): if isinstance(loader, tf.data.Dataset): batch = list(loader.take(1))[0] else: batch = next(iter(loader)) batch = self._configure_tensor(loader, batch) assert isinstance(batch, dict), "please make sure data output format is dictionary" pipeline_output_keys = to_set(batch.keys()) monitor_names = monitor_names - (pipeline_output_keys | network_output_keys) unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys) assert not unmet_requirements, \ "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements) sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys) trace_input_keys.update(traces[0].inputs) self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager) self.network.run_step(batch) self.network.unload_epoch() assert not monitor_names, "found missing key(s): {}".format(monitor_names)