def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None: """Draw a node onto a diagram based on a given op. Args: diagram: The diagram to be appended to. op: The op (or trace) to be visualized. node_id: The id to use as the node label. """ if isinstance(op, Sometimes) and op.numpy_op: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_op, node_id) diagram.add_subgraph(wrapper) elif isinstance(op, OneOf) and op.numpy_ops: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_ops[0], node_id) for sub_op in op.numpy_ops[1:]: Traceability._add_node(wrapper, sub_op, str(id(sub_op))) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
def _document_init_params(self) -> None: """Add initialization parameters to the traceability document. """ with self.doc.create(Section("Parameters")): model_ids = { FEID(id(model)) for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module)) } datasets = { FEID(id(self.system.pipeline.data.get(title, None))): (title, self.system.pipeline.data.get(title, None)) for title in ['train', 'eval', 'test'] } for tbl in self.config_tables: name_override = None toc_ref = None extra_rows = None if issubclass(tbl.type, Estimator): toc_ref = "Estimator" if issubclass(tbl.type, BaseNetwork): toc_ref = "Network" if issubclass(tbl.type, Pipeline): toc_ref = "Pipeline" if tbl.fe_id in model_ids: # Link to a later detailed model description name_override = Hyperref( Marker(name=str(tbl.name), prefix="subsec"), text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) + NoEscape('}')) toc_ref = tbl.name if tbl.fe_id in datasets: title, dataset = datasets[tbl.fe_id] name_override = bold(f'{tbl.name} ({title.capitalize()})') toc_ref = f"{title.capitalize()} Dataset" # Enhance the dataset summary if isinstance(dataset, FEDataset): extra_rows = list( dataset.summary().__getstate__().items()) for idx, (key, val) in enumerate(extra_rows): key = f"{prettify_metric_name(key)}:" if isinstance(val, dict) and val: if isinstance( list(val.values())[0], (int, float, str, bool, type(None))): val = jsonpickle.dumps(val, unpicklable=False) else: subtable = Tabular('l|l') for k, v in val.items(): if hasattr(v, '__getstate__'): v = jsonpickle.dumps( v, unpicklable=False) subtable.add_row((k, v)) val = subtable extra_rows[idx] = (key, val) tbl.render_table(self.doc, name_override=name_override, toc_ref=toc_ref, extra_rows=extra_rows)
def _write_tables(self, tables: List[FeSummaryTable], model_ids: Set[FEID], datasets: Dict[FEID, Tuple[Set[str], Any]]) -> None: """Insert a LaTeX representation of a list of tables into the current doc. Args: tables: The tables to write into the doc. model_ids: The ids of any known models. datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information. """ for tbl in tables: name_override = None toc_ref = None extra_rows = None if tbl.fe_id in model_ids: # Link to a later detailed model description name_override = Hyperref(Marker(name=str(tbl.name), prefix="subsec"), text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) + NoEscape('}')) if tbl.fe_id in datasets: modes, dataset = datasets[tbl.fe_id] title = ", ".join([s.capitalize() for s in modes]) name_override = bold(f'{tbl.name} ({title})') # Enhance the dataset summary if isinstance(dataset, FEDataset): extra_rows = list(dataset.summary().__getstate__().items()) for idx, (key, val) in enumerate(extra_rows): key = f"{prettify_metric_name(key)}:" if isinstance(val, dict) and val: if isinstance( list(val.values())[0], (int, float, str, bool, type(None))): val = jsonpickle.dumps(val, unpicklable=False) else: subtable = Tabularx( 'l|X', width_argument=NoEscape(r'\linewidth')) for k, v in val.items(): if hasattr(v, '__getstate__'): v = jsonpickle.dumps(v, unpicklable=False) subtable.add_row((k, v)) # To nest TabularX, have to wrap it in brackets subtable = ContainerList(data=[ NoEscape("{"), subtable, NoEscape("}") ]) val = subtable extra_rows[idx] = (key, val) tbl.render_table(self.doc, name_override=name_override, toc_ref=toc_ref, extra_rows=extra_rows)
def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str, str], subgraph_name: str, subgraph_ops: List[Union[Op, Trace]]) -> None: """Draw a subgraph of ops into an existing `diagram`. Args: diagram: The diagram to be appended to. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. subgraph_name: The name to be associated with this subgraph. subgraph_ops: The ops to be wrapped in this subgraph. """ subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name) subgraph.set('label', subgraph_name) subgraph.set('labeljust', 'l') for idx, op in enumerate(subgraph_ops): node_id = str(id(op)) if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) edge_srcs = defaultdict(lambda: []) for inp in op.inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): diagram.add_edge( pydot.Edge(src=src, dst=node_id, label=f" {', '.join(labels)} ")) for out in op.outputs: label_last_seen[out] = node_id if isinstance(op, Trace) and idx > 0: # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image diagram.add_edge( pydot.Edge(src=str(id(subgraph_ops[idx - 1])), dst=node_id, style='invis')) diagram.add_subgraph(subgraph)
def _add_node(progenitor: pydot.Dot, diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace, Any], label_last_seen: DefaultDict[str, str], edges: bool = True) -> None: """Draw a node onto a diagram based on a given op. Args: progenitor: The very top level diagram onto which Edges should be written. diagram: The diagram to be appended to. op: The op (or trace) to be visualized. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. edges: Whether to write Edges to/from this Node. """ node_id = str(id(op)) if isinstance(op, (Sometimes, SometimesT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='red', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'l') edge_srcs = defaultdict(lambda: []) if op.extra_inputs: for inp in op.extra_inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) diagram.add_subgraph(wrapper) dst_id = Traceability._get_all_nodes(wrapper)[0].get_name() for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=dst_id, lhead=wrapper.get_name(), label=f" {', '.join(labels)} ")) elif isinstance(op, (OneOf, OneOfT)) and op.ops: wrapper = pydot.Cluster(style='dotted', color='darkorchid4', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'l') Traceability._add_node(progenitor, wrapper, op.ops[0], label_last_seen, edges=True) for sub_op in op.ops[1:]: Traceability._add_node(progenitor, wrapper, sub_op, label_last_seen, edges=False) diagram.add_subgraph(wrapper) elif isinstance(op, (Fuse, FuseT)) and op.ops: Traceability._draw_subgraph(progenitor, diagram, label_last_seen, f'Fuse:', op.ops) elif isinstance(op, (Repeat, RepeatT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='darkgreen', graph_name=str(id(op))) wrapper.set('label', f'Repeat:') wrapper.set('labeljust', 'l') wrapper.add_node( pydot.Node( node_id, label=f'{op.repeat if isinstance(op.repeat, int) else "?"}', shape='doublecircle', width=0.1)) # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color progenitor.add_edge( pydot.Edge(src=node_id + ":ne", dst=node_id + ":w", color='#006300')) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) # Add repeat edges edge_srcs = defaultdict(lambda: []) for out in op.outputs: if out in op.inputs and out not in op.repeat_inputs: edge_srcs[label_last_seen[out]].append(out) for inp in op.repeat_inputs: edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=node_id, constraint=False, label=f" {', '.join(labels)} ")) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) if isinstance(op, (Op, Trace)) and edges: # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader Traceability._add_edge(progenitor, op, label_last_seen)
def test_hyperref(): hr = Hyperref(Marker("marker", "prefix"), "text") repr(hr)