예제 #1
0
파일: XClone.py 프로젝트: huangyh09/xclone
 def set_prior(self, theta_prior=None, gamma_prior=None, Y_prior=None, 
               Z_prior=None):
     """Set prior ditributions
     """
     # Prior distributions for the allelic ratio
     if theta_prior is None:
         self.theta_prior = tfd.Beta(self.cnv_states[:, 0] + 0.01, 
                                     self.cnv_states[:, 1] + 0.01)
     else:
         self.theta_prior = theta_prior
         
     # Prior distributions for the depth ratio
     if gamma_prior is None:
         if self.RDR_cov is None:
             self.set_GP_kernal()
         self.gamma_prior = FullNormal(loc = self.cnv_states.sum(axis=1), 
                                       covariance_matrix = self.RDR_cov)
     else:
         self.gamma_prior = gamma_prior
         
     # Prior distributions for CNV state weights
     if Y_prior is None:
         self.Y_prior = tfd.Multinomial(total_count=1,
                     probs=tf.ones((self.Nb, self.Nk, self.Ns)) / self.Ns)
     else:
         self.Y_prior = Y_prior
         
     # Prior distributions for cell assignment weights
     if Z_prior is None:
         self.Z_prior = tfd.Multinomial(total_count=1,
                     probs=tf.ones((self.Nc, self.Nk)) / self.Nk)
     else:
         self.Z_prior = Z_prior
예제 #2
0
  def prior_fn(shifted_codes):
    """Calculates prior logits on discrete latents.

    Args:
      shifted_codes: A binary `Tensor` of shape
        [num_samples, batch_size, latent_size, num_codes], shifted by one to
        enable autoregressive calculation.

    Returns:
      prior_dist: Multinomial distribution with prior logits coming from
        Transformer applied to shifted input.
    """
    with tf.variable_scope("transformer_prior", reuse=tf.AUTO_REUSE):
      dense_shifted_codes = tf.reduce_sum(
          tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) *
          shifted_codes[Ellipsis, tf.newaxis], axis=-2)
      transformed_codes = cia.transformer_decoder_layers(
          inputs=dense_shifted_codes,
          encoder_output=None,
          num_layers=hparams.num_layers,
          hparams=hparams,
          attention_type=cia.AttentionType.LOCAL_1D)
      logits = tf.reduce_sum(
          tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) *
          transformed_codes[Ellipsis, tf.newaxis, :], axis=-1)
      prior_dist = tfd.Multinomial(total_count=1., logits=logits)
    return prior_dist
예제 #3
0
파일: core.py 프로젝트: debbiemarkslab/BEAR
    def __init__(self,
                 total_count,
                 probs,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='DirichletMultinomialPerm'):

        parameters = dict(locals())

        with tf.name_scope(name) as name:

            self.total_count = total_count
            self.probs = probs
            self.alphabet_size = tf.cast(
                tf.shape(self.probs)[-1] - 1, tf.int32)

            super(tfpMultinomialPerm, self).__init__(
                dtype=self.probs.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)

        self.counts_dist = tfpd.Multinomial(
            self.total_count,
            probs=self.probs,
            validate_args=self.validate_args,
            allow_nan_stats=self.allow_nan_stats,
            name=self.name)
