def trace_model(model: Model, model_idx: int, model_fn: Any, optimizer_fn: Any, weights_path: Any) -> Model: """A function to add traceability information to an FE-compiled model. Args: model: The model to be made traceable. model_idx: Which of the return values from the `model_fn` is this model (or -1 if only a single return value). model_fn: The function used to generate this model. optimizer_fn: The thing used to define this model's optimizer. weights_path: The path to the weights for this model. Returns: The `model`, but now with an fe_summary() method. """ tables = {} description = {'definition': _trace_value(model_fn, tables, ret_ref=Flag())} if model_idx != -1: description['index'] = model_idx if optimizer_fn or isinstance(optimizer_fn, list) and optimizer_fn[0] is not None: description['optimizer'] = _trace_value( optimizer_fn[model_idx] if isinstance(optimizer_fn, list) else optimizer_fn, tables, ret_ref=Flag()) if weights_path: description['weights'] = _trace_value(weights_path, tables, ret_ref=Flag()) fe_id = FEID(id(model)) tbl = FeSummaryTable(name=model.model_name, fe_id=fe_id, target_type=type(model), **description) tables[fe_id] = tbl # Have to put this in a ChainMap b/c dict gets put into model._layers automatically somehow model._fe_traceability_summary = ChainMap(tables) # Use MethodType to bind the method to the class instance setattr(model, 'fe_summary', types.MethodType(fe_summary, model)) return model
def test_lambda_simple(self): tables = {} ret_ref = Flag() epochs = 8 resp = _parse_lambda_fallback( lambda step: cosine_decay(step, cycle_length=3750, init_lr=1e-3 + 1 if epochs > 2 else 1e-4), tables, ret_ref) self.assertIsInstance( resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual( {}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda" ) self.assertIn('function', resp, "response should contain a function summary") self.assertEqual( r"cosine\_decay(step, cycle\_length=3750, init\_lr=1e{-}3 + 1 if epochs > 2 else 1e{-}4)", resp['function']) self.assertIn('kwargs', resp, "response should contain kwargs") self.assertIsInstance(resp['kwargs'], dict, "kwargs should be a dictionary") self.assertDictEqual({NoEscape('epochs'): NoEscape(r'\seqsplit{8}')}, resp['kwargs'])
def test_multi_lambda_different_refs(self): tables = {} ret_ref = Flag() other, resp = lambda x: np.log2(128) + x, _parse_lambda_fallback(lambda x: np.ceil(128) + x, tables, ret_ref) self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertEqual(r"np.ceil(128) + x", resp['function'])
def test_multi_lambda_same_fn(self): tables = {} ret_ref = Flag() other, resp = lambda x: x + 'x1', _parse_lambda_fallback(lambda x: x + 'x1', tables, ret_ref) self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertEqual(r"x + 'x1'", resp['function'])
def test_multi_lambda_different_args(self): tables = {} ret_ref = Flag() resp, other = _parse_lambda_fallback(lambda x: x + 5, tables, ret_ref), lambda y: y + 5 self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertEqual(r"x + 5", resp['function'])
def test_nested_lambda_different_strings(self): tables = {} ret_ref = Flag() resp = _parse_lambda_fallback(lambda x, y: x(lambda x, y: x(y) + 'x' + 'y') + 'x', tables, ret_ref) self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertEqual(r"x(lambda x, y: x(y) + 'x' + 'y') + 'x'", resp['function'])
def test_lambda_inlining(self): tables = {} ret_ref = Flag() resp = _parse_lambda_fallback( lambda a, b=[0, 1, 2, 3], c={'x': "it's"}: b[0] + a * c['x'] - {0 for j in range(5)}, tables, ret_ref) self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertEqual(r"b{[}0{]} + a * c{[}'x'{]} {-} \{0 for j in range(5)\}", resp['function'])
def test_simple_lambda_inlining(self): tables = {} ret_ref = Flag() resp = _trace_value(lambda x: x + 5, tables, ret_ref) self.assertEqual( {}, tables, "trace_value should not have generated any tables for this lambda") self.assertIsInstance( resp, ContainerList, "trace_value should return a ContainerList describing the function" )
def init(self, *args, **kwargs): if not hasattr(self, '_fe_state_whitelist'): self._fe_state_whitelist = whitelist if not hasattr(self, '_fe_state_blacklist'): self._fe_state_blacklist = blacklist + ( '_fe_state_whitelist', '_fe_state_blacklist', '_fe_base_init') if not hasattr(self, '_fe_traceability_summary'): bound_args = inspect.signature(base_init).bind(self, *args, **kwargs) bound_args.apply_defaults() tables = {} _trace_value(_BoundFn(self, bound_args), tables, ret_ref=Flag()) self._fe_traceability_summary = tables base_init(self, *args, **kwargs)
def test_conditional_lambda(self): tables = {} ret_ref = Flag() a = 5 resp = _parse_lambda( lambda x: [0, 1] if x > 10 else (1, a) if x > 8 else {1, 3} if x > 6 else {1: 5} if x < 0 else { 'key': 0, 'key2': 1 }, tables, ret_ref) self.assertIsInstance(resp, dict, "_parse_lambda should return a dictionary") self.assertEqual({}, tables, "_parse_lambda should not have generated any tables for this lambda") self.assertIn('function', resp, "response should contain a function summary") self.assertIsInstance(resp['function'], ContainerList, "_parse_lambda should return a ContainerList")
def test_traceable_summary(self): thing = TraceableObject('x', 11) tables = {} ret_ref = Flag() resp = _trace_value(thing, tables, ret_ref) self.assertIsInstance(resp, HrefFEID, "trace_value should have returned an Href") self.assertEqual(1, len(tables), "trace_value should have generated 1 table") self.assertIn(resp.fe_id, tables, "Object summary table is missing") table = tables[resp.fe_id] self.assertEqual(len(table.kwargs), 2, "trace_value should have found 2 variables to display") self.assertIn('a', table.kwargs, "the variable 'a' should have been found") self.assertEqual(NoEscape(r"\seqsplit{`x'}"), table.kwargs['a'], "member variable value improperly recorded") self.assertIn('b', table.kwargs, "the variable 'b' should have been found") self.assertEqual(NoEscape(r"\seqsplit{11}"), table.kwargs['b'], "member variable value improperly recorded")
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any: """Convert an input value to a FESummaryTable table representation Args: inp: The input value to be converted. tables: A collection of tables representing objects which are used by the current stack of inputs. ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether functions can be in-lined or deserve their own tables). wrap_str: Whether literal string values should be wrapped inside extra quote marks. Returns: An FESummaryTable representation of the input. """ if isinstance(inp, str): inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp) if wrap_str: # Prevent extremely long strings from overflowing the table return NoEscape(r'\seqsplit{' + inp + '}') return inp elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)): if isinstance(inp, (int, float)): # Prevent extremely long numbers from overflowing the table return NoEscape(r'\seqsplit{' + str(inp) + '}') return inp elif hasattr(inp, '_fe_traceability_summary'): # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will # fall through to the class check at the end to get it's id. # noinspection PyProtectedMember,PyUnresolvedReferences tables.update(inp._fe_traceability_summary) inp_id = FEID(id(inp)) ret_ref.set_true() return HrefFEID(inp_id, tables[inp_id].name) elif inspect.ismethod(inp): parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str) return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")]) elif inspect.isfunction(inp) or inspect.isclass(inp): inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: if inspect.isfunction(inp) and inp.__name__ == "<lambda>": code = inp.__code__ var_names = code.co_varnames # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other # function (like one might do with LRScheduler), then the parse should work. flag = Flag() func_description = _parse_lambda(inp, tables, flag) or {} func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False) name = "lambda" path = None if not flag and func_description.keys() == {'vars', 'function'}: # This is a simple lambda function, so inline it instead of making a new table raw_vars = func_description['vars'].raw_input formatted_vars = [] for var in raw_vars: formatted_vars.append(var) formatted_vars.append(', ') if formatted_vars: formatted_vars.pop() # remove trailing comma return ContainerList(data=[ TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '') ]) else: name = inp.__name__ path = f"{inp.__module__}.{inp.__qualname__}" func_description = {} tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp), path=path, **func_description) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, _Function): inp_id = FEID(id(inp)) if inp_id not in tables: if inspect.ismethod(inp.func): path = _trace_value(inp.func, tables, ret_ref, wrap_str) elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'): path = f"{inp.func.__module__}.{inp.func.__qualname__}" else: path = None tables[inp_id] = FeSummaryTable(name=inp.name, fe_id=inp_id, target_type=type(inp.func), path=path) ret_ref.set_true() return HrefFEID(inp_id, inp.name) elif isinstance(inp, _PartialBind): return { "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True), "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input # unwrap kwargs back into a dict } elif isinstance(inp, _Command): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), escape_latex(inp.command), _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _Condition): return ContainerList(data=[ _trace_value(inp.left, tables, ret_ref, wrap_str), " if ", _trace_value(inp.condition, tables, ret_ref, wrap_str), " else ", _trace_value(inp.right, tables, ret_ref, wrap_str) ]) elif isinstance(inp, _BoundFn): flag = Flag() args = _trace_value(inp.args, tables, flag, wrap_str=False) kwargs = {} if isinstance(inp.args, _PartialBind): kwargs = args["kwargs"] args = args["args"] elif isinstance(args, dict): kwargs = args args = None if not flag and isinstance(inp.func, _Function): # The function args are simple, so inline this function in whatever is above it if isinstance(args, PyContainer): args = args.raw_input if isinstance(kwargs, PyContainer): kwargs = kwargs.raw_input formatted = ["("] args = args or () kwargs = kwargs or {} for arg in args: formatted.append(arg) formatted.append(", ") for key, value in kwargs.items(): formatted.append(key) formatted.append("=") formatted.append(value) formatted.append(", ") if len(formatted) > 1: formatted.pop() # Remove trailing comma formatted.append(")") if inspect.ismethod(inp.func.func): container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str) container_list.data.extend(formatted) return container_list return ContainerList(data=[inp.func.name, *formatted]) else: # The function args are complicated, so use the normal approach func_href = _trace_value(inp.func, tables, ret_ref, wrap_str) inp_id = func_href.fe_id inp_table = tables[inp_id] inp_table.args = args inp_table.kwargs = kwargs ret_ref.set_true() return func_href elif isinstance(inp, inspect.BoundArguments): args = inp.arguments args.pop('self', None) return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input # unwrap kwargs back into a dict elif isinstance(inp, _VarWrap): return inp.var elif isinstance(inp, (tf.keras.Model, torch.nn.Module)): # FE models should never actually get here since they are given summaries by trace_model() during fe.build() inp_id = FEID(id(inp)) if inp_id in tables: name = tables[inp_id].name else: name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>" tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp)) ret_ref.set_true() return HrefFEID(inp_id, name) elif isinstance(inp, list): return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp], truncate=_CollectionSizeLimit) elif isinstance(inp, tuple): return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, set): return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]), truncate=_CollectionSizeLimit) elif isinstance(inp, dict): return PyContainer( data={ _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True) for k, v in inp.items() }, truncate=_CollectionSizeLimit) elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)): inp_type = type(inp) inp_id = FEID(id(inp)) if inp_id not in tables: if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)): if isinstance(inp, torch.Tensor): inp = inp.cpu().detach() inp.numpy() # In the elif here we're sure to be tf elif inp.dtype != tf.dtypes.variant: inp = inp.numpy() # The variant dtype can't be cast to numpy() rank = inp.ndim description = {'shape': inp.shape} if rank == 0 or (rank == 1 and inp.shape[0] <= 10): description['values'] = str(inp) tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description) ret_ref.set_true() return HrefFEID(inp_id, "tensor") # This should be the last elif elif hasattr(inp, '__class__'): inp_id = FEID(id(inp)) if inp_id not in tables: kwargs = {} path = None if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__: # Prevent circular recursion tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id) # This object isn't @traceable but does have some stored variables that we can summarize. kwargs = _trace_value({k: v for k, v in inp.__dict__.items() if not k.startswith("_")}, tables, ret_ref, wrap_str=False).raw_input path = "Not @traceable, so summary is approximate" tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), path=path, fe_id=inp_id, kwargs=kwargs) ret_ref.set_true() return HrefFEID(inp_id, inp.__class__.__name__) else: inp_id = FEID(id(inp)) if inp_id not in tables: tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id) ret_ref.set_true() return HrefFEID(inp_id, "an object")