Ejemplo n.º 1
0
 def on_end(self, data: Data) -> None:
     index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
     for mode in self.mode:
         final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()],
                               key=lambda x: x[1])
         max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]}
         min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]}
         target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep)
         for idx in target_idx_list:
             for step, score in self.index_history[mode][idx]:
                 index_summaries[idx].history[mode][self.metric_key][step] = score
     self.system.add_graph(self.outputs[0], list(index_summaries.values()))  # So traceability can draw it
     data.write_without_log(self.outputs[0], list(index_summaries.values()))
Ejemplo n.º 2
0
class LabelTracker(Trace):
    """A Trace to track metrics grouped by labels, for example per-class loss over time during training.

    Use this in conjunction with ImageViewer or ImageSaver to see the graph at training end. This also automatically
    integrates with Traceability reports.

    Args:
        label: The key of the labels by which to group data.
        metric: The key of the metric by which to score data.
        label_mapping: A mapping of {DisplayName: LabelValue} to use when generating the graph. This can also be used to
            limit which label values are graphed, since any label values not included here will not be graphed. A None
            value will monitor all label values.
        bounds: What error bounds should be graphed around the mean value. Options include None, 'std' for standard
            deviation, and 'range' to plot (min_value, mean, max_value). Multiple values can be specified, ex.
             ['std', 'range'] to generate multiple graphs.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        outputs: The name of the output which will be generated by this trace at the end of training. If None then it
            will default to "<metric>_by_<label>".

    Raises:
        ValueError: If `bounds` is not one of the allowed options.
    """
    def __init__(self,
                 label: str,
                 metric: str,
                 label_mapping: Optional[Dict[str, Any]] = None,
                 bounds: Union[None, str, Iterable[Union[str, None]]] = "std",
                 mode: Union[None, str, Iterable[str]] = "eval",
                 ds_id: Union[None, str, Iterable[str]] = None,
                 outputs: Optional[str] = None):
        super().__init__(inputs=[label, metric],
                         outputs=outputs or f"{metric}_by_{label}",
                         mode=mode,
                         ds_id=ds_id)
        self.points = []
        self.label_summaries = DefaultKeyDict(
            default=lambda x: Summary(name=x))
        self.label_mapping = {val: key
                              for key, val in label_mapping.items()
                              } if label_mapping else None
        bounds = to_set(bounds)
        if not bounds:
            bounds.add(None)
        for option in bounds:
            if option not in (None, "std", "range"):
                raise ValueError(
                    f"'interval' must be either None, 'std', or 'range', but got '{bounds}'."
                )
        self.bounds = bounds

    @property
    def label_key(self) -> str:
        return self.inputs[0]

    @property
    def metric_key(self) -> str:
        return self.inputs[1]

    def on_batch_end(self, data: Data) -> None:
        self.points.append((to_number(data[self.label_key]),
                            to_number(data[self.metric_key])))

    def on_epoch_end(self, data: Data) -> None:
        label_scores = defaultdict(list)
        for batch in self.points:
            for label, metric in ((batch[0][i], batch[1][i])
                                  for i in range(len(batch[0]))):
                label_scores[label.item()].append(metric.item())
        for label, metric in label_scores.items():
            if self.label_mapping:
                if label in self.label_mapping:
                    label = self.label_mapping[label]
                else:
                    # Skip labels which the user does not want to inspect
                    continue
            if 'std' in self.bounds:
                mean, std = stats.mean(metric), stats.stdev(
                    metric) if len(metric) > 1 else 0.0
                val = ValWithError(mean - std, mean, mean + std)
                key = f"{self.metric_key} ($\\mu \\pm \\sigma$)"
                # {label: {mode: {key: {step: value}}}}
                self.label_summaries[label].history[self.system.mode][key][
                    self.system.global_step] = val
            if 'range' in self.bounds:
                val = ValWithError(min(metric), stats.mean(metric),
                                   max(metric))
                key = f"{self.metric_key} ($min, \\mu, max$)"
                self.label_summaries[label].history[self.system.mode][key][
                    self.system.global_step] = val
            if None in self.bounds:
                val = stats.mean(metric)
                key = self.metric_key
                self.label_summaries[label].history[self.system.mode][key][
                    self.system.global_step] = val
        self.points = []

    def on_end(self, data: Data) -> None:
        self.system.add_graph(
            self.outputs[0],
            list(self.label_summaries.values()))  # So traceability can draw it
        data.write_without_log(self.outputs[0],
                               list(self.label_summaries.values()))

    def __getstate__(self) -> Dict[str, Any]:
        """Get a representation of the state of this object.

        This method is invoked by pickle.

        Returns:
            The information to be recorded by a pickle summary of this object.
        """
        state = self.__dict__.copy()
        state['label_summaries'] = dict(state['label_summaries'])
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Set this objects internal state from a dictionary of variables.

        This method is invoked by pickle.

        Args:
            state: The saved state to be used by this object.
        """
        label_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
        label_summaries.update(state.get('label_summaries', {}))
        state['label_summaries'] = label_summaries
        self.__dict__.update(state)