예제 #4
0
        def decoder(state_sample, observation_dist="gaussian"):
            """Compute the data distribution of an observation from its state [1]."""
            check_in(
                "observation_dist",
                observation_dist,
                ("gaussian", "laplace", "bernoulli", "multinomial"),
            )

            timesteps = tf.shape(state_sample)[1]

            if self.pixel_observations:
                # original decoder from [1] for deepmind lab envs
                hidden = tf.layers.dense(state_sample, 1024, None)
                kwargs = dict(strides=2, activation=tf.nn.relu)
                hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1]])
                # 1 x 1
                hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs)
                # 5 x 5 x 128
                hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs)
                # 13 x 13 x 64
                hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs)
                # 30 x 30 x 32
                mean = 255 * tf.layers.conv2d_transpose(
                    hidden, 3, 6, strides=2, activation=tf.nn.sigmoid)
                # 64 x 64 x 3
                assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape
            else:
                # decoder for gridworlds / structured observations
                hidden = state_sample
                d = self._hidden_layer_size
                for _ in range(4):
                    hidden = tf.layers.dense(hidden, d, tf.nn.relu)
                mean = tf.layers.dense(hidden, np.prod(self.data_shape), None)

            mean = tf.reshape(mean, [-1, timesteps] + list(self.data_shape))

            check_in(
                "observation_dist",
                observation_dist,
                ("gaussian", "laplace", "bernoulli", "multinomial"),
            )
            if observation_dist == "gaussian":
                dist = tfd.Normal(mean, self._obs_stddev)
            elif observation_dist == "laplace":
                dist = tfd.Laplace(mean, self._obs_stddev / np.sqrt(2))
            elif observation_dist == "bernoulli":
                dist = tfd.Bernoulli(probs=mean)
            else:
                mean = tf.reshape(mean, [-1, timesteps] +
                                  [np.prod(list(self.data_shape))])
                dist = tfd.Multinomial(total_count=1, probs=mean)
                reshape = tfp.bijectors.Reshape(
                    event_shape_out=list(self.data_shape))
                dist = reshape(dist)
                return dist

            dist = tfd.Independent(dist, len(self.data_shape))
            return dist
예제 #5
0
 def log_prob_i(leaves_of_subtree_):
     if len(leaves_of_subtree_) == 1:
         return tf.zeros_like(logits[..., 0])
     with tf.variable_scope(name or self.name):
         output_ = tf.stack([output_list[leaf_idx] for leaf_idx in leaves_of_subtree_], axis=-1)
         logits_ = tf.stack([logits_list[leaf_idx] for leaf_idx in leaves_of_subtree_], axis=-1)
         depth = tf.reduce_sum(output_, axis=-1)
         multinomial = tfd.Multinomial(total_count=depth, logits=logits_,
                                       validate_args=True, allow_nan_stats=False)
         log_prob_i = multinomial.log_prob(output_)
     return log_prob_i
예제 #6
0
        def get_mean_subtree(leaves_of_subtree_):
            if len(leaves_of_subtree_) == 1:
                leaf_idx_ = leaves_of_subtree_[0]
                return [obs_list[leaf_idx_]]

            with tf.variable_scope(name or self.name):
                logits_ = tf.stack([logits_list[leaf_idx_] for leaf_idx_ in leaves_of_subtree_], axis=-1)
                depth = tf.reduce_sum([obs_list[leaf_idx_] for leaf_idx_ in leaves_of_subtree_], axis=0)
                multinomial = tfd.Multinomial(total_count=depth, logits=logits_,
                                              validate_args=True, allow_nan_stats=False)
                mean_subtree = multinomial.mean()
            return tf.unstack(mean_subtree, axis=-1)
예제 #7
0
 def get_multinomial(self, Input, depth):
     """
     :param Input: (T, Dx)
     :param external_inputs:  total counts, (T, )
     :return:
     """
     with tf.variable_scope(self.name):
         logits = self.transformation.transform(Input)
         multinomial = tfd.Multinomial(total_count=depth,
                                       logits=logits,
                                       validate_args=True,
                                       allow_nan_stats=False)
         return multinomial
예제 #8
0
    def set_prior(self, mu_prior=None, sigma_prior=None, ident_prior=None):
        """Set prior ditributions
        """
        # Prior distributions for the means
        if mu_prior is None:
            self.mu_prior = tfd.Normal(tf.zeros((self.Nc, self.Nd)),
                                       tf.ones((self.Nc, self.Nd)))
        else:
            self.mu_prior = self.mu_prior

        # Prior distributions for the standard deviations
        if sigma_prior is None:
            self.sigma_prior = tfd.Gamma(2 * tf.ones((self.Nc, self.Nd)),
                                         2 * tf.ones((self.Nc, self.Nd)))
        else:
            self.sigma_prior = sigma_prior

        # Prior distributions for sample assignment
        if ident_prior is None:
            self.ident_prior = tfd.Multinomial(
                total_count=1, probs=tf.ones((self.Ns, self.Nc)) / self.Nc)
        else:
            self.ident_prior = ident_prior
