예제 #1
0
    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        assert (
            base_dist.event_shape == skewness.shape[-1:]
        ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."

        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
예제 #2
0
    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape,
                                           skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

        if self._validate_args and base_dist.mean.device != skewness.device:
            raise ValueError(
                f"base_density: {base_dist.__class__.__name__} and SineSkewed "
                f"must be on same device.")
예제 #3
0
 def __init__(self):
     self._kernel = lambda x, y: jnp.broadcast_to(0., jnp.broadcast_shapes(x.shape, y.shape))
     self._minderivable = (sys.maxsize, sys.maxsize)
     self._maxderivable = (sys.maxsize, sys.maxsize)
     self.initargs = None