def test_with_rank_ndarray(self): x = np.array([2], dtype=np.int32) with self.assertRaises(ValueError): tensorshape_util.with_rank(x, 2) x = np.array([2, 3, 4], dtype=np.int32) y = tensorshape_util.with_rank(x, 3) self.assertAllEqual(x, y) x = np.array([2, 3, 4, 5], dtype=np.int32) y = tensorshape_util.with_rank_at_least(x, 3) self.assertAllEqual(x, y)
def test_with_rank_list_tuple(self): with self.assertRaises(ValueError): tensorshape_util.with_rank([2], 2) with self.assertRaises(ValueError): tensorshape_util.with_rank((2, ), 2) self.assertAllEqual((2, 1), tensorshape_util.with_rank((2, 1), 2)) self.assertAllEqual([2, 1], tensorshape_util.with_rank([2, 1], 2)) self.assertAllEqual((2, 3, 4), tensorshape_util.with_rank_at_least((2, 3, 4), 2)) self.assertAllEqual([2, 3, 4], tensorshape_util.with_rank_at_least([2, 3, 4], 2))
def _event_shape(self): # We will never broadcast the num_categories with total_count. return tensorshape_util.with_rank( (self._probs if self._logits is None else self._logits).shape[-1:], rank=1)
def _event_shape(self): return tensorshape_util.with_rank(self.concentration.shape[-1:], rank=1)
def _event_shape(self): param = self._logits if self._logits is not None else self._probs return tensorshape_util.with_rank(param.shape[-1:], rank=1)
def _event_shape(self): return tensorshape_util.with_rank(self.mean_direction.shape[-1:], rank=1)
def _event_shape(self): # Event shape depends only on concentration, not total_count. return tensorshape_util.with_rank(self.concentration.shape[-1:], rank=1)
def _sample_n(self, n, seed=None): stream = SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank(self.grid.shape[-2:], rank=2)) if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x