Пример #1
0
                def register(cls):
                    """Registers a class as a JAX pytree node."""
                    def flatten(linop):
                        param_names = set(non_shape_params[cls.__name__])
                        components = {
                            param_name: value
                            for param_name, value in linop.parameters.items()
                            if param_name in param_names
                        }
                        metadata = {
                            param_name: value
                            for param_name, value in linop.parameters.items()
                            if param_name not in param_names
                        }
                        if components:
                            keys, values = zip(*sorted(components.items()))
                        else:
                            keys, values = (), ()
                        return values, (keys, metadata)

                    def unflatten(info, xs):
                        keys, metadata = info
                        parameters = dict(list(zip(keys, xs)), **metadata)
                        return cls(**parameters)

                    tree_util.register_pytree_node(cls, flatten, unflatten)
Пример #2
0
def differentiable(cls: Type[T]) -> Type[T]:

    keys = _get_keys(cls)

    def _tree_flatten(node: Module) -> Tuple[Tuple[Dict[str, Any]], Dict[str, Any]]:
        children = {}
        aux_data = {}
        for key in keys[_DIFFERENTIABLE]:
            children[key] = getattr(node, key)

        for key in keys[_NON_DIFFERENTIABLE]:
            aux_data[key] = getattr(node, key)

        logger.debug('=' * 50)
        logger.debug('flatten: %s', cls)
        logger.debug('aux_data: %s', aux_data)
        logger.debug('children: %s', children)
        return (children,), aux_data

    def _tree_unflatten(aux_data: Tuple[Dict[str, Any]], children: Dict[str, Any]) -> Module:
        logger.debug('=' * 50)
        logger.debug('unflatten: %s', cls)
        logger.debug('aux_data: %s', aux_data)
        logger.debug('children: %s', children)
        return cls(**aux_data, **children[0])  # type: ignore

    register_pytree_node(cls, _tree_flatten, _tree_unflatten)

    return cls
Пример #3
0
def register_pytree_namedtuple(cls):
    register_pytree_node(
        cls,
        lambda xs: (tuple(xs), None),  # tell JAX how to unpack
        lambda _, xs: cls(*xs)  # tell JAX how to pack back
    )
    return cls
Пример #4
0
 def __init_subclass__(cls, **kwargs):
   super().__init_subclass__(**kwargs)
   tree_util.register_pytree_node(
       cls,
       cls.flatten,
       # Pytype incorrectly thinks that cls.unflatten accepts three arguments.
       cls.unflatten  # type: ignore
   )
Пример #5
0
def _register_dataclass_type(data_class):
  """Register dataclass so JAX knows how to handle it."""
  flatten = lambda d: jax.tree_flatten(d.__dict__)
  unflatten = lambda s, xs: data_class(**s.unflatten(xs))
  try:
    tree_util.register_pytree_node(
        nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)
  except ValueError:
    logging.info('%s is already registered as JAX PyTree node.', data_class)
Пример #6
0
def py_tree_registered_dataclass(cls, *args, **kwargs):
    """Creates a new dataclass type and registers it as a pytree node."""
    dcls = dataclasses.dataclass(cls, *args, **kwargs)
    tree_util.register_pytree_node(
        dcls,
        lambda instance: (  # pylint: disable=g-long-lambda
            [getattr(instance, f.name)
             for f in dataclasses.fields(instance)], None),
        lambda _, instance_args: dcls(*instance_args))
    return dcls
Пример #7
0
def register_graph_as_jax_pytree(cls: Type[T]) -> None:
    def tree_unflatten(hashed: Hashable, trees: Sequence[PyTree]) -> T:
        node_dicts, edge_dicts = trees

        if not isinstance(node_dicts, dict):
            raise TypeError
        if not isinstance(edge_dicts, dict):
            raise TypeError

        graph = cls()
        graph.add_nodes_from(node_dicts.items())
        graph.add_edges_from([(source, target, data)
                              for (source, target), data in edge_dicts.items()
                              ])

        return graph

    def tree_flatten(graph: T) -> Tuple[Sequence[PyTree], Hashable]:
        return ((dict(graph.nodes), dict(graph.edges)), None)

    register_pytree_node(cls, tree_flatten, tree_unflatten)
