예제 #1
0
 def _document_init_params(self) -> None:
     """Add initialization parameters to the traceability document.
     """
     with self.doc.create(Section("Parameters")):
         model_ids = {
             FEID(id(model))
             for model in self.system.network.models
             if isinstance(model, (tf.keras.Model, torch.nn.Module))
         }
         datasets = {
             FEID(id(self.system.pipeline.data.get(title, None))):
             (title, self.system.pipeline.data.get(title, None))
             for title in ['train', 'eval', 'test']
         }
         for tbl in self.config_tables:
             name_override = None
             toc_ref = None
             extra_rows = None
             if issubclass(tbl.type, Estimator):
                 toc_ref = "Estimator"
             if issubclass(tbl.type, BaseNetwork):
                 toc_ref = "Network"
             if issubclass(tbl.type, Pipeline):
                 toc_ref = "Pipeline"
             if tbl.fe_id in model_ids:
                 # Link to a later detailed model description
                 name_override = Hyperref(
                     Marker(name=str(tbl.name), prefix="subsec"),
                     text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) +
                     NoEscape('}'))
                 toc_ref = tbl.name
             if tbl.fe_id in datasets:
                 title, dataset = datasets[tbl.fe_id]
                 name_override = bold(f'{tbl.name} ({title.capitalize()})')
                 toc_ref = f"{title.capitalize()} Dataset"
                 # Enhance the dataset summary
                 if isinstance(dataset, FEDataset):
                     extra_rows = list(
                         dataset.summary().__getstate__().items())
                     for idx, (key, val) in enumerate(extra_rows):
                         key = f"{prettify_metric_name(key)}:"
                         if isinstance(val, dict) and val:
                             if isinstance(
                                     list(val.values())[0],
                                 (int, float, str, bool, type(None))):
                                 val = jsonpickle.dumps(val,
                                                        unpicklable=False)
                             else:
                                 subtable = Tabular('l|l')
                                 for k, v in val.items():
                                     if hasattr(v, '__getstate__'):
                                         v = jsonpickle.dumps(
                                             v, unpicklable=False)
                                     subtable.add_row((k, v))
                                 val = subtable
                         extra_rows[idx] = (key, val)
             tbl.render_table(self.doc,
                              name_override=name_override,
                              toc_ref=toc_ref,
                              extra_rows=extra_rows)
예제 #2
0
 def _document_init_params(self) -> None:
     """Add initialization parameters to the traceability document.
     """
     from fastestimator.estimator import Estimator  # Avoid circular import
     with self.doc.create(Section("Parameters")):
         model_ids = {
             FEID(id(model))
             for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module))
         }
         # Locate the datasets in order to provide extra details about them later in the summary
         datasets = {}
         for mode in ['train', 'eval', 'test']:
             objs = to_list(self.system.pipeline.data.get(mode, None))
             idx = 0
             while idx < len(objs):
                 obj = objs[idx]
                 if obj:
                     feid = FEID(id(obj))
                     if feid not in datasets:
                         datasets[feid] = ({mode}, obj)
                     else:
                         datasets[feid][0].add(mode)
                 if isinstance(obj, Scheduler):
                     objs.extend(obj.get_all_values())
                 idx += 1
         # Parse the config tables
         start = 0
         start = self._loop_tables(start,
                                   classes=(Estimator, BaseNetwork, Pipeline),
                                   name="Base Classes",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=Scheduler,
                                   name="Schedulers",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start, classes=Trace, name="Traces", model_ids=model_ids, datasets=datasets)
         start = self._loop_tables(start, classes=Op, name="Ops", model_ids=model_ids, datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(Dataset, tf.data.Dataset),
                                   name="Datasets",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(tf.keras.Model, torch.nn.Module),
                                   name="Models",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=types.FunctionType,
                                   name="Functions",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(np.ndarray, tf.Tensor, tf.Variable, torch.Tensor),
                                   name="Tensors",
                                   model_ids=model_ids,
                                   datasets=datasets)
         self._loop_tables(start, classes=Any, name="Miscellaneous", model_ids=model_ids, datasets=datasets)
