Beispiel #1
0
    def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            node_id: The id to use as the node label.
        """
        if isinstance(op, Sometimes) and op.numpy_op:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_op, node_id)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, OneOf) and op.numpy_ops:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_ops[0], node_id)
            for sub_op in op.numpy_ops[1:]:
                Traceability._add_node(wrapper, sub_op, str(id(sub_op)))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
    def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str,
                                                                        str],
                       subgraph_name: str,
                       subgraph_ops: List[Union[Op, Trace]]) -> None:
        """Draw a subgraph of ops into an existing `diagram`.

        Args:
            diagram: The diagram to be appended to.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            subgraph_name: The name to be associated with this subgraph.
            subgraph_ops: The ops to be wrapped in this subgraph.
        """
        subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name)
        subgraph.set('label', subgraph_name)
        subgraph.set('labeljust', 'l')
        for idx, op in enumerate(subgraph_ops):
            node_id = str(id(op))
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            edge_srcs = defaultdict(lambda: [])
            for inp in op.inputs:
                if inp == '*':
                    continue
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                diagram.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               label=f" {', '.join(labels)} "))
            for out in op.outputs:
                label_last_seen[out] = node_id
            if isinstance(op, Trace) and idx > 0:
                # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image
                diagram.add_edge(
                    pydot.Edge(src=str(id(subgraph_ops[idx - 1])),
                               dst=node_id,
                               style='invis'))
        diagram.add_subgraph(subgraph)
    def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot:
        """Draw a summary diagram of the FastEstimator Ops / Traces.

        Args:
            mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer').
            epoch: The epoch to summarize.

        Returns:
            A pydot digraph representing the execution flow.
        """
        ds = self.system.pipeline.data[mode]
        if isinstance(ds, Scheduler):
            ds = ds.get_current_value(epoch)
        pipe_ops = get_current_items(
            self.system.pipeline.ops, run_modes=mode,
            epoch=epoch) if isinstance(ds, FEDataset) else []
        net_ops = get_current_items(self.system.network.ops,
                                    run_modes=mode,
                                    epoch=epoch)
        traces = get_current_items(self.system.traces,
                                   run_modes=mode,
                                   epoch=epoch)
        diagram = pydot.Dot()
        diagram.set('rankdir', 'TB')
        diagram.set('dpi', 300)
        diagram.set_node_defaults(shape='record')
        diagram.add_node(
            pydot.Node(str(id(ds)),
                       label=f'{ds.__class__.__name__} ({FEID(id(ds))})',
                       texlbl=HrefFEID(FEID(id(ds)),
                                       name=ds.__class__.__name__).dumps()))
        label_last_seen = defaultdict(
            lambda: str(id(ds)))  # Where was this key last generated

        self._draw_subgraph(diagram, label_last_seen, 'Pipeline', pipe_ops)
        self._draw_subgraph(diagram, label_last_seen, 'Network', net_ops)
        self._draw_subgraph(diagram, label_last_seen, 'Traces', traces)
        return diagram
Beispiel #4
0
    def _add_node(progenitor: pydot.Dot,
                  diagram: Union[pydot.Dot, pydot.Cluster],
                  op: Union[Op, Trace, Any],
                  label_last_seen: DefaultDict[str, str],
                  edges: bool = True) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            progenitor: The very top level diagram onto which Edges should be written.
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            edges: Whether to write Edges to/from this Node.
        """
        node_id = str(id(op))
        if isinstance(op, (Sometimes, SometimesT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='red',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'l')
            edge_srcs = defaultdict(lambda: [])
            if op.extra_inputs:
                for inp in op.extra_inputs:
                    if inp == '*':
                        continue
                    edge_srcs[label_last_seen[inp]].append(inp)
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            diagram.add_subgraph(wrapper)
            dst_id = Traceability._get_all_nodes(wrapper)[0].get_name()
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=dst_id,
                               lhead=wrapper.get_name(),
                               label=f" {', '.join(labels)} "))
        elif isinstance(op, (OneOf, OneOfT)) and op.ops:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkorchid4',
                                    graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'l')
            Traceability._add_node(progenitor,
                                   wrapper,
                                   op.ops[0],
                                   label_last_seen,
                                   edges=True)
            for sub_op in op.ops[1:]:
                Traceability._add_node(progenitor,
                                       wrapper,
                                       sub_op,
                                       label_last_seen,
                                       edges=False)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, (Fuse, FuseT)) and op.ops:
            Traceability._draw_subgraph(progenitor, diagram, label_last_seen,
                                        f'Fuse:', op.ops)
        elif isinstance(op, (Repeat, RepeatT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkgreen',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Repeat:')
            wrapper.set('labeljust', 'l')
            wrapper.add_node(
                pydot.Node(
                    node_id,
                    label=f'{op.repeat if isinstance(op.repeat, int) else "?"}',
                    shape='doublecircle',
                    width=0.1))
            # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color
            progenitor.add_edge(
                pydot.Edge(src=node_id + ":ne",
                           dst=node_id + ":w",
                           color='#006300'))
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            # Add repeat edges
            edge_srcs = defaultdict(lambda: [])
            for out in op.outputs:
                if out in op.inputs and out not in op.repeat_inputs:
                    edge_srcs[label_last_seen[out]].append(out)
            for inp in op.repeat_inputs:
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               constraint=False,
                               label=f" {', '.join(labels)} "))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            if isinstance(op, (Op, Trace)) and edges:
                # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader
                Traceability._add_edge(progenitor, op, label_last_seen)
