Exemplo n.º 1
0
 def call(self, inputs, training=None):
     logits = self.logits(inputs)
     q = distributions.OneHotCategorical(
         logits=logits,
         dtype=tf.float32,
     )
     if self.beta > 0.0:
         kld = q.kl_divergence(
             distributions.OneHotCategorical(
                 logits=tf.zeros_like(logits),
                 dtype=tf.float32,
             ))
         self.add_loss(tf.reduce_mean(kld))
     return q
def sample_z(x, mu, hyperparams):
    # get proba dependent on likelihood of x in clusters
    tau2 = np.full(shape=mu.shape, fill_value=hyperparams['tau2'], dtype='float32')
    probability = tfd.Normal(loc=mu, scale=tau2).prob(x).numpy()
    if sum(probability) > 0: probability = probability/sum(probability)
    z = tfd.OneHotCategorical(probs=probability).sample().numpy()
    
    return z, probability
    def __init__(self, mixing_distribution, components, **kwargs):
        super().__init__(mixing_distribution, components, **kwargs)

        self._components_probs = tf.constant(
            tf.stack([component.probs for component in self.components], -2))
        self._components_distribution = tfd.OneHotCategorical(
            probs=self._components_probs, dtype=self.dtype)
        self.uniform_mask = False
Exemplo n.º 4
0
    def _compute_kl_loss(self, post_probs, prior_probs):
        """ Compute KL divergence between two OnehotCategorical Distributions

        Notes:
                KL[ Q(z_post) || P(z_prior) ]
                    Q(z_prior) := Q(z | h, o)
                    P(z_prior) := P(z | h)

        Scratch Impl.:
                qlogq = post_probs * tf.math.log(post_probs)
                qlogp = post_probs * tf.math.log(prior_probs)
                kl_div = tf.reduce_sum(qlogq - qlogp, [1, 2])

        Inputs:
            prior_probs (L, B, latent_dim, n_atoms)
            post_probs (L, B, latent_dim, n_atoms)
        """

        #: Add small value to prevent inf kl
        post_probs += 1e-5
        prior_probs += 1e-5

        #: KL Balancing: See 2.2 BEHAVIOR LEARNING Algorithm 2
        kl_div1 = tfd.kl_divergence(
            tfd.Independent(
                tfd.OneHotCategorical(probs=tf.stop_gradient(post_probs)),
                reinterpreted_batch_ndims=1),
            tfd.Independent(tfd.OneHotCategorical(probs=prior_probs),
                            reinterpreted_batch_ndims=1))

        kl_div2 = tfd.kl_divergence(
            tfd.Independent(tfd.OneHotCategorical(probs=post_probs),
                            reinterpreted_batch_ndims=1),
            tfd.Independent(
                tfd.OneHotCategorical(probs=tf.stop_gradient(prior_probs)),
                reinterpreted_batch_ndims=1))

        alpha = self.config.kl_alpha

        kl_loss = alpha * kl_div1 + (1. - alpha) * kl_div2

        #: Batch mean
        kl_loss = tf.reduce_mean(kl_loss)

        return kl_loss
Exemplo n.º 5
0
    def model_fn(self):
        # regression in latent space
        w = yield JDCRoot(
            Independent(
                tfd.Normal(loc=tf.zeros([self.num_factors, self.k]),
                           scale=tf.fill([self.num_factors, self.k], 10.0))))

        z_scale = yield JDCRoot(
            Independent(
                tfd.HalfCauchy(loc=tf.zeros([self.num_factors, self.k]),
                               scale=1.0)))

        F_test = yield JDCRoot(
            Independent(
                tfd.OneHotCategorical(logits=tf.zeros([
                    self.num_testing_samples, self.num_factors -
                    self.num_confounders
                ]))))

        F_full = tf.concat([tf.expand_dims(self.F, 0), F_test], axis=-2)

        z = yield Independent(
            tfd.Normal(loc=tf.matmul(F_full, w),
                       scale=tf.matmul(F_full, z_scale)))

        x_bias = yield JDCRoot(
            Independent(
                tfd.Normal(loc=tf.fill([self.num_features],
                                       np.float32(self.x_bias_loc0)),
                           scale=np.float32(self.x_bias_scale0))))

        # decoded log-expression space
        x_loc = x_bias + self.decoder(z) - self.sample_scales

        x_scale_concentration_c = yield JDCRoot(
            Independent(
                tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]),
                               scale=1.0)))

        x_scale_mode_c = yield JDCRoot(
            Independent(
                tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]),
                               scale=1.0)))

        weights = kernel_regression_weights(self.kernel_regression_bandwidth,
                                            x_bias, self.x_scale_hinges)

        x_scale = yield Independent(
            mean_variance_model(weights, x_scale_concentration_c,
                                x_scale_mode_c))

        # log expression distribution
        x = yield Independent(tfd.StudentT(df=1.0, loc=x_loc, scale=x_scale))

        if not self.use_point_estimates:
            rnaseq_reads = yield tfd.Independent(
                rnaseq_approx_likelihood_from_vars(self.vars, x))
