Exemple #1
0
# limitations under the License.
# ============================================================================
"""Tests for tensorflow_probability.spinoffs.oryx.core.ppl.effect_handler."""
from absl.testing import absltest
import jax
from jax import abstract_arrays
from jax import random
import jax.numpy as np

from oryx.core import primitive
from oryx.core.ppl import effect_handler
from oryx.internal import test_util

# Define a random normal primitive so we can use it for custom interpreter
# rules.
random_normal_p = primitive.FlatPrimitive('random_normal')


def random_normal(key, loc=0., scale=1.):
  return random_normal_p.bind(key, loc, scale)[0]


@random_normal_p.def_impl
def _random_normal_impl(key, loc, scale):
  return [random.normal(key) * scale + loc]


@random_normal_p.def_abstract_eval
def _random_normal_abstract(key, loc, scale):
  del key, loc, scale
  return [abstract_arrays.ShapedArray((), np.float32)]
Exemple #2
0
safe_map = jax_core.safe_map
safe_zip = jax_core.safe_zip


def _flat_layer_cau(*flat_args, in_tree, kwargs, **params):
    del params
    layer, *args = tree_util.tree_unflatten(in_tree, flat_args)
    kwargs = dict(kwargs)
    has_rng = kwargs.pop('has_rng', False)
    if has_rng:
        rng, args = args[0], args[1:]
        kwargs = dict(kwargs, rng=rng)
    return tree_util.tree_leaves(layer.call_and_update(*args, **kwargs))


flat_layer_cau_p = primitive.FlatPrimitive('flat_layer_cau')
flat_layer_cau_p.def_impl(_flat_layer_cau)


def flat_layer_cau_kwargs_rule(*flat_args, in_tree, kwargs, **_):
    """Custom kwargs rule for flat_layer_cau primitive."""
    layer, *args = tree_util.tree_unflatten(in_tree, flat_args)
    kwargs = dict(kwargs)
    has_rng = kwargs.pop('has_rng', False)
    if has_rng:
        rng, args = args[0], args[1:]
        kwargs = dict(kwargs, rng=rng)
    ans = layer.call_and_update(*args, **kwargs)
    return tree_util.tree_leaves(ans)