def trace_model(model: Model, model_idx: int, model_fn: Any, optimizer_fn: Any, weights_path: Any) -> Model:
    """A function to add traceability information to an FE-compiled model.

    Args:
        model: The model to be made traceable.
        model_idx: Which of the return values from the `model_fn` is this model (or -1 if only a single return value).
        model_fn: The function used to generate this model.
        optimizer_fn: The thing used to define this model's optimizer.
        weights_path: The path to the weights for this model.

    Returns:
        The `model`, but now with an fe_summary() method.
    """
    tables = {}
    description = {'definition': _trace_value(model_fn, tables, ret_ref=Flag())}
    if model_idx != -1:
        description['index'] = model_idx
    if optimizer_fn or isinstance(optimizer_fn, list) and optimizer_fn[0] is not None:
        description['optimizer'] = _trace_value(
            optimizer_fn[model_idx] if isinstance(optimizer_fn, list) else optimizer_fn, tables, ret_ref=Flag())
    if weights_path:
        description['weights'] = _trace_value(weights_path, tables, ret_ref=Flag())
    fe_id = FEID(id(model))
    tbl = FeSummaryTable(name=model.model_name, fe_id=fe_id, target_type=type(model), **description)
    tables[fe_id] = tbl
    # Have to put this in a ChainMap b/c dict gets put into model._layers automatically somehow
    model._fe_traceability_summary = ChainMap(tables)

    # Use MethodType to bind the method to the class instance
    setattr(model, 'fe_summary', types.MethodType(fe_summary, model))
    return model
