Exemplo n.º 1
0
 def sample(key):
     result = primitive.initial_style_bind(
         random_variable_p,
         batch_ndims=0,
         distribution_name=dist.__class__.__name__)(
             _sample_distribution)(key, dist)
     return result
Exemplo n.º 2
0
 def bijector_bind(bijector, x, **kwargs):
     return primitive.initial_style_bind(
         bijector_p,
         direction=kwargs['direction'],
         num_bijector=len(tree_util.tree_leaves(bijector)),
         bijector_name=bijector.__class__.__name__)(_bijector)(bijector, x,
                                                               **kwargs)
Exemplo n.º 3
0
 def wrapped(key):
     result = primitive.initial_style_bind(
         random_variable_p,
         distribution_name=dist.__class__.__name__)(_sample_distribution)(
             key, dist)
     if name is not None:
         result = ppl.random_variable(result, name=name)
     return result
Exemplo n.º 4
0
 def call_and_update(self, *args, rng=None, **kwargs):
     """Uses the `layer_cau` primitive to call `self._call_and_update."""
     if rng is None:
         has_rng = False
     else:
         # The layer_cau primitive expects RNG as the first argument if `has_rng`
         # kwarg is True.
         args = (rng, ) + args
         has_rng = True
     kwargs = dict(kwargs, has_rng=has_rng)
     return primitive.initial_style_bind(layer_cau_p, kwargs=kwargs)(
         Layer._call_and_update)(self, *args, **kwargs)
Exemplo n.º 5
0
 def init(self, init_key, *args, name=None, **kwargs):
   """Initializes a Template into a Layer."""
   specs = jax.tree_map(state.make_array_spec, args)
   kwargs = dict(
       cls=self.cls,
       specs=specs,
       init_args=self.init_args,
       init_kwargs=self.init_kwargs,
   )
   layer = primitive.initial_style_bind(template_init_p)(
       _template_build)(init_key, name=name, **kwargs)
   if name is not None:
     layer = state.variable(layer, name=name)
   else:
     layer_params = {k: state.variable(v, name=k)
                     for k, v in layer.variables().items()}
     layer = layer.replace(**layer_params)
   return layer
Exemplo n.º 6
0
 def __call__(self, *args, **kwargs):
     return primitive.initial_style_bind(self.prim)(self.func)(*args,
                                                               **kwargs)
Exemplo n.º 7
0
 def bijector_bind(bijector, x, **kwargs):
     return primitive.initial_style_bind(
         bijector_p,
         direction=kwargs['direction'],
         bijector_name=bijector.__class__.__name__)(_bijector)(bijector, x,
                                                               **kwargs)