def _document_fe_graph(self) -> None:
     """Add FE execution graphs into the traceability document.
     """
     with self.doc.create(Section("FastEstimator Architecture")):
         for mode in self.system.pipeline.data.keys():
             self.doc.append(NoEscape(r'\FloatBarrier'))
             with self.doc.create(Subsection(mode.capitalize())):
                 scheduled_items = self.system.pipeline.get_scheduled_items(
                     mode) + self.system.network.get_scheduled_items(
                         mode) + self.system.traces
                 signature_epochs = get_signature_epochs(
                     scheduled_items,
                     total_epochs=self.system.epoch_idx,
                     mode=mode)
                 epochs_with_data = self.system.pipeline.get_epochs_with_data(
                     total_epochs=self.system.epoch_idx, mode=mode)
                 for epoch in signature_epochs:
                     if epoch not in epochs_with_data:
                         continue
                     self.doc.append(NoEscape(r'\FloatBarrier'))
                     with self.doc.create(
                             Subsubsection(f"Epoch {epoch}",
                                           label=Label(
                                               Marker(name=f"{mode}{epoch}",
                                                      prefix="ssubsec")))):
                         diagram = self._draw_diagram(mode, epoch)
                         ltx = d2t.dot2tex(diagram.to_string(),
                                           figonly=True)
                         args = Arguments(
                             **{
                                 'max width':
                                 r'\textwidth, max height=0.9\textheight'
                             })
                         args.escape = False
                         with self.doc.create(
                                 AdjustBox(arguments=args)) as box:
                             box.append(NoEscape(ltx))
    def _add_node(progenitor: pydot.Dot,
                  diagram: Union[pydot.Dot, pydot.Cluster],
                  op: Union[Op, Trace, Any],
                  label_last_seen: DefaultDict[str, str],
                  edges: bool = True) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            progenitor: The very top level diagram onto which Edges should be written.
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            edges: Whether to write Edges to/from this Node.
        """
        node_id = str(id(op))
        if isinstance(op, (Sometimes, SometimesT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='red',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'l')
            edge_srcs = defaultdict(lambda: [])
            if op.extra_inputs:
                for inp in op.extra_inputs:
                    if inp == '*':
                        continue
                    edge_srcs[label_last_seen[inp]].append(inp)
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            diagram.add_subgraph(wrapper)
            dst_id = Traceability._get_all_nodes(wrapper)[0].get_name()
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=dst_id,
                               lhead=wrapper.get_name(),
                               label=f" {', '.join(labels)} "))
        elif isinstance(op, (OneOf, OneOfT)) and op.ops:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkorchid4',
                                    graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'l')
            Traceability._add_node(progenitor,
                                   wrapper,
                                   op.ops[0],
                                   label_last_seen,
                                   edges=True)
            for sub_op in op.ops[1:]:
                Traceability._add_node(progenitor,
                                       wrapper,
                                       sub_op,
                                       label_last_seen,
                                       edges=False)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, (Fuse, FuseT)) and op.ops:
            Traceability._draw_subgraph(progenitor, diagram, label_last_seen,
                                        f'Fuse:', op.ops)
        elif isinstance(op, (Repeat, RepeatT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkgreen',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Repeat:')
            wrapper.set('labeljust', 'l')
            wrapper.add_node(
                pydot.Node(
                    node_id,
                    label=f'{op.repeat if isinstance(op.repeat, int) else "?"}',
                    shape='doublecircle',
                    width=0.1))
            # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color
            progenitor.add_edge(
                pydot.Edge(src=node_id + ":ne",
                           dst=node_id + ":w",
                           color='#006300'))
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            # Add repeat edges
            edge_srcs = defaultdict(lambda: [])
            for out in op.outputs:
                if out in op.inputs and out not in op.repeat_inputs:
                    edge_srcs[label_last_seen[out]].append(out)
            for inp in op.repeat_inputs:
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               constraint=False,
                               label=f" {', '.join(labels)} "))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            if isinstance(op, (Op, Trace)) and edges:
                # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader
                Traceability._add_edge(progenitor, op, label_last_seen)
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.resource_dir,
                                "{}_{}.pdf".format(self.report_name,
                                                   model.model_name))
                            dot = tf.keras.utils.model_to_dot(
                                model, show_shapes=True, expand_nested=True)
                            # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
                            # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
                            # set the limit lower (100 inches) to leave some wiggle room.
                            dot.set('size', '100')
                            dot.write(file_path, format='pdf')
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(
                                    pms.summary(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs,
                                        print_summary=False)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))
                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
                                # than 226 inches. However, the 'size' parameter doesn't account for the whole node
                                # height, so set the limit lower (100 inches) to leave some wiggle room.
                                graph.attr(size="100,100")
                                graph.attr(margin='0')
                                file_path = graph.render(
                                    filename="{}_{}".format(
                                        self.report_name, model.model_name),
                                    directory=self.resource_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            file_path = None
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.figure_dir,
                                f"FE_Model_{model.model_name}.pdf")
                            tf.keras.utils.plot_model(model,
                                                      to_file=file_path,
                                                      show_shapes=True,
                                                      expand_nested=True)
                            # TODO - cap output image size like in the pytorch implementation in case of huge network
                            # TODO - save raw .dot file in case system lacks graphviz
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(pms.summary(model, inputs)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))

                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(model, inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                graph.attr(
                                    size="200,200"
                                )  # LaTeX \maxdim is around 575cm (226 inches)
                                graph.attr(margin='0')
                                # TODO - save raw .dot file in case system lacks graphviz
                                file_path = graph.render(
                                    filename=f"FE_Model_{model.model_name}",
                                    directory=self.figure_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
Exemple #5
0
def test_hyperref():
    hr = Hyperref(Marker("marker", "prefix"), "text")
    repr(hr)