Пример #8
0
    def __dir__(self):
        if isinstance(self._data, dict):
            return list(self._data.keys())
        elif isinstance(self._data, FrozenDict):
            return list(self._data._dict.keys())
        else:
            return []

    def __repr__(self):
        return f'{self._data}'

    def __hash__(self):
        # Note: will only work when wrapping FrozenDict.
        return hash(self._data)

    def copy(self, **kwargs):
        return self._data.__class__(self._data.copy(**kwargs))


tree_util.register_pytree_node(
    DotGetter,
    lambda x: ((x._data, ), ()),  # pylint: disable=protected-access
    lambda _, data: data[0])

# Note: restores as raw dict, intentionally.
serialization.register_serialization_state(
    DotGetter,
    serialization._dict_state_dict,  # pylint: disable=protected-access
    serialization._restore_dict)  # pylint: disable=protected-access
Пример #9
0
zeros_like_p.def_abstract_eval(lambda x: x)


class Zero:
    __slots__ = ['aval']

    def __init__(self, aval):
        self.aval = aval

    def __repr__(self):
        return 'Zero({})'.format(self.aval)

    @staticmethod
    def from_value(val):
        return Zero(raise_to_shaped(get_aval(val)))


register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))


def _stop_gradient_impl(x):
    if not valid_jaxtype(x):
        raise TypeError("stop_gradient only works on valid JAX arrays, but "
                        f"input argument is: {x}")
    return x


stop_gradient_p: Primitive = Primitive('stop_gradient')
stop_gradient_p.def_impl(_stop_gradient_impl)
stop_gradient_p.def_abstract_eval(lambda x: x)
Пример #10
0
Файл: ad.py Проект: jbampton/jax
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in):
  return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in)


class UndefinedPrimal:
  __slots__ = ['aval']
  def __init__(self, aval):
    self.aval = aval
  def __repr__(self):
    return 'UndefinedPrimal({})'.format(self.aval)

def is_undefined_primal(x):
  return type(x) is UndefinedPrimal

register_pytree_node(UndefinedPrimal,
                     lambda z: ((), z.aval),
                     lambda aval, _: UndefinedPrimal(aval))

def get_primitive_transpose(p):
  try:
    return primitive_transposes[p]
  except KeyError as err:
    raise NotImplementedError(
        "Transpose rule (for reverse-mode differentiation) for '{}' "
        "not implemented".format(p)) from err

@lu.transformation_with_aux
def nonzero_tangent_outputs(*args, **kwargs):
  results = (_, tangents_out) = yield args, kwargs
  yield results, [type(r) is not Zero for r in tangents_out]
Пример #11
0
from vmap import grad as fx_grad
import torch.fx as fx
from jax import grad, jit, partial, random, value_and_grad, lax
from jax.flatten_util import ravel_pytree
import jax.numpy as np
from jax import random
from jax.tree_util import register_pytree_node, tree_multimap

# (q, p) -> (position (param value), momentum)

IntegratorState = namedtuple("IntegratorState",
                             ["q", "p", "potential_energy", "q_grad"])

# a tree-like JAX primitive that allows program transformations
# to work on Python containers (https://jax.readthedocs.io/en/latest/pytrees.html)
register_pytree_node(IntegratorState, lambda xs: (tuple(xs), None),
                     lambda _, xs: IntegratorState(*xs))


def leapfrog(potential_fn, kinetic_fn):
    r"""
    Second order symplectic integrator that uses the leapfrog algorithm
    for position `q` and momentum `p`.

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum.
    :return: a pair of (`init_fn`, `update_fn`).
    """
    def init_fn(q, p):