Beispiel #5
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.resource_dir,
                                "{}_{}.pdf".format(self.report_name,
                                                   model.model_name))
                            dot = tf.keras.utils.model_to_dot(
                                model, show_shapes=True, expand_nested=True)
                            # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
                            # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
                            # set the limit lower (100 inches) to leave some wiggle room.
                            dot.set('size', '100')
                            dot.write(file_path, format='pdf')
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(
                                    pms.summary(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs,
                                        print_summary=False)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))
                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
                                # than 226 inches. However, the 'size' parameter doesn't account for the whole node
                                # height, so set the limit lower (100 inches) to leave some wiggle room.
                                graph.attr(size="100,100")
                                graph.attr(margin='0')
                                file_path = graph.render(
                                    filename="{}_{}".format(
                                        self.report_name, model.model_name),
                                    directory=self.resource_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            file_path = None
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any:
    """Convert an input value to a FESummaryTable table representation

    Args:
        inp: The input value to be converted.
        tables: A collection of tables representing objects which are used by the current stack of inputs.
        ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether
            functions can be in-lined or deserve their own tables).
        wrap_str: Whether literal string values should be wrapped inside extra quote marks.

    Returns:
        An FESummaryTable representation of the input.
    """
    if isinstance(inp, str):
        inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp)
        if wrap_str:
            # Prevent extremely long strings from overflowing the table
            return NoEscape(r'\seqsplit{' + inp + '}')
        return inp
    elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)):
        if isinstance(inp, (int, float)):
            # Prevent extremely long numbers from overflowing the table
            return NoEscape(r'\seqsplit{' + str(inp) + '}')
        return inp
    elif hasattr(inp, '_fe_traceability_summary'):
        # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will
        # fall through to the class check at the end to get it's id.
        # noinspection PyProtectedMember,PyUnresolvedReferences
        tables.update(inp._fe_traceability_summary)
        inp_id = FEID(id(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, tables[inp_id].name)
    elif inspect.ismethod(inp):
        parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str)
        return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")])
    elif inspect.isfunction(inp) or inspect.isclass(inp):
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            if inspect.isfunction(inp) and inp.__name__ == "<lambda>":
                code = inp.__code__
                var_names = code.co_varnames
                # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other
                # function (like one might do with LRScheduler), then the parse should work.
                flag = Flag()
                func_description = _parse_lambda(inp, tables, flag) or {}
                func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False)
                name = "lambda"
                path = None
                if not flag and func_description.keys() == {'vars', 'function'}:
                    # This is a simple lambda function, so inline it instead of making a new table
                    raw_vars = func_description['vars'].raw_input
                    formatted_vars = []
                    for var in raw_vars:
                        formatted_vars.append(var)
                        formatted_vars.append(', ')
                    if formatted_vars:
                        formatted_vars.pop()  # remove trailing comma
                    return ContainerList(data=[
                        TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '')
                    ])
            else:
                name = inp.__name__
                path = f"{inp.__module__}.{inp.__qualname__}"
                func_description = {}
            tables[inp_id] = FeSummaryTable(name=name,
                                            fe_id=inp_id,
                                            target_type=type(inp),
                                            path=path,
                                            **func_description)
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, _Function):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if inspect.ismethod(inp.func):
                path = _trace_value(inp.func, tables, ret_ref, wrap_str)
            elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'):
                path = f"{inp.func.__module__}.{inp.func.__qualname__}"
            else:
                path = None
            tables[inp_id] = FeSummaryTable(name=inp.name, fe_id=inp_id, target_type=type(inp.func), path=path)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.name)
    elif isinstance(inp, _PartialBind):
        return {
            "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True),
            "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input  # unwrap kwargs back into a dict
        }
    elif isinstance(inp, _Command):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            escape_latex(inp.command),
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _Condition):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            " if ",
            _trace_value(inp.condition, tables, ret_ref, wrap_str),
            " else ",
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _BoundFn):
        flag = Flag()
        args = _trace_value(inp.args, tables, flag, wrap_str=False)
        kwargs = {}
        if isinstance(inp.args, _PartialBind):
            kwargs = args["kwargs"]
            args = args["args"]
        elif isinstance(args, dict):
            kwargs = args
            args = None
        if not flag and isinstance(inp.func, _Function):
            # The function args are simple, so inline this function in whatever is above it
            if isinstance(args, PyContainer):
                args = args.raw_input
            if isinstance(kwargs, PyContainer):
                kwargs = kwargs.raw_input
            formatted = ["("]
            args = args or ()
            kwargs = kwargs or {}
            for arg in args:
                formatted.append(arg)
                formatted.append(", ")
            for key, value in kwargs.items():
                formatted.append(key)
                formatted.append("=")
                formatted.append(value)
                formatted.append(", ")
            if len(formatted) > 1:
                formatted.pop()  # Remove trailing comma
            formatted.append(")")
            if inspect.ismethod(inp.func.func):
                container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str)
                container_list.data.extend(formatted)
                return container_list
            return ContainerList(data=[inp.func.name, *formatted])
        else:
            # The function args are complicated, so use the normal approach
            func_href = _trace_value(inp.func, tables, ret_ref, wrap_str)
            inp_id = func_href.fe_id
            inp_table = tables[inp_id]
            inp_table.args = args
            inp_table.kwargs = kwargs
            ret_ref.set_true()
            return func_href
    elif isinstance(inp, inspect.BoundArguments):
        args = inp.arguments
        args.pop('self', None)
        return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input  # unwrap kwargs back into a dict
    elif isinstance(inp, _VarWrap):
        return inp.var
    elif isinstance(inp, (tf.keras.Model, torch.nn.Module)):
        # FE models should never actually get here since they are given summaries by trace_model() during fe.build()
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>"
            tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, list):
        return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp],
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, tuple):
        return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, set):
        return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, dict):
        return PyContainer(
            data={
                _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True)
                for k,
                v in inp.items()
            },
            truncate=_CollectionSizeLimit)
    elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)):
        inp_type = type(inp)
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)):
                if isinstance(inp, torch.Tensor):
                    inp = inp.cpu().detach()
                    inp.numpy()
                # In the elif here we're sure to be tf
                elif inp.dtype != tf.dtypes.variant:
                    inp = inp.numpy()  # The variant dtype can't be cast to numpy()
            rank = inp.ndim
            description = {'shape': inp.shape}
            if rank == 0 or (rank == 1 and inp.shape[0] <= 10):
                description['values'] = str(inp)
            tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description)
        ret_ref.set_true()
        return HrefFEID(inp_id, "tensor")
    # This should be the last elif
    elif hasattr(inp, '__class__'):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            kwargs = {}
            path = None
            if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__:
                # Prevent circular recursion
                tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id)
                # This object isn't @traceable but does have some stored variables that we can summarize.
                kwargs = _trace_value({k: v
                                       for k, v in inp.__dict__.items() if not k.startswith("_")},
                                      tables,
                                      ret_ref,
                                      wrap_str=False).raw_input
                path = "Not @traceable, so summary is approximate"
            tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__,
                                            target_type=type(inp),
                                            path=path,
                                            fe_id=inp_id,
                                            kwargs=kwargs)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.__class__.__name__)
    else:
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id)
        ret_ref.set_true()
        return HrefFEID(inp_id, "an object")
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.figure_dir,
                                f"FE_Model_{model.model_name}.pdf")
                            tf.keras.utils.plot_model(model,
                                                      to_file=file_path,
                                                      show_shapes=True,
                                                      expand_nested=True)
                            # TODO - cap output image size like in the pytorch implementation in case of huge network
                            # TODO - save raw .dot file in case system lacks graphviz
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(pms.summary(model, inputs)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))

                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(model, inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                graph.attr(
                                    size="200,200"
                                )  # LaTeX \maxdim is around 575cm (226 inches)
                                graph.attr(margin='0')
                                # TODO - save raw .dot file in case system lacks graphviz
                                file_path = graph.render(
                                    filename=f"FE_Model_{model.model_name}",
                                    directory=self.figure_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))