def _document_init_params(self) -> None: """Add initialization parameters to the traceability document. """ with self.doc.create(Section("Parameters")): model_ids = { FEID(id(model)) for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module)) } datasets = { FEID(id(self.system.pipeline.data.get(title, None))): (title, self.system.pipeline.data.get(title, None)) for title in ['train', 'eval', 'test'] } for tbl in self.config_tables: name_override = None toc_ref = None extra_rows = None if issubclass(tbl.type, Estimator): toc_ref = "Estimator" if issubclass(tbl.type, BaseNetwork): toc_ref = "Network" if issubclass(tbl.type, Pipeline): toc_ref = "Pipeline" if tbl.fe_id in model_ids: # Link to a later detailed model description name_override = Hyperref( Marker(name=str(tbl.name), prefix="subsec"), text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) + NoEscape('}')) toc_ref = tbl.name if tbl.fe_id in datasets: title, dataset = datasets[tbl.fe_id] name_override = bold(f'{tbl.name} ({title.capitalize()})') toc_ref = f"{title.capitalize()} Dataset" # Enhance the dataset summary if isinstance(dataset, FEDataset): extra_rows = list( dataset.summary().__getstate__().items()) for idx, (key, val) in enumerate(extra_rows): key = f"{prettify_metric_name(key)}:" if isinstance(val, dict) and val: if isinstance( list(val.values())[0], (int, float, str, bool, type(None))): val = jsonpickle.dumps(val, unpicklable=False) else: subtable = Tabular('l|l') for k, v in val.items(): if hasattr(v, '__getstate__'): v = jsonpickle.dumps( v, unpicklable=False) subtable.add_row((k, v)) val = subtable extra_rows[idx] = (key, val) tbl.render_table(self.doc, name_override=name_override, toc_ref=toc_ref, extra_rows=extra_rows)
def _document_init_params(self) -> None: """Add initialization parameters to the traceability document. """ from fastestimator.estimator import Estimator # Avoid circular import with self.doc.create(Section("Parameters")): model_ids = { FEID(id(model)) for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module)) } # Locate the datasets in order to provide extra details about them later in the summary datasets = {} for mode in ['train', 'eval', 'test']: objs = to_list(self.system.pipeline.data.get(mode, None)) idx = 0 while idx < len(objs): obj = objs[idx] if obj: feid = FEID(id(obj)) if feid not in datasets: datasets[feid] = ({mode}, obj) else: datasets[feid][0].add(mode) if isinstance(obj, Scheduler): objs.extend(obj.get_all_values()) idx += 1 # Parse the config tables start = 0 start = self._loop_tables(start, classes=(Estimator, BaseNetwork, Pipeline), name="Base Classes", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=Scheduler, name="Schedulers", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=Trace, name="Traces", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=Op, name="Ops", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=(Dataset, tf.data.Dataset), name="Datasets", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=(tf.keras.Model, torch.nn.Module), name="Models", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=types.FunctionType, name="Functions", model_ids=model_ids, datasets=datasets) start = self._loop_tables(start, classes=(np.ndarray, tf.Tensor, tf.Variable, torch.Tensor), name="Tensors", model_ids=model_ids, datasets=datasets) self._loop_tables(start, classes=Any, name="Miscellaneous", model_ids=model_ids, datasets=datasets)
def fe_summary(self) -> List[FeSummaryTable]: """Return a summary of how this class was instantiated (for traceability). Args: self: The bound class instance. Returns: A summary of the instance. """ # Delayed imports to avoid circular dependency from fastestimator.estimator import Estimator from fastestimator.network import TFNetwork, TorchNetwork from fastestimator.pipeline import Pipeline from fastestimator.op.op import Op from fastestimator.trace.trace import Trace from fastestimator.schedule.schedule import Scheduler from torch.utils.data import Dataset # re-number the references for nicer viewing ordered_items = sorted( self._fe_traceability_summary.items(), key=lambda x: 0 if issubclass(x[1].type, Estimator) else 1 if issubclass(x[1].type, (TFNetwork, TorchNetwork)) else 2 if issubclass(x[1].type, Pipeline) else 3 if issubclass(x[1].type, Scheduler) else 4 if issubclass(x[1].type, Trace) else 5 if issubclass(x[1].type, Op) else 6 if issubclass(x[1].type, (Dataset, tf.data.Dataset)) else 7 if issubclass(x[1].type, (tf.keras.Model, torch.nn.Module)) else 8 if issubclass(x[1].type, types.FunctionType) else 9 if issubclass(x[1].type, (np.ndarray, tf.Tensor, tf.Variable, torch.Tensor)) else 10) key_mapping = {fe_id: f"@FE{idx}" for idx, (fe_id, val) in enumerate(ordered_items)} FEID.set_translation_dict(key_mapping) return [item[1] for item in ordered_items]
def fix_split_traceabilty(cls, parent: 'FEDataset', children: List['FEDataset'], fractions: Tuple[Union[float, int, Iterable[int]], ...], seed: Optional[int], stratify: Optional[str]) -> None: """A method to fix traceability information after invoking the dataset .split() method. Note that the default implementation of the .split() function invokes this already, so this only needs to be invoked if you override the .split() method when defining a subclass (ex. BatchDataset). Args: parent: The parent dataset on which .split() was invoked. children: The datasets generated by performing the split. fractions: The fraction arguments used to generate the children (should be one-to-one with the children). seed: The random seed used to generate the split. stratify: The stratify key used to generate the split. """ if hasattr(parent, '_fe_traceability_summary'): parent_id = FEID(id(parent)) fractions = [ f"range({frac.start}, {frac.stop}, {frac.step})" if isinstance( frac, range) else f"{frac}" for frac in fractions ] for child, frac in zip(children, fractions): # noinspection PyProtectedMember tables = deepcopy(child._fe_traceability_summary) # Update the ID if necessary child_id = FEID(id(child)) if child_id not in tables: # The child was created without invoking its __init__ method, so its internal summary will have the # wrong id table = tables.pop(parent_id) table.fe_id = child_id tables[child_id] = table else: table = tables[child_id] split_summary = table.fields.get('split', FeSplitSummary()) split_summary.add_split(parent=parent_id, fraction=frac, seed=seed, stratify=stratify) table.fields['split'] = split_summary child._fe_traceability_summary = tables # noinspection PyUnresolvedReferences table = parent._fe_traceability_summary.get(parent_id) split_summary = table.fields.get('split', FeSplitSummary()) split_summary.add_split(parent='self', fraction=", ".join( [f"-{frac}" for frac in fractions]), seed=seed, stratify=stratify) table.fields['split'] = split_summary # Put the new parent summary into the child table to ensure it will always exist in the final set of tables for child in children: child._fe_traceability_summary[parent_id] = deepcopy(table)
def _add_node(diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace], node_id: str) -> None: """Draw a node onto a diagram based on a given op. Args: diagram: The diagram to be appended to. op: The op (or trace) to be visualized. node_id: The id to use as the node label. """ if isinstance(op, Sometimes) and op.numpy_op: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_op, node_id) diagram.add_subgraph(wrapper) elif isinstance(op, OneOf) and op.numpy_ops: wrapper = pydot.Cluster(style='loosely dotted', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'r') Traceability._add_node(wrapper, op.numpy_ops[0], node_id) for sub_op in op.numpy_ops[1:]: Traceability._add_node(wrapper, sub_op, str(id(sub_op))) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
def trace_model(model: Model, model_idx: int, model_fn: Any, optimizer_fn: Any, weights_path: Any) -> Model: """A function to add traceability information to an FE-compiled model. Args: model: The model to be made traceable. model_idx: Which of the return values from the `model_fn` is this model (or -1 if only a single return value). model_fn: The function used to generate this model. optimizer_fn: The thing used to define this model's optimizer. weights_path: The path to the weights for this model. Returns: The `model`, but now with an fe_summary() method. """ tables = {} description = {'definition': _trace_value(model_fn, tables, ret_ref=Flag())} if model_idx != -1: description['index'] = model_idx if optimizer_fn or isinstance(optimizer_fn, list) and optimizer_fn[0] is not None: description['optimizer'] = _trace_value( optimizer_fn[model_idx] if isinstance(optimizer_fn, list) else optimizer_fn, tables, ret_ref=Flag()) if weights_path: description['weights'] = _trace_value(weights_path, tables, ret_ref=Flag()) fe_id = FEID(id(model)) tbl = FeSummaryTable(name=model.model_name, fe_id=fe_id, target_type=type(model), **description) tables[fe_id] = tbl # Have to put this in a ChainMap b/c dict gets put into model._layers automatically somehow model._fe_traceability_summary = ChainMap(tables) # Use MethodType to bind the method to the class instance setattr(model, 'fe_summary', types.MethodType(fe_summary, model)) return model
def fix_split_traceabilty( cls, parent: 'FEDataset', children: List['FEDataset'], fractions: Tuple[Union[float, int, Iterable[int]], ...]) -> None: """A method to fix traceability information after invoking the dataset .split() method. Note that the default implementation of the .split() function invokes this already, so this only needs to be invoked if you override the .split() method when defining a subclass (ex. BatchDataset). Args: parent: The parent dataset on which .split() was invoked. children: The datasets generated by performing the split. fractions: The fraction arguments used to generate the children (should be one-to-one with the children). """ if hasattr(parent, '_fe_traceability_summary'): parent_id = FEID(id(parent)) for child, frac in zip(children, fractions): # noinspection PyProtectedMember tables = deepcopy(child._fe_traceability_summary) # Update the ID if necessary child_id = FEID(id(child)) if child_id not in tables: # The child was created without invoking its __init__ method, so its internal summary will have the # wrong id table = tables.pop(parent_id) table.fe_id = child_id tables[child_id] = table else: table = tables[child_id] split_text = table.fields.get('split', '') if split_text: split_text = f" | {split_text}" if isinstance(frac, range): frac = f"range({frac.start}, {frac.stop}, {frac.step})" table.fields[ 'split'] = f"Child with final size: {len(child)}, split param: {frac}{split_text}" child._fe_traceability_summary = tables # noinspection PyUnresolvedReferences table = parent._fe_traceability_summary.get(parent_id) split_text = table.fields.get('split', '') if split_text: split_text = f" | {split_text}" table.fields[ 'split'] = f"Parent with final size: {len(parent)}{split_text}"
def _draw_subgraph(diagram: pydot.Dot, label_last_seen: DefaultDict[str, str], subgraph_name: str, subgraph_ops: List[Union[Op, Trace]]) -> None: """Draw a subgraph of ops into an existing `diagram`. Args: diagram: The diagram to be appended to. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. subgraph_name: The name to be associated with this subgraph. subgraph_ops: The ops to be wrapped in this subgraph. """ subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name) subgraph.set('label', subgraph_name) subgraph.set('labeljust', 'l') for idx, op in enumerate(subgraph_ops): node_id = str(id(op)) if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() subgraph.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) edge_srcs = defaultdict(lambda: []) for inp in op.inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): diagram.add_edge( pydot.Edge(src=src, dst=node_id, label=f" {', '.join(labels)} ")) for out in op.outputs: label_last_seen[out] = node_id if isinstance(op, Trace) and idx > 0: # Invisibly connect traces in order so that they aren't all just squashed horizontally into the image diagram.add_edge( pydot.Edge(src=str(id(subgraph_ops[idx - 1])), dst=node_id, style='invis')) diagram.add_subgraph(subgraph)
def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. Returns: A pydot digraph representing the execution flow. """ ds = self.system.pipeline.data[mode] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch) if isinstance(ds, FEDataset) else [] net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch) traces = get_current_items(self.system.traces, run_modes=mode, epoch=epoch) diagram = pydot.Dot() diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='record') diagram.add_node( pydot.Node(str(id(ds)), label=f'{ds.__class__.__name__} ({FEID(id(ds))})', texlbl=HrefFEID(FEID(id(ds)), name=ds.__class__.__name__).dumps())) label_last_seen = defaultdict( lambda: str(id(ds))) # Where was this key last generated self._draw_subgraph(diagram, label_last_seen, 'Pipeline', pipe_ops) self._draw_subgraph(diagram, label_last_seen, 'Network', net_ops) self._draw_subgraph(diagram, label_last_seen, 'Traces', traces) return diagram
def _add_node(progenitor: pydot.Dot, diagram: Union[pydot.Dot, pydot.Cluster], op: Union[Op, Trace, Any], label_last_seen: DefaultDict[str, str], edges: bool = True) -> None: """Draw a node onto a diagram based on a given op. Args: progenitor: The very top level diagram onto which Edges should be written. diagram: The diagram to be appended to. op: The op (or trace) to be visualized. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. edges: Whether to write Edges to/from this Node. """ node_id = str(id(op)) if isinstance(op, (Sometimes, SometimesT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='red', graph_name=str(id(op))) wrapper.set('label', f'Sometimes ({op.prob}):') wrapper.set('labeljust', 'l') edge_srcs = defaultdict(lambda: []) if op.extra_inputs: for inp in op.extra_inputs: if inp == '*': continue edge_srcs[label_last_seen[inp]].append(inp) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) diagram.add_subgraph(wrapper) dst_id = Traceability._get_all_nodes(wrapper)[0].get_name() for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=dst_id, lhead=wrapper.get_name(), label=f" {', '.join(labels)} ")) elif isinstance(op, (OneOf, OneOfT)) and op.ops: wrapper = pydot.Cluster(style='dotted', color='darkorchid4', graph_name=str(id(op))) wrapper.set('label', 'One Of:') wrapper.set('labeljust', 'l') Traceability._add_node(progenitor, wrapper, op.ops[0], label_last_seen, edges=True) for sub_op in op.ops[1:]: Traceability._add_node(progenitor, wrapper, sub_op, label_last_seen, edges=False) diagram.add_subgraph(wrapper) elif isinstance(op, (Fuse, FuseT)) and op.ops: Traceability._draw_subgraph(progenitor, diagram, label_last_seen, f'Fuse:', op.ops) elif isinstance(op, (Repeat, RepeatT)) and op.op: wrapper = pydot.Cluster(style='dotted', color='darkgreen', graph_name=str(id(op))) wrapper.set('label', f'Repeat:') wrapper.set('labeljust', 'l') wrapper.add_node( pydot.Node( node_id, label=f'{op.repeat if isinstance(op.repeat, int) else "?"}', shape='doublecircle', width=0.1)) # dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color progenitor.add_edge( pydot.Edge(src=node_id + ":ne", dst=node_id + ":w", color='#006300')) Traceability._add_node(progenitor, wrapper, op.op, label_last_seen) # Add repeat edges edge_srcs = defaultdict(lambda: []) for out in op.outputs: if out in op.inputs and out not in op.repeat_inputs: edge_srcs[label_last_seen[out]].append(out) for inp in op.repeat_inputs: edge_srcs[label_last_seen[inp]].append(inp) for src, labels in edge_srcs.items(): progenitor.add_edge( pydot.Edge(src=src, dst=node_id, constraint=False, label=f" {', '.join(labels)} ")) diagram.add_subgraph(wrapper) else: if isinstance(op, ModelOp): label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}" model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'), text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) + NoEscape('}')).dumps() texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}" else: label = f"{op.__class__.__name__} ({FEID(id(op))})" texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps() diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl)) if isinstance(op, (Op, Trace)) and edges: # Need the instance check since subgraph_ops might contain a tf dataset or torch dataloader Traceability._add_edge(progenitor, op, label_last_seen)
def _document_models(self) -> None: """Add model summaries to the traceability document. """ with self.doc.create(Section("Models")): for model in humansorted(self.system.network.models, key=lambda m: m.model_name): if not isinstance(model, (tf.keras.Model, torch.nn.Module)): continue self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create(Subsection(f"{model.model_name}")): if isinstance(model, tf.keras.Model): # Text Summary summary = [] model.summary(line_length=92, print_fn=lambda x: summary.append(x)) summary = "\n".join(summary) self.doc.append(Verbatim(summary)) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # noinspection PyBroadException try: file_path = os.path.join( self.resource_dir, "{}_{}.pdf".format(self.report_name, model.model_name)) dot = tf.keras.utils.model_to_dot( model, show_shapes=True, expand_nested=True) # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so # set the limit lower (100 inches) to leave some wiggle room. dot.set('size', '100') dot.write(file_path, format='pdf') except Exception: file_path = None print( f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability" ) elif isinstance(model, torch.nn.Module): if hasattr(model, 'fe_input_spec'): # Text Summary # noinspection PyUnresolvedReferences inputs = model.fe_input_spec.get_dummy_input() self.doc.append( Verbatim( pms.summary( model.module if self.system.num_devices > 1 else model, inputs, print_summary=False))) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # Import has to be done while matplotlib is using the Agg backend old_backend = matplotlib.get_backend() or 'Agg' matplotlib.use('Agg') # noinspection PyBroadException try: # Fake the IPython import when user isn't running from Jupyter sys.modules.setdefault('IPython', MagicMock()) sys.modules.setdefault('IPython.display', MagicMock()) import hiddenlayer as hl with Suppressor(): graph = hl.build_graph( model.module if self.system.num_devices > 1 else model, inputs) graph = graph.build_dot() graph.attr( rankdir='TB' ) # Switch it to Top-to-Bottom instead of Left-to-Right # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less # than 226 inches. However, the 'size' parameter doesn't account for the whole node # height, so set the limit lower (100 inches) to leave some wiggle room. graph.attr(size="100,100") graph.attr(margin='0') file_path = graph.render( filename="{}_{}".format( self.report_name, model.model_name), directory=self.resource_dir, format='pdf', cleanup=True) except Exception: file_path = None print( "FastEstimator-Warn: Model {} could not be visualized by Traceability" .format(model.model_name)) finally: matplotlib.use(old_backend) else: file_path = None self.doc.append( "This model was not used by the Network during training." ) if file_path: with self.doc.create(Figure(position='ht!')) as fig: fig.append( Label( Marker(name=str(FEID(id(model))), prefix="model"))) fig.add_image( os.path.relpath(file_path, start=self.save_dir), width=NoEscape( r'1.0\textwidth,height=0.95\textheight,keepaspectratio' )) fig.add_caption( NoEscape( HrefFEID(FEID(id(model)), model.model_name).dumps()))
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any: """Convert an input value to a FESummaryTable table representation Args: inp: The input value to be converted. tables: A collection of tables representing objects which are used by the current stack of inputs. ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether functions can be in-lined or deserve their own tables). wrap_str: Whether literal string values should be wrapped inside extra quote marks. Returns: An FESummaryTable representation of the input. """ if isinstance(inp, str): inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp) if wrap_str: # Prevent extremely long strings from overflowing the table return NoEscape(r'\seqsplit{' + inp + '}') return inp elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)): if isinstance(inp, (int, float)): # Prevent extremely long numbers from overflowing the table return NoEscape(r'\seqsplit{' + str(inp) + '}') return inp elif hasattr(inp, '_fe_traceability_summary'): # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will # fall through to the class check at the end to get it's id. # noinspection PyProtectedMember,PyUnresolvedReferences tables.update(inp._fe_traceability_summary) inp_id = FEID(id(inp)) ret_ref.set_true() return HrefFEID(inp_id, tables[inp_id].name) elif inspect.ismethod(inp): parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str) return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")]) elif inspect.isfunction(inp) or inspect.isclass(inp): inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: if inspect.isfunction(inp) and inp.__name__ == "<lambda>": code = inp.__code__ var_names = code.co_varnames # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other # function (like one might do with LRScheduler), then the parse should work. flag = Flag() func_description = _parse_lambda(inp, tables, flag) or {} func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False) name = "lambda" path = None if not flag and func_description.keys() == {'vars', 'function'}: # This is a simple lambda function, so inline it instead of making a new table raw_vars = func_description['vars'].raw_input formatted_vars = [] for var in raw_vars: formatted_vars.append(var) formatted_vars.append(', ') if formatted_vars: formatted_vars.pop() # remove trailing comma return ContainerList(data=[ TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '') ]) else: name = inp.__name__ path = f"{inp.__module__}.{inp.__qualname__}" func_description = {} tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp), path=path, **func_description) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, _Function): inp_id = FEID(id(inp)) if inp_id not in tables: if inspect.ismethod(inp.func): path = _trace_value(inp.func, tables, ret_ref, wrap_str) elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'): path = f"{inp.func.__module__}.{inp.func.__qualname__}" else: path = None tables[inp_id] = FeSummaryTable(name=inp.name, fe_id=inp_id, target_type=type(inp.func), path=path) ret_ref.set_true() return HrefFEID(inp_id, inp.name) elif isinstance(inp, _PartialBind): return { "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True), "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input # unwrap kwargs back into a dict } elif isinstance(inp, _Command): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), escape_latex(inp.command), _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _Condition): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), " if ", _trace_value(inp.condition, tables, ret_ref, wrap_str), " else ", _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _BoundFn): flag = Flag() args = _trace_value(inp.args, tables, flag, wrap_str=False) kwargs = {} if isinstance(inp.args, _PartialBind): kwargs = args["kwargs"] args = args["args"] elif isinstance(args, dict): kwargs = args args = None if not flag and isinstance(inp.func, _Function): # The function args are simple, so inline this function in whatever is above it if isinstance(args, PyContainer): args = args.raw_input if isinstance(kwargs, PyContainer): kwargs = kwargs.raw_input formatted = ["("] args = args or () kwargs = kwargs or {} for arg in args: formatted.append(arg) formatted.append(", ") for key, value in kwargs.items(): formatted.append(key) formatted.append("=") formatted.append(value) formatted.append(", ") if len(formatted) > 1: formatted.pop() # Remove trailing comma formatted.append(")") if inspect.ismethod(inp.func.func): container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str) container_list.data.extend(formatted) return container_list return ContainerList(data=[inp.func.name, *formatted]) else: # The function args are complicated, so use the normal approach func_href = _trace_value(inp.func, tables, ret_ref, wrap_str) inp_id = func_href.fe_id inp_table = tables[inp_id] inp_table.args = args inp_table.kwargs = kwargs ret_ref.set_true() return func_href elif isinstance(inp, inspect.BoundArguments): args = inp.arguments args.pop('self', None) return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input # unwrap kwargs back into a dict elif isinstance(inp, _VarWrap): return inp.var elif isinstance(inp, (tf.keras.Model, torch.nn.Module)): # FE models should never actually get here since they are given summaries by trace_model() during fe.build() inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>" tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp)) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, list): return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp], truncate=_CollectionSizeLimit) elif isinstance(inp, tuple): return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, set): return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, dict): return PyContainer( data={ _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True) for k, v in inp.items() }, truncate=_CollectionSizeLimit) elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)): inp_type = type(inp) inp_id = FEID(id(inp)) if inp_id not in tables: if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)): if isinstance(inp, torch.Tensor): inp = inp.cpu().detach() inp.numpy() # In the elif here we're sure to be tf elif inp.dtype != tf.dtypes.variant: inp = inp.numpy() # The variant dtype can't be cast to numpy() rank = inp.ndim description = {'shape': inp.shape} if rank == 0 or (rank == 1 and inp.shape[0] <= 10): description['values'] = str(inp) tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description) ret_ref.set_true() return HrefFEID(inp_id, "tensor") # This should be the last elif elif hasattr(inp, '__class__'): inp_id = FEID(id(inp)) if inp_id not in tables: kwargs = {} path = None if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__: # Prevent circular recursion tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id) # This object isn't @traceable but does have some stored variables that we can summarize. kwargs = _trace_value({k: v for k, v in inp.__dict__.items() if not k.startswith("_")}, tables, ret_ref, wrap_str=False).raw_input path = "Not @traceable, so summary is approximate" tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), path=path, fe_id=inp_id, kwargs=kwargs) ret_ref.set_true() return HrefFEID(inp_id, inp.__class__.__name__) else: inp_id = FEID(id(inp)) if inp_id not in tables: tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id) ret_ref.set_true() return HrefFEID(inp_id, "an object")
def _document_models(self) -> None: """Add model summaries to the traceability document. """ with self.doc.create(Section("Models")): for model in humansorted(self.system.network.models, key=lambda m: m.model_name): if not isinstance(model, (tf.keras.Model, torch.nn.Module)): continue self.doc.append(NoEscape(r'\FloatBarrier')) with self.doc.create(Subsection(f"{model.model_name}")): if isinstance(model, tf.keras.Model): # Text Summary summary = [] model.summary(line_length=92, print_fn=lambda x: summary.append(x)) summary = "\n".join(summary) self.doc.append(Verbatim(summary)) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # noinspection PyBroadException try: file_path = os.path.join( self.figure_dir, f"FE_Model_{model.model_name}.pdf") tf.keras.utils.plot_model(model, to_file=file_path, show_shapes=True, expand_nested=True) # TODO - cap output image size like in the pytorch implementation in case of huge network # TODO - save raw .dot file in case system lacks graphviz except Exception: file_path = None print( f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability" ) elif isinstance(model, torch.nn.Module): if hasattr(model, 'fe_input_spec'): # Text Summary # noinspection PyUnresolvedReferences inputs = model.fe_input_spec.get_dummy_input() self.doc.append( Verbatim(pms.summary(model, inputs))) with self.doc.create(Center()): self.doc.append( HrefFEID(FEID(id(model)), model.model_name)) # Visual Summary # Import has to be done while matplotlib is using the Agg backend old_backend = matplotlib.get_backend() or 'Agg' matplotlib.use('Agg') # noinspection PyBroadException try: # Fake the IPython import when user isn't running from Jupyter sys.modules.setdefault('IPython', MagicMock()) sys.modules.setdefault('IPython.display', MagicMock()) import hiddenlayer as hl with Suppressor(): graph = hl.build_graph(model, inputs) graph = graph.build_dot() graph.attr( rankdir='TB' ) # Switch it to Top-to-Bottom instead of Left-to-Right graph.attr( size="200,200" ) # LaTeX \maxdim is around 575cm (226 inches) graph.attr(margin='0') # TODO - save raw .dot file in case system lacks graphviz file_path = graph.render( filename=f"FE_Model_{model.model_name}", directory=self.figure_dir, format='pdf', cleanup=True) except Exception: file_path = None print( "FastEstimator-Warn: Model {} could not be visualized by Traceability" .format(model.model_name)) finally: matplotlib.use(old_backend) else: self.doc.append( "This model was not used by the Network during training." ) if file_path: with self.doc.create(Figure(position='ht!')) as fig: fig.append( Label( Marker(name=str(FEID(id(model))), prefix="model"))) fig.add_image( os.path.relpath(file_path, start=self.save_dir), width=NoEscape( r'1.0\textwidth,height=0.95\textheight,keepaspectratio' )) fig.add_caption( NoEscape( HrefFEID(FEID(id(model)), model.model_name).dumps()))