Ejemplo n.º 1
0
                def _loop_body(t, ta_z_prior, ta_z_post, ta_kl):
                    """
                        iter body. iter over trading days.
                    """
                    with tf.variable_scope('iter_body', reuse=tf.AUTO_REUSE):

                        init = lambda: tf.random_normal(shape=[self.batch_size, self.z_size], name='z_post_t_1')
                        subsequent = lambda: tf.reshape(ta_z_post.read(t-1), [self.batch_size, self.z_size])

                        z_post_t_1 = tf.cond(t >= 1, subsequent, init)

                        with tf.variable_scope('h_z_prior'):
                            h_z_prior_t = self._linear([x[t], h_s[t], z_post_t_1], self.z_size, 'tanh')
                        with tf.variable_scope('z_prior'):
                            z_prior_t, z_prior_t_pdf = self._z(h_z_prior_t, is_prior=True)

                        with tf.variable_scope('h_z_post'):
                            h_z_post_t = self._linear([x[t], h_s[t], y_[t], z_post_t_1], self.z_size, 'tanh')
                        with tf.variable_scope('z_post'):
                            z_post_t, z_post_t_pdf = self._z(h_z_post_t, is_prior=False)

                    kl_t = ds.kl_divergence(z_post_t_pdf, z_prior_t_pdf)  # batch_size * z_size

                    ta_z_prior = ta_z_prior.write(t, z_prior_t)  # write: batch_size * z_size
                    ta_z_post = ta_z_post.write(t, z_post_t)  # write: batch_size * z_size
                    ta_kl = ta_kl.write(t, kl_t)  # write: batch_size * 1

                    return t + 1, ta_z_prior, ta_z_post, ta_kl
Ejemplo n.º 2
0
def gumbel_reparmeterization(logits_z, tau, rnd_sample=None,
                             hard=True, eps=1e-9):
    '''
    The gumbel-softmax reparameterization
    '''
    latent_size = logits_z.get_shape().as_list()[1]

    # Prior
    p_z = d.OneHotCategorical(probs=tf.constant(1.0/latent_size,
                                                shape=[latent_size]))
    # p_z = d.RelaxedOneHotCategorical(probs=tf.constant(1.0/latent_size,
    #                                                    shape=[latent_size]),
    #                                  temperature=10.0)
    # p_z = 1.0 / latent_size
    # log_p_z = tf.log(p_z + eps)

    with st.value_type(st.SampleValue()):
        q_z = st.StochasticTensor(d.RelaxedOneHotCategorical(temperature=tau,
                                                             logits=logits_z))
        q_z_full = st.StochasticTensor(d.OneHotCategorical(logits=logits_z))

    reduce_index = [1] if len(logits_z.get_shape().as_list()) == 2 else [1, 2]
    kl = d.kl_divergence(q_z_full.distribution, p_z, allow_nan_stats=False)
    if len(shp(kl)) > 1:
        return [q_z, tf.reduce_sum(kl, reduce_index)]
    else:
        return [q_z, kl]
Ejemplo n.º 3
0
  def build_loss_and_gradients(self, var_list):
      
      cof = tf.constant(self.alpha,tf.float32)
      cof2 = tf.constant(self.alpha+1,tf.float32)
      M= tf.constant(self.size,tf.float32)
      N= tf.constant(self.tot,tf.float32)
      
      p_log_prob = [0.0] * self.n_samples
      q_log_prob = [0.0] * self.n_samples
      base_scope = tf.get_default_graph().unique_name("inference") + '/'
      for s in range(self.n_samples):
    # Form dictionary in order to replace conditioning on prior or
    # observed variable with conditioning on a specific value.
        scope = base_scope + tf.get_default_graph().unique_name("sample")
        dict_swap = {}
        for x, qx in six.iteritems(self.data):
            if isinstance(x, RandomVariable):
                if isinstance(qx, RandomVariable):
                    qx_copy = copy(qx, scope=scope)
                    dict_swap[x] = qx_copy.value()
                else:
                    dict_swap[x] = qx

        for z, qz in six.iteritems(self.latent_vars):
          # Copy q(z) to obtain new set of posterior samples.
          qz_copy = copy(qz, scope=scope)
          dict_swap[z] = qz_copy.value()
          q_log_prob[s] += tf.reduce_sum(
              self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z]))
    
        for z in six.iterkeys(self.latent_vars):
          z_copy = copy(z, dict_swap, scope=scope)
          q_log_prob[s] -= tf.reduce_sum(
              self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z]))
    
        for x in six.iterkeys(self.data):
          if isinstance(x, RandomVariable):
            x_copy = copy(x, dict_swap, scope=scope)
            p_log_prob[s] += cof2/cof*tf.reduce_sum(
                tf.exp( x_copy.log_prob(dict_swap[x])*cof))*N/M-tf.exp(tf.reduce_logsumexp(tf.log((x_copy.mean())**cof2+(1-x_copy.mean())**cof2)))*N/M
            
      kl_penalty = tf.reduce_sum([
          self.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z))
          for z, qz in six.iteritems(self.latent_vars)])

      p_log_prob = tf.reduce_mean(p_log_prob)
      q_log_prob = tf.reduce_mean(q_log_prob)
      
      if self.logging:
        tf.summary.scalar("loss/p_log_prob", p_log_prob,
                          collections=[self._summary_key])
        tf.summary.scalar("loss/q_log_prob", q_log_prob,
                          collections=[self._summary_key])
    
      
      loss = -(p_log_prob - q_log_prob)
      grads = tf.gradients(loss, var_list)
      
      grads_and_vars = list(zip(grads, var_list))
      return loss, grads_and_vars