예제 #9
0
def main(argv):
  del argv  # unused
  FLAGS.encoder_layers = [int(units) for units
                          in FLAGS.encoder_layers.split(",")]
  FLAGS.decoder_layers = [int(units) for units
                          in FLAGS.decoder_layers.split(",")]
  FLAGS.activation = getattr(tf.nn, FLAGS.activation)
  if tf.gfile.Exists(FLAGS.model_dir):
    tf.logging.warn("Deleting old log directory at {}".format(FLAGS.model_dir))
    tf.gfile.DeleteRecursively(FLAGS.model_dir)
  tf.gfile.MakeDirs(FLAGS.model_dir)

  if FLAGS.fake_data:
    mnist_data = build_fake_data()
  else:
    mnist_data = mnist.read_data_sets(FLAGS.data_dir)

  with tf.Graph().as_default():
    (images, _, handle,
     training_iterator, heldout_iterator) = build_input_pipeline(
         mnist_data, FLAGS.batch_size, mnist_data.validation.num_examples)

    # Reshape as a pixel image and binarize pixels.
    images = tf.reshape(images, shape=[-1] + IMAGE_SHAPE)
    images = tf.cast(images > 0.5, dtype=tf.int32)

    encoder = make_encoder(FLAGS.encoder_layers,
                           FLAGS.activation,
                           FLAGS.latent_size,
                           FLAGS.code_size)
    decoder = make_decoder(FLAGS.decoder_layers,
                           FLAGS.activation,
                           IMAGE_SHAPE)
    vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size)

    # Build the model and loss function.
    loss, decoder_distribution = make_vq_vae(
        images, encoder, decoder, vector_quantizer, FLAGS.beta, FLAGS.decay)
    reconstructed_images = decoder_distribution.mean()

    # Decode samples from a uniform prior for visualization.
    prior = tfd.Multinomial(total_count=1., logits=tf.zeros(FLAGS.num_codes))
    prior_samples = tf.reduce_sum(
        tf.expand_dims(prior.sample([10, FLAGS.latent_size]), -1) *
        tf.reshape(vector_quantizer.codebook,
                   [1, 1, FLAGS.num_codes, FLAGS.code_size]),
        axis=2)
    decoded_distribution_given_random_prior = decoder(prior_samples)
    random_images = decoded_distribution_given_random_prior.mean()

    # Perform inference by minimizing the loss function.
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    train_op = optimizer.minimize(loss)

    summary = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
      summary_writer = tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
      sess.run(init)

      # Run the training loop.
      train_handle = sess.run(training_iterator.string_handle())
      heldout_handle = sess.run(heldout_iterator.string_handle())
      for step in range(FLAGS.max_steps):
        start_time = time.time()
        _, loss_value = sess.run([train_op, loss],
                                 feed_dict={handle: train_handle})
        duration = time.time() - start_time
        if step % 100 == 0:
          print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
              step, loss_value, duration))

          # Update the events file.
          summary_str = sess.run(summary, feed_dict={handle: train_handle})
          summary_writer.add_summary(summary_str, step)
          summary_writer.flush()

        # Periodically save a checkpoint and visualize model progress.
        if (step + 1) % FLAGS.viz_steps == 0 or (step + 1) == FLAGS.max_steps:
          checkpoint_file = os.path.join(FLAGS.model_dir, "model.ckpt")
          saver.save(sess, checkpoint_file, global_step=step)

          # Visualize inputs and model reconstructions from the training set.
          images_val, reconstructions_val, random_images_val = sess.run(
              (images, reconstructed_images, random_images),
              feed_dict={handle: train_handle})
          visualize_training(images_val,
                             reconstructions_val,
                             random_images_val,
                             log_dir=FLAGS.model_dir,
                             prefix="step{:05d}_train".format(step))

          # Visualize inputs and model reconstructions from the validation set.
          heldout_images_val, heldout_reconstructions_val = sess.run(
              (images, reconstructed_images),
              feed_dict={handle: heldout_handle})
          visualize_training(heldout_images_val,
                             heldout_reconstructions_val,
                             None,
                             log_dir=FLAGS.model_dir,
                             prefix="step{:05d}_validation".format(step))
예제 #10
0
 def _init_distribution(conditions):
     n, p = conditions["n"], conditions["p"]
     return tfd.Multinomial(total_count=n, probs=p)