Exemplo n.º 6
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.one_hot_categorical.OneHotCategorical(
             logits=self['logits'], probs=self['probs'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.OneHotCategorical(logits=self['logits'],
                                      probs=self['probs'])
Exemplo n.º 7
0
    def __init__(self, N, K, B, temperature=1.0, **kwargs):
        super().__init__(**kwargs)
        dtype = self.dtype

        probs = tf.one_hot(np.random.choice(np.arange(K), (N, B)),
                           K,
                           dtype=dtype)  # delta probs
        self.components = tfd.OneHotCategorical(probs=probs,
                                                dtype=dtype)  # N x B x K

        self.logits = tf.Variable(tf.random.normal((N, B, K), dtype=dtype),
                                  name="logits")
        self.temperature = temperature
    def decoder(topics: tf.Tensor) -> tfd.OneHotCategorical:
        """
        Map Tensor to a OneHotCategorical instance.

        :param topics: Tensor containing topic values
        :return: OneHotCategorical of topics
        """
        word_probabilities = tf.matmul(topics, topics_words)

        # The observations are bag of words and therefore not one-hot.
        # However, log_prob of OneHotCategorical computes the probability correctly in this case.
        return tfd.OneHotCategorical(probs=word_probabilities,
                                     name="bag_of_words")
def gmm(nb_clusters, hyperparams, batch_size):
    alpha = np.full(shape=nb_clusters, fill_value=hyperparams['alpha'], dtype='float32')
    
    theta = tfd.Dirichlet(concentration=alpha).sample() # assignments probability
    
    mu = tfd.Normal(loc=hyperparams['mu0'], scale=hyperparams['sigma0']).sample(nb_clusters) # centroids
    
    z = tfd.OneHotCategorical(probs=theta).sample(batch_size) # assignment indicators
    assignments = tf.argmax(z.numpy(), axis=1)
    
    means = tf.gather(mu, assignments) # mapping 
    stds = tf.fill(dims=batch_size, value=hyperparams['tau2'])
    x = tfd.Normal(loc=means, scale=stds).sample()

    return mu.numpy(), theta.numpy(), assignments.numpy(), x.numpy()
Exemplo n.º 10
0
def test_forward_mean():

    a0 = np.array([0.9, 0.08, 0.02])
    a = np.array([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]])
    e = np.array([[0.99, 0.01], [0.01, 0.99], [0.5, 0.5]])

    model = tfpd.HiddenMarkovModel(
        tfpd.Categorical(logits=tf.math.log(tf.convert_to_tensor(a0))),
        tfpd.Categorical(logits=tf.math.log(tf.convert_to_tensor(a))),
        tfpd.OneHotCategorical(logits=tf.math.log(tf.convert_to_tensor(e))), 5)

    tst_mean = model.mean()

    chk_mean = mue.hmm_mean(model, 5)

    assert np.allclose(tst_mean.numpy(), chk_mean.numpy())
Exemplo n.º 11
0
def get_prior(desired_shape, prior_type="normal"):
    """
    Args:
        size: int, the event size
        prior_type: string, the type of prior
    Returns: a prior distribution
    """
    if prior_type is "normal":
        shapes = tfd.Normal.param_shapes(desired_shape)
        loc, scale = tf.zeros(shapes["loc"]), tf.ones(shapes["scale"])
        return tfd.Normal(loc, scale), shapes
    elif "categorical" in prior_type:
        n = get_size(desired_shape)
        probs = [1 / n for _ in range(n)]
        if "one_hot" in prior_type:
            return tfd.OneHotCategorical(probs=probs), n
        else:
            return tfd.Categorical(probs=probs), n
    def sample_z_prior(self, h):

        x = self.dense_z_prior1(h)
        logits = self.dense_z_prior2(x)
        logits = tf.reshape(logits,
                            [logits.shape[0], self.latent_dim, self.n_atoms])

        z_probs = tf.nn.softmax(logits, axis=2)

        #: batch_shape=[batch_size] event_shape=[32, 32]
        dist = tfd.Independent(tfd.OneHotCategorical(probs=z_probs),
                               reinterpreted_batch_ndims=1)

        z = tf.cast(dist.sample(), tf.float32)

        #: Reparameterization trick for OneHotCategorcalDist
        z = z + z_probs - tf.stop_gradient(z_probs)

        return z, z_probs
Exemplo n.º 13
0
def encode(x,
           uln0,
           rln0,
           lln0,
           latent_length,
           latent_alphabet_size,
           alphabet_size,
           padded_data_length,
           transfer_mats,
           dtype=tf.float64,
           eps=1e-32):
    """First layer of encoder, using the MuE mean."""

    # Set initial sequence (replace inf with large number)
    vxln = tf.maximum(tf.math.log(x), -1e32)

    # Set insert biases to uniform distribution.
    vcln = -np.log(alphabet_size) * tf.ones_like(vxln)

    # Set deletion and insertion parameters.
    uln = tf.ones((padded_data_length, 2),
                  dtype=dtype) * (uln0 - tf.reduce_logsumexp(uln0))[None, :]
    rln = tf.ones((padded_data_length, 2),
                  dtype=dtype) * (rln0 - tf.reduce_logsumexp(rln0))[None, :]
    lln = lln0 - tf.reduce_logsumexp(lln0, axis=1, keepdims=True)

    # Build HiddenMarkovModel, with one-hot encoded output.
    a0_enc, a_enc, e_enc = make_hmm_params(vxln,
                                           vcln,
                                           uln,
                                           rln,
                                           lln,
                                           transfer_mats,
                                           eps=eps,
                                           dtype=dtype)

    hmm_enc = tfpd.HiddenMarkovModel(tfpd.Categorical(logits=a0_enc),
                                     tfpd.Categorical(logits=a_enc),
                                     tfpd.OneHotCategorical(logits=e_enc),
                                     latent_length)

    return hmm_mean(hmm_enc, latent_length)
Exemplo n.º 14
0
def RegressMuE(z, latent_dims, latent_length, latent_alphabet_size,
               alphabet_size, seq_len, transfer_mats,
               bt_scale, b0_scale, u_conc, r_conc, l_conc, dtype=tf.float32):

    """Regress MuE model."""

    # Factors.
    bt = Normal(0., bt_scale, sample_shape=[2, latent_dims, latent_length+1,
                                            latent_alphabet_size], name="bt")
    # Offset.
    b0 = Normal(0., b0_scale, sample_shape=[2, latent_length+1,
                                            latent_alphabet_size], name="b0")

    # Ancestral sequence.
    vxln = tf.einsum('j,jkl->kl', z, bt[0, :, :, :]) + b0[0, :, :]
    # Insert biases.
    vcln = tf.einsum('j,jkl->kl', z, bt[1, :, :, :]) + b0[1, :, :]

    # Assemble priors -- in this version, we use a Dirichlet.
    uc, rc, lc = get_prior_conc(
                    latent_length, latent_alphabet_size, alphabet_size,
                    u_conc, r_conc, l_conc, dtype=dtype)
    # Deletion probability.
    u = Dirichlet(uc, name="u")
    # Insertion probability.
    r = Dirichlet(rc, name="r")
    # Substitution probability.
    l = Dirichlet(lc, name="l")

    # Generate data from the MuE.
    a0, a, e = mue.make_hmm_params(
            vxln - tf.reduce_logsumexp(vxln, axis=1, keepdims=True),
            vcln - tf.reduce_logsumexp(vcln, axis=1, keepdims=True),
            tf.math.log(u), tf.math.log(r), tf.math.log(l),
            transfer_mats, eps=eps, dtype=dtype)
    x = HiddenMarkovModel(
            tfpd.Categorical(logits=a0), tfpd.Categorical(logits=a),
            tfpd.OneHotCategorical(logits=e), seq_len, name="x")

    return x
Exemplo n.º 15
0
    def design_matrix_model_fn(self):
        F_test = yield JDCRoot(
            Independent(
                tfd.OneHotCategorical(logits=tf.zeros([
                    self.num_testing_samples, self.num_factors -
                    self.num_confounders
                ]))))

        F_full = tf.concat([tf.expand_dims(self.F, 0), F_test], axis=-2)

        # F_confounders = yield JDCRoot(Independent(tfd.Normal(
        #     loc=tf.zeros([self.num_samples, self.num_confounders]),
        #     scale=1.0)))

        # F_full = tf.concat([F_full, F_confounders], axis=-1)

        if self.confounders is not None:
            F_full = tf.concat(
                [F_full, tf.expand_dims(self.confounders, 0)], axis=-1)

        F_full = tf.matmul(F_full, self.F_mask)

        return F_full
Exemplo n.º 16
0
def main(argv):
  del argv  # unused
  FLAGS.activation = getattr(tf.nn, FLAGS.activation)
  if tf.gfile.Exists(FLAGS.model_dir):
    tf.logging.warn("Deleting old log directory at {}".format(FLAGS.model_dir))
    tf.gfile.DeleteRecursively(FLAGS.model_dir)
  tf.gfile.MakeDirs(FLAGS.model_dir)

  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # TODO(b/113163167): Speed up and tune hyperparameters for Bernoulli MNIST.
    (images, _, handle,
     training_iterator, heldout_iterator) = build_input_pipeline(
         FLAGS.data_dir, FLAGS.batch_size, heldout_size=10000,
         mnist_type=FLAGS.mnist_type)

    encoder = Encoder(FLAGS.base_depth,
                      FLAGS.activation,
                      FLAGS.latent_size,
                      FLAGS.code_size)
    decoder = Decoder(FLAGS.base_depth,
                      FLAGS.activation,
                      FLAGS.latent_size * FLAGS.code_size,
                      IMAGE_SHAPE)
    vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size)

    codes = encoder(images)
    nearest_codebook_entries, one_hot_assignments, dist = vector_quantizer(
        codes)
    if FLAGS.bottleneck_type == "deterministic":
      one_hot_assignments = one_hot_assignments[tf.newaxis, Ellipsis]
      neg_q_entropy = 0.
      # Perform straight-through.
      class_probs = tf.nn.softmax(-dist[tf.newaxis, Ellipsis])
      one_hot_assignments = class_probs + tf.stop_gradient(
          one_hot_assignments - class_probs)
    elif FLAGS.bottleneck_type == "categorical":
      one_hot_assignments, neg_q_entropy = (
          categorical_bottleneck(dist,
                                 vector_quantizer,
                                 FLAGS.num_iaf_flows,
                                 FLAGS.average_categorical_samples,
                                 FLAGS.use_transformer_for_iaf_parameters,
                                 FLAGS.sum_over_latents,
                                 FLAGS.num_samples))
    elif FLAGS.bottleneck_type == "gumbel_softmax":
      one_hot_assignments, neg_q_entropy = (
          gumbel_softmax_bottleneck(dist,
                                    vector_quantizer,
                                    FLAGS.temperature,
                                    FLAGS.num_iaf_flows,
                                    FLAGS.use_transformer_for_iaf_parameters,
                                    FLAGS.num_samples,
                                    FLAGS.sum_over_latents,
                                    summary=True))
    else:
      raise ValueError("Unknown bottleneck type.")

    bottleneck_output = tf.reduce_sum(
        one_hot_assignments[Ellipsis, tf.newaxis] *
        tf.reshape(
            vector_quantizer.codebook,
            [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]),
        axis=3)

    decoder_distribution = decoder(bottleneck_output)
    reconstructed_images = decoder_distribution.mean()[0]  # get first sample

    reconstruction_loss = -tf.reduce_mean(decoder_distribution.log_prob(images))
    commitment_loss = tf.reduce_mean(
        tf.square(codes[tf.newaxis, Ellipsis] -
                  tf.stop_gradient(nearest_codebook_entries)))
    commitment_loss = add_ema_control_dependencies(
        vector_quantizer,
        tf.reduce_mean(one_hot_assignments, axis=0),  # reduce mean over samples
        codes,
        commitment_loss,
        FLAGS.decay)

    if FLAGS.use_autoregressive_prior:
      prior_fn = make_transformer_prior(FLAGS.num_codes, FLAGS.code_size)
    else:
      prior_fn = make_uniform_prior()
    prior_inputs = one_hot_assignments
    if FLAGS.stop_gradient_for_prior:
      prior_inputs = tf.stop_gradient(one_hot_assignments)
    prior_loss = make_prior_loss(prior_fn,
                                 prior_inputs,
                                 sum_over_latents=FLAGS.sum_over_latents)

    loss = (reconstruction_loss + FLAGS.beta * commitment_loss + prior_loss +
            FLAGS.entropy_scale * neg_q_entropy)

    if FLAGS.bottleneck_type == "deterministic":
      if not FLAGS.sum_over_latents:
        prior_loss = prior_loss * FLAGS.latent_size
      heldout_prior_loss = prior_loss
      heldout_reconstruction_loss = reconstruction_loss
      heldout_neg_q_entropy = tf.constant(0.)
    if FLAGS.bottleneck_type == "categorical":
      # To accurately evaluate heldout NLL, we need to sum over latent dimension
      # and use a single sample for the categorical (and not multinomial) prior.
      (heldout_one_hot_assignments,
       heldout_neg_q_entropy) = categorical_bottleneck(
           dist,
           vector_quantizer,
           FLAGS.num_iaf_flows,
           FLAGS.average_categorical_samples,
           FLAGS.use_transformer_for_iaf_parameters,
           sum_over_latents=True,
           num_samples=1,
           summary=False)
      heldout_bottleneck_output = tf.reduce_sum(
          heldout_one_hot_assignments[Ellipsis, tf.newaxis] *
          tf.reshape(
              vector_quantizer.codebook,
              [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]),
          axis=3)
      heldout_prior_loss = make_prior_loss(
          prior_fn,
          heldout_one_hot_assignments,
          sum_over_latents=True)
      heldout_decoder_distribution = decoder(heldout_bottleneck_output)
      heldout_reconstruction_loss = -tf.reduce_mean(
          heldout_decoder_distribution.log_prob(images))
    elif FLAGS.bottleneck_type == "gumbel_softmax":
      num_test_samples = 1
      heldout_q_dist = tfd.OneHotCategorical(logits=-dist, dtype=tf.float32)
      heldout_one_hot_assignments = heldout_q_dist.sample(num_test_samples)
      heldout_neg_q_entropy = heldout_q_dist.log_prob(
          heldout_one_hot_assignments)
      for flow_num in range(FLAGS.num_iaf_flows):
        with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE):
          shifted_codes = shift_assignments(heldout_one_hot_assignments)
          scale_bias = tf.get_variable("scale_bias_" + str(flow_num))
          if FLAGS.use_transformer_for_iaf_parameters:
            unconstrained_scale = iaf_scale_from_transformer(
                shifted_codes,
                FLAGS.code_size,
                name=str(flow_num))
          else:
            unconstrained_scale = iaf_scale_from_matmul(shifted_codes,
                                                        name=str(flow_num))
        # Don't need to add inverse log determinant jacobian since samples are
        # discrete (no change in volume when bijecting discrete variables).
        heldout_one_hot_assignments, _ = iaf_flow(heldout_one_hot_assignments,
                                                  unconstrained_scale,
                                                  scale_bias,
                                                  summary=False)
      heldout_neg_q_entropy = tf.reduce_sum(
          tf.reshape(heldout_neg_q_entropy, [-1, FLAGS.latent_size]), axis=1)
      heldout_neg_q_entropy = tf.reduce_mean(heldout_neg_q_entropy)
      heldout_nearest_codebook_entries = tf.reduce_sum(
          heldout_one_hot_assignments[Ellipsis, tf.newaxis] *
          tf.reshape(
              vector_quantizer.codebook,
              [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]),
          axis=3)
      # We still evaluate the prior on the transformed samples. But in order
      # for this categorical distribution to be valid, we have to binarize.
      heldout_one_hot_assignments = tf.one_hot(
          tf.argmax(heldout_one_hot_assignments, axis=-1),
          depth=FLAGS.num_codes)

      heldout_prior_loss = make_prior_loss(
          prior_fn,
          heldout_one_hot_assignments,
          sum_over_latents=True)
      heldout_decoder_distribution = decoder(heldout_nearest_codebook_entries)
      heldout_reconstruction_loss = -tf.reduce_mean(
          heldout_decoder_distribution.log_prob(images))

    marginal_nll = (heldout_prior_loss + heldout_reconstruction_loss +
                    heldout_neg_q_entropy)

    tf.summary.scalar("losses/total_loss", loss)
    tf.summary.scalar("losses/neg_q_entropy_loss",
                      neg_q_entropy * FLAGS.entropy_scale)
    tf.summary.scalar("losses/reconstruction_loss", reconstruction_loss)
    tf.summary.scalar("losses/prior_loss", prior_loss)
    tf.summary.scalar("losses/commitment_loss", FLAGS.beta * commitment_loss)

    tf.summary.scalar("heldout/neg_q_entropy_loss",
                      heldout_neg_q_entropy,
                      collections=["heldout"])
    tf.summary.scalar("heldout/reconstruction_loss",
                      heldout_reconstruction_loss,
                      collections=["heldout"])
    tf.summary.scalar("heldout/prior_loss",
                      heldout_prior_loss,
                      collections=["heldout"])

    # Decode 10 samples from prior for visualization.
    if FLAGS.use_autoregressive_prior:
      assignments = tf.zeros([10, 1, FLAGS.num_codes])
      # Decode autoregressively.
      for d in range(FLAGS.latent_size):
        logits = prior_fn(assignments).logits_parameter()
        latent_dim_logit = logits[0, :, tf.newaxis, d, :]
        sample = tfd.OneHotCategorical(
            logits=latent_dim_logit, dtype=tf.float32).sample()
        assignments = tf.concat([assignments, sample], axis=1)
      assignments = assignments[:, 1:, :]
    else:
      logits = tf.zeros([10, FLAGS.latent_size, FLAGS.num_codes])
      assignments = tf.reduce_mean(tfd.OneHotCategorical(
          logits=logits, dtype=tf.float32).sample(1), axis=0)

    prior_samples = tf.reduce_sum(
        assignments[Ellipsis, tf.newaxis] *
        tf.reshape(vector_quantizer.codebook,
                   [1, 1, FLAGS.num_codes, FLAGS.code_size]),
        axis=2)
    prior_samples = prior_samples[tf.newaxis, Ellipsis]
    decoded_distribution_given_random_prior = decoder(prior_samples)
    random_images = decoded_distribution_given_random_prior.mean()[0]

    # Save summaries.
    tf.summary.image("train_inputs",
                     tf.cast(images, tf.float32),
                     max_outputs=10,
                     collections=["train_image"])
    tf.summary.image("train_reconstructions",
                     reconstructed_images,
                     max_outputs=10,
                     collections=["train_image"])
    tf.summary.image("train_prior_samples",
                     tf.cast(random_images, tf.float32),
                     max_outputs=10,
                     collections=["train_image"])
    tf.summary.image("heldout_inputs",
                     tf.cast(images, tf.float32),
                     max_outputs=10,
                     collections=["heldout_image"])
    tf.summary.image("heldout_reconstructions",
                     reconstructed_images,
                     max_outputs=10,
                     collections=["heldout_image"])
    tf.summary.scalar("heldout/marginal_loss",
                      marginal_nll,
                      collections=["heldout"])

    # Perform inference by minimizing the loss function.
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

    if FLAGS.num_iaf_flows > 0:
      encoder_variables = encoder.variables
      iaf_variables = tf.get_collection(
          tf.GraphKeys.TRAINABLE_VARIABLES, scope="iaf_variables")
      grads_and_vars = optimizer.compute_gradients(loss)
      grads_and_vars_except_encoder = [
          x for x in grads_and_vars if x[1] not in encoder_variables]
      grads_and_vars_except_iaf = [
          x for x in grads_and_vars if x[1] not in iaf_variables]

      def train_op_except_iaf():
        return optimizer.apply_gradients(
            grads_and_vars_except_iaf,
            global_step=global_step)

      def train_op_except_encoder():
        return optimizer.apply_gradients(
            grads_and_vars_except_encoder,
            global_step=global_step)

      def train_op_all():
        return optimizer.apply_gradients(
            grads_and_vars,
            global_step=global_step)

      if FLAGS.stop_training_encoder_after_startup:
        after_startup_train_op = train_op_except_encoder
      else:
        after_startup_train_op = train_op_all

      train_op = tf.cond(
          global_step < FLAGS.iaf_startup_steps,
          true_fn=train_op_except_iaf,
          false_fn=after_startup_train_op)
    else:
      train_op = optimizer.minimize(loss, global_step=global_step)

    summary = tf.summary.merge_all()
    heldout_summary = tf.summary.merge_all(key="heldout")
    train_image_summary = tf.summary.merge_all(key="train_image")
    heldout_image_summary = tf.summary.merge_all(key="heldout_image")
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session(FLAGS.master) as sess:
      summary_writer = tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
      sess.run(init)

      # Run the training loop.
      train_handle = sess.run(training_iterator.string_handle())
      heldout_handle = sess.run(heldout_iterator.string_handle())
      for step in range(FLAGS.max_steps):
        start_time = time.time()
        _, loss_value = sess.run([train_op, loss],
                                 feed_dict={handle: train_handle})
        duration = time.time() - start_time
        if step % 100 == 0:
          marginal_nll_val = sess.run(marginal_nll,
                                      feed_dict={handle: heldout_handle})
          print("Step: {:>3d} Training Loss: {:.3f} Heldout NLL: {:.3f} "
                "({:.3f} sec)".format(step, loss_value, marginal_nll_val,
                                      duration))

          # Update the events file.
          summary_str = sess.run(summary, feed_dict={handle: train_handle})
          summary_writer.add_summary(summary_str, step)
          summary_writer.flush()

          summary_str_heldout = sess.run(heldout_summary,
                                         feed_dict={handle: heldout_handle})
          summary_writer.add_summary(summary_str_heldout, step)
          summary_writer.flush()

        # Periodically save a checkpoint and visualize model progress.
        if (step + 1) % FLAGS.viz_steps == 0:
          summary_str_train_images = sess.run(
              train_image_summary,
              feed_dict={handle: train_handle})
          summary_str_heldout_images = sess.run(
              heldout_image_summary,
              feed_dict={handle: heldout_handle})
          summary_writer.add_summary(summary_str_train_images, step)
          summary_writer.add_summary(summary_str_heldout_images, step)
          checkpoint_file = os.path.join(FLAGS.model_dir, "model.ckpt")
          saver.save(sess, checkpoint_file, global_step=step)
