def Zheng07Cens(halo_mvir,
                logMmin=ed.Deterministic(11.35, name='logMmin'),
                sigma_logM=ed.Deterministic(0.25, name='sigma_logM'),
                temperature=0.2,
                name='zheng07Cens', **kwargs):
  halo_mvir = tf.math.log(halo_mvir) / tf.math.log(10.)
  # Compute the mean number of centrals
  p = tf.clip_by_value(0.5 * (1+tf.math.erf((halo_mvir - logMmin)/sigma_logM)), 1.e-4, 1-1.e-4)
  return ed.RelaxedBernoulli(temperature, probs=p, name=name)
def Zheng07SatsPoisson(halo_mvir,
                n_cen,
                logM0=ed.Deterministic(11.2, name='logM0'),
                logM1=ed.Deterministic(12.4, name='logM1'),
                alpha=ed.Deterministic(0.83, name='alpha'),
                name='zheng07Sats', **kwargs):
  M0 = 10.**logM0
  M1 = 10.**logM1
  rate = n_cen.distribution.probs * ((halo_mvir - M0)/M1)**alpha
  rate = tf.where(halo_mvir < M0, 1e-4, rate)
  return ed.Poisson(rate=rate, name=name)
def Zheng07SatsRelaxedBernoulli(halo_mvir,
                n_cen,
                sample_shape,
                logM0=ed.Deterministic(11.2, name='logM0'),
                logM1=ed.Deterministic(12.4, name='logM1'),
                alpha=ed.Deterministic(0.83, name='alpha'),
                temperature=0.2,
                name='zheng07Sats', **kwargs):
  M0 = 10.**logM0
  M1 = 10.**logM1
  rate = n_cen.distribution.probs * (tf.nn.relu(halo_mvir - M0)/M1)**alpha
  return ed.RelaxedBernoulli(temperature=temperature,
                             probs=tf.clip_by_value(rate/sample_shape[0],1.e-5,1-1e-4),
                             sample_shape=sample_shape)
def trainable_positive_deterministic(shape, min_loc=1e-3, name=None):
  """Learnable Deterministic distribution over positive reals."""
  with tf.variable_scope(None, default_name="trainable_positive_deterministic"):
    unconstrained_loc = tf.get_variable("unconstrained_loc", shape)
    loc = tf.maximum(tf.nn.softplus(unconstrained_loc), min_loc)
    rv = ed.Deterministic(loc=loc, name=name)
    return rv