def copy(self, **override_parameters_kwargs): """Creates a deep copy of the distribution. Note: the copy distribution may continue to depend on the original initialization arguments. Args: **override_parameters_kwargs: String/value dictionary of initialization arguments to override with new values. Returns: distribution: A new instance of `type(self)` initialized from the union of self.parameters and override_parameters_kwargs, i.e., `dict(self.parameters, **override_parameters_kwargs)`. """ try: # We want track provenance from origin variables, so we use batch_slice # if this distribution supports slicing. See the comment on # PROVENANCE_ATTR in slicing.py return slicing.batch_slice(self, self._params_event_ndims(), override_parameters_kwargs, Ellipsis) except NotImplementedError: parameters = dict(self.parameters, **override_parameters_kwargs) d = type(self)(**parameters) # pylint: disable=protected-access d._parameters = parameters d._parameters_sanitized = True # pylint: enable=protected-access return d
def __getitem__(self, slices): """Slices the batch axes of this distribution, returning a new instance. ```python b = tfd.Bernoulli(logits=tf.zeros([3, 5, 7, 9])) b.batch_shape # => [3, 5, 7, 9] b2 = b[:, tf.newaxis, ..., -2:, 1::2] b2.batch_shape # => [3, 1, 5, 2, 4] x = tf.random.normal([5, 3, 2, 2]) cov = tf.matmul(x, x, transpose_b=True) chol = tf.cholesky(cov) loc = tf.random.normal([4, 1, 3, 1]) mvn = tfd.MultivariateNormalTriL(loc, chol) mvn.batch_shape # => [4, 5, 3] mvn.event_shape # => [2] mvn2 = mvn[:, 3:, ..., ::-1, tf.newaxis] mvn2.batch_shape # => [4, 2, 3, 1] mvn2.event_shape # => [2] ``` Args: slices: slices from the [] operator Returns: dist: A new `tfd.Distribution` instance with sliced parameters. """ return slicing.batch_slice(self, self._params_event_ndims(), {}, slices)