예제 #11
0
def main(argv):
    del argv  # unused
    FLAGS.activation = getattr(tf.nn, FLAGS.activation)
    if tf.io.gfile.exists(FLAGS.model_dir):
        tf.compat.v1.logging.warn("Deleting old log directory at {}".format(
            FLAGS.model_dir))
        tf.io.gfile.rmtree(FLAGS.model_dir)
    tf.io.gfile.makedirs(FLAGS.model_dir)

    with tf.Graph().as_default():
        # TODO(b/113163167): Speed up and tune hyperparameters for Bernoulli MNIST.
        (images, _, handle, training_iterator,
         heldout_iterator) = build_input_pipeline(FLAGS.data_dir,
                                                  FLAGS.batch_size,
                                                  heldout_size=10000,
                                                  mnist_type=FLAGS.mnist_type)

        encoder = make_encoder(FLAGS.base_depth, FLAGS.activation,
                               FLAGS.latent_size, FLAGS.code_size)
        decoder = make_decoder(FLAGS.base_depth, FLAGS.activation,
                               FLAGS.latent_size * FLAGS.code_size,
                               IMAGE_SHAPE)
        vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size)

        codes = encoder(images)
        nearest_codebook_entries, one_hot_assignments = vector_quantizer(codes)
        codes_straight_through = codes + tf.stop_gradient(
            nearest_codebook_entries - codes)
        decoder_distribution = decoder(codes_straight_through)
        reconstructed_images = decoder_distribution.mean()

        reconstruction_loss = -tf.reduce_mean(
            input_tensor=decoder_distribution.log_prob(images))
        commitment_loss = tf.reduce_mean(
            input_tensor=tf.square(codes -
                                   tf.stop_gradient(nearest_codebook_entries)))
        commitment_loss = add_ema_control_dependencies(vector_quantizer,
                                                       one_hot_assignments,
                                                       codes, commitment_loss,
                                                       FLAGS.decay)
        prior_dist = tfd.Multinomial(total_count=1.0,
                                     logits=tf.zeros(
                                         [FLAGS.latent_size, FLAGS.num_codes]))
        prior_loss = -tf.reduce_mean(input_tensor=tf.reduce_sum(
            input_tensor=prior_dist.log_prob(one_hot_assignments), axis=1))

        loss = reconstruction_loss + FLAGS.beta * commitment_loss + prior_loss
        # Upper bound marginal negative log-likelihood as prior loss +
        # reconstruction loss.
        marginal_nll = prior_loss + reconstruction_loss

        tf.compat.v1.summary.scalar("losses/total_loss", loss)
        tf.compat.v1.summary.scalar("losses/reconstruction_loss",
                                    reconstruction_loss)
        tf.compat.v1.summary.scalar("losses/prior_loss", prior_loss)
        tf.compat.v1.summary.scalar("losses/commitment_loss",
                                    FLAGS.beta * commitment_loss)

        # Decode samples from a uniform prior for visualization.
        prior_samples = tf.reduce_sum(
            input_tensor=tf.expand_dims(prior_dist.sample(10), -1) *
            tf.reshape(vector_quantizer.codebook,
                       [1, 1, FLAGS.num_codes, FLAGS.code_size]),
            axis=2)
        decoded_distribution_given_random_prior = decoder(prior_samples)
        random_images = decoded_distribution_given_random_prior.mean()

        # Perform inference by minimizing the loss function.
        optimizer = tf.compat.v1.train.AdamOptimizer(FLAGS.learning_rate)
        train_op = optimizer.minimize(loss)

        summary = tf.compat.v1.summary.merge_all()
        init = tf.compat.v1.global_variables_initializer()
        saver = tf.compat.v1.train.Saver()
        with tf.compat.v1.Session() as sess:
            summary_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.model_dir, sess.graph)
            sess.run(init)

            # Run the training loop.
            train_handle = sess.run(training_iterator.string_handle())
            heldout_handle = sess.run(heldout_iterator.string_handle())
            for step in range(FLAGS.max_steps):
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss],
                                         feed_dict={handle: train_handle})
                duration = time.time() - start_time
                if step % 100 == 0:
                    marginal_nll_val = sess.run(
                        marginal_nll, feed_dict={handle: heldout_handle})
                    print(
                        "Step: {:>3d} Training Loss: {:.3f} Heldout NLL: {:.3f} "
                        "({:.3f} sec)".format(step, loss_value,
                                              marginal_nll_val, duration))

                    # Update the events file.
                    summary_str = sess.run(summary,
                                           feed_dict={handle: train_handle})
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                # Periodically save a checkpoint and visualize model progress.
                if (step + 1) % FLAGS.viz_steps == 0 or (step +
                                                         1) == FLAGS.max_steps:
                    checkpoint_file = os.path.join(FLAGS.model_dir,
                                                   "model.ckpt")
                    saver.save(sess, checkpoint_file, global_step=step)

                    # Visualize inputs and model reconstructions from the training set.
                    images_val, reconstructions_val, random_images_val = sess.run(
                        (images, reconstructed_images, random_images),
                        feed_dict={handle: train_handle})
                    visualize_training(images_val,
                                       reconstructions_val,
                                       random_images_val,
                                       log_dir=FLAGS.model_dir,
                                       prefix="step{:05d}_train".format(step))

                    # Visualize inputs and model reconstructions from the validation set.
                    heldout_images_val, heldout_reconstructions_val = sess.run(
                        (images, reconstructed_images),
                        feed_dict={handle: heldout_handle})
                    visualize_training(
                        heldout_images_val,
                        heldout_reconstructions_val,
                        None,
                        log_dir=FLAGS.model_dir,
                        prefix="step{:05d}_validation".format(step))
