Exemple #1
0
  def wrapped(self, *args, **kwargs):
    """Calls the original method with a group name set before and after."""
    if not base.frame_stack:
      raise ValueError(
          "All `hk.Module`s must be initialized inside an `hk.transform`.")

    # Submodules are associated with this method. We allow users to associate
    # submodules with a different method than the one being called via
    # `@name_like("other_method")`. Interceptors and custom getters are still
    # provided the actual method name (e.g. "submodule_method_name" is only used
    # for naming submodules).
    submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name)

    frame = base.current_frame()
    state = base.ModuleState(module=self, method_name=submodule_method_name)
    with frame.module(state), _module_method_call(self, method_name):
      # hk.Module enters the module name scope for all methods.
      module_name = getattr(self, "module_name", None)
      f = functools.partial(unbound_method, self)
      f = functools.partial(run_interceptors, f, method_name, self)
      if jax.config.jax_experimental_name_stack and module_name:
        local_module_name = module_name.split("/")[-1]
        f = jax.named_call(f, name=local_module_name)
        if method_name != "__call__":
          f = jax.named_call(f, name=method_name)
      elif module_name:
        # TODO(lenamartens): remove this branch once jax_experimental_name_stack
        # flag is removed.
        cfg = config.get_config()
        if cfg.profiler_name_scopes and method_name == "__call__":
          local_module_name = module_name.split("/")[-1]
          f = stateful.named_call(f, name=local_module_name)

      out = f(*args, **kwargs)

      # Module names are set in the constructor. If `f` is the constructor then
      # its name will only be set **after** `f` has run. For methods other
      # than `__init__` we need the name before running in order to wrap their
      # execution with `named_call`.
      if module_name is None:
        module_name = getattr(self, "module_name", None)

      # Notify parent modules about our existence.
      if module_name is not None:
        for module_state in frame.module_stack:
          if module_state.module is not self:
            module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
    return out
Exemple #2
0
def remat_impl(*args,
               call_jaxpr: Optional[core.Jaxpr] = None,
               jaxpr: Optional[core.Jaxpr] = None,
               prevent_cse: bool,
               differentiated: bool,
               policy,
               is_gpu_platform: bool = False,
               concrete: bool = False,
               name: str = "checkpoint"):
    # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
    # name is not passed for remat2, defaults to "checkpoint"
    # TODO: remove call_jaxpr once we drop the remat call primitive
    if jaxpr is None:
        jaxpr = call_jaxpr
    assert jaxpr is not None
    assert not jaxpr.constvars

    del concrete, policy  # Unused.
    if differentiated and prevent_cse:
        if config.jax_remat_opt_barrier:
            translation_rule = _remat_translation_using_opt_barrier
        elif is_gpu_platform:
            translation_rule = _remat_translation_using_while
        else:
            translation_rule = _remat_translation_using_cond
    else:
        translation_rule = lambda *args, jaxpr: core.eval_jaxpr(
            jaxpr, (), *args)

    return jax.named_call(translation_rule,
                          name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
Exemple #3
0
 def inner(scope_fn, repack_fn, variable_groups, rng_groups, args, kwargs):
   @functools.wraps(fn)
   def named(variable_groups, rng_groups):
     scope = scope_fn(variable_groups, rng_groups)
     y = fn(scope, *args, **kwargs)
     return y, repack_fn(scope)
   named = jax.named_call(named, name=name)
   return named(variable_groups, rng_groups)
Exemple #4
0
 def wrapper(*args, **kwargs):
     if base.inside_transform():
         stateful_named_call = thread_hk_state_in_kwargs(jax.named_call)
         named_fun = stateful_named_call(fun, name=name)
     else:
         named_fun = jax.named_call(fun, name=name)
     out = named_fun(*args, **kwargs)
     return out
Exemple #5
0
    def wrapper(*args, **kwargs):
        if jax.config.jax_experimental_name_stack:
            return jax.named_call(fun, name=name)(*args, **kwargs)

        side_channel = {"non_jaxtypes": [], "treedef": None}
        wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel)
        if base.inside_transform():
            wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)(
                wrapped_fun, name=name)
        else:
            wrapped_fun = jax.named_call(wrapped_fun, name=name)

        jax_types = wrapped_fun(*args, **kwargs)

        non_jaxtypes = side_channel["non_jaxtypes"]
        out_leaves = [
            y if x is None else x for x, y in zip(jax_types, non_jaxtypes)
        ]
        out = jax.tree_unflatten(side_channel["treedef"], out_leaves)

        return out
