def sample(key): result = primitive.initial_style_bind( random_variable_p, batch_ndims=0, distribution_name=dist.__class__.__name__)( _sample_distribution)(key, dist) return result
def bijector_bind(bijector, x, **kwargs): return primitive.initial_style_bind( bijector_p, direction=kwargs['direction'], num_bijector=len(tree_util.tree_leaves(bijector)), bijector_name=bijector.__class__.__name__)(_bijector)(bijector, x, **kwargs)
def wrapped(key): result = primitive.initial_style_bind( random_variable_p, distribution_name=dist.__class__.__name__)(_sample_distribution)( key, dist) if name is not None: result = ppl.random_variable(result, name=name) return result
def call_and_update(self, *args, rng=None, **kwargs): """Uses the `layer_cau` primitive to call `self._call_and_update.""" if rng is None: has_rng = False else: # The layer_cau primitive expects RNG as the first argument if `has_rng` # kwarg is True. args = (rng, ) + args has_rng = True kwargs = dict(kwargs, has_rng=has_rng) return primitive.initial_style_bind(layer_cau_p, kwargs=kwargs)( Layer._call_and_update)(self, *args, **kwargs)
def init(self, init_key, *args, name=None, **kwargs): """Initializes a Template into a Layer.""" specs = jax.tree_map(state.make_array_spec, args) kwargs = dict( cls=self.cls, specs=specs, init_args=self.init_args, init_kwargs=self.init_kwargs, ) layer = primitive.initial_style_bind(template_init_p)( _template_build)(init_key, name=name, **kwargs) if name is not None: layer = state.variable(layer, name=name) else: layer_params = {k: state.variable(v, name=k) for k, v in layer.variables().items()} layer = layer.replace(**layer_params) return layer
def __call__(self, *args, **kwargs): return primitive.initial_style_bind(self.prim)(self.func)(*args, **kwargs)
def bijector_bind(bijector, x, **kwargs): return primitive.initial_style_bind( bijector_p, direction=kwargs['direction'], bijector_name=bijector.__class__.__name__)(_bijector)(bijector, x, **kwargs)