def _sample_n(self, n, seed=None): gamma1_seed, gamma2_seed, binomial_seed = samplers.split_seed( seed, n=3, salt='beta_binomial') total_count, concentration1, concentration0 = self._params_list_as_tensors() batch_shape = self._batch_shape_tensor(total_count=total_count, concentration1=concentration1, concentration0=concentration0) expanded_concentration1 = tf.broadcast_to(concentration1, batch_shape) expanded_concentration0 = tf.broadcast_to(concentration0, batch_shape) # probs = g1 / (g1 + g2) # logits = log(probs) - log(1 - probs) # = log(g1 / (g1 + g2)) - log(1 - g1 / (g1 + g2)) # = log(g1) - log(g1 + g2) - log(((g1 + g2) - g1) / (g1 + g2)) # = log(g1) - log(g1 + g2) - (log(g1 + g2 - g1) - log(g1 + g2)) # = log(g1) - log(g1 + g2) - log(g2) + log(g1 + g2)) # = log(g1) - log(g2) log_gamma1 = gamma_lib.random_gamma( shape=[n], concentration=expanded_concentration1, seed=gamma1_seed, log_space=True) log_gamma2 = gamma_lib.random_gamma( shape=[n], concentration=expanded_concentration0, seed=gamma2_seed, log_space=True) return binomial.Binomial( total_count, logits=log_gamma1 - log_gamma2, validate_args=self.validate_args).sample(seed=binomial_seed)
def fn(i, num_trials, consumed_prob, accum): """Sample the counts for one class using binomial.""" probs_here = tf.gather(probs, i, axis=-1) binomial_probs = tf.clip_by_value(probs_here / (1. - consumed_prob), 0, 1) seed_here = tf.gather(seeds, i, axis=0) binom = binomial.Binomial(total_count=num_trials, probs=binomial_probs) # Not passing `num_samples` to `binom.sample`, as it's is already in # `num_trials.shape`. sample = binom.sample(seed=seed_here) accum = accum.write(i, tf.cast(sample, dtype=dtype)) return i + 1, num_trials - sample, consumed_prob + probs_here, accum
def _sample_n(self, n, seed=None): seed_stream = SeedStream(seed, 'beta_binomial') total_count, concentration1, concentration0 = self._params_list_as_tensors( ) batch_shape_tensor = self.batch_shape_tensor() probs = beta.Beta(tf.broadcast_to(concentration1, batch_shape_tensor), concentration0, validate_args=self.validate_args).sample( n, seed=seed_stream()) return binomial.Binomial( total_count, probs=probs, validate_args=self.validate_args).sample(seed=seed_stream())
def _sample_n(self, n, seed=None): beta_seed, binomial_seed = samplers.split_seed(seed, salt='beta_binomial') params = self._params_list_as_tensors() batch_shape = self._batch_shape_tensor(params=params) total_count, concentration1, concentration0 = params probs = beta.Beta(tf.broadcast_to(concentration1, batch_shape), concentration0, validate_args=self.validate_args).sample( n, seed=beta_seed) return binomial.Binomial( total_count, probs=probs, validate_args=self.validate_args).sample(seed=binomial_seed)