def __init__(self, label: str, metric: str, label_mapping: Optional[Dict[str, Any]] = None, bounds: Union[None, str, Iterable[Union[str, None]]] = "std", mode: Union[None, str, Iterable[str]] = "eval", ds_id: Union[None, str, Iterable[str]] = None, outputs: Optional[str] = None): super().__init__(inputs=[label, metric], outputs=outputs or f"{metric}_by_{label}", mode=mode, ds_id=ds_id) self.points = [] self.label_summaries = DefaultKeyDict( default=lambda x: Summary(name=x)) self.label_mapping = {val: key for key, val in label_mapping.items() } if label_mapping else None bounds = to_set(bounds) if not bounds: bounds.add(None) for option in bounds: if option not in (None, "std", "range"): raise ValueError( f"'interval' must be either None, 'std', or 'range', but got '{bounds}'." ) self.bounds = bounds
def __init__(self, root_log_dir: str, time_stamp: str, network: TFNetwork) -> None: super().__init__(root_log_dir=root_log_dir, time_stamp=time_stamp, network=network) self.tf_summary_writers = DefaultKeyDict( lambda key: (tf.summary.create_file_writer( os.path.join(root_log_dir, time_stamp, key))))
def __setstate__(self, state: Dict[str, Any]) -> None: """Set this objects internal state from a dictionary of variables. This method is invoked by pickle. Args: state: The saved state to be used by this object. """ label_summaries = DefaultKeyDict(default=lambda x: Summary(name=x)) label_summaries.update(state.get('label_summaries', {})) state['label_summaries'] = label_summaries self.__dict__.update(state)
def on_end(self, data: Data) -> None: index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x)) for mode in self.mode: final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()], key=lambda x: x[1]) max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]} min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]} target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep) for idx in target_idx_list: for step, score in self.index_history[mode][idx]: index_summaries[idx].history[mode][self.metric_key][step] = score self.system.add_graph(self.outputs[0], list(index_summaries.values())) # So traceability can draw it data.write_without_log(self.outputs[0], list(index_summaries.values()))
def __init__(self, root_log_dir: str, time_stamp: str, network: BaseNetwork) -> None: self.summary_writers = DefaultKeyDict(lambda key: (SummaryWriter( log_dir=os.path.join(root_log_dir, time_stamp, key)))) self.network = network
def _draw_diagram(self, mode: str, epoch: int, ds_id: str) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. ds_id: The ds_id to summarize. Returns: A pydot digraph representing the execution flow. """ ds = self.system.pipeline.data[mode][ds_id] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch, ds_id=ds_id) if isinstance(ds, Dataset) else [] net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch, ds_id=ds_id) net_post = get_current_items(self.system.network.postprocessing, run_modes=mode, epoch=epoch, ds_id=ds_id) traces = sort_traces(get_current_items(self.system.traces, run_modes=mode, epoch=epoch, ds_id=ds_id), ds_ids=self.system.pipeline.get_ds_ids( epoch=epoch, mode=mode)) diagram = pydot.Dot( compound='true' ) # Compound lets you draw edges which terminate at sub-graphs diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='box') # Make the dataset the first of the pipeline ops pipe_ops.insert(0, ds) label_last_seen = DefaultKeyDict( lambda k: str(id(ds))) # Where was this key last generated batch_size = "" if isinstance(ds, Dataset): if hasattr(ds, "fe_batch") and ds.fe_batch: batch_size = ds.fe_batch else: batch_size = self.system.pipeline.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(epoch) if isinstance(batch_size, dict): batch_size = batch_size[mode] if batch_size is not None: batch_size = f" (Batch Size: {batch_size})" self._draw_subgraph(diagram, diagram, label_last_seen, f'Pipeline{batch_size}', pipe_ops, ds_id) self._draw_subgraph(diagram, diagram, label_last_seen, 'Network', net_ops + net_post, ds_id) self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces', traces, ds_id) return diagram