예제 #3
0
def fe_summary(self) -> List[FeSummaryTable]:
    """Return a summary of how this class was instantiated (for traceability).

    Args:
        self: The bound class instance.

    Returns:
        A summary of the instance.
    """
    # Delayed imports to avoid circular dependency
    from fastestimator.estimator import Estimator
    from fastestimator.network import TFNetwork, TorchNetwork
    from fastestimator.pipeline import Pipeline
    from fastestimator.op.op import Op
    from fastestimator.trace.trace import Trace
    from fastestimator.schedule.schedule import Scheduler
    from torch.utils.data import Dataset
    # re-number the references for nicer viewing
    ordered_items = sorted(
        self._fe_traceability_summary.items(),
        key=lambda x: 0 if issubclass(x[1].type, Estimator) else 1
        if issubclass(x[1].type, (TFNetwork, TorchNetwork)) else 2 if issubclass(x[1].type, Pipeline) else 3
        if issubclass(x[1].type, Scheduler) else 4 if issubclass(x[1].type, Trace) else 5
        if issubclass(x[1].type, Op) else 6 if issubclass(x[1].type, (Dataset, tf.data.Dataset)) else 7
        if issubclass(x[1].type, (tf.keras.Model, torch.nn.Module)) else 8
        if issubclass(x[1].type, types.FunctionType) else 9
        if issubclass(x[1].type, (np.ndarray, tf.Tensor, tf.Variable, torch.Tensor)) else 10)
    key_mapping = {fe_id: f"@FE{idx}" for idx, (fe_id, val) in enumerate(ordered_items)}
    FEID.set_translation_dict(key_mapping)
    return [item[1] for item in ordered_items]
예제 #4
0
    def fix_split_traceabilty(cls, parent: 'FEDataset',
                              children: List['FEDataset'],
                              fractions: Tuple[Union[float, int,
                                                     Iterable[int]],
                                               ...], seed: Optional[int],
                              stratify: Optional[str]) -> None:
        """A method to fix traceability information after invoking the dataset .split() method.

        Note that the default implementation of the .split() function invokes this already, so this only needs to be
        invoked if you override the .split() method when defining a subclass (ex. BatchDataset).

        Args:
            parent: The parent dataset on which .split() was invoked.
            children: The datasets generated by performing the split.
            fractions: The fraction arguments used to generate the children (should be one-to-one with the children).
            seed: The random seed used to generate the split.
            stratify: The stratify key used to generate the split.
        """
        if hasattr(parent, '_fe_traceability_summary'):
            parent_id = FEID(id(parent))
            fractions = [
                f"range({frac.start}, {frac.stop}, {frac.step})" if isinstance(
                    frac, range) else f"{frac}" for frac in fractions
            ]
            for child, frac in zip(children, fractions):
                # noinspection PyProtectedMember
                tables = deepcopy(child._fe_traceability_summary)
                # Update the ID if necessary
                child_id = FEID(id(child))
                if child_id not in tables:
                    # The child was created without invoking its __init__ method, so its internal summary will have the
                    # wrong id
                    table = tables.pop(parent_id)
                    table.fe_id = child_id
                    tables[child_id] = table
                else:
                    table = tables[child_id]
                split_summary = table.fields.get('split', FeSplitSummary())
                split_summary.add_split(parent=parent_id,
                                        fraction=frac,
                                        seed=seed,
                                        stratify=stratify)
                table.fields['split'] = split_summary
                child._fe_traceability_summary = tables
            # noinspection PyUnresolvedReferences
            table = parent._fe_traceability_summary.get(parent_id)
            split_summary = table.fields.get('split', FeSplitSummary())
            split_summary.add_split(parent='self',
                                    fraction=", ".join(
                                        [f"-{frac}" for frac in fractions]),
                                    seed=seed,
                                    stratify=stratify)
            table.fields['split'] = split_summary
            # Put the new parent summary into the child table to ensure it will always exist in the final set of tables
            for child in children:
                child._fe_traceability_summary[parent_id] = deepcopy(table)