예제 #12
0
파일: XClone.py 프로젝트: huangyh09/xclone
 def Y(self):
     """Variational posterior for CNV state"""
     return tfd.Multinomial(total_count=1, logits=self.CNV_logit)
예제 #13
0
파일: simulator.py 프로젝트: huangyh09/brie
def simulator(adata,
              Psi=None,
              effLen=None,
              mode="posterior",
              layer_keys=['isoform1', 'isoform2', 'ambiguous'],
              prior_sigma=None):
    """Simulate read counts for BRIE model
    """
    # Check Psi
    if Psi is None and "Psi" not in adata.layers:
        print("Error: no Psi available in adata.layers.")
        exit()
    elif Psi is None:
        if mode == "posterior":
            Psi = adata.layers['Psi'].copy()
        else:
            Psi = np.zeros((adata.shape), np.float32)
            if 'Xc' in adata.obsm and adata.obsm['Xc'].shape[1] > 0:
                Psi += np.dot(adata.obsm['Xc'], adata.varm['cell_coeff'].T)
            if 'Xg' in adata.varm and adata.varm['Xg'].shape[1] > 0:
                Psi += np.dot(adata.obsm['gene_coeff'], adata.varm['Xg'].T)
            if 'intercept' in adata.varm and adata.varm['intercept'].shape[
                    1] > 0:
                Psi += adata.varm['intercept'].T
            if 'intercept' in adata.obsm and adata.obsm['intercept'].shape[
                    1] > 0:
                Psi += adata.obsm['intercept']

            adata.layers['Psi_sim_noNoise'] = expit(Psi)

            if prior_sigma is None:
                _sigma = adata.varm['sigma'].T
            else:
                _sigma = np.ones([1, adata.shape[1]]) * prior_sigma
            _noise = np.random.normal(loc=0.0, scale=_sigma, size=None)
            Psi += _noise

            Psi[Psi > 9] = 9
            Psi[Psi < -9] = -9
            Psi = expit(Psi)
    adata.layers['Psi_sim'] = Psi

    # Check effective length for isoform specific positions
    if effLen is None and 'effLen' not in adata.varm:
        print("Error: no effLen available in adata.varm.")
        exit()
    elif effLen is None:
        effLen = adata.varm['effLen'][:, [0, 4, 5]]
    else:
        effLen = effLen[:, [0, 4, 5]].copy()
    effLen = np.expand_dims(effLen, 0)

    Psi_tensor = np.concatenate(
        (np.expand_dims(Psi, 2), 1 - np.expand_dims(Psi, 2),
         np.ones((Psi.shape[0], Psi.shape[1], 1), np.float32)),
        axis=2)

    Phi = Psi_tensor * effLen
    Phi = Phi / np.sum(Phi, axis=2, keepdims=True)

    adata = adata.copy()
    total_counts = np.zeros(adata.shape, np.float32)
    for i in range(len(layer_keys)):
        total_counts += adata.layers[layer_keys[i]]

    model = tfd.Multinomial(total_counts, probs=Phi)
    count_sim = model.sample().numpy()

    adata.layers[layer_keys[0]] = count_sim[:, :, 0]
    adata.layers[layer_keys[1]] = count_sim[:, :, 1]
    adata.layers[layer_keys[2]] = count_sim[:, :, 2]

    return adata
