def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`.") # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) if jax.config.jax_experimental_name_stack and module_name: local_module_name = module_name.split("/")[-1] f = jax.named_call(f, name=local_module_name) if method_name != "__call__": f = jax.named_call(f, name=method_name) elif module_name: # TODO(lenamartens): remove this branch once jax_experimental_name_stack # flag is removed. cfg = config.get_config() if cfg.profiler_name_scopes and method_name == "__call__": local_module_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_module_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def remat_impl(*args, call_jaxpr: Optional[core.Jaxpr] = None, jaxpr: Optional[core.Jaxpr] = None, prevent_cse: bool, differentiated: bool, policy, is_gpu_platform: bool = False, concrete: bool = False, name: str = "checkpoint"): # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat) # name is not passed for remat2, defaults to "checkpoint" # TODO: remove call_jaxpr once we drop the remat call primitive if jaxpr is None: jaxpr = call_jaxpr assert jaxpr is not None assert not jaxpr.constvars del concrete, policy # Unused. if differentiated and prevent_cse: if config.jax_remat_opt_barrier: translation_rule = _remat_translation_using_opt_barrier elif is_gpu_platform: translation_rule = _remat_translation_using_while else: translation_rule = _remat_translation_using_cond else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr( jaxpr, (), *args) return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
def inner(scope_fn, repack_fn, variable_groups, rng_groups, args, kwargs): @functools.wraps(fn) def named(variable_groups, rng_groups): scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args, **kwargs) return y, repack_fn(scope) named = jax.named_call(named, name=name) return named(variable_groups, rng_groups)
def wrapper(*args, **kwargs): if base.inside_transform(): stateful_named_call = thread_hk_state_in_kwargs(jax.named_call) named_fun = stateful_named_call(fun, name=name) else: named_fun = jax.named_call(fun, name=name) out = named_fun(*args, **kwargs) return out
def wrapper(*args, **kwargs): if jax.config.jax_experimental_name_stack: return jax.named_call(fun, name=name)(*args, **kwargs) side_channel = {"non_jaxtypes": [], "treedef": None} wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel) if base.inside_transform(): wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)( wrapped_fun, name=name) else: wrapped_fun = jax.named_call(wrapped_fun, name=name) jax_types = wrapped_fun(*args, **kwargs) non_jaxtypes = side_channel["non_jaxtypes"] out_leaves = [ y if x is None else x for x, y in zip(jax_types, non_jaxtypes) ] out = jax.tree_unflatten(side_channel["treedef"], out_leaves) return out
def f_caller(x): y = jnp.tanh(x) z = jax.named_call(f_callee, name="callee")(y) return jnp.sin(z)
def g(x): branch = jax.named_call(lambda x: x) out = jax.lax.cond(True, branch, branch, x) return out
def named_call(f=None, name=None): """Adds a name to a function for profiling purposes.""" if f is None: return functools.partial(named_call, name=name) return jax.named_call(f, name=name)
def solve(item_embedding_table, user_history, item_gramian, users_from_batch, id_list, device_item_table_size, reg, batch_size, num_devices, cfg): """Gather item embeddings and solver for a batch of users.""" embedding_dim = item_embedding_table.shape[1] num_users = users_from_batch.shape[0] item_emb = gather_embeddings(item_embedding_table, user_history, device_item_table_size, batch_size, cfg.seq_len, embedding_dim, num_devices, cfg) # We use lax.convert_element_type directly intentionally instead of # tensor.astype, since XLA sometimes fuses it in a worse way such that >50% # of TPU time goes in convert_element_type ops. item_emb = jax.lax.convert_element_type(item_emb, jnp.float32) # Local compute. lambda_batch = jnp.einsum('bij,bik->bjk', item_emb, item_emb) mu_batch = jnp.einsum('bij->bj', item_emb) # Local segment sum over sharded batch. lambda_batch_summed = jax.ops.segment_sum(lambda_batch, jnp.asarray(id_list), num_segments=num_users) mu_batch_summed = jax.ops.segment_sum(mu_batch, jnp.asarray(id_list), num_segments=num_users) assert lambda_batch_summed.shape == (num_users, embedding_dim, embedding_dim) reg = jnp.broadcast_to(reg, [num_users]) reg = reg[Ellipsis, jnp.newaxis, jnp.newaxis] print(f'reg: {reg.shape}') lambda_batch_summed += cfg.unobserved_weight * item_gramian # G_i += G print(f'lambda_batch_summed: {lambda_batch_summed.shape}') process_reg = jnp.expand_dims(jnp.identity(embedding_dim), 0) * reg print(f'process_reg: {process_reg.shape}') post_lambda = lambda_batch_summed + jnp.expand_dims( jnp.identity(embedding_dim), 0) * reg # G_i += reg * I @jax.vmap def cg_solve(a, b): x, _ = jax.scipy.sparse.linalg.cg(a, b) return x @jax.vmap def cholesky_solve(a, b): factors = jsp.linalg.cho_factor(a, overwrite_a=True) return jsp.linalg.cho_solve(factors, b, overwrite_b=True) @jax.vmap def qr_solve(a, b): q, r = jax.lax.linalg.qr(a) return jax.lax.linalg.triangular_solve(r, q.T @ b, left_side=True, lower=False) solve_fn = None if cfg.linear_solver == 'lu': solve_fn = jnp.linalg.solve elif cfg.linear_solver == 'qr': solve_fn = qr_solve elif cfg.linear_solver == 'cholesky': solve_fn = cholesky_solve else: solve_fn = cg_solve solve_fn = jax.named_call(solve_fn, name='linear_solver') user_embeddings = solve_fn(post_lambda, mu_batch_summed) # U_i = G_i^{-1} b_i return user_embeddings