예제 #5
0
    def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            node_id: The id to use as the node label.
        """
        if isinstance(op, Sometimes) and op.numpy_op:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_op, node_id)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, OneOf) and op.numpy_ops:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_ops[0], node_id)
            for sub_op in op.numpy_ops[1:]:
                Traceability._add_node(wrapper, sub_op, str(id(sub_op)))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
예제 #6
0
def trace_model(model: Model, model_idx: int, model_fn: Any, optimizer_fn: Any, weights_path: Any) -> Model:
    """A function to add traceability information to an FE-compiled model.

    Args:
        model: The model to be made traceable.
        model_idx: Which of the return values from the `model_fn` is this model (or -1 if only a single return value).
        model_fn: The function used to generate this model.
        optimizer_fn: The thing used to define this model's optimizer.
        weights_path: The path to the weights for this model.

    Returns:
        The `model`, but now with an fe_summary() method.
    """
    tables = {}
    description = {'definition': _trace_value(model_fn, tables, ret_ref=Flag())}
    if model_idx != -1:
        description['index'] = model_idx
    if optimizer_fn or isinstance(optimizer_fn, list) and optimizer_fn[0] is not None:
        description['optimizer'] = _trace_value(
            optimizer_fn[model_idx] if isinstance(optimizer_fn, list) else optimizer_fn, tables, ret_ref=Flag())
    if weights_path:
        description['weights'] = _trace_value(weights_path, tables, ret_ref=Flag())
    fe_id = FEID(id(model))
    tbl = FeSummaryTable(name=model.model_name, fe_id=fe_id, target_type=type(model), **description)
    tables[fe_id] = tbl
    # Have to put this in a ChainMap b/c dict gets put into model._layers automatically somehow
    model._fe_traceability_summary = ChainMap(tables)

    # Use MethodType to bind the method to the class instance
    setattr(model, 'fe_summary', types.MethodType(fe_summary, model))
    return model
예제 #7
0
    def fix_split_traceabilty(
            cls, parent: 'FEDataset', children: List['FEDataset'],
            fractions: Tuple[Union[float, int, Iterable[int]], ...]) -> None:
        """A method to fix traceability information after invoking the dataset .split() method.

        Note that the default implementation of the .split() function invokes this already, so this only needs to be
        invoked if you override the .split() method when defining a subclass (ex. BatchDataset).

        Args:
            parent: The parent dataset on which .split() was invoked.
            children: The datasets generated by performing the split.
            fractions: The fraction arguments used to generate the children (should be one-to-one with the children).
        """
        if hasattr(parent, '_fe_traceability_summary'):
            parent_id = FEID(id(parent))
            for child, frac in zip(children, fractions):
                # noinspection PyProtectedMember
                tables = deepcopy(child._fe_traceability_summary)
                # Update the ID if necessary
                child_id = FEID(id(child))
                if child_id not in tables:
                    # The child was created without invoking its __init__ method, so its internal summary will have the
                    # wrong id
                    table = tables.pop(parent_id)
                    table.fe_id = child_id
                    tables[child_id] = table
                else:
                    table = tables[child_id]
                split_text = table.fields.get('split', '')
                if split_text:
                    split_text = f" | {split_text}"
                if isinstance(frac, range):
                    frac = f"range({frac.start}, {frac.stop}, {frac.step})"
                table.fields[
                    'split'] = f"Child with final size: {len(child)}, split param: {frac}{split_text}"
                child._fe_traceability_summary = tables
            # noinspection PyUnresolvedReferences
            table = parent._fe_traceability_summary.get(parent_id)
            split_text = table.fields.get('split', '')
            if split_text:
                split_text = f" | {split_text}"
            table.fields[
                'split'] = f"Parent with final size: {len(parent)}{split_text}"
예제 #8
0
    def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str,
                                                                        str],
                       subgraph_name: str,
                       subgraph_ops: List[Union[Op, Trace]]) -> None:
        """Draw a subgraph of ops into an existing `diagram`.

        Args:
            diagram: The diagram to be appended to.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            subgraph_name: The name to be associated with this subgraph.
            subgraph_ops: The ops to be wrapped in this subgraph.
        """
        subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name)
        subgraph.set('label', subgraph_name)
        subgraph.set('labeljust', 'l')
        for idx, op in enumerate(subgraph_ops):
            node_id = str(id(op))
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            edge_srcs = defaultdict(lambda: [])
            for inp in op.inputs:
                if inp == '*':
                    continue
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                diagram.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               label=f" {', '.join(labels)} "))
            for out in op.outputs:
                label_last_seen[out] = node_id
            if isinstance(op, Trace) and idx > 0:
                # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image
                diagram.add_edge(
                    pydot.Edge(src=str(id(subgraph_ops[idx - 1])),
                               dst=node_id,
                               style='invis'))
        diagram.add_subgraph(subgraph)
예제 #9
0
    def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot:
        """Draw a summary diagram of the FastEstimator Ops / Traces.

        Args:
            mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer').
            epoch: The epoch to summarize.

        Returns:
            A pydot digraph representing the execution flow.
        """
        ds = self.system.pipeline.data[mode]
        if isinstance(ds, Scheduler):
            ds = ds.get_current_value(epoch)
        pipe_ops = get_current_items(
            self.system.pipeline.ops, run_modes=mode,
            epoch=epoch) if isinstance(ds, FEDataset) else []
        net_ops = get_current_items(self.system.network.ops,
                                    run_modes=mode,
                                    epoch=epoch)
        traces = get_current_items(self.system.traces,
                                   run_modes=mode,
                                   epoch=epoch)
        diagram = pydot.Dot()
        diagram.set('rankdir', 'TB')
        diagram.set('dpi', 300)
        diagram.set_node_defaults(shape='record')
        diagram.add_node(
            pydot.Node(str(id(ds)),
                       label=f'{ds.__class__.__name__} ({FEID(id(ds))})',
                       texlbl=HrefFEID(FEID(id(ds)),
                                       name=ds.__class__.__name__).dumps()))
        label_last_seen = defaultdict(
            lambda: str(id(ds)))  # Where was this key last generated

        self._draw_subgraph(diagram, label_last_seen, 'Pipeline', pipe_ops)
        self._draw_subgraph(diagram, label_last_seen, 'Network', net_ops)
        self._draw_subgraph(diagram, label_last_seen, 'Traces', traces)
        return diagram
예제 #10
0
    def _add_node(progenitor: pydot.Dot,
                  diagram: Union[pydot.Dot, pydot.Cluster],
                  op: Union[Op, Trace, Any],
                  label_last_seen: DefaultDict[str, str],
                  edges: bool = True) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            progenitor: The very top level diagram onto which Edges should be written.
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            edges: Whether to write Edges to/from this Node.
        """
        node_id = str(id(op))
        if isinstance(op, (Sometimes, SometimesT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='red',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'l')
            edge_srcs = defaultdict(lambda: [])
            if op.extra_inputs:
                for inp in op.extra_inputs:
                    if inp == '*':
                        continue
                    edge_srcs[label_last_seen[inp]].append(inp)
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            diagram.add_subgraph(wrapper)
            dst_id = Traceability._get_all_nodes(wrapper)[0].get_name()
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=dst_id,
                               lhead=wrapper.get_name(),
                               label=f" {', '.join(labels)} "))
        elif isinstance(op, (OneOf, OneOfT)) and op.ops:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkorchid4',
                                    graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'l')
            Traceability._add_node(progenitor,
                                   wrapper,
                                   op.ops[0],
                                   label_last_seen,
                                   edges=True)
            for sub_op in op.ops[1:]:
                Traceability._add_node(progenitor,
                                       wrapper,
                                       sub_op,
                                       label_last_seen,
                                       edges=False)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, (Fuse, FuseT)) and op.ops:
            Traceability._draw_subgraph(progenitor, diagram, label_last_seen,
                                        f'Fuse:', op.ops)
        elif isinstance(op, (Repeat, RepeatT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkgreen',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Repeat:')
            wrapper.set('labeljust', 'l')
            wrapper.add_node(
                pydot.Node(
                    node_id,
                    label=f'{op.repeat if isinstance(op.repeat, int) else "?"}',
                    shape='doublecircle',
                    width=0.1))
            # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color
            progenitor.add_edge(
                pydot.Edge(src=node_id + ":ne",
                           dst=node_id + ":w",
                           color='#006300'))
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            # Add repeat edges
            edge_srcs = defaultdict(lambda: [])
            for out in op.outputs:
                if out in op.inputs and out not in op.repeat_inputs:
                    edge_srcs[label_last_seen[out]].append(out)
            for inp in op.repeat_inputs:
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               constraint=False,
                               label=f" {', '.join(labels)} "))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            if isinstance(op, (Op, Trace)) and edges:
                # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader
                Traceability._add_edge(progenitor, op, label_last_seen)