Ejemplo n.º 4
0
def build_reparam_kl_loss_and_gradients(inference, var_list):
  """Build loss function. Its automatic differentiation
  is a stochastic gradient of

  .. math::

    -\\text{ELBO} =  - ( \mathbb{E}_{q(z; \lambda)} [ \log p(x \mid z) ]
          + \\text{KL}(q(z; \lambda) \| p(z)) )

  based on the reparameterization trick (Kingma and Welling, 2014).

  It assumes the KL is analytic.

  Computed by sampling from $q(z;\lambda)$ and evaluating the
  expectation using Monte Carlo sampling.
  """
  p_log_lik = [0.0] * inference.n_samples
  base_scope = tf.get_default_graph().unique_name("inference") + '/'
  for s in range(inference.n_samples):
    # Form dictionary in order to replace conditioning on prior or
    # observed variable with conditioning on a specific value.
    scope = base_scope + tf.get_default_graph().unique_name("sample")
    dict_swap = {}
    for x, qx in six.iteritems(inference.data):
      if isinstance(x, RandomVariable):
        if isinstance(qx, RandomVariable):
          qx_copy = copy(qx, scope=scope)
          dict_swap[x] = qx_copy.value()
        else:
          dict_swap[x] = qx

    for z, qz in six.iteritems(inference.latent_vars):
      # Copy q(z) to obtain new set of posterior samples.
      qz_copy = copy(qz, scope=scope)
      dict_swap[z] = qz_copy.value()

    for x in six.iterkeys(inference.data):
      if isinstance(x, RandomVariable):
        x_copy = copy(x, dict_swap, scope=scope)
        p_log_lik[s] += tf.reduce_sum(
            inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))

  p_log_lik = tf.reduce_mean(p_log_lik)

  kl_penalty = tf.reduce_sum([
      inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z))
      for z, qz in six.iteritems(inference.latent_vars)])

  if inference.logging:
    tf.summary.scalar("loss/p_log_lik", p_log_lik,
                      collections=[inference._summary_key])
    tf.summary.scalar("loss/kl_penalty", kl_penalty,
                      collections=[inference._summary_key])

  loss = -(p_log_lik - kl_penalty)

  grads = tf.gradients(loss, var_list)
  grads_and_vars = list(zip(grads, var_list))
  return loss, grads_and_vars
def build_reparam_kl_loss_and_gradients(inference, var_list):
  """Build loss function. Its automatic differentiation
  is a stochastic gradient of

  .. math::

    -\\text{ELBO} =  - ( \mathbb{E}_{q(z; \lambda)} [ \log p(x \mid z) ]
          + \\text{KL}(q(z; \lambda) \| p(z)) )

  based on the reparameterization trick [@kingma2014auto].

  It assumes the KL is analytic.

  Computed by sampling from $q(z;\lambda)$ and evaluating the
  expectation using Monte Carlo sampling.
  """
  p_log_lik = [0.0] * inference.n_samples
  base_scope = tf.get_default_graph().unique_name("inference") + '/'
  for s in range(inference.n_samples):
    # Form dictionary in order to replace conditioning on prior or
    # observed variable with conditioning on a specific value.
    scope = base_scope + tf.get_default_graph().unique_name("sample")
    dict_swap = {}
    for x, qx in six.iteritems(inference.data):
      if isinstance(x, RandomVariable):
        if isinstance(qx, RandomVariable):
          qx_copy = copy(qx, scope=scope)
          dict_swap[x] = qx_copy.value()
        else:
          dict_swap[x] = qx

    for z, qz in six.iteritems(inference.latent_vars):
      # Copy q(z) to obtain new set of posterior samples.
      qz_copy = copy(qz, scope=scope)
      dict_swap[z] = qz_copy.value()

    for x in six.iterkeys(inference.data):
      if isinstance(x, RandomVariable):
        x_copy = copy(x, dict_swap, scope=scope)
        p_log_lik[s] += tf.reduce_sum(
            inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))

  p_log_lik = tf.reduce_mean(p_log_lik)

  kl_penalty = tf.reduce_sum([
      tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z))
      for z, qz in six.iteritems(inference.latent_vars)])

  if inference.logging:
    tf.summary.scalar("loss/p_log_lik", p_log_lik,
                      collections=[inference._summary_key])
    tf.summary.scalar("loss/kl_penalty", kl_penalty,
                      collections=[inference._summary_key])

  loss = -(p_log_lik - kl_penalty)

  grads = tf.gradients(loss, var_list)
  grads_and_vars = list(zip(grads, var_list))
  return loss, grads_and_vars
Ejemplo n.º 6
0
 def sample_qH(self, H):
     h_mu = H[:, :self.dim_h]
     h_var = tf.exp(H[:, self.dim_h:])
     qh = dist.Normal(h_mu, tf.sqrt(h_var))
     ph = dist.Normal(tf.zeros_like(h_mu), tf.ones_like(h_var))
     kl_h = dist.kl_divergence(qh, ph)
     h_sample = qh.sample()
     return h_sample, kl_h
Ejemplo n.º 7
0
def build_score_kl_loss_and_gradients(inference, var_list):
  """Build loss function and gradients based on the score function
  estimator (Paisley et al., 2012).

  It assumes the KL is analytic.

  Computed by sampling from $q(z;\lambda)$ and evaluating the
  expectation using Monte Carlo sampling.
  """
  p_log_lik = [0.0] * inference.n_samples
  q_log_prob = [0.0] * inference.n_samples
  base_scope = tf.get_default_graph().unique_name("inference") + '/'
  for s in range(inference.n_samples):
    # Form dictionary in order to replace conditioning on prior or
    # observed variable with conditioning on a specific value.
    scope = base_scope + tf.get_default_graph().unique_name("sample")
    dict_swap = {}
    for x, qx in six.iteritems(inference.data):
      if isinstance(x, RandomVariable):
        if isinstance(qx, RandomVariable):
          qx_copy = copy(qx, scope=scope)
          dict_swap[x] = qx_copy.value()
        else:
          dict_swap[x] = qx

    for z, qz in six.iteritems(inference.latent_vars):
      # Copy q(z) to obtain new set of posterior samples.
      qz_copy = copy(qz, scope=scope)
      dict_swap[z] = qz_copy.value()
      q_log_prob[s] += tf.reduce_sum(
          inference.scale.get(z, 1.0) *
          qz_copy.log_prob(tf.stop_gradient(dict_swap[z])))

    for x in six.iterkeys(inference.data):
      if isinstance(x, RandomVariable):
        x_copy = copy(x, dict_swap, scope=scope)
        p_log_lik[s] += tf.reduce_sum(
            inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))

  p_log_lik = tf.stack(p_log_lik)
  q_log_prob = tf.stack(q_log_prob)

  kl_penalty = tf.reduce_sum([
      inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z))
      for z, qz in six.iteritems(inference.latent_vars)])

  if inference.logging:
    tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik),
                      collections=[inference._summary_key])
    tf.summary.scalar("loss/kl_penalty", kl_penalty,
                      collections=[inference._summary_key])

  loss = -(tf.reduce_mean(p_log_lik) - kl_penalty)
  grads = tf.gradients(
      -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty),
      var_list)
  grads_and_vars = list(zip(grads, var_list))
  return loss, grads_and_vars