예제 #14
0
def gumbel_softmax_bottleneck(dist,
                              vector_quantizer,
                              temperature=0.5,
                              num_iaf_flows=0,
                              use_transformer_for_iaf_parameters=False,
                              num_samples=1,
                              sum_over_latents=True,
                              summary=True):
  """Gumbel-Softmax discrete bottleneck.

  Args:
    dist: Distances between encoder outputs and codebook entries, to be used as
      categorical logits. A float Tensor of shape [batch_size, latent_size,
      code_size].
    vector_quantizer: An instance of the VectorQuantizer class.
    temperature: Temperature parameter used for Gumbel-Softmax distribution.
    num_iaf_flows: Number of inverse-autoregressive flows to perform.
    use_transformer_for_iaf_parameters: Whether to use a Transformer instead of
      a lower-triangular mat-mul to generate IAF parameters.
    num_samples: Number of categorical samples.
    sum_over_latents: Whether to sum over latent dimension when computing
      entropy.
    summary: Whether to log summary histogram.

  Returns:
    one_hot_assignments: Simplex-valued assignments sampled from categorical.
    neg_q_entropy: Negative entropy of categorical distribution.
  """
  latent_size = dist.shape[1]
  # TODO(vafa): Consider randomly setting high temperature to help training.
  one_hot_assignments = tfd.RelaxedOneHotCategorical(
      temperature=temperature,
      logits=-dist).sample(num_samples)
  one_hot_assignments = tf.clip_by_value(one_hot_assignments, 1e-6, 1-1e-6)

  # Approximate density with multinomial distribution.
  q_dist = tfd.Multinomial(total_count=1., logits=-dist)
  neg_q_entropy = q_dist.log_prob(one_hot_assignments)
  if summary:
    tf.summary.histogram("neg_q_entropy_0",
                         tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1),
                                    [-1]))

  # Perform IAF flows
  for flow_num in range(num_iaf_flows):
    with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE):
      # Pad the one_hot_assignments by zeroing out the first latent dimension
      # and shifting the rest down by one (and removing the last dimension).
      shifted_codes = shift_assignments(one_hot_assignments)
      if use_transformer_for_iaf_parameters:
        unconstrained_scale = iaf_scale_from_transformer(
            shifted_codes,
            vector_quantizer.code_size,
            name=str(flow_num))
      else:
        unconstrained_scale = iaf_scale_from_matmul(shifted_codes,
                                                    name=str(flow_num))
      # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias
      # evaluates to 1.
      initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes],
                                   INITIAL_SCALE_BIAS)
      scale_bias = tf.get_variable("scale_bias_" + str(flow_num),
                                   initializer=initial_scale_bias)

    one_hot_assignments, inverse_log_det_jacobian = iaf_flow(
        one_hot_assignments,
        unconstrained_scale,
        scale_bias,
        summary=summary)
    neg_q_entropy += inverse_log_det_jacobian

  if sum_over_latents:
    neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1)
  neg_q_entropy = tf.reduce_mean(neg_q_entropy)
  return one_hot_assignments, neg_q_entropy
예제 #15
0
 def _base_dist(self, n: IntTensorLike, p: TensorLike, *args, **kwargs):
     return tfd.Multinomial(total_count=n, probs=p, *args, **kwargs)