예제 #11
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.resource_dir,
                                "{}_{}.pdf".format(self.report_name,
                                                   model.model_name))
                            dot = tf.keras.utils.model_to_dot(
                                model, show_shapes=True, expand_nested=True)
                            # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
                            # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
                            # set the limit lower (100 inches) to leave some wiggle room.
                            dot.set('size', '100')
                            dot.write(file_path, format='pdf')
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(
                                    pms.summary(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs,
                                        print_summary=False)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))
                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
                                # than 226 inches. However, the 'size' parameter doesn't account for the whole node
                                # height, so set the limit lower (100 inches) to leave some wiggle room.
                                graph.attr(size="100,100")
                                graph.attr(margin='0')
                                file_path = graph.render(
                                    filename="{}_{}".format(
                                        self.report_name, model.model_name),
                                    directory=self.resource_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            file_path = None
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
예제 #12
0
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any:
    """Convert an input value to a FESummaryTable table representation

    Args:
        inp: The input value to be converted.
        tables: A collection of tables representing objects which are used by the current stack of inputs.
        ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether
            functions can be in-lined or deserve their own tables).
        wrap_str: Whether literal string values should be wrapped inside extra quote marks.

    Returns:
        An FESummaryTable representation of the input.
    """
    if isinstance(inp, str):
        inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp)
        if wrap_str:
            # Prevent extremely long strings from overflowing the table
            return NoEscape(r'\seqsplit{' + inp + '}')
        return inp
    elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)):
        if isinstance(inp, (int, float)):
            # Prevent extremely long numbers from overflowing the table
            return NoEscape(r'\seqsplit{' + str(inp) + '}')
        return inp
    elif hasattr(inp, '_fe_traceability_summary'):
        # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will
        # fall through to the class check at the end to get it's id.
        # noinspection PyProtectedMember,PyUnresolvedReferences
        tables.update(inp._fe_traceability_summary)
        inp_id = FEID(id(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, tables[inp_id].name)
    elif inspect.ismethod(inp):
        parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str)
        return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")])
    elif inspect.isfunction(inp) or inspect.isclass(inp):
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            if inspect.isfunction(inp) and inp.__name__ == "<lambda>":
                code = inp.__code__
                var_names = code.co_varnames
                # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other
                # function (like one might do with LRScheduler), then the parse should work.
                flag = Flag()
                func_description = _parse_lambda(inp, tables, flag) or {}
                func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False)
                name = "lambda"
                path = None
                if not flag and func_description.keys() == {'vars', 'function'}:
                    # This is a simple lambda function, so inline it instead of making a new table
                    raw_vars = func_description['vars'].raw_input
                    formatted_vars = []
                    for var in raw_vars:
                        formatted_vars.append(var)
                        formatted_vars.append(', ')
                    if formatted_vars:
                        formatted_vars.pop()  # remove trailing comma
                    return ContainerList(data=[
                        TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '')
                    ])
            else:
                name = inp.__name__
                path = f"{inp.__module__}.{inp.__qualname__}"
                func_description = {}
            tables[inp_id] = FeSummaryTable(name=name,
                                            fe_id=inp_id,
                                            target_type=type(inp),
                                            path=path,
                                            **func_description)
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, _Function):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if inspect.ismethod(inp.func):
                path = _trace_value(inp.func, tables, ret_ref, wrap_str)
            elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'):
                path = f"{inp.func.__module__}.{inp.func.__qualname__}"
            else:
                path = None
            tables[inp_id] = FeSummaryTable(name=inp.name, fe_id=inp_id, target_type=type(inp.func), path=path)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.name)
    elif isinstance(inp, _PartialBind):
        return {
            "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True),
            "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input  # unwrap kwargs back into a dict
        }
    elif isinstance(inp, _Command):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            escape_latex(inp.command),
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _Condition):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            " if ",
            _trace_value(inp.condition, tables, ret_ref, wrap_str),
            " else ",
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _BoundFn):
        flag = Flag()
        args = _trace_value(inp.args, tables, flag, wrap_str=False)
        kwargs = {}
        if isinstance(inp.args, _PartialBind):
            kwargs = args["kwargs"]
            args = args["args"]
        elif isinstance(args, dict):
            kwargs = args
            args = None
        if not flag and isinstance(inp.func, _Function):
            # The function args are simple, so inline this function in whatever is above it
            if isinstance(args, PyContainer):
                args = args.raw_input
            if isinstance(kwargs, PyContainer):
                kwargs = kwargs.raw_input
            formatted = ["("]
            args = args or ()
            kwargs = kwargs or {}
            for arg in args:
                formatted.append(arg)
                formatted.append(", ")
            for key, value in kwargs.items():
                formatted.append(key)
                formatted.append("=")
                formatted.append(value)
                formatted.append(", ")
            if len(formatted) > 1:
                formatted.pop()  # Remove trailing comma
            formatted.append(")")
            if inspect.ismethod(inp.func.func):
                container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str)
                container_list.data.extend(formatted)
                return container_list
            return ContainerList(data=[inp.func.name, *formatted])
        else:
            # The function args are complicated, so use the normal approach
            func_href = _trace_value(inp.func, tables, ret_ref, wrap_str)
            inp_id = func_href.fe_id
            inp_table = tables[inp_id]
            inp_table.args = args
            inp_table.kwargs = kwargs
            ret_ref.set_true()
            return func_href
    elif isinstance(inp, inspect.BoundArguments):
        args = inp.arguments
        args.pop('self', None)
        return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input  # unwrap kwargs back into a dict
    elif isinstance(inp, _VarWrap):
        return inp.var
    elif isinstance(inp, (tf.keras.Model, torch.nn.Module)):
        # FE models should never actually get here since they are given summaries by trace_model() during fe.build()
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>"
            tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, list):
        return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp],
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, tuple):
        return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, set):
        return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, dict):
        return PyContainer(
            data={
                _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True)
                for k,
                v in inp.items()
            },
            truncate=_CollectionSizeLimit)
    elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)):
        inp_type = type(inp)
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)):
                if isinstance(inp, torch.Tensor):
                    inp = inp.cpu().detach()
                    inp.numpy()
                # In the elif here we're sure to be tf
                elif inp.dtype != tf.dtypes.variant:
                    inp = inp.numpy()  # The variant dtype can't be cast to numpy()
            rank = inp.ndim
            description = {'shape': inp.shape}
            if rank == 0 or (rank == 1 and inp.shape[0] <= 10):
                description['values'] = str(inp)
            tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description)
        ret_ref.set_true()
        return HrefFEID(inp_id, "tensor")
    # This should be the last elif
    elif hasattr(inp, '__class__'):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            kwargs = {}
            path = None
            if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__:
                # Prevent circular recursion
                tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id)
                # This object isn't @traceable but does have some stored variables that we can summarize.
                kwargs = _trace_value({k: v
                                       for k, v in inp.__dict__.items() if not k.startswith("_")},
                                      tables,
                                      ret_ref,
                                      wrap_str=False).raw_input
                path = "Not @traceable, so summary is approximate"
            tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__,
                                            target_type=type(inp),
                                            path=path,
                                            fe_id=inp_id,
                                            kwargs=kwargs)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.__class__.__name__)
    else:
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id)
        ret_ref.set_true()
        return HrefFEID(inp_id, "an object")
예제 #13
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.figure_dir,
                                f"FE_Model_{model.model_name}.pdf")
                            tf.keras.utils.plot_model(model,
                                                      to_file=file_path,
                                                      show_shapes=True,
                                                      expand_nested=True)
                            # TODO - cap output image size like in the pytorch implementation in case of huge network
                            # TODO - save raw .dot file in case system lacks graphviz
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(pms.summary(model, inputs)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))

                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(model, inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                graph.attr(
                                    size="200,200"
                                )  # LaTeX \maxdim is around 575cm (226 inches)
                                graph.attr(margin='0')
                                # TODO - save raw .dot file in case system lacks graphviz
                                file_path = graph.render(
                                    filename=f"FE_Model_{model.model_name}",
                                    directory=self.figure_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))