Ejemplo n.º 8
0
  def build_loss_and_gradients(self, var_list):
      
      cof = tf.constant(self.alpha,tf.float32)
      cof2 = tf.constant(self.alpha+1,tf.float32)
      M= tf.constant(self.size,tf.float32)
      N= tf.constant(self.tot,tf.float32)
      
      p_log_prob = [0.0] * self.n_samples
      q_log_prob = [0.0] * self.n_samples
      base_scope = tf.get_default_graph().unique_name("inference") + '/'
      for s in range(self.n_samples):
    # Form dictionary in order to replace conditioning on prior or
    # observed variable with conditioning on a specific value.
        scope = base_scope + tf.get_default_graph().unique_name("sample")
        dict_swap = {}
        for x, qx in six.iteritems(self.data):
            if isinstance(x, RandomVariable):
                if isinstance(qx, RandomVariable):
                    qx_copy = copy(qx, scope=scope)
                    dict_swap[x] = qx_copy.value()
                else:
                    dict_swap[x] = qx

        for z, qz in six.iteritems(self.latent_vars):
          # Copy q(z) to obtain new set of posterior samples.
          qz_copy = copy(qz, scope=scope)
          dict_swap[z] = qz_copy.value()
          q_log_prob[s] += tf.reduce_sum(
              self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z]))
    
        for z in six.iterkeys(self.latent_vars):
          z_copy = copy(z, dict_swap, scope=scope)
          q_log_prob[s] -= tf.reduce_sum(
              self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z]))
    
        for x in six.iterkeys(self.data):
          if isinstance(x, RandomVariable):
            x_copy = copy(x, dict_swap, scope=scope)
            p_log_prob[s] +=cof2/cof* tf.reduce_sum(
                self.scale.get(x, 1.0) *tf.exp( x_copy.log_prob(dict_swap[x])*cof))#-self.scale.get(x, 1.0) *1/cof2*(2*3.1415*1)**(cof/2)*(1+cof)**0.5)
                  # the above second term for the unbiasedness need not to be included in the objective function because  it will be constant when we consider the regression problem, and thus it will vanish when we take the gradient. 
      kl_penalty = tf.reduce_sum([
          self.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z))
          for z, qz in six.iteritems(self.latent_vars)])
      
      p_log_prob = tf.reduce_mean(p_log_prob)
      q_log_prob = tf.reduce_mean(q_log_prob)
      
      if self.logging:
        tf.summary.scalar("loss/p_log_prob", p_log_prob,
                          collections=[self._summary_key])
        tf.summary.scalar("loss/q_log_prob", q_log_prob,
                          collections=[self._summary_key])
      loss = -(p_log_prob - q_log_prob)
      grads = tf.gradients(loss, var_list)
      grads_and_vars = list(zip(grads, var_list))
      return loss, grads_and_vars
    def __init__(self, batch_size=1000, latent_dim=25, epochs=50):

        self.epochs = epochs
        self.latent_dim = latent_dim

        # Data input
        plink_dataset = SingleDataset(
            plink_file=
            '/plink_tensorflow/data/test/scz_easy-access_wave2.no_trio.bgn',
            scratch_dir='/plink_tensorflow/data/test/',
            overwrite=False)
        self.m_variants = plink_dataset.bim.shape[0]
        self.total_train_batches = (len(plink_dataset.train_files) //
                                    batch_size) + 1
        self.total_test_batches = (len(plink_dataset.test_files) //
                                   batch_size) + 1

        print('\nTraining Summary:')
        print('\tTraining files: {}'.format(len(plink_dataset.train_files)))
        print('\tTesting  files: {}'.format(len(plink_dataset.test_files)))
        print('\tTraining batches: {}'.format(self.total_train_batches))
        print('\tTesing  batches: {}'.format(self.total_test_batches))

        print('\nBuilding computational graph...')
        # Input pipeline
        test_dataset = self.build_test_dataset(plink_dataset, batch_size)
        training_dataset = self.build_training_dataset(plink_dataset,
                                                       batch_size)

        self.handle = tf.placeholder(tf.string, shape=[])
        self.iterator = tf.data.Iterator.from_string_handle(
            self.handle, training_dataset.output_types,
            training_dataset.output_shapes)

        self.training_iterator = training_dataset.make_initializable_iterator()
        self.test_iterator = test_dataset.make_initializable_iterator()

        genotypes = self.iterator.get_next()

        genotypes = tf.cast(genotypes, tf.float32, name='cast_genotypes')
        genotypes.set_shape([None, self.m_variants])

        # Define the model.
        prior = self.make_prior(latent_dim=self.latent_dim)
        make_encoder = tf.make_template('encoder', self.make_encoder)
        posterior = make_encoder(genotypes, latent_dim=self.latent_dim)
        self.latent_z = posterior.sample()

        # Define the loss.
        make_decoder = tf.make_template('decoder', self.make_decoder)
        likelihood = make_decoder(self.latent_z).log_prob(genotypes)
        divergence = tfd.kl_divergence(posterior, prior)
        self.elbo = tf.reduce_mean(likelihood - divergence)
        self.optimizer = tf.train.AdamOptimizer(0.001).minimize(-self.elbo)
        print('Done')
Ejemplo n.º 10
0
    def __init__(
        self,
        s_dim,
        a_dim,
        kl_target,
    ):
        self.a_dim = a_dim
        self.s_dim = s_dim
        self.kl_target = kl_target

        self.tfs = tf.placeholder(tf.float32, [None, s_dim])

        # critic
        with tf.variable_scope('critic'):
            l1 = tf.layers.dense(self.tfs, 100, tf.nn.relu)
            self.v = tf.layers.dense(l1, 1)
            self.tfdc_r = tf.placeholder(tf.float32, [
                None,
            ])
            self.advantage = self.tfdc_r - tf.squeeze(self.v)
            self.closs = tf.reduce_mean(tf.square(self.advantage))
            self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs)

        # actor
        pi, pi_params = self._build_anet('pi', trainable=True)
        oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)

        with tf.variable_scope('update_oldpi'):
            self.update_oldpi_op = [
                oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)
            ]

        self.sample_op = pi.sample(1)
        self.tfa = tf.placeholder(tf.float32, [
            None,
        ], 'action')
        with tf.variable_scope('ratio'):
            # ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa))
            ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa)
        with tf.variable_scope('kl'):
            self.kl = tf.stop_gradient(tf.reduce_mean(kl_divergence(oldpi,
                                                                    pi)))
        self.tflam = tf.placeholder(tf.float32, None, 'lambda')
        self.tfadv = tf.placeholder(tf.float32, [
            None,
        ], 'advantage')
        with tf.variable_scope('loss'):
            self.aloss = -(tf.reduce_mean(ratio * self.tfadv) -
                           self.tflam * self.kl)
        with tf.variable_scope('atrain'):
            self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)

        tf.summary.FileWriter("log/", self.sess.graph)

        self.sess.run(tf.global_variables_initializer())
