def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat([[n], batch_shape, event_shape], 0) # Complexity: O(nbk**2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * array_ops.ones( self.scale_operator.batch_shape_tensor(), dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( seed, "wishart")) # Complexity: O(nbk**2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = array_ops.concat([math_ops.range(1, ndims), [0]], 0) x = array_ops.transpose(x, perm) shape = array_ops.concat([batch_shape, [event_shape[0]], [-1]], 0) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = array_ops.concat([batch_shape, event_shape, [n]], 0) x = array_ops.reshape(x, shape) perm = array_ops.concat([[ndims - 1], math_ops.range(0, ndims - 1)], 0) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.matmul(x, x, adjoint_b=True) return x
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) if self.batch_shape.is_fully_defined() else math_ops.reduce_prod(self.batch_shape_tensor())) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return random_ops.random_poisson( lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.n, dtype=dtypes.int32) if self.n.get_shape().ndims is not None: if self.n.get_shape().ndims != 0: raise NotImplementedError( "Sample only supported for scalar number of draws.") elif self.validate_args: is_scalar = check_ops.assert_rank( n_draws, 0, message="Sample only supported for scalar number of draws.") n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape()[0] unnormalized_logits = array_ops.reshape( math_ops.log(random_ops.random_gamma( shape=[n], alpha=self.alpha, dtype=self.dtype, seed=seed)), shape=[-1, k]) draws = random_ops.multinomial( logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), reduction_indices=-2) final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0) return array_ops.reshape(x, final_shape)
def _sample_n(self, n, seed): batch_shape = self.batch_shape() event_shape = self.event_shape() batch_ndims = array_ops.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = array_ops.concat(((n,), batch_shape, event_shape), 0) # Complexity: O(nbk^2) x = random_ops.random_normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) g = random_ops.random_gamma(shape=(n,), alpha=self._multi_gamma_sequence( 0.5 * self.df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( seed, "wishart")) # Complexity: O(nbk^2) x = array_ops.matrix_band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = array_ops.matrix_set_diag(x, math_ops.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk^2) perm = array_ops.concat((math_ops.range(1, ndims), (0,)), 0) x = array_ops.transpose(x, perm) shape = array_ops.concat((batch_shape, (event_shape[0], -1)), 0) x = array_ops.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for OperatorPDDiag, each matmul is O(k^2), so # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is # O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator_pd.sqrt_matmul(x) # Undo make batch-op ready. # Complexity: O(nbk^2) shape = array_ops.concat((batch_shape, event_shape, (n,)), 0) x = array_ops.reshape(x, shape) perm = array_ops.concat(((ndims - 1,), math_ops.range(0, ndims - 1)), 0) x = array_ops.transpose(x, perm) if not self.cholesky_input_output_matrices: # Complexity: O(nbk^3) x = math_ops.matmul(x, x, adjoint_b=True) return x
def _sample_n(self, n, seed=None): a = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.a b = array_ops.ones_like(self.a_b_sum, dtype=self.dtype) * self.b gamma1_sample = random_ops.random_gamma( [n,], a, dtype=self.dtype, seed=seed) gamma2_sample = random_ops.random_gamma( [n,], b, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "beta")) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _sample_n(self, n, seed=None): # The sampling method comes from the well known fact that if X ~ Normal(0, # 1), and Z ~ Chi2(df), then X / sqrt(Z / df) ~ StudentT(df). shape = array_ops.concat(0, ([n], self.batch_shape())) normal_sample = random_ops.random_normal( shape, dtype=self.dtype, seed=seed) half = constant_op.constant(0.5, self.dtype) df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype) gamma_sample = random_ops.random_gamma( [n,], half * df, beta=half, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, salt="student_t")) samples = normal_sample / math_ops.sqrt(gamma_sample / df) return samples * self.sigma + self.mu
def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. rate = random_ops.random_gamma( shape=[n], alpha=self.total_count, beta=math_ops.exp(-self.logits), dtype=self.dtype, seed=seed) return random_ops.random_poisson( rate, shape=[], dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "negative_binom"))
def _sample_n(self, n, seed=None): # The sampling method comes from the fact that if: # X ~ Normal(0, 1) # Z ~ Chi2(df) # Y = X / sqrt(Z / df) # then: # Y ~ StudentT(df). shape = array_ops.concat_v2([[n], self.batch_shape()], 0) normal_sample = random_ops.random_normal( shape, dtype=self.dtype, seed=seed) df = self.df * array_ops.ones(self.batch_shape(), dtype=self.dtype) gamma_sample = random_ops.random_gamma( [n], 0.5 * df, beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, salt="student_t")) samples = normal_sample / math_ops.sqrt(gamma_sample / df) return samples * self.sigma + self.mu
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] unnormalized_logits = array_ops.reshape( math_ops.log(random_ops.random_gamma( shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed)), shape=[-1, k]) draws = random_ops.multinomial( logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) return array_ops.reshape(x, final_shape)
def _sample_n(self, n, seed=None): expanded_concentration1 = array_ops.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( self.total_concentration, dtype=self.dtype) * self.concentration0 gamma1_sample = random_ops.random_gamma( shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed) gamma2_sample = random_ops.random_gamma( shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=distribution_util.gen_new_seed(seed, "beta")) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = reduce_prod(self.batch_shape_tensor()) ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `quadrature_degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * len(self.quadrature_probs), delta=len(self.quadrature_probs), dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): batch_size = reduce_prod(self.batch_shape_tensor()) x = self.distribution.sample( sample_shape=concat_vectors( [n * batch_size], self.event_shape_tensor()), seed=seed) x = [array_ops.reshape( aff.forward(x), shape=concat_vectors( [-1], self.batch_shape_tensor(), self.event_shape_tensor())) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # Stride `self._degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( array_ops.reshape(self.distribution.rate, shape=[-1]), ids) rate = array_ops.reshape( rate, shape=concat_vectors([n], self.batch_shape_tensor())) return random_ops.random_poisson( lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = reduce_prod(self.batch_shape_tensor()) ids = self._mixture_distribution.sample( sample_shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture")) # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): batch_size = reduce_prod(self.batch_shape_tensor()) x = self.distribution.sample(sample_shape=concat_vectors( [n * batch_size], self.event_shape_tensor()), seed=seed) x = [ array_ops.reshape(aff.forward(x), shape=concat_vectors([-1], self.batch_shape_tensor(), self.event_shape_tensor())) for aff in self.endpoint_affine ] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. ids = self._mixture_distribution.sample( sample_shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size])), seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture")) # Stride `self._degree` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._degree, delta=self._degree, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.interpolate_weight, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): x = self.distribution.sample( sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = math_ops.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least( 2)[-2:].num_elements() if stride is None: stride = array_ops.reduce_prod( array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = array_ops.gather( array_ops.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = array_ops.concat( ([-1], self.batch_shape_tensor(), [1]), axis=0) weight = array_ops.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(cat_samples) samples_size = array_ops.size(cat_samples) static_batch_shape = self.batch_shape if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape_tensor() batch_size = math_ops.reduce_prod(batch_shape) static_event_shape = self.event_shape if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape( lhs_flat_ret, array_ops.concat( [samples_shape, self.event_shape_tensor()], 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).concatenate( self.event_shape)) return ret
def _sample_n(self, n, seed=None): x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=seed) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) mix_batch_size = self.mixture_distribution.batch_shape.num_elements() if mix_batch_size is None: mix_batch_size = math_ops.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed(seed, "vector_diffeomixture")) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = array_ops.reshape(ids, shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = self.grid.shape.with_rank_at_least(2)[-2:].num_elements() if stride is None: stride = array_ops.reduce_prod(array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = array_ops.gather(array_ops.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if self.batch_shape.is_fully_defined(): new_shape = [-1] + self.batch_shape.as_list() + [1] else: new_shape = array_ops.concat( ([-1], self.batch_shape_tensor(), [1]), axis=0) weight = array_ops.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def testOnlyNoneReturnsNone(self): self.assertFalse(distribution_util.gen_new_seed(0, "salt") is None) self.assertTrue(distribution_util.gen_new_seed(None, "salt") is None)
def _sample_n(self, n, seed=None): with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.get_shape() if static_samples_shape.is_fully_defined(): samples_shape = static_samples_shape.as_list() samples_size = static_samples_shape.num_elements() else: samples_shape = array_ops.shape(cat_samples) samples_size = array_ops.size(cat_samples) static_batch_shape = self.get_batch_shape() if static_batch_shape.is_fully_defined(): batch_shape = static_batch_shape.as_list() batch_size = static_batch_shape.num_elements() else: batch_shape = self.batch_shape() batch_size = array_ops.reduce_prod(batch_shape) static_event_shape = self.get_event_shape() if static_event_shape.is_fully_defined(): event_shape = np.array(static_event_shape.as_list(), dtype=np.int32) else: event_shape = self.event_shape() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = array_ops.reshape( math_ops.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = data_flow_ops.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = array_ops.reshape( array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = data_flow_ops.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] for c in range(self.num_components): n_class = array_ops.size(partitioned_samples_indices[c]) seed = distribution_util.gen_new_seed(seed, "mixture") samples_class_c = self.components[c].sample(n_class, seed=seed) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * math_ops.range(n_class) + partitioned_batch_indices[c]) samples_class_c = array_ops.reshape( samples_class_c, array_ops.concat(([n_class * batch_size], event_shape), 0)) samples_class_c = array_ops.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = data_flow_ops.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = array_ops.reshape(lhs_flat_ret, array_ops.concat((samples_shape, self.event_shape()), 0)) ret.set_shape( tensor_shape.TensorShape(static_samples_shape).concatenate( self.get_event_shape())) return ret