def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) # Complexity: O(nbk**2) x = tf.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * tf.ones( self.scale_operator.batch_shape_tensor(), dtype=self.df.dtype.base_dtype) g = tf.random_gamma( shape=[n], alpha=self._multi_gamma_sequence(0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "wishart")) # Complexity: O(nbk**2) x = tf.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.matrix_set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(x, perm) shape = tf.concat([batch_shape, [event_shape[0]], [-1]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) # Complexity: O(nbk**2) x = tf.random_normal( shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * tf.ones( self.scale_operator.batch_shape_tensor(), dtype=self.df.dtype.base_dtype) g = tf.random_gamma( shape=[n], alpha=self._multi_gamma_sequence(0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "wishart")) # Complexity: O(nbk**2) x = tf.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.matrix_set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(x, perm) shape = tf.concat([batch_shape, [event_shape[0]], [-1]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(x, perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def _sample_n(self, n, seed=None): expanded_concentration1 = tf.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = tf.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration0 gamma1_sample = tf.random_gamma(shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed) gamma2_sample = tf.random_gamma(shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=util.gen_new_seed(seed, "beta")) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _sample_n(self, n, seed=None): n_draws = tf.cast(self.total_count, dtype=tf.int32) k = self.event_shape_tensor()[0] unnormalized_logits = tf.reshape(tf.log( tf.random_gamma(shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed)), shape=[-1, k]) draws = tf.multinomial(logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed( seed, salt="dirichlet_multinomial")) x = tf.reduce_sum(tf.one_hot(draws, depth=k), -2) final_shape = tf.concat([[n], self.batch_shape_tensor(), [k]], 0) x = tf.reshape(x, final_shape) return tf.cast(x, self.dtype)
def _sample_n(self, n, seed=None): # The sampling method comes from the fact that if: # X ~ Normal(0, 1) # Z ~ Chi2(df) # Y = X / sqrt(Z / df) # then: # Y ~ StudentT(df). shape = tf.concat([[n], self.batch_shape_tensor()], 0) normal_sample = tf.random_normal(shape, dtype=self.dtype, seed=seed) df = self.df * tf.ones(self.batch_shape_tensor(), dtype=self.dtype) gamma_sample = tf.random_gamma( [n], 0.5 * df, beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, salt="student_t")) samples = normal_sample * tf.rsqrt(gamma_sample / df) return samples * self.scale + self.loc # Abs(scale) not wanted.
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids) rate = tf.reshape(rate, shape=concat_vectors([n], self.batch_shape_tensor())) return tf.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return tf.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def testOnlyNoneReturnsNone(self): self.assertIsNotNone(distribution_util.gen_new_seed(0, 'salt')) self.assertIsNone(distribution_util.gen_new_seed(None, 'salt'))
def _sample_n(self, n, seed=None): x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements() if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n pi_samples = self.pi.sample(n, seed=seed) static_samples_shape = pi_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(pi_samples) samples_size = array_ops.size(pi_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = math_ops.reduce_prod(batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw pi sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=pi_samples, num_partitions=self.num_dist) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=pi_samples, num_partitions=self.num_dist) samples_class = [None for _ in range(self.num_dist)] for c in range(self.num_dist): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "ZeroInflated") samples_class_c = self.dist[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along lopiions (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.conpi([[n_class * batch_size], event_shape], 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the dist. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape(lhs_flat_ret, array_ops.conpi([samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).conpienate( self.event_shape)) return ret
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least( 2)[-2:].num_elements() if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range( start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x