def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None: """Draw a node onto a diagram based on a given op. Args: diagram: The diagram to be appended to. op: The op (or trace) to be visualized. node_id: The id to use as the node label. """ if isinstance(op, Sometimes) and op.numpy_op: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_op, node_id) diagram.add_subgraph(wrapper) elif isinstance(op, OneOf) and op.numpy_ops: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_ops[0], node_id) for sub_op in op.numpy_ops[1:]: Traceability._add_node(wrapper, sub_op, str(id(sub_op))) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str, str], subgraph_name: str, subgraph_ops: List[Union[Op, Trace]]) -> None: """Draw a subgraph of ops into an existing `diagram`. Args: diagram: The diagram to be appended to. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. subgraph_name: The name to be associated with this subgraph. subgraph_ops: The ops to be wrapped in this subgraph. """ subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name) subgraph.set('label', subgraph_name) subgraph.set('labeljust', 'l') for idx, op in enumerate(subgraph_ops): node_id = str(id(op)) if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) edge_srcs = defaultdict(lambda: []) for inp in op.inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): diagram.add_edge( pydot.Edge(src=src, dst=node_id, label=f" {', '.join(labels)} ")) for out in op.outputs: label_last_seen[out] = node_id if isinstance(op, Trace) and idx > 0: # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image diagram.add_edge( pydot.Edge(src=str(id(subgraph_ops[idx - 1])), dst=node_id, style='invis')) diagram.add_subgraph(subgraph)
def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. Returns: A pydot digraph representing the execution flow. """ ds =[mode] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch) if isinstance(ds, FEDataset) else [] net_ops = get_current_items(, run_modes=mode, epoch=epoch) traces = get_current_items(self.system.traces, run_modes=mode, epoch=epoch) diagram = pydot.Dot() diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='record') diagram.add_node( pydot.Node(str(id(ds)), label=f'{ds.__class__.__name__} ({FEID(id(ds))})', texlbl=HrefFEID(FEID(id(ds)), name=ds.__class__.__name__).dumps())) label_last_seen = defaultdict( lambda: str(id(ds))) # Where was this key last generated self._draw_subgraph(diagram, label_last_seen, 'Pipeline', pipe_ops) self._draw_subgraph(diagram, label_last_seen, 'Network', net_ops) self._draw_subgraph(diagram, label_last_seen, 'Traces', traces) return diagram
def _add_node(progenitor: pydot.Dot, diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace, Any], label_last_seen: DefaultDict[str, str], edges: bool = True) -> None: """Draw a node onto a diagram based on a given op. Args: progenitor: The very top level diagram onto which Edges should be written. diagram: The diagram to be appended to. op: The op (or trace) to be visualized. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. edges: Whether to write Edges to/from this Node. """ node_id = str(id(op)) if isinstance(op, (Sometimes, SometimesT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='red', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'l') edge_srcs = defaultdict(lambda: []) if op.extra_inputs: for inp in op.extra_inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) diagram.add_subgraph(wrapper) dst_id = Traceability._get_all_nodes(wrapper)[0].get_name() for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=dst_id, lhead=wrapper.get_name(), label=f" {', '.join(labels)} ")) elif isinstance(op, (OneOf, OneOfT)) and op.ops: wrapper = pydot.Cluster(style='dotted', color='darkorchid4', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'l') Traceability._add_node(progenitor, wrapper, op.ops[0], label_last_seen, edges=True) for sub_op in op.ops[1:]: Traceability._add_node(progenitor, wrapper, sub_op, label_last_seen, edges=False) diagram.add_subgraph(wrapper) elif isinstance(op, (Fuse, FuseT)) and op.ops: Traceability._draw_subgraph(progenitor, diagram, label_last_seen, f'Fuse:', op.ops) elif isinstance(op, (Repeat, RepeatT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='darkgreen', graph_name=str(id(op))) wrapper.set('label', f'Repeat:') wrapper.set('labeljust', 'l') wrapper.add_node( pydot.Node( node_id, label=f'{op.repeat if isinstance(op.repeat, int) else "?"}', shape='doublecircle', width=0.1)) # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color progenitor.add_edge( pydot.Edge(src=node_id + ":ne", dst=node_id + ":w", color='#006300')) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) # Add repeat edges edge_srcs = defaultdict(lambda: []) for out in op.outputs: if out in op.inputs and out not in op.repeat_inputs: edge_srcs[label_last_seen[out]].append(out) for inp in op.repeat_inputs: edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=node_id, constraint=False, label=f" {', '.join(labels)} ")) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) if isinstance(op, (Op, Trace)) and edges: # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader Traceability._add_edge(progenitor, op, label_last_seen)
def _document_models(self) -> None: """Add model summaries to the traceability document. """ with self.doc.create(Section("Models")): for model in humansorted(, key=lambda m: m.model_name): if not isinstance(model, (tf.keras.Model, torch.nn.Module)): continue self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create(Subsection(f"{model.model_name}")): if isinstance(model, tf.keras.Model): # Text Summary summary = [] model.summary(line_length=92, print_fn=lambda x: summary.append(x)) summary = "\n".join(summary) self.doc.append(Verbatim(summary)) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # noinspection PyBroadException try: file_path = os.path.join( self.resource_dir, "{}_{}.pdf".format(self.report_name, model.model_name)) dot = tf.keras.utils.model_to_dot( model, show_shapes=True, expand_nested=True) # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so # set the limit lower (100 inches) to leave some wiggle room. dot.set('size', '100') dot.write(file_path, format='pdf') except Exception: file_path = None print( f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability" ) elif isinstance(model, torch.nn.Module): if hasattr(model, 'fe_input_spec'): # Text Summary # noinspection PyUnresolvedReferences inputs = model.fe_input_spec.get_dummy_input() self.doc.append( Verbatim( pms.summary( model.module if self.system.num_devices > 1 else model, inputs, print_summary=False))) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # Import has to be done while matplotlib is using the Agg backend old_backend = matplotlib.get_backend() or 'Agg' matplotlib.use('Agg') # noinspection PyBroadException try: # Fake the IPython import when user isn't running from Jupyter sys.modules.setdefault('IPython', MagicMock()) sys.modules.setdefault('IPython.display', MagicMock()) import hiddenlayer as hl with Suppressor(): graph = hl.build_graph( model.module if self.system.num_devices > 1 else model, inputs) graph = graph.build_dot() graph.attr( rankdir='TB' ) # Switch it to Top-to-Bottom instead of Left-to-Right # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less # than 226 inches. However, the 'size' parameter doesn't account for the whole node # height, so set the limit lower (100 inches) to leave some wiggle room. graph.attr(size="100,100") graph.attr(margin='0') file_path = graph.render( filename="{}_{}".format( self.report_name, model.model_name), directory=self.resource_dir, format='pdf', cleanup=True) except Exception: file_path = None print( "FastEstimator-Warn: Model {} could not be visualized by Traceability" .format(model.model_name)) finally: matplotlib.use(old_backend) else: file_path = None self.doc.append( "This model was not used by the Network during training." ) if file_path: with self.doc.create(Figure(position='ht!')) as fig: fig.append( Label( Marker(name=str(FEID(id(model))), prefix="model"))) fig.add_image( os.path.relpath(file_path, start=self.save_dir), width=NoEscape( r'1.0\textwidth,height=0.95\textheight,keepaspectratio' )) fig.add_caption( NoEscape( HrefFEID(FEID(id(model)), model.model_name).dumps()))
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any: """Convert an input value to a FESummaryTable table representation Args: inp: The input value to be converted. tables: A collection of tables representing objects which are used by the current stack of inputs. ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether functions can be in-lined or deserve their own tables). wrap_str: Whether literal string values should be wrapped inside extra quote marks. Returns: An FESummaryTable representation of the input. """ if isinstance(inp, str): inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp) if wrap_str: # Prevent extremely long strings from overflowing the table return NoEscape(r'\seqsplit{' + inp + '}') return inp elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)): if isinstance(inp, (int, float)): # Prevent extremely long numbers from overflowing the table return NoEscape(r'\seqsplit{' + str(inp) + '}') return inp elif hasattr(inp, '_fe_traceability_summary'): # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will # fall through to the class check at the end to get it's id. # noinspection PyProtectedMember,PyUnresolvedReferences tables.update(inp._fe_traceability_summary) inp_id = FEID(id(inp)) ret_ref.set_true() return HrefFEID(inp_id, tables[inp_id].name) elif inspect.ismethod(inp): parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str) return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")]) elif inspect.isfunction(inp) or inspect.isclass(inp): inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: if inspect.isfunction(inp) and inp.__name__ == "<lambda>": code = inp.__code__ var_names = code.co_varnames # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other # function (like one might do with LRScheduler), then the parse should work. flag = Flag() func_description = _parse_lambda(inp, tables, flag) or {} func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False) name = "lambda" path = None if not flag and func_description.keys() == {'vars', 'function'}: # This is a simple lambda function, so inline it instead of making a new table raw_vars = func_description['vars'].raw_input formatted_vars = [] for var in raw_vars: formatted_vars.append(var) formatted_vars.append(', ') if formatted_vars: formatted_vars.pop() # remove trailing comma return ContainerList(data=[ TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '') ]) else: name = inp.__name__ path = f"{inp.__module__}.{inp.__qualname__}" func_description = {} tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp), path=path, **func_description) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, _Function): inp_id = FEID(id(inp)) if inp_id not in tables: if inspect.ismethod(inp.func): path = _trace_value(inp.func, tables, ret_ref, wrap_str) elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'): path = f"{inp.func.__module__}.{inp.func.__qualname__}" else: path = None tables[inp_id] = FeSummaryTable(, fe_id=inp_id, target_type=type(inp.func), path=path) ret_ref.set_true() return HrefFEID(inp_id, elif isinstance(inp, _PartialBind): return { "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True), "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input # unwrap kwargs back into a dict } elif isinstance(inp, _Command): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), escape_latex(inp.command), _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _Condition): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), " if ", _trace_value(inp.condition, tables, ret_ref, wrap_str), " else ", _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _BoundFn): flag = Flag() args = _trace_value(inp.args, tables, flag, wrap_str=False) kwargs = {} if isinstance(inp.args, _PartialBind): kwargs = args["kwargs"] args = args["args"] elif isinstance(args, dict): kwargs = args args = None if not flag and isinstance(inp.func, _Function): # The function args are simple, so inline this function in whatever is above it if isinstance(args, PyContainer): args = args.raw_input if isinstance(kwargs, PyContainer): kwargs = kwargs.raw_input formatted = ["("] args = args or () kwargs = kwargs or {} for arg in args: formatted.append(arg) formatted.append(", ") for key, value in kwargs.items(): formatted.append(key) formatted.append("=") formatted.append(value) formatted.append(", ") if len(formatted) > 1: formatted.pop() # Remove trailing comma formatted.append(")") if inspect.ismethod(inp.func.func): container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str) return container_list return ContainerList(data=[, *formatted]) else: # The function args are complicated, so use the normal approach func_href = _trace_value(inp.func, tables, ret_ref, wrap_str) inp_id = func_href.fe_id inp_table = tables[inp_id] inp_table.args = args inp_table.kwargs = kwargs ret_ref.set_true() return func_href elif isinstance(inp, inspect.BoundArguments): args = inp.arguments args.pop('self', None) return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input # unwrap kwargs back into a dict elif isinstance(inp, _VarWrap): return inp.var elif isinstance(inp, (tf.keras.Model, torch.nn.Module)): # FE models should never actually get here since they are given summaries by trace_model() during inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>" tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp)) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, list): return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp], truncate=_CollectionSizeLimit) elif isinstance(inp, tuple): return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, set): return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, dict): return PyContainer( data={ _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True) for k, v in inp.items() }, truncate=_CollectionSizeLimit) elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)): inp_type = type(inp) inp_id = FEID(id(inp)) if inp_id not in tables: if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)): if isinstance(inp, torch.Tensor): inp = inp.cpu().detach() inp.numpy() # In the elif here we're sure to be tf elif inp.dtype != tf.dtypes.variant: inp = inp.numpy() # The variant dtype can't be cast to numpy() rank = inp.ndim description = {'shape': inp.shape} if rank == 0 or (rank == 1 and inp.shape[0] <= 10): description['values'] = str(inp) tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description) ret_ref.set_true() return HrefFEID(inp_id, "tensor") # This should be the last elif elif hasattr(inp, '__class__'): inp_id = FEID(id(inp)) if inp_id not in tables: kwargs = {} path = None if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__: # Prevent circular recursion tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id) # This object isn't @traceable but does have some stored variables that we can summarize. kwargs = _trace_value({k: v for k, v in inp.__dict__.items() if not k.startswith("_")}, tables, ret_ref, wrap_str=False).raw_input path = "Not @traceable, so summary is approximate" tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), path=path, fe_id=inp_id, kwargs=kwargs) ret_ref.set_true() return HrefFEID(inp_id, inp.__class__.__name__) else: inp_id = FEID(id(inp)) if inp_id not in tables: tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id) ret_ref.set_true() return HrefFEID(inp_id, "an object")
def _document_models(self) -> None: """Add model summaries to the traceability document. """ with self.doc.create(Section("Models")): for model in humansorted(, key=lambda m: m.model_name): if not isinstance(model, (tf.keras.Model, torch.nn.Module)): continue self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create(Subsection(f"{model.model_name}")): if isinstance(model, tf.keras.Model): # Text Summary summary = [] model.summary(line_length=92, print_fn=lambda x: summary.append(x)) summary = "\n".join(summary) self.doc.append(Verbatim(summary)) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # noinspection PyBroadException try: file_path = os.path.join( self.figure_dir, f"FE_Model_{model.model_name}.pdf") tf.keras.utils.plot_model(model, to_file=file_path, show_shapes=True, expand_nested=True) # TODO - cap output image size like in the pytorch implementation in case of huge network # TODO - save raw .dot file in case system lacks graphviz except Exception: file_path = None print( f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability" ) elif isinstance(model, torch.nn.Module): if hasattr(model, 'fe_input_spec'): # Text Summary # noinspection PyUnresolvedReferences inputs = model.fe_input_spec.get_dummy_input() self.doc.append( Verbatim(pms.summary(model, inputs))) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # Import has to be done while matplotlib is using the Agg backend old_backend = matplotlib.get_backend() or 'Agg' matplotlib.use('Agg') # noinspection PyBroadException try: # Fake the IPython import when user isn't running from Jupyter sys.modules.setdefault('IPython', MagicMock()) sys.modules.setdefault('IPython.display', MagicMock()) import hiddenlayer as hl with Suppressor(): graph = hl.build_graph(model, inputs) graph = graph.build_dot() graph.attr( rankdir='TB' ) # Switch it to Top-to-Bottom instead of Left-to-Right graph.attr( size="200,200" ) # LaTeX \maxdim is around 575cm (226 inches) graph.attr(margin='0') # TODO - save raw .dot file in case system lacks graphviz file_path = graph.render( filename=f"FE_Model_{model.model_name}", directory=self.figure_dir, format='pdf', cleanup=True) except Exception: file_path = None print( "FastEstimator-Warn: Model {} could not be visualized by Traceability" .format(model.model_name)) finally: matplotlib.use(old_backend) else: self.doc.append( "This model was not used by the Network during training." ) if file_path: with self.doc.create(Figure(position='ht!')) as fig: fig.append( Label( Marker(name=str(FEID(id(model))), prefix="model"))) fig.add_image( os.path.relpath(file_path, start=self.save_dir), width=NoEscape( r'1.0\textwidth,height=0.95\textheight,keepaspectratio' )) fig.add_caption( NoEscape( HrefFEID(FEID(id(model)), model.model_name).dumps()))