Exemple #6
0
 def f_caller(x):
     y = jnp.tanh(x)
     z = jax.named_call(f_callee, name="callee")(y)
     return jnp.sin(z)
Exemple #7
0
 def g(x):
     branch = jax.named_call(lambda x: x)
     out = jax.lax.cond(True, branch, branch, x)
     return out
Exemple #8
0
def named_call(f=None, name=None):
    """Adds a name to a function for profiling purposes."""
    if f is None:
        return functools.partial(named_call, name=name)

    return jax.named_call(f, name=name)
Exemple #9
0
def solve(item_embedding_table, user_history, item_gramian, users_from_batch,
          id_list, device_item_table_size, reg, batch_size, num_devices, cfg):
    """Gather item embeddings and solver for a batch of users."""
    embedding_dim = item_embedding_table.shape[1]
    num_users = users_from_batch.shape[0]

    item_emb = gather_embeddings(item_embedding_table, user_history,
                                 device_item_table_size, batch_size,
                                 cfg.seq_len, embedding_dim, num_devices, cfg)

    # We use lax.convert_element_type directly intentionally instead of
    # tensor.astype, since XLA sometimes fuses it in a worse way such that >50%
    # of TPU time goes in convert_element_type ops.
    item_emb = jax.lax.convert_element_type(item_emb, jnp.float32)

    # Local compute.
    lambda_batch = jnp.einsum('bij,bik->bjk', item_emb, item_emb)
    mu_batch = jnp.einsum('bij->bj', item_emb)

    # Local segment sum over sharded batch.
    lambda_batch_summed = jax.ops.segment_sum(lambda_batch,
                                              jnp.asarray(id_list),
                                              num_segments=num_users)
    mu_batch_summed = jax.ops.segment_sum(mu_batch,
                                          jnp.asarray(id_list),
                                          num_segments=num_users)

    assert lambda_batch_summed.shape == (num_users, embedding_dim,
                                         embedding_dim)
    reg = jnp.broadcast_to(reg, [num_users])
    reg = reg[Ellipsis, jnp.newaxis, jnp.newaxis]
    print(f'reg: {reg.shape}')

    lambda_batch_summed += cfg.unobserved_weight * item_gramian  # G_i += G
    print(f'lambda_batch_summed: {lambda_batch_summed.shape}')
    process_reg = jnp.expand_dims(jnp.identity(embedding_dim), 0) * reg
    print(f'process_reg: {process_reg.shape}')
    post_lambda = lambda_batch_summed + jnp.expand_dims(
        jnp.identity(embedding_dim), 0) * reg  # G_i += reg * I

    @jax.vmap
    def cg_solve(a, b):
        x, _ = jax.scipy.sparse.linalg.cg(a, b)
        return x

    @jax.vmap
    def cholesky_solve(a, b):
        factors = jsp.linalg.cho_factor(a, overwrite_a=True)
        return jsp.linalg.cho_solve(factors, b, overwrite_b=True)

    @jax.vmap
    def qr_solve(a, b):
        q, r = jax.lax.linalg.qr(a)
        return jax.lax.linalg.triangular_solve(r,
                                               q.T @ b,
                                               left_side=True,
                                               lower=False)

    solve_fn = None
    if cfg.linear_solver == 'lu':
        solve_fn = jnp.linalg.solve
    elif cfg.linear_solver == 'qr':
        solve_fn = qr_solve
    elif cfg.linear_solver == 'cholesky':
        solve_fn = cholesky_solve
    else:
        solve_fn = cg_solve

    solve_fn = jax.named_call(solve_fn, name='linear_solver')

    user_embeddings = solve_fn(post_lambda,
                               mu_batch_summed)  # U_i = G_i^{-1} b_i
    return user_embeddings