Пример #12
0
def register_pytree_namedtuple(cls):
    register_pytree_node(cls, lambda xs: (tuple(xs), None),
                         lambda _, xs: cls(*xs))
Пример #13
0
    def scatter(self, indices, value, name=None):
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')

    def split(self, value, lengths, name=None):
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')

    def __repr__(self):
        return ('tf2{}.TensorArray(dtype={}, size={}, dynamic_size={}, '
                'element_shape={}, len(data)={})').format(
                    'jax' if JAX_MODE else 'numpy', self._dtype, self._size,
                    self._dynamic_size, self._element_shape, len(self._data))


if JAX_MODE:
    from jax import tree_util  # pylint: disable=g-import-not-at-top

    def flatten(val):
        vals = (val._data, )  # pylint: disable=protected-access
        aux = dict(dtype=val.dtype,
                   element_shape=val.element_shape,
                   dynamic_size=val.dynamic_size)
        return vals, aux

    def unflatten(aux, vals):
        return TensorArray(data=vals[0], **aux)

    tree_util.register_pytree_node(TensorArray, flatten, unflatten)
Пример #14
0
        """
        out = TFCDict(self)
        if isinstance(o, dict) or (type(o) is type(self)):
            for key in self._keys:
                out[key] -= o[key]
        elif isinstance(o, np.ndarray):
            o = o.flatten()
            for k in range(self._nKeys):
                out[self._keys[k]] -= o[self._slices[k]]
        return out


# Register TFCDict as a JAX type
register_pytree_node(
    TFCDict,
    lambda x: (list(x.values()), list(x.keys())),
    lambda keys, values: TFCDict(safe_zip(keys, values)),
)


class TFCDictRobust(OrderedDict):
    """This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays."""

    def __init__(self, *args):
        """Initialize TFCDictRobust using the OrderedDict method."""

        # Store dictionary and keep a record of the keys. Keys will stay in same
        # order, so that adding and subtracting is repeatable.
        super().__init__(*args)
        self._keys = list(self.keys())
        self._nKeys = len(self._keys)
Пример #15
0
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')

    def grad(self, source, flow=None, name=None):
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')

    def scatter(self, indices, value, name=None):
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')

    def split(self, value, lengths, name=None):
        raise NotImplementedError('If you need this feature, please email '
                                  '`[email protected]`.')


if JAX_MODE:
    from jax import tree_util  # pylint: disable=g-import-not-at-top

    def to_tree(val):
        vals = (val._data, )  # pylint: disable=protected-access
        aux = dict(dtype=val.dtype,
                   element_shape=val.element_shape,
                   dynamic_size=val.dynamic_size)
        return vals, aux

    def from_tree(aux, vals):
        return TensorArray(data=vals[0], **aux)

    tree_util.register_pytree_node(TensorArray, to_tree, from_tree)
Пример #16
0
                f'at mapped index {", ".join(map(str, idx))}: '  # type: ignore
                f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}'  # type: ignore
                for idx, e in np.ndenumerate(self.err) if e) or None
        return None

    def throw(self):
        """Throw ValueError with error message if error happened."""
        err = self.get()
        if err:
            raise ValueError(err)


register_pytree_node(
    Error,
    lambda e: ((e.err, e.code, e.payload), tuple(sorted(e.msgs.items()))),
    lambda msgs, data: Error(
        data[0],
        data[1],  # type: ignore
        dict(msgs),
        data[2]))  # type: ignore

init_error = Error(False, 0, {})
next_code = it.count(1).__next__  # globally unique ids, could be uuid4


def assert_func(error: Error, pred: Bool, msg: str,
                payload: Optional[Payload]) -> Error:
    code = next_code()
    payload = init_payload if payload is None else payload
    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    out_payload = lax.select(error.err, error.payload, payload)
Пример #17
0
    Return
    ------
    g : mean.shape array
        The new gvars.
    """
    cov = gvar.gvar.cov
    mean = np.asarray(mean)
    shape = mean.shape
    mean = mean.flat
    jac = np.array(jac)  # TODO patches gvar issue #27
    jac = jac.reshape(len(mean), len(indices))
    g = np.zeros(len(mean), object)
    for i, jacrow in enumerate(jac):
        der = gvar.svec(len(indices))
        der._assign(jac[i], indices)
        g[i] = gvar.GVar(mean[i], der, cov)
    return g.reshape(shape)