Ejemplo n.º 11
0
    def __init__(
        self,
        s_dim,
        a_dim,
    ):
        self.a_dim = a_dim
        self.s_dim = s_dim
        self.sess = tf.Session()

        self.tfs = tf.placeholder(tf.float32, [None, s_dim], 'state')

        # critic
        with tf.variable_scope('critic'):
            l1 = tf.layers.dense(self.tfs, 100, tf.nn.relu)
            self.v = tf.layers.dense(l1, 1)
            self.tfdc_r = tf.placeholder(tf.float32, [None, 1], 'discounted_r')
            self.advantage = self.tfdc_r - self.v
            self.closs = tf.reduce_mean(tf.square(self.advantage))
            self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs)

        # actor
        pi, pi_params = self._build_anet('pi', trainable=True)
        oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)
        self.sample_op = tf.squeeze(pi.sample(1), axis=0)  # choosing action
        with tf.variable_scope('update_oldpi'):
            self.update_oldpi_op = [
                oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)
            ]

        self.tfa = tf.placeholder(tf.float32, [None, a_dim], 'action')
        self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')
        with tf.variable_scope('surrogate'):
            # ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa))
            ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa)
            surr = ratio * self.tfadv
        if METHOD['name'] == 'kl_pen':
            self.tflam = tf.placeholder(tf.float32, None, 'lambda')
            with tf.variable_scope('loss'):
                kl = tf.stop_gradient(kl_divergence(oldpi, pi))
                self.kl_mean = tf.reduce_mean(kl)
                self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))
        else:  # clipping method, find this is better
            with tf.variable_scope('loss'):
                self.aloss = -tf.reduce_mean(
                    tf.minimum(
                        surr,
                        tf.clip_by_value(ratio, 1. - METHOD['epsilon'],
                                         1. + METHOD['epsilon']) * self.tfadv))

        with tf.variable_scope('atrain'):
            self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss)

        tf.summary.FileWriter("log/", self.sess.graph)

        self.sess.run(tf.global_variables_initializer())
Ejemplo n.º 12
0
 def kl_q_p(self):
     kl_separated = distributions.kl_divergence(self.q_dist,
                                                self.p_dist)  # [bs/nbs, N]
     kl_minibatch = tf.reduce_mean(kl_separated, 0,
                                   keep_dims=True)  # [1, N]
     tf.summary.scalar("true_kl", tf.reduce_sum(kl_minibatch))
     if self.kl_min > 0:
         kl_lower_bounded = tf.maximum(kl_minibatch, self.kl_min)
         kl = tf.reduce_sum(kl_lower_bounded)  # [], i.e., scalar
     else:
         kl = tf.reduce_sum(kl_minibatch)  # [], i.e., scalar
     return kl
Ejemplo n.º 13
0
                def _loop_body(t, ta_h_s, ta_z_prior, ta_z_post, ta_kl):

                    with tf.variable_scope('iter_body', reuse=tf.AUTO_REUSE):

                        def _init():
                            h_s_init = tf.nn.tanh(tf.random_normal(shape=[self.batch_size, self.h_size]))
                            h_z_init = tf.nn.tanh(tf.random_normal(shape=[self.batch_size, self.z_size]))

                            z_init, _ = self._z(arg=h_z_init, is_prior=False)

                            return h_s_init, z_init

                        def _subsequent():
                            h_s_t_1 = tf.reshape(ta_h_s.read(t-1), [self.batch_size, self.h_size])
                            z_t_1 = tf.reshape(ta_z_post.read(t-1), [self.batch_size, self.z_size])

                            return h_s_t_1, z_t_1

                        h_s_t_1, z_t_1 = tf.cond(t >= 1, _subsequent, _init)

                        gate_args = [x[t], h_s_t_1, z_t_1]

                        with tf.variable_scope('gru_r'):
                            r = self._linear(gate_args, self.h_size, 'sigmoid')
                        with tf.variable_scope('gru_u'):
                            u = self._linear(gate_args, self.h_size, 'sigmoid')

                        h_args = [x[t], tf.multiply(r, h_s_t_1), z_t_1]

                        with tf.variable_scope('gru_h'):
                            h_tilde = self._linear(h_args, self.h_size, 'tanh')

                        h_s_t = tf.multiply(1 - u, h_s_t_1) + tf.multiply(u, h_tilde)

                        with tf.variable_scope('h_z_prior'):
                            h_z_prior_t = self._linear([x[t], h_s_t], self.z_size, 'tanh')
                        with tf.variable_scope('z_prior'):
                            z_prior_t, z_prior_t_pdf = self._z(h_z_prior_t, is_prior=True)

                        with tf.variable_scope('h_z_post'):
                            h_z_post_t = self._linear([x[t], h_s_t, y_[t]], self.z_size, 'tanh')
                        with tf.variable_scope('z_post'):
                            z_post_t, z_post_t_pdf = self._z(h_z_post_t, is_prior=False)

                    kl_t = ds.kl_divergence(z_post_t_pdf, z_prior_t_pdf)

                    # write
                    ta_h_s = ta_h_s.write(t, h_s_t)
                    ta_z_prior = ta_z_prior.write(t, z_prior_t)  # write: batch_size * z_size
                    ta_z_post = ta_z_post.write(t, z_post_t)  # write: batch_size * z_size
                    ta_kl = ta_kl.write(t, kl_t)  # write: batch_size * 1

                    return t + 1, ta_h_s, ta_z_prior, ta_z_post, ta_kl
