Exemplo n.º 1
0
    def _warmup(self, eager: bool = True) -> None:
        """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later.

        Traces are not executed in the warmup since they are likely to contain state variables which could become
        corrupted by running extra steps.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"})
        sort_traces(all_traces)  # This ensures that the traces can sort properly for on_begin and on_end
        monitor_names = self.monitor_names
        for mode in self.pipeline.get_modes() - {"test"}:
            scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items(
                mode) + self.get_scheduled_items(mode)
            signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode)
            epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode)
            for epoch in signature_epochs:
                if epoch not in epochs_with_data:
                    continue
                network_output_keys = self.network.get_all_output_keys(mode, epoch)
                network_input_keys = self.network.get_effective_input_keys(mode, epoch)
                trace_input_keys = set()
                trace_output_keys = {"*"}
                traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch)
                for idx, trace in enumerate(traces):
                    if idx > 0:  # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking
                        trace_input_keys.update(trace.inputs)
                    trace_output_keys.update(trace.outputs)
                # key checking
                loader = self._configure_loader(
                    self.pipeline.get_loader(mode,
                                             epoch,
                                             output_keys=trace_input_keys - network_output_keys | network_input_keys))
                with Suppressor():
                    if isinstance(loader, tf.data.Dataset):
                        batch = list(loader.take(1))[0]
                    else:
                        batch = next(iter(loader))
                batch = self._configure_tensor(loader, batch)
                assert isinstance(batch, dict), "please make sure data output format is dictionary"
                pipeline_output_keys = to_set(batch.keys())

                monitor_names = monitor_names - (pipeline_output_keys | network_output_keys)
                unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys)
                assert not unmet_requirements, \
                    "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements)
                sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys)
                trace_input_keys.update(traces[0].inputs)
                self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager)
                self.network.run_step(batch)
                self.network.unload_epoch()
        assert not monitor_names, "found missing key(s): {}".format(monitor_names)
Exemplo n.º 2
0
    def _start(self, run_modes: Set[str]) -> None:
        """The outer training loop.

        This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes
        the trace on_end method.

        Args:
            run_modes: The current execution modes.
        """
        all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes))
        try:
            self._run_traces_on_begin(traces=all_traces)
            if "train" in run_modes or "eval" in run_modes:
                for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1):
                    if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                        self.system.mode = "train"
                        self._run_epoch()
                    if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                        self.system.mode = "eval"
                        self._run_epoch()
            else:
                self._run_epoch()
        except EarlyStop:
            pass  # On early stopping we still want to run the final traces and return results
        self._run_traces_on_end(traces=all_traces)
Exemplo n.º 3
0
    def _prepare_traces(self, run_modes: Set[str]) -> None:
        """Prepare information about the traces for training.

        Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected.

        Args:
            run_modes: The current execution modes.
        """
        self.traces_in_use = [trace for trace in self.traces]
        if self.system.log_steps is not None:
            self.traces_in_use.append(Logger())
        # Look for any monitor names which should be automagically added.
        trace_outputs = set()
        extra_monitor_keys = set()
        for trace in sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)):
            trace_outputs.update(trace.outputs)
            extra_monitor_keys.update(trace.fe_monitor_names - trace_outputs)
        # Add the essential traces
        if "train" in run_modes:
            self.traces_in_use.insert(0, TrainEssential(monitor_names=self.monitor_names.union(extra_monitor_keys)))
            no_save_warning = True
            for trace in get_current_items(self.traces_in_use, run_modes=run_modes):
                if isinstance(trace, (ModelSaver, BestModelSaver)):
                    no_save_warning = False
            if no_save_warning:
                print("FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.")
        if "eval" in run_modes and "eval" in self.pipeline.get_modes():
            self.traces_in_use.insert(1, EvalEssential(monitor_names=self.monitor_names.union(extra_monitor_keys)))
        # insert system instance to trace
        for trace in get_current_items(self.traces_in_use, run_modes=run_modes):
            trace.system = self.system
Exemplo n.º 4
0
    def _start(self, run_modes: Set[str], eager: bool) -> None:
        """The outer training loop.

        This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes
        the trace on_end method.

        Args:
            run_modes: The current execution modes.
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes))
        try:
            self._run_traces_on_begin(traces=all_traces)
            if "train" in run_modes or "eval" in run_modes:
                # If the training is re-starting from a restore wizard, it should re-run the last eval epoch
                if self.system.epoch_idx > 0 and "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                    self.system.mode = "eval"
                    self._run_epoch(eager=eager)
                for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1):
                    if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                        self.system.mode = "train"
                        self._run_epoch(eager=eager)
                    if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx):
                        self.system.mode = "eval"
                        self._run_epoch(eager=eager)
            else:
                self._run_epoch(eager=eager)
        except EarlyStop:
            pass  # On early stopping we still want to run the final traces and return results
        self._run_traces_on_end(traces=all_traces)