def bufferdict_flatten(bd):
    return tuple(bd.values()), tuple(bd.keys())


def bufferdict_unflatten(keys, values):
    return gvar.BufferDict(zip(keys, values))


# register BufferDict as a pytree
tree_util.register_pytree_node(gvar.BufferDict, bufferdict_flatten,
                               bufferdict_unflatten)
Пример #18
0
        nn_params = nn_module.init(rng_key, jnp.ones(input_shape))
        # haiku init returns an immutable dict
        nn_params = haiku.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.apply, nn_params, None)


# register an "empty" parameter which only stores its shape
# so that the optimizer can skip optimize this parameter, while
# it still provides shape information for priors
ParamShape = namedtuple("ParamShape", ["shape"])
register_pytree_node(ParamShape, lambda x: ((None,), x.shape), lambda shape, x: ParamShape(shape))


def _update_params(params, new_params, prior, prefix=''):
    """
    A helper to recursively set prior to new_params.
    """
    for name, item in params.items():
        flatten_name = ".".join([prefix, name]) if prefix else name
        if isinstance(item, dict):
            assert not isinstance(prior, dict) or flatten_name not in prior
            new_item = new_params[name]
            _update_params(item, new_item, prior, prefix=flatten_name)
        elif (not isinstance(prior, dict)) or flatten_name in prior:
            d = prior[flatten_name] if isinstance(prior, dict) else prior
            if isinstance(params[name], ParamShape):
Пример #19
0
@_add_doc(optimizers.sm3)
class SM3(_NumPyroOptim):
    def __init__(self, *args, **kwargs):
        super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)


# TODO: currently, jax.scipy.optimize.minimize only supports 1D input,
# so we need to add the following mechanism to transform params to flat_params
# and pass `unravel_fn` arround.
# When arbitrary pytree is supported in JAX, we can just simply use
# identity functions for `init_fn` and `get_params`.
_MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"])
register_pytree_node(
    _MinimizeState,
    lambda state: ((state.flat_params, ), (state.unravel_fn, )),
    lambda data, xs: _MinimizeState(xs[0], data[0]),
)


def _minimize_wrapper():
    def init_fn(params):
        flat_params, unravel_fn = ravel_pytree(params)
        return _MinimizeState(flat_params, unravel_fn)

    def update_fn(i, grad_tree, opt_state):
        # we don't use update_fn in Minimize, so let it do nothing
        return opt_state

    def get_params(opt_state):
        flat_params, unravel_fn = opt_state
Пример #20
0
 - **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
 - **accept_prob** - Acceptance probability of the proposal. Note that ``z``
   does not correspond to the proposal if it is rejected.
 - **mean_accept_prob** - Mean acceptance probability until current iteration
   during warmup adaptation or sampling (for diagnostics).
 - **step_size** - Step size to be used by the integrator in the next iteration.
   This is adapted during warmup.
 - **inverse_mass_matrix** - The inverse mass matrix to be be used for the next
   iteration. This is adapted during warmup.
 - **rng** - random number generator seed used for the iteration.