Exemplo n.º 17
0
def categorical_bottleneck(dist,
                           vector_quantizer,
                           num_iaf_flows=0,
                           average_categorical_samples=True,
                           use_transformer_for_iaf_parameters=False,
                           sum_over_latents=True,
                           num_samples=1,
                           summary=True):
  """Implements soft EM bottleneck using averaged categorical samples.

  Args:
    dist: Distances between encoder outputs and codebook entries. Negative
      distances are used as categorical logits. A float Tensor of shape
      [batch_size, latent_size, code_size].
    vector_quantizer: An instance of the VectorQuantizer class.
    num_iaf_flows: Number of inverse autoregressive flows.
    average_categorical_samples: Whether to take the average of `num_samples'
      categorical samples as in Roy et al. or to use `num_samples` categorical
      samples to approximate gradient.
    use_transformer_for_iaf_parameters: Whether to use a Transformer instead of
      a lower-triangular mat-mul for generating IAF parameters.
    sum_over_latents: Whether to sum over latent dimension when computing
      entropy.
    num_samples: Number of categorical samples.
    summary: Whether to log entropy histogram.

  Returns:
    one_hot_assignments: Simplex-valued assignments sampled from categorical.
    neg_q_entropy: Negative entropy of categorical distribution.
  """
  latent_size = dist.shape[1]
  x_means_idx = tf.multinomial(
      logits=tf.reshape(-dist, [-1, FLAGS.num_codes]),
      num_samples=num_samples)
  one_hot_assignments = tf.one_hot(x_means_idx,
                                   depth=vector_quantizer.num_codes)
  one_hot_assignments = tf.reshape(
      one_hot_assignments,
      [-1, latent_size, num_samples, vector_quantizer.num_codes])
  if average_categorical_samples:
    summed_assignments = tf.reduce_sum(one_hot_assignments, axis=2)
    averaged_assignments = tf.reduce_mean(one_hot_assignments, axis=2)
    entropy_dist = tfd.Multinomial(
        total_count=tf.cast(num_samples, tf.float32),
        logits=-dist)
    neg_q_entropy = entropy_dist.log_prob(summed_assignments)
    one_hot_assignments = averaged_assignments[tf.newaxis, Ellipsis]
  else:
    one_hot_assignments = tf.transpose(one_hot_assignments, [2, 0, 1, 3])
    entropy_dist = tfd.OneHotCategorical(logits=-dist, dtype=tf.float32)
    neg_q_entropy = entropy_dist.log_prob(one_hot_assignments)

  if summary:
    tf.summary.histogram("neg_q_entropy_0",
                         tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1),
                                    [-1]))
  # Perform straight-through.
  class_probs = tf.nn.softmax(-dist[tf.newaxis, Ellipsis])
  one_hot_assignments = class_probs + tf.stop_gradient(
      one_hot_assignments - class_probs)

  # Perform IAF flows
  for flow_num in range(num_iaf_flows):
    with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE):
      # Pad the one_hot_assignments by zeroing out the first latent dimension
      # and shifting the rest down by one (and removing the last dimension).
      shifted_codes = shift_assignments(one_hot_assignments)
      if use_transformer_for_iaf_parameters:
        unconstrained_scale = iaf_scale_from_transformer(
            shifted_codes,
            vector_quantizer.code_size,
            name=str(flow_num))
      else:
        unconstrained_scale = iaf_scale_from_matmul(shifted_codes,
                                                    name=str(flow_num))
      # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias
      # is 1.
      initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes],
                                   INITIAL_SCALE_BIAS)
      scale_bias = tf.get_variable("scale_bias_" + str(flow_num),
                                   initializer=initial_scale_bias)

    # Since categorical is discrete, we don't need to add inverse log
    # determinant.
    one_hot_assignments, _ = iaf_flow(one_hot_assignments,
                                      unconstrained_scale,
                                      scale_bias,
                                      summary=summary)

  if sum_over_latents:
    neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1)
  neg_q_entropy = tf.reduce_mean(neg_q_entropy)

  return one_hot_assignments, neg_q_entropy