Ejemplo n.º 14
0
    def _build(self, inputs, hvar_labels, n_samples=10, analytic_kl=True):
        datum_shape = inputs.get_shape().as_list()[1:]
        enc_repr = self._encoder(inputs)

        self.hvar_prior = tfd.ExpRelaxedOneHotCategorical(
            temperature=self._temperature, logits=hvar_labels)
        self.hvar_posterior = tfd.ExpRelaxedOneHotCategorical(
            temperature=self._temperature, logits=self._hvar(enc_repr))
        hvar_sample_shape = [n_samples
                             ] + self.hvar_posterior.batch_shape.as_list(
                             ) + self.hvar_posterior.event_shape.as_list()
        hvar_sample = tf.reshape(self.hvar_posterior.sample(n_samples),
                                 hvar_sample_shape)

        self.latent_posterior = self._latent_posterior_fn(
            self._loc(enc_repr), self._scale(enc_repr))
        latent_posterior_sample = self.latent_posterior.sample(n_samples)

        joint_sample = tf.concat([hvar_sample, latent_posterior_sample],
                                 axis=-1)

        sample_decoder = snt.BatchApply(self._decoder)
        self.output_distribution = tfd.Independent(
            tfd.Bernoulli(logits=sample_decoder(joint_sample)),
            reinterpreted_batch_ndims=len(datum_shape))

        distortion = -self.output_distribution.log_prob(inputs)
        if analytic_kl and n_samples == 1:
            rate = tfd.kl_divergence(self.latent_posterior, self.latent_prior)
        else:
            rate = (self.latent_posterior.log_prob(latent_posterior_sample) -
                    self.latent_prior.log_prob(latent_posterior_sample))
        hrate = self.hvar_posterior.log_prob(
            hvar_sample) - self.hvar_prior.log_prob(hvar_sample)
        # hrate = tf.Print(hrate, [temperature])
        # hrate = tf.Print(hrate, [hvar_sample], summarize=10)
        # hrate = tf.Print(hrate, [self.hvar_posterior.log_prob(hvar_sample)])
        # hrate = tf.Print(hrate, [self.hvar_prior.log_prob(hvar_sample)])
        # hrate = tf.Print(hrate, [hrate], summarize=10)
        elbo_local = -(rate + hrate + distortion)
        self.elbo = tf.reduce_mean(elbo_local)
        self.importance_weighted_elbo = tf.reduce_mean(
            tf.reduce_logsumexp(elbo_local, axis=0) -
            tf.log(tf.to_float(n_samples)))

        self.hvar_sample = tf.exp(tf.split(hvar_sample, n_samples)[0])
        self.hvar_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=hvar_labels, logits=tf.split(hvar_sample, n_samples)[0])
        self.hvar_labels = hvar_labels
        self.distortion = distortion
        self.rate = rate
        self.hrate = hrate
Ejemplo n.º 15
0
    def testKL(self):
        mu1, sd1, mu2, sd2 = [np.random.rand(4, 6) for _ in range(4)]
        pair_kl = pair_kl_divergence(mu1, sd1, mu2, sd2)

        dist1 = distributions.Normal(mu1, sd1)
        dist2 = distributions.Normal(mu2, sd2)
        kl_tf = distributions.kl_divergence(dist1, dist2)

        with tf.Session() as sess:
            kl_val = sess.run(kl_tf)
            kl_val = kl_val.sum(axis=-1)

        self.assertAllClose(np.diag(pair_kl), kl_val)
Ejemplo n.º 16
0
def compute_loss(inf_mean_list, inf_var_list, gen_mean_list, gen_var_list,
                 q_log_discrete, log_px, batch_size):
    gaussian_div = []

    for mean0, var0, mean1, var1 in zip(inf_mean_list, inf_var_list,
                                        reversed(gen_mean_list),
                                        reversed(gen_var_list)):
        kl_gauss = dist.kl_divergence(dist.MultivariateNormalDiag(mean0, var0),
                                      dist.MultivariateNormalDiag(mean1, var1))
        gaussian_div.append(kl_gauss)

    kl_gauss = tf.reshape(tf.concat(gaussian_div, axis=0),
                          [batch_size, len(gaussian_div)])
    kl_dis = dist.kl_divergence(
        dist.OneHotCategorical(logits=q_log_discrete),
        dist.OneHotCategorical(
            logits=tf.log(tf.ones_like(q_log_discrete) * 1 / 10)))
    mean_KL = tf.reduce_mean(tf.reduce_sum(kl_gauss, axis=1) + kl_dis)
    mean_rec = tf.reduce_mean(log_px)
    loss = tf.reduce_mean(log_px - 0.5 *
                          ((tf.reduce_sum(kl_gauss, axis=1) + kl_dis)))

    return loss, mean_rec, mean_KL
Ejemplo n.º 17
0
 def get_loss(self, config):
     self.divergence = tf.reduce_mean(
         tfd.kl_divergence(self.posterior, self.prior))
     self.crossent = tf.contrib.seq2seq.sequence_loss(
         self.logits,
         self.targets,
         self.target_weights,
         average_across_timesteps=True,
         average_across_batch=True)
     loss = self.divergence + self.crossent
     #print self.divergence
     #print self.crossent
     #print loss
     #exit(1)
     return loss
Ejemplo n.º 18
0
def kl_divergence(distribution_a, distribution_b,
                  average_across_latent_dim=False,
                  average_across_batch=True):
    kl_div = distributions.kl_divergence(distribution_a, distribution_b)

    if average_across_latent_dim:
        kl_div = tf.reduce_mean(kl_div, axis=1)     # [b]
    else:
        kl_div = tf.reduce_sum(kl_div, axis=1)      # [b]

    if average_across_batch:
        kl_div = tf.reduce_mean(kl_div, axis=0)
    else:
        kl_div = tf.reduce_sum(kl_div, axis=0)

    return kl_div
Ejemplo n.º 19
0
    def one_step(self, a, x):

        z = a[0]
        u, enc = x

        q_mean, q_var = self.q_transition(z, enc, u)
        p_mean, p_var = self.p_transition(z, u)

        q = MultivariateNormalDiag(q_mean, tf.sqrt(q_var))
        p = MultivariateNormalDiag(p_mean, tf.sqrt(p_var))

        z_step = q.sample()

        kl = kl_divergence(q, p)

        return z_step, kl
Ejemplo n.º 20
0
    def kl_categorical(p=None, q=None, p_logits=None, q_logits=None, eps=1e-6):
        '''
        Given p and q (as EITHER BOTH logits or softmax's)
        then this func returns the KL between them.

        Utilizes an eps in order to resolve divide by zero / log issues
        '''
        if p_logits is not None and q_logits is not None:
            Q = distributions.Categorical(logits=q_logits, dtype=tf.float32)
            P = distributions.Categorical(logits=p_logits, dtype=tf.float32)
        elif p is not None and q is not None:
            print 'p shp = ', p.get_shape().as_list(), \
                ' | q shp = ', q.get_shape().as_list()
            Q = distributions.Categorical(probs=q + eps, dtype=tf.float32)
            P = distributions.Categorical(probs=p + eps, dtype=tf.float32)
        else:
            raise Exception("please provide either logits or dists")

        return distributions.kl_divergence(P, Q)
Ejemplo n.º 21
0
def divergence(q, p, metric='kl', n_monte_carlo_samples=1000):
    """Compute divergence measure between probability distributions.
    
    Args:
        q,p: probability distributions
        metric: divergence metric
        n_monte_carlo_samples: number of monte carlo samples for estimate
    """
    if metric == 'kl':
        return kl_divergence(q, p, allow_nan_stats=False)
    elif metric == 'dotproduct':
        samples_q = q.sample([n_monte_carlo_samples])
        distance_wrt_q = tf.reduce_mean(q.prob(samples_q) - p.prob(samples_q))
        samples_p = p.sample([n_monte_carlo_samples])
        distance_wrt_p = tf.reduce_mean(q.prob(samples_p) - p.prob(samples_p))
        return (distance_wrt_q - distance_wrt_p)
    elif metric == 'gradkl':
        raise NotImplementedError('Metric not supported %s' % metric)
    else:
        raise NotImplementedError('Metric not supported %s' % metric)
