예제 #1
0
 def __init__(self,
              label: str,
              metric: str,
              label_mapping: Optional[Dict[str, Any]] = None,
              bounds: Union[None, str, Iterable[Union[str, None]]] = "std",
              mode: Union[None, str, Iterable[str]] = "eval",
              ds_id: Union[None, str, Iterable[str]] = None,
              outputs: Optional[str] = None):
     super().__init__(inputs=[label, metric],
                      outputs=outputs or f"{metric}_by_{label}",
                      mode=mode,
                      ds_id=ds_id)
     self.points = []
     self.label_summaries = DefaultKeyDict(
         default=lambda x: Summary(name=x))
     self.label_mapping = {val: key
                           for key, val in label_mapping.items()
                           } if label_mapping else None
     bounds = to_set(bounds)
     if not bounds:
         bounds.add(None)
     for option in bounds:
         if option not in (None, "std", "range"):
             raise ValueError(
                 f"'interval' must be either None, 'std', or 'range', but got '{bounds}'."
             )
     self.bounds = bounds
예제 #2
0
 def __init__(self, root_log_dir: str, time_stamp: str,
              network: TFNetwork) -> None:
     super().__init__(root_log_dir=root_log_dir,
                      time_stamp=time_stamp,
                      network=network)
     self.tf_summary_writers = DefaultKeyDict(
         lambda key: (tf.summary.create_file_writer(
             os.path.join(root_log_dir, time_stamp, key))))
예제 #3
0
    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Set this objects internal state from a dictionary of variables.

        This method is invoked by pickle.

        Args:
            state: The saved state to be used by this object.
        """
        label_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
        label_summaries.update(state.get('label_summaries', {}))
        state['label_summaries'] = label_summaries
        self.__dict__.update(state)
예제 #4
0
 def on_end(self, data: Data) -> None:
     index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
     for mode in self.mode:
         final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()],
                               key=lambda x: x[1])
         max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]}
         min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]}
         target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep)
         for idx in target_idx_list:
             for step, score in self.index_history[mode][idx]:
                 index_summaries[idx].history[mode][self.metric_key][step] = score
     self.system.add_graph(self.outputs[0], list(index_summaries.values()))  # So traceability can draw it
     data.write_without_log(self.outputs[0], list(index_summaries.values()))
예제 #5
0
 def __init__(self, root_log_dir: str, time_stamp: str,
              network: BaseNetwork) -> None:
     self.summary_writers = DefaultKeyDict(lambda key: (SummaryWriter(
         log_dir=os.path.join(root_log_dir, time_stamp, key))))
     self.network = network
예제 #6
0
    def _draw_diagram(self, mode: str, epoch: int, ds_id: str) -> pydot.Dot:
        """Draw a summary diagram of the FastEstimator Ops / Traces.

        Args:
            mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer').
            epoch: The epoch to summarize.
            ds_id: The ds_id to summarize.

        Returns:
            A pydot digraph representing the execution flow.
        """
        ds = self.system.pipeline.data[mode][ds_id]
        if isinstance(ds, Scheduler):
            ds = ds.get_current_value(epoch)
        pipe_ops = get_current_items(
            self.system.pipeline.ops, run_modes=mode, epoch=epoch,
            ds_id=ds_id) if isinstance(ds, Dataset) else []
        net_ops = get_current_items(self.system.network.ops,
                                    run_modes=mode,
                                    epoch=epoch,
                                    ds_id=ds_id)
        net_post = get_current_items(self.system.network.postprocessing,
                                     run_modes=mode,
                                     epoch=epoch,
                                     ds_id=ds_id)
        traces = sort_traces(get_current_items(self.system.traces,
                                               run_modes=mode,
                                               epoch=epoch,
                                               ds_id=ds_id),
                             ds_ids=self.system.pipeline.get_ds_ids(
                                 epoch=epoch, mode=mode))
        diagram = pydot.Dot(
            compound='true'
        )  # Compound lets you draw edges which terminate at sub-graphs
        diagram.set('rankdir', 'TB')
        diagram.set('dpi', 300)
        diagram.set_node_defaults(shape='box')

        # Make the dataset the first of the pipeline ops
        pipe_ops.insert(0, ds)
        label_last_seen = DefaultKeyDict(
            lambda k: str(id(ds)))  # Where was this key last generated

        batch_size = ""
        if isinstance(ds, Dataset):
            if hasattr(ds, "fe_batch") and ds.fe_batch:
                batch_size = ds.fe_batch
            else:
                batch_size = self.system.pipeline.batch_size
                if isinstance(batch_size, Scheduler):
                    batch_size = batch_size.get_current_value(epoch)
                if isinstance(batch_size, dict):
                    batch_size = batch_size[mode]
        if batch_size is not None:
            batch_size = f" (Batch Size: {batch_size})"
        self._draw_subgraph(diagram, diagram, label_last_seen,
                            f'Pipeline{batch_size}', pipe_ops, ds_id)
        self._draw_subgraph(diagram, diagram, label_last_seen, 'Network',
                            net_ops + net_post, ds_id)
        self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces',
                            traces, ds_id)
        return diagram