Exemplo n.º 1
0
 def test_tree_bytes(self, shape, dtype, container):
     x = np.ones(shape, dtype=dtype)
     expected_bytes = (np.prod(x.shape) if x.ndim else 1) * x.itemsize
     self.assertEqual(utils.tree_bytes(container(x)), expected_bytes)
Exemplo n.º 2
0
    "config":
    Column("Config", lambda r: repr(r.module_details.module)),
    "owned_params":
    Column("Module params", format_owned_params),
    "input":
    Column("Input", format_input),
    "output":
    Column("Output", format_output),
    "params_size":
    Column("Param count",
           lambda r: "{:,}".format(utils.tree_size(r.module_details.params)),
           "right"),
    "params_bytes":
    Column(
        "Param bytes", lambda r: utils.format_bytes(
            utils.tree_bytes(r.module_details.params)), "right"),
}

DEFAULT_COLUMNS = ("module", "config", "owned_params", "input", "output",
                   "params_size", "params_bytes")
DEFAULT_FILTERS = ("has_output", )


def tabulate(
    f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState],
    *,
    columns: Optional[Sequence[str]] = DEFAULT_COLUMNS,
    filters: Optional[Sequence[str]] = DEFAULT_FILTERS,
    tabulate_kwargs={"tablefmt": "grid"},
) -> Callable[..., str]:
    # pylint: disable=line-too-long
Exemplo n.º 3
0
    "has_params": Filter(lambda r: bool(r.module_details.params)),
}

all_columns = {
    "module": Column("Module", format_call_stack),
    "config": Column("Config", lambda r: repr(r.module_details.module)),
    "owned_params": Column("Module params", format_owned_params),
    "input": Column("Input", format_input),
    "output": Column("Output", format_output),
    "params_size": Column(
        "Param count",
        lambda r: "{:,}".format(utils.tree_size(r.module_details.params)),
        "right"),
    "params_bytes": Column(
        "Param bytes",
        lambda r: utils.format_bytes(utils.tree_bytes(r.module_details.params)),
        "right"),
}

DEFAULT_COLUMNS = tuple(all_columns.keys())
DEFAULT_FILTERS = tuple(all_filters.keys())


def tabulate(
    f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState],
    *,
    columns: Optional[Sequence[str]] = DEFAULT_COLUMNS,
    filters: Optional[Sequence[str]] = DEFAULT_FILTERS,
    tabulate_kwargs={"tablefmt": "grid"},
) -> Callable[..., str]:
  # pylint: disable=line-too-long