Ejemplo n.º 22
0
def gaussian_reparmeterization(logits_z, rnd_sample=None):
    '''
    The vanilla gaussian reparameterization from Kingma et. al

    z = mu + sigma * N(0, I)
    '''
    zshp = logits_z.get_shape().as_list()
    assert zshp[1] % 2 == 0
    q_sigma = 1e-6 + tf.nn.softplus(logits_z[:, 0:zshp[1]/2])
    q_mu = logits_z[:, zshp[1]/2:]

    # Prior
    p_z = d.Normal(loc=tf.zeros(zshp[1] / 2),
                   scale=tf.ones(zshp[1] / 2))

    with st.value_type(st.SampleValue()):
        q_z = st.StochasticTensor(d.Normal(loc=q_mu, scale=q_sigma))

    reduce_index = [1] if len(zshp) == 2 else [1, 2]
    kl = d.kl_divergence(q_z.distribution, p_z, allow_nan_stats=False)
    return [q_z, tf.reduce_sum(kl, reduce_index)]
Ejemplo n.º 23
0
def plot_objective():
    weights_q = [0.6, 0.4]
    # weights_s = gamma is what we iterate on
    gammas = np.arange(0., 1., 0.02)
    # for exact gamma
    mus = [2., -1., 0.]
    stds = [.6, .4, 0.5]

    # for inexact approx
    mus2 = [-1., 1., 0., 2.0]
    stds2 = [3.3, 0.9, 0.5, 0.4]

    g = tf.Graph()
    with g.as_default():
        sess = tf.InteractiveSession()
        with sess.as_default():
            comps = [
                Normal(loc=tf.convert_to_tensor(mus[i], dtype=tf.float32),
                       scale=tf.convert_to_tensor(stds[i], dtype=tf.float32))
                for i in range(len(mus))
            ]
            comps2 = [
                Normal(loc=tf.convert_to_tensor(mus2[i], dtype=tf.float32),
                       scale=tf.convert_to_tensor(stds2[i], dtype=tf.float32))
                for i in range(len(mus2))
            ]
            # p = pi[0] * N(mus[0], stds[0]) + ... + pi[2] * N(mus[2], stds[2])
            weight_s = 0.5
            logger.info('true gamma for exact mixture %.2f' % (weight_s))
            final_weights = [(1 - weight_s) * w for w in weights_q]
            final_weights.append(weight_s)
            p = Mixture(
                cat=Categorical(probs=tf.convert_to_tensor(final_weights)),
                components=comps)

            objective_exact = []
            objective_inexact = []
            for gamma in gammas:
                new_weights = [(1 - gamma) * w for w in weights_q]
                new_weights.append(gamma)
                q = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=comps)
                objective = kl_divergence(q, p, allow_nan_stats=False).eval()
                objective_exact.append(objective)

                new_weights2 = [(1 - gamma) * w for w in final_weights]
                new_weights2.append(gamma)
                q2 = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights2)),
                    components=comps2)
                objective2 = kl_divergence(q2, p, allow_nan_stats=False).eval()
                objective_inexact.append(objective2)

                logger.info(
                    'gamma = %.2f, D_kl_exact = %.5f, D_kl_inexact = %.5f' %
                    (gamma, objective, objective2))
    plt.plot(gammas,
             objective_exact,
             '-',
             color='r',
             linewidth=2.0,
             label='exact mixture')
    plt.plot(gammas,
             objective_inexact,
             '-',
             color='b',
             linewidth=2.0,
             label='inexact mixture')
    plt.legend()
    plt.xlabel('gamma')
    plt.ylabel('kl divergence of mixture')
    plt.show()
Ejemplo n.º 24
0
    return tfd.Independent(tfd.Bernoulli(logit), 2)


make_encoder = tf.make_template('encoder', make_encoder)
make_decoder = tf.make_template('decoder', make_decoder)

data = tf.placeholder(tf.float32, [None, input_size, 1])

prior = make_prior(code_size)
posterior, loc, scale = make_encoder(data, code_size)
code_all = posterior.sample(10)
code = tf.reduce_mean(code_all, reduction_indices=0)

likelihood = make_decoder(code, [input_size, 1]).log_prob(data)
divergence = tfd.kl_divergence(posterior, prior)
elbo = tf.reduce_mean(likelihood - divergence)

optimize = tf.train.AdamOptimizer(lr).minimize(-elbo)
samples = make_decoder(prior.sample(10), [input_size, 1]).mean()

init1 = tf.global_variables_initializer()
sess1 = tf.Session()
sess1.run(init1)

saver = tf.train.Saver()

if __name__ == '__main__':

    for epoch in range(5):
Ejemplo n.º 25
0
def main(_):
    epoch_size = 20
    logdir = './logdir'

    make_encoder = tf.make_template('encoder', _make_encoder)
    # In TensorFlow, if you call a network function twice,
    # it will create two separate networks.
    # TensorFlow templates allow you to wrap a function
    # so that multiple calls to it will reuse the same network parameters.
    make_decoder = tf.make_template('decoder', _make_decoder)

    data = tf.placeholder(tf.float32, [None, 28, 28])

    prior = make_prior(code_size=2)
    posterior = make_encoder(data, code_size=2)
    code = posterior.sample()

    likelihood = make_decoder(code, [28, 28]).log_prob(data)
    divergence = tfd.kl_divergence(posterior, prior)
    elbo = tf.reduce_mean(likelihood - divergence)

    optimizer = tf.train.AdamOptimizer(0.001).minimize(-elbo)
    samples = make_decoder(prior.sample(10), [28, 28]).mean()

    mnist = input_data.read_data_sets('/tmp/MNIST_data/')
    fig, ax = plt.subplots(nrows=epoch_size, ncols=11, figsize=(10, 20))

    # Merged all summaries.
    _summary('likelihood', likelihood)
    _summary('divergence', divergence)
    _summary('elbo', elbo)
    _summary('samples', samples)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()

    _global_step = tf.get_variable('global_step', [],
                                   dtype=tf.int32,
                                   trainable=False)
    global_step_op = tf.assign_add(_global_step, 1)
    with tf.train.MonitoredSession() as sess:
        writer = tf.summary.FileWriter(logdir, sess.graph)
        for epoch in range(epoch_size):
            feed = {data: mnist.test.images.reshape([-1, 28, 28])}
            test_elbo, test_codes, test_samples = sess.run(
                [elbo, code, samples], feed)

            test_likelihood, test_divergence = sess.run(
                [likelihood, divergence], feed)
            print('likeli {}, ')

            # Plot codes and samples
            ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch))
            plot_codes(ax[epoch, 0], test_codes, mnist.test.labels)
            plot_samples(ax[epoch, 1:], test_samples)
            print(
                '\rEpoch {}, elbo {}, labes {}, test_codes {}, test_samples {}'
                .format(epoch, test_elbo, mnist.test.labels.shape,
                        test_codes.shape, test_samples.shape),
                end='',
                flush=True)

            for step in range(1, 600):
                feed = {
                    data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28])
                }
                _, summary, global_step = sess.run(
                    [optimizer, merged, global_step_op], feed)
                writer.add_summary(summary, global_step=global_step)
 def f(gamma):
     weights = [(1 - gamma), gamma]
     q_l = Mixture(cat=Categorical(probs=tf.convert_to_tensor(weights)),
                   components=[MultivariateNormalDiag(**c) for c in comps])
     return kl_divergence(q_l, qt).eval()
