Example #1
0
            logging.info('%s: %s %s', prefix, type(inputs),
                         return_shapes(inputs))
        return inputs

    return init_fun, apply_fun


# Staxlayer binding to python variables
# ------------------------------------------------------------------------------
# Stax params-tree leaf type to mark bound subtrees references.
class _TreeMarker(dict):
    pass


# Add this leaf-type to JAX's tree-walker.
_register_pytree_node(_TreeMarker, lambda xs: (tuple(), None),
                      lambda _, xs: _TreeMarker())


# TODO(levskaya, rsepassi): abstract away tuple-subclassing to StaxLayer?
class Share(tuple):
    """Layer parameter caching function to allow weight sharing.

  Args:
    A staxlayer: an (init_fun, apply_fun) pair.

  Returns:
    A 'parameter-bound' staxlayer that can be assigned to a python variable.
  Wherever this value is needed elsewhere in the stax tree, call this bound
  variable and all occurrences will share parameters that will automatically
  be updated by Stax optimizers.
  """
Example #2
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from jax.tree_util import register_pytree_node as _register_pytree_node


# Staxlayer binding to python variables
# ------------------------------------------------------------------------------
# Stax params-tree leaf type to mark bound subtrees references.
class _TreeMarker(dict):
    pass


# Add this leaf-type to JAX's tree-walker.
_register_pytree_node(_TreeMarker, lambda xs: (tuple(), None),
                      lambda _, xs: _TreeMarker())


# TODO(lukaszkaiser): make this the base layer class (share by object).
class Share(tuple):
    """Layer parameter caching function to allow weight sharing.

  Args:
    A staxlayer: an (init_fun, apply_fun) pair.

  Returns:
    A 'parameter-bound' staxlayer that can be assigned to a python variable.
  Wherever this value is needed elsewhere in the stax tree, call this bound
  variable and all occurrences will share parameters that will automatically
  be updated by Stax optimizers.
  """