Esempio n. 1
0
 def _document_fe_graph(self) -> None:
     """Add FE execution graphs into the traceability document.
     """
     with self.doc.create(Section("FastEstimator Architecture")):
         for mode in self.system.pipeline.data.keys():
             scheduled_items = self.system.pipeline.get_scheduled_items(
                 mode) + self.system.network.get_scheduled_items(mode) + self.system.traces
             signature_epochs = get_signature_epochs(scheduled_items, total_epochs=self.system.epoch_idx, mode=mode)
             epochs_with_data = self.system.pipeline.get_epochs_with_data(total_epochs=self.system.epoch_idx,
                                                                          mode=mode)
             if set(signature_epochs) & epochs_with_data:
                 self.doc.append(NoEscape(r'\FloatBarrier'))
                 with self.doc.create(Subsection(mode.capitalize())):
                     for epoch in signature_epochs:
                         if epoch not in epochs_with_data:
                             continue
                         self.doc.append(NoEscape(r'\FloatBarrier'))
                         with self.doc.create(
                                 Subsubsection(f"Epoch {epoch}",
                                               label=Label(Marker(name=f"{mode}{epoch}", prefix="ssubsec")))):
                             diagram = self._draw_diagram(mode, epoch)
                             ltx = d2t.dot2tex(diagram.to_string(), figonly=True)
                             args = Arguments(**{'max width': r'\textwidth, max height=0.9\textheight'})
                             args.escape = False
                             with self.doc.create(Center()):
                                 with self.doc.create(AdjustBox(arguments=args)) as box:
                                     box.append(NoEscape(ltx))
Esempio n. 2
0
    def _warmup(self, eager: bool = True) -> None:
        """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later.

        Traces are not executed in the warmup since they are likely to contain state variables which could become
        corrupted by running extra steps.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"})
        sort_traces(all_traces)  # This ensures that the traces can sort properly for on_begin and on_end
        monitor_names = self.monitor_names
        for mode in self.pipeline.get_modes() - {"test"}:
            scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items(
                mode) + self.get_scheduled_items(mode)
            signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode)
            epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode)
            for epoch in signature_epochs:
                if epoch not in epochs_with_data:
                    continue
                network_output_keys = self.network.get_all_output_keys(mode, epoch)
                network_input_keys = self.network.get_effective_input_keys(mode, epoch)
                trace_input_keys = set()
                trace_output_keys = {"*"}
                traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch)
                for idx, trace in enumerate(traces):
                    if idx > 0:  # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking
                        trace_input_keys.update(trace.inputs)
                    trace_output_keys.update(trace.outputs)
                # key checking
                loader = self._configure_loader(
                    self.pipeline.get_loader(mode,
                                             epoch,
                                             output_keys=trace_input_keys - network_output_keys | network_input_keys))
                with Suppressor():
                    if isinstance(loader, tf.data.Dataset):
                        batch = list(loader.take(1))[0]
                    else:
                        batch = next(iter(loader))
                batch = self._configure_tensor(loader, batch)
                assert isinstance(batch, dict), "please make sure data output format is dictionary"
                pipeline_output_keys = to_set(batch.keys())

                monitor_names = monitor_names - (pipeline_output_keys | network_output_keys)
                unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys)
                assert not unmet_requirements, \
                    "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements)
                sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys)
                trace_input_keys.update(traces[0].inputs)
                self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager)
                self.network.run_step(batch)
                self.network.unload_epoch()
        assert not monitor_names, "found missing key(s): {}".format(monitor_names)