def main(argv):
    del argv

    outdir = FLAGS.outdir
    if '~' in outdir: outdir = os.path.expanduser(outdir)
    os.makedirs(outdir, exist_ok=True)

    # Files to log metrics
    times_filename = os.path.join(outdir, 'times.csv')
    elbos_filename = os.path.join(outdir, 'elbos.csv')
    objective_filename = os.path.join(outdir, 'kl.csv')
    reference_filename = os.path.join(outdir, 'ref_kl.csv')
    step_filename = os.path.join(outdir, 'steps.csv')
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        curvature_filename = os.path.join(outdir, 'curvature.csv')
        gap_filename = os.path.join(outdir, 'gap.csv')
        iter_info_filename = os.path.join(outdir, 'iter_info.txt')
    elif FLAGS.fw_variant == 'line_search':
        goutdir = os.path.join(outdir, 'gradients')

    # empty the files present in the folder already
    open(times_filename, 'w').close()
    open(elbos_filename, 'w').close()
    open(objective_filename, 'w').close()
    open(reference_filename, 'w').close()
    open(step_filename, 'w').close()
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        open(curvature_filename, 'w').close()
        append_to_file(curvature_filename, "c_local,c_global")
        open(gap_filename, 'w').close()
        open(iter_info_filename, 'w').close()
    elif FLAGS.fw_variant == 'line_search':
        os.makedirs(goutdir, exist_ok=True)

    for i in range(FLAGS.n_fw_iter):
        # NOTE: First iteration (t = 0) is initialization
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                p, mus, stds = create_target_dist()

                # current iterate (solution until now)
                if FLAGS.init == 'random':
                    muq = np.random.randn(D).astype(np.float32)
                    stdq = softplus(np.random.randn(D).astype(np.float32))
                    raise ValueError
                else:
                    muq = mus[0]
                    stdq = stds[0]

                # 1 correct LMO
                t = 1
                comps = [{'loc': muq, 'scale_diag': stdq}]
                weights = [1.0]
                curvature_estimate = opt.adafw_linit()

                qtx = MultivariateNormalDiag(
                    loc=tf.convert_to_tensor(muq, dtype=tf.float32),
                    scale_diag=tf.convert_to_tensor(stdq, dtype=tf.float32))
                fw_iterates = {p: qtx}

                # calculate kl-div with 1 component
                objective_old = kl_divergence(qtx, p).eval()
                logger.info("kl with init %.4f" % (objective_old))
                append_to_file(reference_filename, objective_old)

                # s is the solution to LMO. It is initialized randomly
                # mu ~ N(0, 1), std ~ softplus(N(0, 1))
                s = coreutils.construct_multivariatenormaldiag([D], t, 's')

                sess.run(tf.global_variables_initializer())

                total_time = 0
                start_inference_time = time.time()
                if FLAGS.LMO == 'vi':
                    # we have to iterate over parameter space
                    raise ValueError
                    inference = relbo.KLqp({p: s},
                                           fw_iterates=fw_iterates,
                                           fw_iter=t)
                    inference.run(n_iter=FLAGS.LMO_iter)
                # s now contains solution to LMO
                end_inference_time = time.time()

                mu_s = s.mean().eval()
                cov_s = s.stddev().eval()

                # NOTE: keep only step size time
                #total_time += end_inference_time - start_inference_time

                # compute step size to update the next iterate
                step_result = {}
                if FLAGS.fw_variant == 'fixed':
                    gamma = 2. / (t + 2.)
                elif FLAGS.fw_variant == 'line_search':
                    start_line_search_time = time.time()
                    step_result = opt.line_search_dkl(
                        weights, [c['loc'] for c in comps],
                        [c['scale_diag']
                         for c in comps], qtx, mu_s, cov_s, s, p, t)
                    end_line_search_time = time.time()
                    total_time += (end_line_search_time -
                                   start_line_search_time)
                    gamma = step_result['gamma']
                elif FLAGS.fw_variant == 'adafw':
                    start_adafw_time = time.time()
                    step_result = opt.adaptive_fw(
                        weights, [c['loc'] for c in comps],
                        [c['scale_diag'] for c in comps], qtx, mu_s, cov_s, s,
                        p, t, curvature_estimate)
                    end_adafw_time = time.time()
                    total_time += end_adafw_time - start_adafw_time
                    gamma = step_result['gamma']
                else:
                    raise NotImplementedError

                comps.append({'loc': mu_s, 'scale_diag': cov_s})
                weights = [(1. - gamma), gamma]

                c_global = estimate_global_curvature(comps, qtx)

                q_latest = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(weights)),
                    components=[MultivariateNormalDiag(**c) for c in comps])

                # Log metrics for current iteration
                time_t = float(total_time)
                logger.info('total time %f' % (time_t))
                append_to_file(times_filename, time_t)

                elbo_t = elbo(q_latest, p, n_samples=1000)
                logger.info("iter, %d, elbo, %.2f +/- %.2f" %
                            (t, elbo_t[0], elbo_t[1]))
                append_to_file(elbos_filename,
                               "%f,%f" % (elbo_t[0], elbo_t[1]))

                logger.info('iter %d, gamma %.4f' % (t, gamma))
                append_to_file(step_filename, gamma)

                objective_t = kl_divergence(q_latest, p).eval()
                logger.info("run %d, kl %.4f" % (i, objective_t))
                append_to_file(objective_filename, objective_t)

                if FLAGS.fw_variant.startswith('ada'):
                    curvature_estimate = step_result['c_estimate']
                    append_to_file(gap_filename, step_result['gap'])
                    append_to_file(iter_info_filename,
                                   step_result['step_type'])
                    logger.info('gap = %.3f, ct = %.5f, iter_type = %s' %
                                (step_result['gap'], step_result['c_estimate'],
                                 step_result['step_type']))
                    append_to_file(curvature_filename,
                                   '%f,%f' % (curvature_estimate, c_global))
                elif FLAGS.fw_variant == 'line_search':
                    n_line_search_samples = step_result['n_samples']
                    grad_t = step_result['grad_gamma']
                    g_outfile = os.path.join(
                        goutdir, 'line_search_samples_%d.npy.%d' %
                        (n_line_search_samples, t))
                    logger.info('saving line search data to, %s' % g_outfile)
                    np.save(open(g_outfile, 'wb'), grad_t)

            sess.close()

        tf.reset_default_graph()
