Exemplo n.º 1
0
}

all_columns = {
    "module":
    Column("Module", format_call_stack),
    "config":
    Column("Config", lambda r: repr(r.module_details.module)),
    "owned_params":
    Column("Module params", format_owned_params),
    "input":
    Column("Input", format_input),
    "output":
    Column("Output", format_output),
    "params_size":
    Column("Param count",
           lambda r: "{:,}".format(utils.tree_size(r.module_details.params)),
           "right"),
    "params_bytes":
    Column(
        "Param bytes", lambda r: utils.format_bytes(
            utils.tree_bytes(r.module_details.params)), "right"),
}

DEFAULT_COLUMNS = ("module", "config", "owned_params", "input", "output",
                   "params_size", "params_bytes")
DEFAULT_FILTERS = ("has_output", )


def tabulate(
    f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState],
    *,
Exemplo n.º 2
0
 def test_tree_size(self, shape, dtype, container):
     x = np.ones(shape, dtype=dtype)
     expected_size = np.prod(x.shape) if x.ndim else 1
     self.assertEqual(utils.tree_size(container(x)), expected_size)