Ejemplo n.º 1
0
    def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            node_id: The id to use as the node label.
        """
        if isinstance(op, Sometimes) and op.numpy_op:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_op, node_id)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, OneOf) and op.numpy_ops:
            wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'r')
            Traceability._add_node(wrapper, op.numpy_ops[0], node_id)
            for sub_op in op.numpy_ops[1:]:
                Traceability._add_node(wrapper, sub_op, str(id(sub_op)))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
Ejemplo n.º 2
0
 def _document_init_params(self) -> None:
     """Add initialization parameters to the traceability document.
     """
     with self.doc.create(Section("Parameters")):
         model_ids = {
             FEID(id(model))
             for model in self.system.network.models
             if isinstance(model, (tf.keras.Model, torch.nn.Module))
         }
         datasets = {
             FEID(id(self.system.pipeline.data.get(title, None))):
             (title, self.system.pipeline.data.get(title, None))
             for title in ['train', 'eval', 'test']
         }
         for tbl in self.config_tables:
             name_override = None
             toc_ref = None
             extra_rows = None
             if issubclass(tbl.type, Estimator):
                 toc_ref = "Estimator"
             if issubclass(tbl.type, BaseNetwork):
                 toc_ref = "Network"
             if issubclass(tbl.type, Pipeline):
                 toc_ref = "Pipeline"
             if tbl.fe_id in model_ids:
                 # Link to a later detailed model description
                 name_override = Hyperref(
                     Marker(name=str(tbl.name), prefix="subsec"),
                     text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) +
                     NoEscape('}'))
                 toc_ref = tbl.name
             if tbl.fe_id in datasets:
                 title, dataset = datasets[tbl.fe_id]
                 name_override = bold(f'{tbl.name} ({title.capitalize()})')
                 toc_ref = f"{title.capitalize()} Dataset"
                 # Enhance the dataset summary
                 if isinstance(dataset, FEDataset):
                     extra_rows = list(
                         dataset.summary().__getstate__().items())
                     for idx, (key, val) in enumerate(extra_rows):
                         key = f"{prettify_metric_name(key)}:"
                         if isinstance(val, dict) and val:
                             if isinstance(
                                     list(val.values())[0],
                                 (int, float, str, bool, type(None))):
                                 val = jsonpickle.dumps(val,
                                                        unpicklable=False)
                             else:
                                 subtable = Tabular('l|l')
                                 for k, v in val.items():
                                     if hasattr(v, '__getstate__'):
                                         v = jsonpickle.dumps(
                                             v, unpicklable=False)
                                     subtable.add_row((k, v))
                                 val = subtable
                         extra_rows[idx] = (key, val)
             tbl.render_table(self.doc,
                              name_override=name_override,
                              toc_ref=toc_ref,
                              extra_rows=extra_rows)
Ejemplo n.º 3
0
    def _write_tables(self, tables: List[FeSummaryTable], model_ids: Set[FEID],
                      datasets: Dict[FEID, Tuple[Set[str], Any]]) -> None:
        """Insert a LaTeX representation of a list of tables into the current doc.

        Args:
            tables: The tables to write into the doc.
            model_ids: The ids of any known models.
            datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information.
        """
        for tbl in tables:
            name_override = None
            toc_ref = None
            extra_rows = None
            if tbl.fe_id in model_ids:
                # Link to a later detailed model description
                name_override = Hyperref(Marker(name=str(tbl.name),
                                                prefix="subsec"),
                                         text=NoEscape(r'\textcolor{blue}{') +
                                         bold(tbl.name) + NoEscape('}'))
            if tbl.fe_id in datasets:
                modes, dataset = datasets[tbl.fe_id]
                title = ", ".join([s.capitalize() for s in modes])
                name_override = bold(f'{tbl.name} ({title})')
                # Enhance the dataset summary
                if isinstance(dataset, FEDataset):
                    extra_rows = list(dataset.summary().__getstate__().items())
                    for idx, (key, val) in enumerate(extra_rows):
                        key = f"{prettify_metric_name(key)}:"
                        if isinstance(val, dict) and val:
                            if isinstance(
                                    list(val.values())[0],
                                (int, float, str, bool, type(None))):
                                val = jsonpickle.dumps(val, unpicklable=False)
                            else:
                                subtable = Tabularx(
                                    'l|X',
                                    width_argument=NoEscape(r'\linewidth'))
                                for k, v in val.items():
                                    if hasattr(v, '__getstate__'):
                                        v = jsonpickle.dumps(v,
                                                             unpicklable=False)
                                    subtable.add_row((k, v))
                                # To nest TabularX, have to wrap it in brackets
                                subtable = ContainerList(data=[
                                    NoEscape("{"), subtable,
                                    NoEscape("}")
                                ])
                                val = subtable
                        extra_rows[idx] = (key, val)
            tbl.render_table(self.doc,
                             name_override=name_override,
                             toc_ref=toc_ref,
                             extra_rows=extra_rows)
Ejemplo n.º 4
0
    def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str,
                                                                        str],
                       subgraph_name: str,
                       subgraph_ops: List[Union[Op, Trace]]) -> None:
        """Draw a subgraph of ops into an existing `diagram`.

        Args:
            diagram: The diagram to be appended to.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            subgraph_name: The name to be associated with this subgraph.
            subgraph_ops: The ops to be wrapped in this subgraph.
        """
        subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name)
        subgraph.set('label', subgraph_name)
        subgraph.set('labeljust', 'l')
        for idx, op in enumerate(subgraph_ops):
            node_id = str(id(op))
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            edge_srcs = defaultdict(lambda: [])
            for inp in op.inputs:
                if inp == '*':
                    continue
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                diagram.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               label=f" {', '.join(labels)} "))
            for out in op.outputs:
                label_last_seen[out] = node_id
            if isinstance(op, Trace) and idx > 0:
                # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image
                diagram.add_edge(
                    pydot.Edge(src=str(id(subgraph_ops[idx - 1])),
                               dst=node_id,
                               style='invis'))
        diagram.add_subgraph(subgraph)
Ejemplo n.º 5
0
    def _add_node(progenitor: pydot.Dot,
                  diagram: Union[pydot.Dot, pydot.Cluster],
                  op: Union[Op, Trace, Any],
                  label_last_seen: DefaultDict[str, str],
                  edges: bool = True) -> None:
        """Draw a node onto a diagram based on a given op.

        Args:
            progenitor: The very top level diagram onto which Edges should be written.
            diagram: The diagram to be appended to.
            op: The op (or trace) to be visualized.
            label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
            edges: Whether to write Edges to/from this Node.
        """
        node_id = str(id(op))
        if isinstance(op, (Sometimes, SometimesT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='red',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Sometimes ({op.prob}):')
            wrapper.set('labeljust', 'l')
            edge_srcs = defaultdict(lambda: [])
            if op.extra_inputs:
                for inp in op.extra_inputs:
                    if inp == '*':
                        continue
                    edge_srcs[label_last_seen[inp]].append(inp)
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            diagram.add_subgraph(wrapper)
            dst_id = Traceability._get_all_nodes(wrapper)[0].get_name()
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=dst_id,
                               lhead=wrapper.get_name(),
                               label=f" {', '.join(labels)} "))
        elif isinstance(op, (OneOf, OneOfT)) and op.ops:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkorchid4',
                                    graph_name=str(id(op)))
            wrapper.set('label', 'One Of:')
            wrapper.set('labeljust', 'l')
            Traceability._add_node(progenitor,
                                   wrapper,
                                   op.ops[0],
                                   label_last_seen,
                                   edges=True)
            for sub_op in op.ops[1:]:
                Traceability._add_node(progenitor,
                                       wrapper,
                                       sub_op,
                                       label_last_seen,
                                       edges=False)
            diagram.add_subgraph(wrapper)
        elif isinstance(op, (Fuse, FuseT)) and op.ops:
            Traceability._draw_subgraph(progenitor, diagram, label_last_seen,
                                        f'Fuse:', op.ops)
        elif isinstance(op, (Repeat, RepeatT)) and op.op:
            wrapper = pydot.Cluster(style='dotted',
                                    color='darkgreen',
                                    graph_name=str(id(op)))
            wrapper.set('label', f'Repeat:')
            wrapper.set('labeljust', 'l')
            wrapper.add_node(
                pydot.Node(
                    node_id,
                    label=f'{op.repeat if isinstance(op.repeat, int) else "?"}',
                    shape='doublecircle',
                    width=0.1))
            # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color
            progenitor.add_edge(
                pydot.Edge(src=node_id + ":ne",
                           dst=node_id + ":w",
                           color='#006300'))
            Traceability._add_node(progenitor, wrapper, op.op, label_last_seen)
            # Add repeat edges
            edge_srcs = defaultdict(lambda: [])
            for out in op.outputs:
                if out in op.inputs and out not in op.repeat_inputs:
                    edge_srcs[label_last_seen[out]].append(out)
            for inp in op.repeat_inputs:
                edge_srcs[label_last_seen[inp]].append(inp)
            for src, labels in edge_srcs.items():
                progenitor.add_edge(
                    pydot.Edge(src=src,
                               dst=node_id,
                               constraint=False,
                               label=f" {', '.join(labels)} "))
            diagram.add_subgraph(wrapper)
        else:
            if isinstance(op, ModelOp):
                label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
                model_ref = Hyperref(Marker(name=str(op.model.model_name),
                                            prefix='subsec'),
                                     text=NoEscape(r'\textcolor{blue}{') +
                                     bold(op.model.model_name) +
                                     NoEscape('}')).dumps()
                texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
            else:
                label = f"{op.__class__.__name__} ({FEID(id(op))})"
                texlbl = HrefFEID(FEID(id(op)),
                                  name=op.__class__.__name__).dumps()
            diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
            if isinstance(op, (Op, Trace)) and edges:
                # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader
                Traceability._add_edge(progenitor, op, label_last_seen)
Ejemplo n.º 6
0
def test_hyperref():
    hr = Hyperref(Marker("marker", "prefix"), "text")
    repr(hr)