"""


register_pytree_node(
    HMCState,
    lambda xs: (tuple(xs), None),
    lambda _, xs: HMCState(*xs)
)


HMCState.update = HMCState._replace


def _get_num_steps(step_size, trajectory_length):
    num_steps = np.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to np.int64 does not take effect (returns np.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(np.int64)


def _sample_momentum(unpack_fn, mass_matrix_sqrt, rng):
Пример #21
0
 def __init_subclass__(cls, **kwargs):
     super().__init_subclass__(**kwargs)
     tree_util.register_pytree_node(cls, cls.tree_flatten,
                                    cls.tree_unflatten)
Пример #22
0
def register_pytree(cls):
    if not getattr(cls, '_registered', False):
        register_pytree_node(cls, lambda xs: (tuple(xs), None),
                             lambda _, xs: cls(*xs))
    cls._registered = True
Пример #23
0
  def __init__(self, x, y, z):
    self.x = x
    self.y = y
    self.z = z

  def __eq__(self, other):
    return self.x == other.x and self.y == other.y and self.z == other.z

  def __hash__(self):
    return hash((self.x, self.y, self.z))

  def __repr__(self):
    return "AnObject({},{},{})".format(self.x, self.y, self.z)

tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
                               lambda z, xy: AnObject(xy[0], xy[1], z))

@tree_util.register_pytree_node_class
class Special:
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    return ((self.x, self.y), None)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
Пример #24
0
def make_wrapper_type(cls):
    """Creates a flattenable Distribution type."""

    clsid = (cls.__module__, cls.__name__)

    if clsid not in _registry:

        class _WrapperType(cls):
            """Oryx distribution wrapper type."""
            def __init__(self, *args, **kwargs):
                self._args = args
                self._kwargs = kwargs
                self._instance = object.__new__(cls)
                cls.__init__(self._instance, *self._args, **self._kwargs)

            def __getattr__(self, key):
                if key not in ('_args', '_kwargs', '_type_spec', '_instance'):
                    return getattr(self._instance, key)
                return object.__getattribute__(self, key)

            @property
            def _type_spec(self):
                kwargs = dict(self._kwargs)
                param_specs = {}
                try:
                    event_ndims = self._params_event_ndims()
                except NotImplementedError:
                    event_ndims = {}
                for k in event_ndims:
                    if k in kwargs and kwargs[k] is not None:
                        elem = kwargs.pop(k)
                        if type(elem) == object:  # pylint: disable=unidiomatic-typecheck
                            param_specs[k] = object
                        elif tf.is_tensor(elem):
                            param_specs[k] = (elem.shape, elem.dtype)
                        else:
                            param_specs[k] = type(elem)
                for k, v in list(kwargs.items()):
                    if isinstance(v, tfd.Distribution):
                        param_specs[k] = kwargs.pop(k)
                return _JaxDistributionTypeSpec(clsid, param_specs, kwargs)

            def __str__(self):
                return repr(self)

            def __repr__(self):
                return '{}()'.format(self.__class__.__name__)

        _WrapperType.__name__ = cls.__name__ + 'Wrapper'

        def to_tree(obj):
            type_spec = obj._type_spec  # pylint: disable=protected-access
            components = type_spec._to_components(obj)  # pylint: disable=protected-access
            keys, values = list(zip(*sorted(components.items())))
            return values, (keys, type_spec)

        def from_tree(info, xs):
            keys, type_spec = info
            components = dict(list(zip(keys, xs)))
            return type_spec._from_components(components)  # pylint: disable=protected-access

        tree_util.register_pytree_node(_WrapperType, to_tree, from_tree)

        _registry[clsid] = _WrapperType
    return _registry[clsid]
Пример #25
0
def make_wrapper_type(cls):
    """Creates new Bijector type that can be flattened/unflattened and is lazy."""

    clsid = (cls.__module__, cls.__name__)

    def bijector_bind(bijector, x, **kwargs):
        return core.call_bind(
            bijector_p, direction=kwargs['direction'])(_bijector)(bijector, x,
                                                                  **kwargs)

    def _bijector(bij, x, **kwargs):
        direction = kwargs.pop('direction', 'forward')
        if direction == 'forward':
            return cls.forward(bij, x, **kwargs)
        elif direction == 'inverse':
            return cls.inverse(bij, x, **kwargs)
        else:
            raise ValueError(
                'Bijector direction must be "forward" or "inverse".')

    if clsid not in _registry:

        class _WrapperType(cls):
            """Oryx bijector wrapper type."""
            def __init__(self, *args, **kwargs):
                self.use_primitive = kwargs.pop('use_primitive', True)
                self._args = args
                self._kwargs = kwargs

            def forward(self, x, **kwargs):
                if self.use_primitive:
                    return bijector_bind(self,
                                         x,
                                         direction='forward',
                                         **kwargs)
                return cls.forward(self, x, **kwargs)

            def inverse(self, x, **kwargs):
                if self.use_primitive:
                    return bijector_bind(self,
                                         x,
                                         direction='inverse',
                                         **kwargs)
                return cls.inverse(self, x, **kwargs)

            def _get_instance(self):
                obj = object.__new__(cls)
                cls.__init__(obj, *self._args, **self._kwargs)
                return obj

            def __getattr__(self, key):
                if key not in ('_args', '_kwargs', 'parameters', '_type_spec'):
                    return getattr(self._get_instance(), key)
                return object.__getattribute__(self, key)

            @property
            def parameters(self):
                return self._get_instance().parameters

            @property
            def _type_spec(self):
                kwargs = dict(self._kwargs)
                param_specs = {}
                event_ndims = {}
                for k in event_ndims:
                    if k in kwargs and kwargs[k] is not None:
                        elem = kwargs.pop(k)
                        if type(elem) == object:  # pylint: disable=unidiomatic-typecheck
                            param_specs[k] = object
                        elif tf.is_tensor(elem):
                            param_specs[k] = (elem.shape, elem.dtype)
                        else:
                            param_specs[k] = type(elem)
                for k, v in list(kwargs.items()):
                    if isinstance(v, tfb.Bijector):
                        param_specs[k] = kwargs.pop(k)
                return _JaxBijectorTypeSpec(clsid, param_specs, kwargs)

            def __str__(self):
                return repr(self)

            def __repr__(self):
                return '{}()'.format(self.__class__.__name__)

        _WrapperType.__name__ = cls.__name__ + 'Wrapper'

        def to_tree(obj):
            type_spec = obj._type_spec  # pylint: disable=protected-access
            components = type_spec._to_components(obj)  # pylint: disable=protected-access
            keys, values = list(zip(*sorted(components.items())))
            return values, (keys, type_spec)

        def from_tree(info, xs):
            keys, type_spec = info
            components = dict(list(zip(keys, xs)))
            return type_spec._from_components(components)  # pylint: disable=protected-access

        tree_util.register_pytree_node(_WrapperType, to_tree, from_tree)

        _registry[clsid] = _WrapperType
    return _registry[clsid]
Пример #26
0
    """Dataclass for storing parameters of a Linear RNN."""
    A: jnp.array  # Input weights.      pylint: disable=invalid-name
    W: jnp.array  # Recurrent weights.  pylint: disable=invalid-name
    b: jnp.array  # Bias.

    def apply(self, x, h) -> jnp.array:
        """Linear RNN Update."""
        return self.A @ x + self.W @ h + self.b

    def flatten(self):
        return (self.A, self.W, self.b)


# Register the LinearRNN dataclass as a pytree, so that we can directly
# pass it to other jax functions (optimizers, flatten, etc.)
register_pytree_node(LinearRNN, lambda node: (node.flatten(), None),
                     lambda _, children: LinearRNN(*children))


class RNNCell:
    """Base class for all RNN Cells.

  An RNNCell must implement the following methods:
    init(PRNGKey, input_shape) -> output_shape, rnn_params
    apply(params, inputs, state) -> next_state
  """
    def __init__(self, num_units, h_init=zeros):
        """Initializes an RNNCell."""
        self.num_units = num_units
        self.h_init = h_init

        # Compute RNN Jacobians.
Пример #27
0
Файл: jet.py Проект: nhanwei/jax
        # TODO(mattjj): don't just ignore custom jvp rules?
        del primitive, jvp  # Unused.
        return fun.call_wrapped(*tracers)

    def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
                                out_trees):
        del primitive, fwd, bwd, out_trees  # Unused.
        return fun.call_wrapped(*tracers)


class ZeroTerm(object):
    pass


zero_term = ZeroTerm()
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)


class ZeroSeries(object):
    pass


zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None),
                     lambda _, xs: zero_series)

call_param_updaters = {}


def _xla_call_param_updater(params, num_inputs):
    donated_invars = params['donated_invars']
Пример #28
0
        self.size += 1

    def as_array(self):
        return tree_util.tree_multimap(lambda *args: np.array(list(args)),
                                       *self.data)

    def flatten(self):
        return (self.data, ), (self.size, self.idx)

    @classmethod
    def unflatten(cls, data, xs):
        size, idx = data
        return HarvestList(xs[0], size=size, idx=idx)


tree_util.register_pytree_node(HarvestList, HarvestList.flatten,
                               HarvestList.unflatten)


class HarvestTrace(jax_core.Trace):
    """A HarvestTrace manages HarvestTracer objects.

  Since HarvestTracers are just wrappers around known values, HarvestTrace
  just passes these values through primitives, except in the case of
  `sow` and `nest`, which are specially handled by the active HarvestContext.

  Default primitive logic lives in `process_primitive`, with special logic for
  `sow` in `handle_sow`.
  """
    def pure(self, val):
        return HarvestTracer(self, val)
Пример #29
0
map = safe_map
zip = safe_zip

# The implementation here basically works by flattening pytrees. There are two
# levels of pytrees to think about: the pytree of params, which we can think of
# as defining an "outer pytree", and a pytree produced by applying init_fun to
# each leaf of the params pytree, which we can think of as the "inner pytrees".
# Since pytrees can be flattened, that structure is isomorphic to a list of
# lists (with no further nesting).

pack = tuple
OptimizerState = namedtuple("OptimizerState",
                            ["packed_state", "tree_def", "subtree_defs"])
register_pytree_node(
    OptimizerState, lambda xs: ((xs.packed_state, ),
                                (xs.tree_def, xs.subtree_defs)),
    lambda data, xs: OptimizerState(xs[0], data[0], data[1]))


def optimizer(opt_maker):
    """Decorator to make an optimizer defined for arrays generalize to containers.

  With this decorator, you can write init, update, and get_params functions that
  each operate only on single arrays, and convert them to corresponding
  functions that operate on pytrees of parameters. See the optimizers defined in
  optimizers.py for examples.

  Args:
    opt_maker: a function that returns an ``(init_fun, update_fun, get_params)``
      triple of functions that might only work with ndarrays, as per
Пример #30
0
                       maxval=high))

    def truncated_normal(self, lower, upper, size, scale=1.):
        rands = jr.truncated_normal(self.split_key(),
                                    lower=lower,
                                    upper=upper,
                                    shape=_size2shape(size))
        return JaxArray(rands * scale)

    def bernoulli(self, p, size=None):
        return JaxArray(
            jr.bernoulli(self.split_key(), p=p, shape=_size2shape(size)))


register_pytree_node(
    RandomState, lambda t: ((t.value, ), None),
    lambda aux_data, flat_contents: RandomState(*flat_contents))

DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))


def seed(seed=None):
    global DEFAULT
    DEFAULT.seed(np.random.randint(0, 100000) if seed is None else seed)


def rand(*dn):
    return JaxArray(
        jr.uniform(DEFAULT.split_key(), shape=dn, minval=0., maxval=1.))