def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `Boolean` scalar. If `True` the batch dims will be expanded such that batch_ndims>=1. name: `String`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0) x = array_ops.reshape(x, shape=new_shape) x = distribution_util.rotate_transpose(x, shift=-1) return x, sample_shape
def make_batch_of_event_sample_matrices( self, x, expand_batch_dim=True, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B or not expand_batch_dim else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. expand_batch_dim: Python `Boolean` scalar. If `True` the batch dims will be expanded such that batch_ndims>=1. name: `String`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, [1], event_shape) if expand_batch_dim: batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, [1], batch_shape) new_shape = array_ops.concat_v2([[-1], batch_shape, event_shape], 0) x = array_ops.reshape(x, shape=new_shape) x = distribution_util.rotate_transpose(x, shift=-1) return x, sample_shape
def make_batch_of_event_sample_matrices( self, x, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. name: `String`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector( self._event_ndims_is_0, (1, ), event_shape) batch_shape = distribution_util.pick_vector( self._batch_ndims_is_0, (1, ), batch_shape) new_shape = array_ops.concat(0, ((-1, ), batch_shape, event_shape)) x = array_ops.reshape(x, shape=new_shape) x = distribution_util.rotate_transpose(x, shift=-1) return x, sample_shape
def make_batch_of_event_sample_matrices(self, x, name="make_batch_of_event_sample_matrices"): """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. Where: - `B_ = B if B else [1]`, - `E_ = E if E else [1]`, - `S_ = [tf.reduce_prod(S)]`. Args: x: `Tensor`. name: `String`. The name to give this op. Returns: x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. sample_shape: `Tensor` (1D, `int32`). """ with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") sample_shape, batch_shape, event_shape = self.get_shape(x) event_shape = distribution_util.pick_vector(self._event_ndims_is_0, (1,), event_shape) batch_shape = distribution_util.pick_vector(self._batch_ndims_is_0, (1,), batch_shape) new_shape = array_ops.concat(0, ((-1,), batch_shape, event_shape)) x = array_ops.reshape(x, shape=new_shape) x = distribution_util.rotate_transpose(x, shift=-1) return x, sample_shape
def _sample_n(self, n, seed=None): x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = math_ops.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements() if stride is None: stride = array_ops.reduce_prod(array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = array_ops.gather(array_ops.reshape(self.grid, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) return self.bijector.forward(x)
def _sample_n(self, n, seed=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) x = self.distribution.sample(sample_shape=sample_shape, seed=seed) x = self._maybe_rotate_dims(x) return self.bijector.forward(x)
def testCorrectlyPicksVector(self): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( x, distribution_util.pick_vector(tf.less(0, 5), x, y).eval()) self.assertAllEqual( y, distribution_util.pick_vector(tf.less(5, 0), x, y).eval()) self.assertAllEqual(x, distribution_util.pick_vector( tf.constant(True), x, y)) # No eval. self.assertAllEqual(y, distribution_util.pick_vector( tf.constant(False), x, y)) # No eval.
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) if self.batch_shape.is_fully_defined() else math_ops.reduce_prod(self.batch_shape_tensor())) ids = self._mixture_distribution.sample( sample_shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape(rate, shape=concat_vectors( [n], self.batch_shape_tensor())) return random_ops.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tensor_util.constant_value(x) if x_static_val is None: prod = math_ops.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.get_shape().ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = array_ops.rank(x) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), np.array([1], dtype=np.int32), array_ops.shape(x)) x = array_ops.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = ops.convert_to_tensor(np.array( [x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = array_ops.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def _expand_sample_shape_to_vector(self, x, name): """Helper to `sample` which ensures input is 1D.""" x_static_val = tensor_util.constant_value(x) if x_static_val is None: prod = math_ops.reduce_prod(x) else: prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) ndims = x.get_shape().ndims # != sample_ndims if ndims is None: # Maybe expand_dims. ndims = array_ops.rank(x) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), np.array([1], dtype=np.int32), array_ops.shape(x)) x = array_ops.reshape(x, expanded_shape) elif ndims == 0: # Definitely expand_dims. if x_static_val is not None: x = ops.convert_to_tensor( np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), name=name) else: x = array_ops.reshape(x, [1]) elif ndims != 1: raise ValueError("Input is neither scalar nor vector.") return x, prod
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if x.get_shape().ndims is not None: if x.get_shape().ndims == 1: x = x[array_ops.newaxis, :] return x shape = array_ops.shape(x) maybe_expanded_shape = array_ops.concat([ shape[:-1], distribution_util.pick_vector( math_ops.equal(array_ops.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return array_ops.reshape(x, maybe_expanded_shape)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) if self.batch_shape.is_fully_defined() else math_ops.reduce_prod(self.batch_shape_tensor())) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return random_ops.random_poisson( lam=rate, shape=[], dtype=self.dtype, seed=seed)
def testCorrectlyPicksVector(self): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual(x, distribution_util.pick_vector( math_ops.less(0, 5), x, y).eval()) self.assertAllEqual(y, distribution_util.pick_vector( math_ops.less(5, 0), x, y).eval()) self.assertAllEqual(x, distribution_util.pick_vector( constant_op.constant(True), x, y)) # No eval. self.assertAllEqual(y, distribution_util.pick_vector( constant_op.constant(False), x, y)) # No eval.
def testChooseVector(self): with self.test_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( x, distribution_util.pick_vector( tf.less(0, 5), x, y).eval()) self.assertAllEqual( y, distribution_util.pick_vector( tf.less(5, 0), x, y).eval()) self.assertAllEqual( x, distribution_util.pick_vector( tf.constant(True), x, y)) # No eval. self.assertAllEqual( y, distribution_util.pick_vector( tf.constant(False), x, y)) # No eval.
def _expand_sample_shape(self, sample_shape): """Helper to `sample` which ensures sample_shape is 1D.""" sample_shape_static_val = tensor_util.constant_value(sample_shape) ndims = sample_shape.get_shape().ndims if sample_shape_static_val is None: if ndims is None or not sample_shape.get_shape().is_fully_defined(): ndims = array_ops.rank(sample_shape) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), np.array((1,), dtype=dtypes.int32.as_numpy_dtype()), array_ops.shape(sample_shape)) sample_shape = array_ops.reshape(sample_shape, expanded_shape) total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1 else: if ndims is None: raise ValueError( "Shouldn't be here; ndims cannot be none when we have a " "tf.constant shape.") if ndims == 0: sample_shape_static_val = np.reshape(sample_shape_static_val, [1]) sample_shape = ops.convert_to_tensor( sample_shape_static_val, dtype=dtypes.int32, name="sample_shape") total = np.prod(sample_shape_static_val, dtype=dtypes.int32.as_numpy_dtype()) return sample_shape, total
def _expand_sample_shape(self, sample_shape): """Helper to `sample` which ensures sample_shape is 1D.""" sample_shape_static_val = tensor_util.constant_value(sample_shape) ndims = sample_shape.get_shape().ndims if sample_shape_static_val is None: if ndims is None or not sample_shape.get_shape().is_fully_defined( ): ndims = array_ops.rank(sample_shape) expanded_shape = distribution_util.pick_vector( math_ops.equal(ndims, 0), np.array((1, ), dtype=dtypes.int32.as_numpy_dtype()), array_ops.shape(sample_shape)) sample_shape = array_ops.reshape(sample_shape, expanded_shape) total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1 else: if ndims is None: raise ValueError( "Shouldn't be here; ndims cannot be none when we have a " "tf.constant shape.") if ndims == 0: sample_shape_static_val = np.reshape(sample_shape_static_val, [1]) sample_shape = ops.convert_to_tensor(sample_shape_static_val, dtype=dtypes.int32, name="sample_shape") total = np.prod(sample_shape_static_val, dtype=dtypes.int32.as_numpy_dtype()) return sample_shape, total
def _forward_log_det_jacobian(self, x): if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. event_size = distribution_util.pick_vector( math_ops.equal(self._shaper.event_ndims, 0), [1], array_ops.shape(x))[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size return self.scale.log_abs_determinant()
def _forward_log_det_jacobian(self, x): if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. event_size = distribution_util.pick_vector( math_ops.equal(self._shaper.event_ndims, 0), [1], array_ops.shape(x))[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size return self.scale.log_abs_determinant()
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape(rate, shape=concat_vectors( [n], self.batch_shape_tensor())) return random_ops.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return random_ops.random_poisson( lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = reduce_prod(self.batch_shape_tensor()) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `quadrature_degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * len(self.quadrature_probs), delta=len(self.quadrature_probs), dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = reduce_prod(self.batch_shape_tensor()) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `self._degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): batch_size = reduce_prod(self.batch_shape_tensor()) x = self.distribution.sample( sample_shape=concat_vectors( [n * batch_size], self.event_shape_tensor()), seed=seed) x = [array_ops.reshape( aff.forward(x), shape=concat_vectors( [-1], self.batch_shape_tensor(), self.event_shape_tensor())) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `self._degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): batch_size = reduce_prod(self.batch_shape_tensor()) x = self.distribution.sample( sample_shape=concat_vectors( [n * batch_size], self.event_shape_tensor()), seed=seed) x = [array_ops.reshape( aff.forward(x), shape=concat_vectors( [-1], self.batch_shape_tensor(), self.event_shape_tensor())) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `self._degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + array_ops.ones_like(x[0]) * x[1] return x
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())
def __init__(self, df, shift=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `shift` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Numeric `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `shift`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `shift`, `scale_*`. shift: Numeric `Tensor`. If this is set to `None`, no `shift` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Numeric `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Numeric `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Numeric `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Numeric `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: `Boolean`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = locals() parameters.pop("self") graph_parents = [ df, shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag ] with ops.name_scope(name) as ns: with ops.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast self._distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. self._distribution = student_t.StudentT(df=df, mu=0., sigma=1.) self._affine = bijectors.Affine( shift=shift, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) self._batch_shape, self._override_event_shape = _infer_shapes( self.scale, self.shift) self._override_batch_shape = distribution_util.pick_vector( self._distribution.is_scalar_batch(), self._batch_shape, constant_op.constant([], dtype=dtypes.int32)) super(_VectorStudentT, self).__init__(distribution=self._distribution, bijector=self._affine, batch_shape=self._override_batch_shape, event_shape=self._override_event_shape, validate_args=validate_args, name=ns)
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name): with ops.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = bijectors.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) distribution = student_t.StudentT( df=df, loc=array_ops.zeros([], dtype=affine.dtype), scale=array_ops.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = _infer_shapes( affine.scale, affine.shift) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, constant_op.constant([], dtype=dtypes.int32)) super(_VectorStudentT, self).__init__( distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def _batch_shape_tensor(self): return distribution_util.pick_vector( self._is_batch_override, self._override_batch_shape, self.distribution.batch_shape_tensor())
def _event_shape_tensor(self): return self.bijector.forward_event_shape_tensor( distribution_util.pick_vector( self._is_event_override, self._override_event_shape, self.distribution.event_shape_tensor()))
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = math_ops.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least( 2)[-2:].num_elements() if stride is None: stride = array_ops.reduce_prod( array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = array_ops.concat( ([-1], self.batch_shape_tensor(), [1]), axis=0) weight = array_ops.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = locals() graph_parents = [ df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag ] with ops.name_scope(name) as ns: with ops.name_scope("init", values=graph_parents): # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = bijectors.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args) distribution = student_t.StudentT( df=df, loc=array_ops.zeros([], dtype=affine.dtype), scale=array_ops.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = _infer_shapes( affine.scale, affine.shift) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, constant_op.constant([], dtype=dtypes.int32)) super(_VectorStudentT, self).__init__(distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=ns) self._parameters = parameters