예제 #16
0
def categorical_bottleneck(dist,
                           vector_quantizer,
                           num_iaf_flows=0,
                           average_categorical_samples=True,
                           use_transformer_for_iaf_parameters=False,
                           sum_over_latents=True,
                           num_samples=1,
                           summary=True):
  """Implements soft EM bottleneck using averaged categorical samples.

  Args:
    dist: Distances between encoder outputs and codebook entries. Negative
      distances are used as categorical logits. A float Tensor of shape
      [batch_size, latent_size, code_size].
    vector_quantizer: An instance of the VectorQuantizer class.
    num_iaf_flows: Number of inverse autoregressive flows.
    average_categorical_samples: Whether to take the average of `num_samples'
      categorical samples as in Roy et al. or to use `num_samples` categorical
      samples to approximate gradient.
    use_transformer_for_iaf_parameters: Whether to use a Transformer instead of
      a lower-triangular mat-mul for generating IAF parameters.
    sum_over_latents: Whether to sum over latent dimension when computing
      entropy.
    num_samples: Number of categorical samples.
    summary: Whether to log entropy histogram.

  Returns:
    one_hot_assignments: Simplex-valued assignments sampled from categorical.
    neg_q_entropy: Negative entropy of categorical distribution.
  """
  latent_size = dist.shape[1]
  x_means_idx = tf.multinomial(
      logits=tf.reshape(-dist, [-1, FLAGS.num_codes]),
      num_samples=num_samples)
  one_hot_assignments = tf.one_hot(x_means_idx,
                                   depth=vector_quantizer.num_codes)
  one_hot_assignments = tf.reshape(
      one_hot_assignments,
      [-1, latent_size, num_samples, vector_quantizer.num_codes])
  if average_categorical_samples:
    summed_assignments = tf.reduce_sum(one_hot_assignments, axis=2)
    averaged_assignments = tf.reduce_mean(one_hot_assignments, axis=2)
    entropy_dist = tfd.Multinomial(
        total_count=tf.cast(num_samples, tf.float32),
        logits=-dist)
    neg_q_entropy = entropy_dist.log_prob(summed_assignments)
    one_hot_assignments = averaged_assignments[tf.newaxis, Ellipsis]
  else:
    one_hot_assignments = tf.transpose(one_hot_assignments, [2, 0, 1, 3])
    entropy_dist = tfd.OneHotCategorical(logits=-dist, dtype=tf.float32)
    neg_q_entropy = entropy_dist.log_prob(one_hot_assignments)

  if summary:
    tf.summary.histogram("neg_q_entropy_0",
                         tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1),
                                    [-1]))
  # Perform straight-through.
  class_probs = tf.nn.softmax(-dist[tf.newaxis, Ellipsis])
  one_hot_assignments = class_probs + tf.stop_gradient(
      one_hot_assignments - class_probs)

  # Perform IAF flows
  for flow_num in range(num_iaf_flows):
    with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE):
      # Pad the one_hot_assignments by zeroing out the first latent dimension
      # and shifting the rest down by one (and removing the last dimension).
      shifted_codes = shift_assignments(one_hot_assignments)
      if use_transformer_for_iaf_parameters:
        unconstrained_scale = iaf_scale_from_transformer(
            shifted_codes,
            vector_quantizer.code_size,
            name=str(flow_num))
      else:
        unconstrained_scale = iaf_scale_from_matmul(shifted_codes,
                                                    name=str(flow_num))
      # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias
      # is 1.
      initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes],
                                   INITIAL_SCALE_BIAS)
      scale_bias = tf.get_variable("scale_bias_" + str(flow_num),
                                   initializer=initial_scale_bias)

    # Since categorical is discrete, we don't need to add inverse log
    # determinant.
    one_hot_assignments, _ = iaf_flow(one_hot_assignments,
                                      unconstrained_scale,
                                      scale_bias,
                                      summary=summary)

  if sum_over_latents:
    neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1)
  neg_q_entropy = tf.reduce_mean(neg_q_entropy)

  return one_hot_assignments, neg_q_entropy
예제 #17
0
 def ident(self):
     return tfd.Multinomial(total_count=1,
                            probs=tf.math.softmax(self.gamma))
예제 #18
0
 def _init_distribution(conditions, **kwargs):
     total_count, probs = conditions["total_count"], conditions["probs"]
     return tfd.Multinomial(total_count=total_count, probs=probs, **kwargs)
예제 #19
0
파일: XClone.py 프로젝트: huangyh09/xclone
 def Z(self):
     """Variational posterior for cell assignment"""
     return tfd.Multinomial(total_count=1, logits=self.cell_logit)