Exemplo n.º 18
0
 def prior_fn(codes):
   logits = tf.zeros_like(codes)
   prior_dist = tfd.OneHotCategorical(logits=logits, dtype=tf.float32)
   return prior_dist
Exemplo n.º 19
0
    def update_actor_critic(self, trajectory, batch_size=512, strategy="PPO"):
        """ Actor-Critic update using PPO & Generalized Advantage Estimator
        """

        #: adv: (L*B, 1)
        targets, weights = self.compute_target(trajectory['state'],
                                               trajectory['reward'],
                                               trajectory['next_state'],
                                               trajectory['discount'])
        #: (H, L*B, ...)
        states = trajectory['state']
        selected_actions = trajectory['action']

        N = weights.shape[0] * weights.shape[1]
        states = tf.reshape(states, [N, -1])
        selected_actions = tf.reshape(selected_actions, [N, -1])
        targets = tf.reshape(targets, [N, -1])
        weights = tf.reshape(weights, [N, -1])
        _, old_action_probs = self.policy(states)
        old_logprobs = tf.math.log(old_action_probs + 1e-5)

        for _ in range(10):

            indices = np.random.choice(N, batch_size)

            _states = tf.gather(states, indices)
            _targets = tf.gather(targets, indices)
            _selected_actions = tf.gather(selected_actions, indices)
            _old_logprobs = tf.gather(old_logprobs, indices)
            _weights = tf.gather(weights, indices)

            #: Update value network
            with tf.GradientTape() as tape1:
                v_pred = self.value(_states)
                advantages = _targets - v_pred
                value_loss = 0.5 * tf.square(advantages)
                discount_value_loss = tf.reduce_mean(value_loss * _weights)

            grads = tape1.gradient(discount_value_loss,
                                   self.value.trainable_variables)
            self.value_optimizer.apply_gradients(
                zip(grads, self.value.trainable_variables))

            #: Update policy network
            if strategy == "VanillaPG":

                with tf.GradientTape() as tape2:
                    _, action_probs = self.policy(_states)
                    action_probs += 1e-5

                    selected_action_logprobs = tf.reduce_sum(
                        _selected_actions * tf.math.log(action_probs),
                        axis=1,
                        keepdims=True)

                    objective = selected_action_logprobs * advantages

                    dist = tfd.Independent(
                        tfd.OneHotCategorical(probs=action_probs),
                        reinterpreted_batch_ndims=0)
                    ent = dist.entropy()

                    policy_loss = objective + self.config.ent_scale * ent[...,
                                                                          None]
                    policy_loss *= -1
                    discounted_policy_loss = tf.reduce_mean(policy_loss *
                                                            _weights)

            elif strategy == "PPO":

                with tf.GradientTape() as tape2:
                    _, action_probs = self.policy(_states)
                    action_probs += 1e-5
                    new_logprobs = tf.math.log(action_probs)

                    ratio = tf.reduce_sum(_selected_actions *
                                          tf.exp(new_logprobs - _old_logprobs),
                                          axis=1,
                                          keepdims=True)
                    ratio_clipped = tf.clip_by_value(ratio, 0.9, 1.1)

                    obj_unclipped = ratio * advantages
                    obj_clipped = ratio_clipped * advantages

                    objective = tf.minimum(obj_unclipped, obj_clipped)

                    dist = tfd.Independent(
                        tfd.OneHotCategorical(probs=action_probs),
                        reinterpreted_batch_ndims=0)
                    ent = dist.entropy()

                    policy_loss = objective + self.config.ent_scale * ent[...,
                                                                          None]
                    policy_loss *= -1
                    discounted_policy_loss = tf.reduce_mean(policy_loss *
                                                            _weights)

            grads = tape2.gradient(discounted_policy_loss,
                                   self.policy.trainable_variables)
            self.policy_optimizer.apply_gradients(
                zip(grads, self.policy.trainable_variables))

        info = {
            "policy_loss": tf.reduce_mean(policy_loss),
            "objective": tf.reduce_mean(objective),
            "actor_entropy": tf.reduce_mean(ent),
            "value_loss": tf.reduce_mean(value_loss),
            "target_0": tf.reduce_mean(_targets),
        }

        return info
 def sample(self, feat):
     logits, probs = self(feat)
     dist = tfd.Independent(tfd.OneHotCategorical(probs=probs),
                            reinterpreted_batch_ndims=0)
     actions_onehot = dist.sample()
     return actions_onehot