Exemplo n.º 5
0
    def _run_epoch(self, eager: bool) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        traces = get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx)
        trace_input_keys = set()
        for trace in traces:
            trace_input_keys.update(trace.inputs)
        network_input_keys = self.network.get_effective_input_keys(self.system.mode, self.system.epoch_idx)
        network_output_keys = self.network.get_all_output_keys(self.system.mode, self.system.epoch_idx)
        loader = self._configure_loader(
            self.pipeline.get_loader(self.system.mode,
                                     self.system.epoch_idx,
                                     output_keys=trace_input_keys - network_output_keys | network_input_keys))
        iterator = iter(loader)
        self.network.load_epoch(mode=self.system.mode,
                                epoch=self.system.epoch_idx,
                                output_keys=trace_input_keys,
                                eager=eager)
        self.system.batch_idx = None
        with Suppressor():
            batch = next(iterator)
        traces = sort_traces(traces, available_outputs=to_set(batch.keys()) | network_output_keys)
        self._run_traces_on_epoch_begin(traces=traces)
        while True:
            try:
                if self.system.mode == "train":
                    self.system.update_global_step()
                self.system.update_batch_idx()
                batch = self._configure_tensor(loader, batch)
                self._run_traces_on_batch_begin(batch, traces=traces)
                batch, prediction = self.network.run_step(batch)
                self._run_traces_on_batch_end(batch, prediction, traces=traces)
                if isinstance(loader, DataLoader) and (
                    (self.system.batch_idx == self.system.max_train_steps_per_epoch and self.system.mode == "train") or
                    (self.system.batch_idx == self.system.max_eval_steps_per_epoch and self.system.mode == "eval")):
                    raise StopIteration
                with Suppressor():
                    batch = next(iterator)
            except StopIteration:
                break
        self._run_traces_on_epoch_end(traces=traces)
        self.network.unload_epoch()
Exemplo n.º 6
0
    def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot:
        """Draw a summary diagram of the FastEstimator Ops / Traces.

        Args:
            mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer').
            epoch: The epoch to summarize.

        Returns:
            A pydot digraph representing the execution flow.
        """
        ds = self.system.pipeline.data[mode]
        if isinstance(ds, Scheduler):
            ds = ds.get_current_value(epoch)
        pipe_ops = get_current_items(
            self.system.pipeline.ops, run_modes=mode,
            epoch=epoch) if isinstance(ds, Dataset) else []
        net_ops = get_current_items(self.system.network.ops,
                                    run_modes=mode,
                                    epoch=epoch)
        traces = sort_traces(
            get_current_items(self.system.traces, run_modes=mode, epoch=epoch))
        diagram = pydot.Dot(
            compound='true'
        )  # Compound lets you draw edges which terminate at sub-graphs
        diagram.set('rankdir', 'TB')
        diagram.set('dpi', 300)
        diagram.set_node_defaults(shape='box')

        # Make the dataset the first of the pipeline ops
        pipe_ops.insert(0, ds)
        label_last_seen = defaultdict(
            lambda: str(id(ds)))  # Where was this key last generated

        batch_size = ""
        if isinstance(ds, Dataset) and not isinstance(ds, BatchDataset):
            batch_size = self.system.pipeline.batch_size
            if isinstance(batch_size, Scheduler):
                batch_size = batch_size.get_current_value(epoch)
            if batch_size is not None:
                batch_size = f" (Batch Size: {batch_size})"
        self._draw_subgraph(diagram, diagram, label_last_seen,
                            f'Pipeline{batch_size}', pipe_ops)
        self._draw_subgraph(diagram, diagram, label_last_seen, 'Network',
                            net_ops)
        self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces',
                            traces)
        return diagram
Exemplo n.º 7
0
    def _run_epoch(self, eager: bool) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx,
                                          self.system.mode)
        epoch_traces = sort_traces(get_current_items(
            self.traces_in_use,
            run_modes=self.system.mode,
            epoch=self.system.epoch_idx),
                                   ds_ids=ds_ids)
        self._run_traces_on_epoch_begin(traces=epoch_traces)
        self.system.batch_idx = None
        end_epoch_data = Data(
        )  # We will aggregate data over on_ds_end and put it into on_epoch_end for printing
        # run for each dataset
        for self.system.ds_id in ds_ids:
            ds_traces = get_current_items(self.traces_in_use,
                                          run_modes=self.system.mode,
                                          epoch=self.system.epoch_idx,
                                          ds_id=self.system.ds_id)
            trace_input_keys = set()
            for ds_trace in ds_traces:
                trace_input_keys.update(ds_trace.inputs)
            network_input_keys = self.network.get_effective_input_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            network_output_keys = self.network.get_all_output_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            self.network.load_epoch(mode=self.system.mode,
                                    epoch=self.system.epoch_idx,
                                    ds_id=self.system.ds_id,
                                    output_keys=trace_input_keys,
                                    eager=eager)

            with self.pipeline(
                    mode=self.system.mode,
                    epoch=self.system.epoch_idx,
                    ds_id=self.system.ds_id,
                    steps_per_epoch=self.system.steps_per_epoch,
                    output_keys=trace_input_keys - network_output_keys
                    | network_input_keys) as loader:
                loader = self._configure_loader(loader)
                iterator = iter(loader)
                with Suppressor():
                    batch = next(iterator)
                ds_traces = sort_traces(ds_traces,
                                        available_outputs=to_set(batch.keys())
                                        | network_output_keys,
                                        ds_ids=ds_ids)
                per_ds_traces = [
                    trace for trace in ds_traces
                    if isinstance(trace, PerDSTrace)
                ]
                self._run_traces_on_ds_begin(traces=per_ds_traces)
                while True:
                    try:
                        if self.system.mode == "train":
                            self.system.update_global_step()
                        self.system.update_batch_idx()
                        batch = self._configure_tensor(loader, batch)
                        self._run_traces_on_batch_begin(batch,
                                                        traces=ds_traces)
                        batch, prediction = self.network.run_step(batch)
                        self._run_traces_on_batch_end(batch,
                                                      prediction,
                                                      traces=ds_traces)
                        if isinstance(loader, DataLoader) and (
                            (self.system.batch_idx
                             == self.system.train_steps_per_epoch
                             and self.system.mode == "train") or
                            (self.system.batch_idx
                             == self.system.eval_steps_per_epoch
                             and self.system.mode == "eval")):
                            raise StopIteration
                        with Suppressor():
                            batch = next(iterator)
                    except StopIteration:
                        break
                self._run_traces_on_ds_end(traces=per_ds_traces,
                                           data=end_epoch_data)
            self.network.unload_epoch()
        self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)