Ejemplo n.º 28
0
 def kl_q_p(self):
     return tf.reduce_mean(
         distributions.kl_divergence(self.q_dist, self.p_dist))
def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Adaptive pairwise variant.
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Pairwise gap
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    gap_pw = step_v_t - step_s
    if gap_pw < 0: eprint("Pairwise gap is negative")

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap_pw,
            'step_type': fail_type
        }

    logger.info('Pairwise gap %.5f' % gap_pw)

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)
    gap = gap_pw
    if gap <= 0:
        return default_fixed_step()
    gamma_max = weights[index_v_t]
    step_type = 'adaptive'

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    drop_step = False
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        new_weights = copy.copy(weights)
        new_weights.append(gamma)
        if gamma == gamma_max:
            # hardcoding to 0 for precision issues
            new_weights[index_v_t] = 0
            drop_step = True
        else:
            new_weights[index_v_t] -= gamma
            drop_step = False

        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            new_comps = copy.copy(comps)
            new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
            if drop_step:
                del new_comps[index_v_t]
                del new_weights[index_v_t]
                logger.info("...drop step")
                step_type = 'drop'
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1
    
    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step("fixed_adaptive_MAXITER")
def adaptive_fw(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev,
                return_gamma=False):
    """Adaptive Frank-Wolfe algorithm.
    
    Sets step size as suggested in Algorithm 1 of
    https://arxiv.org/pdf/1806.05123.pdf

    Args:
        weights: [k], weights of the mixture components of q_t
        locs: [k x dim], means of mixture components of q_t
        diags: [k x dim], std deviations of mixture components of q_t
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
        return_gamma: only return the value of gamma
    Returns:
        If return_gamma is True, only the computed value of gamma
        is returned. Else returns a dictionary containing gamma, 
        lipschitz estimate, duality gap and step information
    """

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    N_samples = FLAGS.n_monte_carlo_samples
    # create and sample from $s_t, q_t$
    sample_q = q_t.sample([N_samples])
    sample_s = s_t.sample([N_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap = step_q - step_s
    logger.info('duality gap %.5f' % gap)
    if gap < 0: logger.warning("Duality gap is negative returning 0 step")

    #gamma = 2. / (k + 2.)
    gamma = 0.
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    # did the adaptive loop suceed or not
    step_type = "fixed"
    # NOTE: this is from v1 of the paper, new version
    # replaces multiplicative tau with divisor eta
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    # return intial estimate if gap is -ve
    while gap >= 0:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), 1.0)
        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$
        new_weights = copy.copy(weights)
        new_weights = [(1. - gamma) * w for w in new_weights]
        new_weights.append(gamma)
        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])
        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            break
        pow_tau *= tau
        i += 1
        #if i > FLAGS.adafw_MAXITER or gamma < MIN_GAMMA:
        if i > FLAGS.adafw_MAXITER:
            # estimate not good
            #gamma = 2. / (k + 2.)
            gamma = 0.
            l_t = l_prev
            step_type = "fixed_adaptive_MAXITER"
            break

    if return_gamma: return gamma
    return {
        'gamma': gamma,
        'l_estimate': l_t,
        'gap': gap,
        'step_type': step_type
    }
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Away steps variant
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Frank-Wolfe gap
    sample_q = q_t.sample([FLAGS.n_monte_carlo_samples])
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap_fw = step_q - step_s
    if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative")
    # Away gap
    gap_a = step_v_t - step_q
    if gap_a < 0: eprint('Away gap < 0!!!')
    logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a))

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    if (gap_fw >= gap_a) or (len(comps) == 1):
        # FW direction, proceeds exactly as adafw
        logger.info('Proceeding in FW direction ')
        adaptive_step_type = 'fw'
        gap = gap_fw
        new_locs.append(mu_s)
        new_diags.append(cov_s)
        gamma_max = 1.0
    else:
        # Away direction
        logger.info('Proceeding in Away direction ')
        adaptive_step_type = 'away'
        gap = gap_a
        if weights[index_v_t] < 1.0:
            gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t])
        else:
            gamma_max = 100. # Large value when t = 1

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap,
            'step_type': fail_type
        }
    
    if gap <= 0:
        return default_fixed_step()

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    is_drop_step = False
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        # NOTE: Handle extreme values of gamma carefully
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        if adaptive_step_type == 'fw':
            if gamma == gamma_max:
                # gamma = 1.0, q_{t + 1} = s_t
                new_comps = [{'loc': mu_s, 'scale_diag': cov_s}]
                new_weights = [1.]
                qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s)
            else:
                new_comps = copy.copy(comps)
                new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
                new_weights = copy.copy(weights)
                new_weights = [(1. - gamma) * w for w in new_weights]
                new_weights.append(gamma)
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])
        elif adaptive_step_type == 'away':
            new_weights = copy.copy(weights)
            new_comps = copy.copy(comps)
            if gamma == gamma_max:
                # drop v_t
                is_drop_step = True
                logger.info('...drop step')
                del new_weights[index_v_t]
                new_weights = [(1. + gamma) * w for w in new_weights]
                del new_comps[index_v_t]
                # NOTE: recompute locs and diags after dropping v_t
                drop_locs = [c['loc'] for c in new_comps]
                drop_diags = [c['scale_diag'] for c in new_comps]
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(drop_locs, drop_diags)
                    ])
            else:
                is_drop_step = False
                new_weights = [(1. + gamma) * w for w in new_weights]
                new_weights[index_v_t] -= gamma
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            if adaptive_step_type == "away": step_type = "away"
            if is_drop_step: step_type = "drop"
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1

    # adaptive loop failed, return fixed step size
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step()