Exemplo n.º 21
0
    def construct_model(self, learning_rate=None):

        if learning_rate is None:
            learning_rate = self.learning_rate

        with self.graph.as_default():

            self.sess.close()
            self.sess = tf.compat.v1.InteractiveSession()
            self.sess.as_default()

            self.x = tf.convert_to_tensor(self.rescaled_features, dtype=tf.float32)
            self.y = tf.convert_to_tensor(self.targets, dtype=tf.float32)

            # construct precisness
            self.tau_rescaling = np.zeros((self.num_obs, self.bnn_output_size))
            kernel_ranges      = self.config.kernel_ranges
            for obs_index in range(self.num_obs):
                self.tau_rescaling[obs_index] += kernel_ranges
            self.tau_rescaling = self.tau_rescaling**2

            # construct weight and bias shapes
            activations = [tf.nn.tanh]
            weight_shapes, bias_shapes = [[self.feature_size, self.hidden_shape]], [[self.hidden_shape]]
            for _ in range(1, self.num_layers - 1):
                activations.append(tf.nn.tanh)
                weight_shapes.append([self.hidden_shape, self.hidden_shape])
                bias_shapes.append([self.hidden_shape])
            activations.append(lambda x: x)
            weight_shapes.append([self.hidden_shape, self.bnn_output_size])
            bias_shapes.append([self.bnn_output_size])

            # ---------------
            # construct prior
            # ---------------
            self.prior_layer_outputs = [self.x]
            self.priors = {}
            for layer_index in range(self.num_layers):
                weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index]
                activation = activations[layer_index]

                weight = tfd.Normal(loc=tf.zeros(weight_shape) + self.weight_loc, scale=tf.zeros(weight_shape) + self.weight_scale)
                bias = tfd.Normal(loc=tf.zeros(bias_shape) + self.bias_loc, scale=tf.zeros(bias_shape) + self.bias_scale)
                self.priors['weight_%d' % layer_index] = weight
                self.priors['bias_%d' % layer_index] = bias

                prior_layer_output = activation(tf.matmul(self.prior_layer_outputs[-1], weight.sample()) + bias.sample())
                self.prior_layer_outputs.append(prior_layer_output)

            self.prior_bnn_output = self.prior_layer_outputs[-1]
            # draw precisions from gamma distribution
            self.prior_tau_normed = tfd.Gamma(
                            12*(self.num_obs/self.frac_feas)**2 + tf.zeros((self.num_obs, self.bnn_output_size)),
                            tf.ones((self.num_obs, self.bnn_output_size)),
                        )
            self.prior_tau        = self.prior_tau_normed.sample() / self.tau_rescaling
            self.prior_scale      = tfd.Deterministic(1. / tf.sqrt(self.prior_tau))

            # -------------------
            # construct posterior
            # -------------------
            self.post_layer_outputs = [self.x]
            self.posteriors = {}
            for layer_index in range(self.num_layers):
                weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index]
                activation = activations[layer_index]

                weight = tfd.Normal(loc=tf.Variable(tf.random.normal(weight_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(weight_shape))))
                bias = tfd.Normal(loc=tf.Variable(tf.random.normal(bias_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(bias_shape))))

                self.posteriors['weight_%d' % layer_index] = weight
                self.posteriors['bias_%d' % layer_index] = bias

                post_layer_output = activation(tf.matmul(self.post_layer_outputs[-1], weight.sample()) + bias.sample())
                self.post_layer_outputs.append(post_layer_output)

            self.post_bnn_output = self.post_layer_outputs[-1]
            self.post_tau_normed = tfd.Gamma(
                                12*(self.num_obs/self.frac_feas)**2 + tf.Variable(tf.zeros((self.num_obs, self.bnn_output_size))),
                                tf.nn.softplus(tf.Variable(tf.ones((self.num_obs, self.bnn_output_size)))),
                            )
            self.post_tau        = self.post_tau_normed.sample() / self.tau_rescaling
            self.post_sqrt_tau   = tf.sqrt(self.post_tau)
            self.post_scale	     = tfd.Deterministic(1. / self.post_sqrt_tau)

            # map bnn output to prediction
            post_kernels = {}
            targets_dict = {}
            inferences = []

            target_element_index = 0
            kernel_element_index = 0

            while kernel_element_index < len(self.config.kernel_names):

                kernel_type = self.config.kernel_types[kernel_element_index]
                kernel_size = self.config.kernel_sizes[kernel_element_index]

                feature_begin, feature_end = target_element_index, target_element_index + 1
                kernel_begin, kernel_end   = kernel_element_index, kernel_element_index + kernel_size

                prior_relevant = self.prior_bnn_output[:, kernel_begin: kernel_end]
                post_relevant  = self.post_bnn_output[:,  kernel_begin: kernel_end]

                if kernel_type == 'continuous':

                    target = self.y[:, kernel_begin: kernel_end]
                    lowers, uppers = self.config.kernel_lowers[kernel_begin: kernel_end], self.config.kernel_uppers[kernel_begin : kernel_end]

                    prior_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(prior_relevant) - 0.1) + lowers
                    post_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(post_relevant) - 0.1) + lowers

                    prior_predict = tfd.Normal(prior_support, self.prior_scale[:, kernel_begin: kernel_end].sample())
                    post_predict = tfd.Normal(post_support,  self.post_scale[:,  kernel_begin: kernel_end].sample())

                    targets_dict[prior_predict] = target
                    post_kernels['param_%d' % target_element_index] = {
                        'loc':       tfd.Deterministic(post_support),
                        'sqrt_prec': tfd.Deterministic(self.post_sqrt_tau[:, kernel_begin: kernel_end]),
                        'scale':     tfd.Deterministic(self.post_scale[:, kernel_begin: kernel_end].sample())}

                    inference = {'pred': post_predict, 'target': target}
                    inferences.append(inference)

                elif kernel_type in ['categorical', 'discrete']:
                    target = tf.cast(self.y[:, kernel_begin: kernel_end], tf.int32)

                    prior_temperature = 0.5 + 10.0 / (self.num_obs / self.frac_feas)
                    #prior_temperature = 1.0
                    post_temperature = prior_temperature

                    prior_support = prior_relevant
                    post_support = post_relevant

                    prior_predict_relaxed = tfd.RelaxedOneHotCategorical(prior_temperature, prior_support)
                    prior_predict = tfd.OneHotCategorical(probs=prior_predict_relaxed.sample())

                    post_predict_relaxed = tfd.RelaxedOneHotCategorical(post_temperature, post_support)
                    post_predict = tfd.OneHotCategorical(probs=post_predict_relaxed.sample())

                    targets_dict[prior_predict] = target
                    post_kernels['param_%d' % target_element_index] = {'probs': post_predict_relaxed}

                    inference = {'pred': post_predict, 'target': target}
                    inferences.append(inference)

                    '''
                        Temperature annealing schedule:
                            - temperature of 100   yields 1e-2 deviation from uniform
                            - temperature of  10   yields 1e-1 deviation from uniform
                            - temperature of   1   yields *almost* perfect agreement with expectation
                            - temperature of   0.1 yields perfect agreement with expectation
                    '''

                else:
                    GryffinUnknownSettingsError(f'did not understand kernel type: {kernel_type}')

                target_element_index += 1
                kernel_element_index += kernel_size

            self.post_kernels = post_kernels
            self.targets_dict = targets_dict

            self.loss = 0.
            for inference in inferences:
                self.loss += - tf.reduce_sum(inference['pred'].log_prob(inference['target']))

            self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
            self.train_op = self.optimizer.minimize(self.loss)

            tf.compat.v1.global_variables_initializer().run()
Exemplo n.º 22
0
 def __init__(self, temp, logits=None, probs=None, dtype=None):
     self._sample_dtype = dtype or prec.global_policy().compute_dtype
     self._exact = tfd.OneHotCategorical(logits=logits, probs=probs)
     super().__init__(temp, logits=logits, probs=probs)
Exemplo n.º 23
0
 def sample_posterior_zs(self, r, nsamples=1000):
     posterior_z = tfd.OneHotCategorical(probs=r, dtype=r.dtype)
     zs = posterior_z.sample(nsamples) 
     return posterior_z, zs