Example #2
0
 def test_lambda_simple(self):
     tables = {}
     ret_ref = Flag()
     epochs = 8
     resp = _parse_lambda_fallback(
         lambda step: cosine_decay(step,
                                   cycle_length=3750,
                                   init_lr=1e-3 + 1
                                   if epochs > 2 else 1e-4), tables,
         ret_ref)
     self.assertIsInstance(
         resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual(
         {}, tables,
         "_parse_lambda_fallback should not have generated any tables for this lambda"
     )
     self.assertIn('function', resp,
                   "response should contain a function summary")
     self.assertEqual(
         r"cosine\_decay(step, cycle\_length=3750, init\_lr=1e{-}3 + 1 if epochs > 2 else 1e{-}4)",
         resp['function'])
     self.assertIn('kwargs', resp, "response should contain kwargs")
     self.assertIsInstance(resp['kwargs'], dict,
                           "kwargs should be a dictionary")
     self.assertDictEqual({NoEscape('epochs'): NoEscape(r'\seqsplit{8}')},
                          resp['kwargs'])
Example #3
0
 def test_multi_lambda_different_refs(self):
     tables = {}
     ret_ref = Flag()
     other, resp = lambda x: np.log2(128) + x, _parse_lambda_fallback(lambda x: np.ceil(128) + x, tables, ret_ref)
     self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertEqual(r"np.ceil(128) + x", resp['function'])
Example #4
0
 def test_multi_lambda_same_fn(self):
     tables = {}
     ret_ref = Flag()
     other, resp = lambda x: x + 'x1', _parse_lambda_fallback(lambda x: x + 'x1', tables, ret_ref)
     self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertEqual(r"x + 'x1'", resp['function'])
Example #5
0
 def test_multi_lambda_different_args(self):
     tables = {}
     ret_ref = Flag()
     resp, other = _parse_lambda_fallback(lambda x: x + 5, tables, ret_ref), lambda y: y + 5
     self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertEqual(r"x + 5", resp['function'])
Example #6
0
 def test_nested_lambda_different_strings(self):
     tables = {}
     ret_ref = Flag()
     resp = _parse_lambda_fallback(lambda x, y: x(lambda x, y: x(y) + 'x' + 'y') + 'x', tables, ret_ref)
     self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertEqual(r"x(lambda x, y: x(y) + 'x' + 'y') + 'x'", resp['function'])
Example #7
0
 def test_lambda_inlining(self):
     tables = {}
     ret_ref = Flag()
     resp = _parse_lambda_fallback(
         lambda a, b=[0, 1, 2, 3], c={'x': "it's"}: b[0] + a * c['x'] - {0
                                                                         for j in range(5)}, tables, ret_ref)
     self.assertIsInstance(resp, dict, "_parse_lambda_fallback should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertEqual(r"b{[}0{]} + a * c{[}'x'{]} {-} \{0 for j in range(5)\}", resp['function'])
Example #8
0
 def test_simple_lambda_inlining(self):
     tables = {}
     ret_ref = Flag()
     resp = _trace_value(lambda x: x + 5, tables, ret_ref)
     self.assertEqual(
         {}, tables,
         "trace_value should not have generated any tables for this lambda")
     self.assertIsInstance(
         resp, ContainerList,
         "trace_value should return a ContainerList describing the function"
     )
 def init(self, *args, **kwargs):
     if not hasattr(self, '_fe_state_whitelist'):
         self._fe_state_whitelist = whitelist
     if not hasattr(self, '_fe_state_blacklist'):
         self._fe_state_blacklist = blacklist + (
             '_fe_state_whitelist', '_fe_state_blacklist', '_fe_base_init')
     if not hasattr(self, '_fe_traceability_summary'):
         bound_args = inspect.signature(base_init).bind(self, *args, **kwargs)
         bound_args.apply_defaults()
         tables = {}
         _trace_value(_BoundFn(self, bound_args), tables, ret_ref=Flag())
         self._fe_traceability_summary = tables
     base_init(self, *args, **kwargs)
Example #10
0
 def test_conditional_lambda(self):
     tables = {}
     ret_ref = Flag()
     a = 5
     resp = _parse_lambda(
         lambda x: [0, 1] if x > 10 else (1, a) if x > 8 else {1, 3} if x > 6 else {1: 5} if x < 0 else {
             'key': 0, 'key2': 1
         },
         tables,
         ret_ref)
     self.assertIsInstance(resp, dict, "_parse_lambda should return a dictionary")
     self.assertEqual({}, tables, "_parse_lambda should not have generated any tables for this lambda")
     self.assertIn('function', resp, "response should contain a function summary")
     self.assertIsInstance(resp['function'], ContainerList, "_parse_lambda should return a ContainerList")
Example #11
0
 def test_traceable_summary(self):
     thing = TraceableObject('x', 11)
     tables = {}
     ret_ref = Flag()
     resp = _trace_value(thing, tables, ret_ref)
     self.assertIsInstance(resp, HrefFEID, "trace_value should have returned an Href")
     self.assertEqual(1, len(tables), "trace_value should have generated 1 table")
     self.assertIn(resp.fe_id, tables, "Object summary table is missing")
     table = tables[resp.fe_id]
     self.assertEqual(len(table.kwargs), 2, "trace_value should have found 2 variables to display")
     self.assertIn('a', table.kwargs, "the variable 'a' should have been found")
     self.assertEqual(NoEscape(r"\seqsplit{`x'}"), table.kwargs['a'], "member variable value improperly recorded")
     self.assertIn('b', table.kwargs, "the variable 'b' should have been found")
     self.assertEqual(NoEscape(r"\seqsplit{11}"), table.kwargs['b'], "member variable value improperly recorded")
def _trace_value(inp: Any, tables: Dict[FEID, FeSummaryTable], ret_ref: Flag, wrap_str: bool = True) -> Any:
    """Convert an input value to a FESummaryTable table representation

    Args:
        inp: The input value to be converted.
        tables: A collection of tables representing objects which are used by the current stack of inputs.
        ret_ref: A flag to indicate that _trace_value is returning a reference (this is used to figure out whether
            functions can be in-lined or deserve their own tables).
        wrap_str: Whether literal string values should be wrapped inside extra quote marks.

    Returns:
        An FESummaryTable representation of the input.
    """
    if isinstance(inp, str):
        inp = f"`{escape_latex(inp)}'" if wrap_str else escape_latex(inp)
        if wrap_str:
            # Prevent extremely long strings from overflowing the table
            return NoEscape(r'\seqsplit{' + inp + '}')
        return inp
    elif isinstance(inp, (int, float, bool, type(None), HrefFEID, FEID, PyContainer)):
        if isinstance(inp, (int, float)):
            # Prevent extremely long numbers from overflowing the table
            return NoEscape(r'\seqsplit{' + str(inp) + '}')
        return inp
    elif hasattr(inp, '_fe_traceability_summary'):
        # The first time a traceable object goes through here it won't have it's summary instantiated yet, so it will
        # fall through to the class check at the end to get it's id.
        # noinspection PyProtectedMember,PyUnresolvedReferences
        tables.update(inp._fe_traceability_summary)
        inp_id = FEID(id(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, tables[inp_id].name)
    elif inspect.ismethod(inp):
        parent = _trace_value(inp.__self__, tables, ret_ref, wrap_str)
        return ContainerList(data=[parent, escape_latex(f".{inp.__name__}")])
    elif inspect.isfunction(inp) or inspect.isclass(inp):
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            if inspect.isfunction(inp) and inp.__name__ == "<lambda>":
                code = inp.__code__
                var_names = code.co_varnames
                # Attempt to figure out what the lambda function is doing. If it is being used only to invoke some other
                # function (like one might do with LRScheduler), then the parse should work.
                flag = Flag()
                func_description = _parse_lambda(inp, tables, flag) or {}
                func_description['vars'] = _trace_value(var_names, tables, flag, wrap_str=False)
                name = "lambda"
                path = None
                if not flag and func_description.keys() == {'vars', 'function'}:
                    # This is a simple lambda function, so inline it instead of making a new table
                    raw_vars = func_description['vars'].raw_input
                    formatted_vars = []
                    for var in raw_vars:
                        formatted_vars.append(var)
                        formatted_vars.append(', ')
                    if formatted_vars:
                        formatted_vars.pop()  # remove trailing comma
                    return ContainerList(data=[
                        TextColor('cyan', f"{name} "), *formatted_vars, ": ", func_description.get('function', '')
                    ])
            else:
                name = inp.__name__
                path = f"{inp.__module__}.{inp.__qualname__}"
                func_description = {}
            tables[inp_id] = FeSummaryTable(name=name,
                                            fe_id=inp_id,
                                            target_type=type(inp),
                                            path=path,
                                            **func_description)
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, _Function):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if inspect.ismethod(inp.func):
                path = _trace_value(inp.func, tables, ret_ref, wrap_str)
            elif hasattr(inp.func, '__module__') and hasattr(inp.func, '__qualname__'):
                path = f"{inp.func.__module__}.{inp.func.__qualname__}"
            else:
                path = None
            tables[inp_id] = FeSummaryTable(name=inp.name, fe_id=inp_id, target_type=type(inp.func), path=path)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.name)
    elif isinstance(inp, _PartialBind):
        return {
            "args": _trace_value(inp.args, tables, ret_ref, wrap_str=True),
            "kwargs": _trace_value(inp.kwargs, tables, ret_ref, wrap_str).raw_input  # unwrap kwargs back into a dict
        }
    elif isinstance(inp, _Command):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            escape_latex(inp.command),
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _Condition):
        return ContainerList(data=[
            _trace_value(inp.left, tables, ret_ref, wrap_str),
            " if ",
            _trace_value(inp.condition, tables, ret_ref, wrap_str),
            " else ",
            _trace_value(inp.right, tables, ret_ref, wrap_str)
        ])
    elif isinstance(inp, _BoundFn):
        flag = Flag()
        args = _trace_value(inp.args, tables, flag, wrap_str=False)
        kwargs = {}
        if isinstance(inp.args, _PartialBind):
            kwargs = args["kwargs"]
            args = args["args"]
        elif isinstance(args, dict):
            kwargs = args
            args = None
        if not flag and isinstance(inp.func, _Function):
            # The function args are simple, so inline this function in whatever is above it
            if isinstance(args, PyContainer):
                args = args.raw_input
            if isinstance(kwargs, PyContainer):
                kwargs = kwargs.raw_input
            formatted = ["("]
            args = args or ()
            kwargs = kwargs or {}
            for arg in args:
                formatted.append(arg)
                formatted.append(", ")
            for key, value in kwargs.items():
                formatted.append(key)
                formatted.append("=")
                formatted.append(value)
                formatted.append(", ")
            if len(formatted) > 1:
                formatted.pop()  # Remove trailing comma
            formatted.append(")")
            if inspect.ismethod(inp.func.func):
                container_list = _trace_value(inp.func.func, tables, ret_ref, wrap_str)
                container_list.data.extend(formatted)
                return container_list
            return ContainerList(data=[inp.func.name, *formatted])
        else:
            # The function args are complicated, so use the normal approach
            func_href = _trace_value(inp.func, tables, ret_ref, wrap_str)
            inp_id = func_href.fe_id
            inp_table = tables[inp_id]
            inp_table.args = args
            inp_table.kwargs = kwargs
            ret_ref.set_true()
            return func_href
    elif isinstance(inp, inspect.BoundArguments):
        args = inp.arguments
        args.pop('self', None)
        return _trace_value(args, tables, ret_ref, wrap_str=False).raw_input  # unwrap kwargs back into a dict
    elif isinstance(inp, _VarWrap):
        return inp.var
    elif isinstance(inp, (tf.keras.Model, torch.nn.Module)):
        # FE models should never actually get here since they are given summaries by trace_model() during fe.build()
        inp_id = FEID(id(inp))
        if inp_id in tables:
            name = tables[inp_id].name
        else:
            name = inp.model_name if hasattr(inp, 'model_name') else "<Unknown Model Name>"
            tables[inp_id] = FeSummaryTable(name=name, fe_id=inp_id, target_type=type(inp))
        ret_ref.set_true()
        return HrefFEID(inp_id, name)
    elif isinstance(inp, list):
        return PyContainer(data=[_trace_value(x, tables, ret_ref, wrap_str) for x in inp],
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, tuple):
        return PyContainer(data=tuple([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, set):
        return PyContainer(data=set([_trace_value(x, tables, ret_ref, wrap_str) for x in inp]),
                           truncate=_CollectionSizeLimit)
    elif isinstance(inp, dict):
        return PyContainer(
            data={
                _trace_value(k, tables, ret_ref, wrap_str=wrap_str): _trace_value(v, tables, ret_ref, wrap_str=True)
                for k,
                v in inp.items()
            },
            truncate=_CollectionSizeLimit)
    elif isinstance(inp, (tf.Tensor, torch.Tensor, np.ndarray, tf.Variable)):
        inp_type = type(inp)
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            if isinstance(inp, (tf.Tensor, torch.Tensor, tf.Variable)):
                if isinstance(inp, torch.Tensor):
                    inp = inp.cpu().detach()
                    inp.numpy()
                # In the elif here we're sure to be tf
                elif inp.dtype != tf.dtypes.variant:
                    inp = inp.numpy()  # The variant dtype can't be cast to numpy()
            rank = inp.ndim
            description = {'shape': inp.shape}
            if rank == 0 or (rank == 1 and inp.shape[0] <= 10):
                description['values'] = str(inp)
            tables[inp_id] = FeSummaryTable(name="tensor", fe_id=inp_id, target_type=inp_type, **description)
        ret_ref.set_true()
        return HrefFEID(inp_id, "tensor")
    # This should be the last elif
    elif hasattr(inp, '__class__'):
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            kwargs = {}
            path = None
            if hasattr(inp, '__dict__') and '_fe_state_whitelist' not in inp.__dict__:
                # Prevent circular recursion
                tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__, target_type=type(inp), fe_id=inp_id)
                # This object isn't @traceable but does have some stored variables that we can summarize.
                kwargs = _trace_value({k: v
                                       for k, v in inp.__dict__.items() if not k.startswith("_")},
                                      tables,
                                      ret_ref,
                                      wrap_str=False).raw_input
                path = "Not @traceable, so summary is approximate"
            tables[inp_id] = FeSummaryTable(name=inp.__class__.__name__,
                                            target_type=type(inp),
                                            path=path,
                                            fe_id=inp_id,
                                            kwargs=kwargs)
        ret_ref.set_true()
        return HrefFEID(inp_id, inp.__class__.__name__)
    else:
        inp_id = FEID(id(inp))
        if inp_id not in tables:
            tables[inp_id] = FeSummaryTable(name="an object", target_type=type(inp), fe_id=inp_id)
        ret_ref.set_true()
        return HrefFEID(inp_id, "an object")