Exemplo n.º 1
0
    def testArgRenames(self):
        with self.test_session():

            a = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
            b = [[True, False, False], [False, True, True]]
            dim0 = [1]
            dim1 = [1]

            self.assertAllEqual(tf.reduce_any(b, reduction_indices=dim0).eval(), [True, True])
            self.assertAllEqual(tf.reduce_all(b, reduction_indices=[0]).eval(), [False, False, False])
            self.assertAllEqual(tf.reduce_all(b, reduction_indices=dim1).eval(), [False, False])
            self.assertAllEqual(tf.reduce_sum(a, reduction_indices=[1]).eval(), [6.0, 15.0])
            self.assertAllEqual(tf.reduce_sum(a, reduction_indices=[0, 1]).eval(), 21.0)
            self.assertAllEqual(tf.reduce_sum(a, [0, 1]).eval(), 21.0)
            self.assertAllEqual(tf.reduce_prod(a, reduction_indices=[1]).eval(), [6.0, 120.0])
            self.assertAllEqual(tf.reduce_prod(a, reduction_indices=[0, 1]).eval(), 720.0)
            self.assertAllEqual(tf.reduce_prod(a, [0, 1]).eval(), 720.0)
            self.assertAllEqual(tf.reduce_mean(a, reduction_indices=[1]).eval(), [2.0, 5.0])
            self.assertAllEqual(tf.reduce_mean(a, reduction_indices=[0, 1]).eval(), 3.5)
            self.assertAllEqual(tf.reduce_mean(a, [0, 1]).eval(), 3.5)
            self.assertAllEqual(tf.reduce_min(a, reduction_indices=[1]).eval(), [1.0, 4.0])
            self.assertAllEqual(tf.reduce_min(a, reduction_indices=[0, 1]).eval(), 1.0)
            self.assertAllEqual(tf.reduce_min(a, [0, 1]).eval(), 1.0)
            self.assertAllEqual(tf.reduce_max(a, reduction_indices=[1]).eval(), [3.0, 6.0])
            self.assertAllEqual(tf.reduce_max(a, reduction_indices=[0, 1]).eval(), 6.0)
            self.assertAllEqual(tf.reduce_max(a, [0, 1]).eval(), 6.0)
            self.assertAllClose(tf.reduce_logsumexp(a, reduction_indices=[1]).eval(), [3.40760589, 6.40760612])
            self.assertAllClose(tf.reduce_logsumexp(a, reduction_indices=[0, 1]).eval(), 6.45619344711)
            self.assertAllClose(tf.reduce_logsumexp(a, [0, 1]).eval(), 6.45619344711)
            self.assertAllEqual(tf.expand_dims([[1, 2], [3, 4]], dim=1).eval(), [[[1, 2]], [[3, 4]]])
Exemplo n.º 2
0
 def while_step(t, rnn_state, tas, accs):
   """Implements one timestep of FIVO computation."""
   log_weights_acc, log_p_hat_acc, kl_acc = accs
   cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
   # Run the cell for one step.
   log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
       cur_inputs,
       rnn_state,
       cur_mask,
   )
   # Compute the incremental weight and use it to update the current
   # accumulated weight.
   kl_acc += kl * cur_mask
   log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
   log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
   log_weights_acc += log_alpha
   # Calculate the effective sample size.
   ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
   ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
   log_ess = ess_num - ess_denom
   # Calculate the ancestor indices via resampling. Because we maintain the
   # log unnormalized weights, we pass the weights in as logits, allowing
   # the distribution object to apply a softmax and normalize them.
   resampling_dist = tf.contrib.distributions.Categorical(
       logits=tf.transpose(log_weights_acc, perm=[1, 0]))
   ancestor_inds = tf.stop_gradient(
       resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
   # Because the batch is flattened and laid out as discussed
   # above, we must modify ancestor_inds to index the proper samples.
   # The particles in the ith filter are distributed every batch_size rows
   # in the batch, and offset i rows from the top. So, to correct the indices
   # we multiply by the batch_size and add the proper offset. Crucially,
   # when ancestor_inds is flattened the layout of the batch is maintained.
   offset = tf.expand_dims(tf.range(batch_size), 0)
   ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
   noresample_inds = tf.range(num_samples * batch_size)
   # Decide whether or not we should resample; don't resample if we are past
   # the end of a sequence.
   should_resample = resampling_criterion(num_samples, log_ess, t)
   should_resample = tf.logical_and(should_resample,
                                    cur_mask[:batch_size] > 0.)
   float_should_resample = tf.to_float(should_resample)
   ancestor_inds = tf.where(
       tf.tile(should_resample, [num_samples]),
       ancestor_inds,
       noresample_inds)
   new_state = nested.gather_tensors(new_state, ancestor_inds)
   # Update the TensorArrays before we reset the weights so that we capture
   # the incremental weights and not zeros.
   ta_updates = [log_weights_acc, log_ess, float_should_resample]
   new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
   # For the particle filters that resampled, update log_p_hat and
   # reset weights to zero.
   log_p_hat_update = tf.reduce_logsumexp(
       log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
   log_p_hat_acc += log_p_hat_update * float_should_resample
   log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :],
                                    [num_samples, 1]))
   new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
   return t + 1, new_state, new_tas, new_accs
Exemplo n.º 3
0
 def while_step(t, rnn_state, tas, accs):
   """Implements one timestep of IWAE computation."""
   log_weights_acc, kl_acc = accs
   cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
   # Run the cell for one step.
   log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
       cur_inputs,
       rnn_state,
       cur_mask,
   )
   # Compute the incremental weight and use it to update the current
   # accumulated weight.
   kl_acc += kl * cur_mask
   log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
   log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
   log_weights_acc += log_alpha
   # Calculate the effective sample size.
   ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
   ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
   log_ess = ess_num - ess_denom
   # Update the  Tensorarrays and accumulators.
   ta_updates = [log_weights_acc, log_ess]
   new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
   new_accs = (log_weights_acc, kl_acc)
   return t + 1, new_state, new_tas, new_accs
Exemplo n.º 4
0
def ess_criterion(log_weights, unused_t):
  """A criterion that resamples based on effective sample size."""
  num_particles = tf.shape(log_weights)[0]
  # Calculate the effective sample size.
  ess_num = 2 * tf.reduce_logsumexp(log_weights, axis=0)
  ess_denom = tf.reduce_logsumexp(2 * log_weights, axis=0)
  log_ess = ess_num - ess_denom
  return log_ess <= tf.log(tf.to_float(num_particles) / 2.0)
Exemplo n.º 5
0
    def __init__(self, env_spec, expert_trajs=None,
                 discrim_arch=relu_net,
                 discrim_arch_args={},
                 score_using_discrim=False,
                 l2_reg=0,
                 name='gcl'):
        super(AIRLDiscrete, self).__init__()
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        self.score_using_discrim = score_using_discrim
        if expert_trajs:
            self.expert_trajs = expert_trajs
            self.expert_trajs_extracted = self.extract_paths(expert_trajs)

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
            self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
            with tf.variable_scope('discrim') as dvs:
                with tf.variable_scope('energy'):
                    energy = discrim_arch(obs_act, dout=self.dU, **discrim_arch_args)

                self.value_fn = tf.reduce_logsumexp(-energy, axis=1, keep_dims=True)
                self.energy = tf.reduce_sum(energy*self.act_t, axis=1, keep_dims=True)  # select action

                log_p_tau = -self.energy - self.value_fn  
                discrim_vars = tf.get_collection('reg_vars', scope=dvs.name)


            log_q_tau = self.lprobs

            if l2_reg > 0:
                reg_loss = l2_reg*tf.reduce_sum([tf.reduce_sum(tf.square(var)) for var in discrim_vars])
            else:
                reg_loss = 0

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.d_tau = tf.exp(log_p_tau-log_pq)
            cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))

            self.loss = cent_loss + reg_loss
            self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
            self._make_param_ops(_vs)
Exemplo n.º 6
0
  def testCrfLogNorm(self):
    inputs = np.array(
        [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
    transition_params = np.array(
        [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
    num_words = inputs.shape[0]
    num_tags = inputs.shape[1]
    sequence_lengths = np.array(3, dtype=np.int32)
    with self.test_session() as sess:
      all_sequence_scores = []

      # Compare the dynamic program with brute force computation.
      for tag_indices in itertools.product(
          range(num_tags), repeat=sequence_lengths):
        tag_indices = list(tag_indices)
        tag_indices.extend([0] * (num_words - sequence_lengths))
        all_sequence_scores.append(
            tf.contrib.crf.crf_sequence_score(
                inputs=tf.expand_dims(inputs, 0),
                tag_indices=tf.expand_dims(tag_indices, 0),
                sequence_lengths=tf.expand_dims(sequence_lengths, 0),
                transition_params=tf.constant(transition_params)))

      brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores)
      log_norm = tf.contrib.crf.crf_log_norm(
          inputs=tf.expand_dims(inputs, 0),
          sequence_lengths=tf.expand_dims(sequence_lengths, 0),
          transition_params=tf.constant(transition_params))
      log_norm = tf.squeeze(log_norm, [0])
      tf_brute_force_log_norm, tf_log_norm = sess.run(
          [brute_force_log_norm, log_norm])

      self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
Exemplo n.º 7
0
  def log_alpha_likelihood_ratio(self, activation_fn=tf.nn.relu):

    # each nn sample returns (log f, log likelihoods)
    nn_samples = [
        self.sample_neural_network(activation_fn)
        for _ in range(self.num_mc_nn_samples)
    ]
    nn_log_f_samples = [elt[0] for elt in nn_samples]
    nn_log_lk_samples = [elt[1] for elt in nn_samples]

    # we stack the (log f, log likelihoods) from the k nn samples
    nn_log_f_stack = tf.stack(nn_log_f_samples)      # k x 1
    nn_log_lk_stack = tf.stack(nn_log_lk_samples)    # k x N
    nn_f_tile = tf.tile(nn_log_f_stack, [self.batch_size])
    nn_f_tile = tf.reshape(nn_f_tile,
                           [self.num_mc_nn_samples, self.batch_size])

    # now both the log f and log likelihood terms have shape: k x N
    # apply formula in https://www.overleaf.com/12837696kwzjxkyhdytk#/49028744/
    nn_log_ratio = nn_log_lk_stack - nn_f_tile
    nn_log_ratio = self.alpha * tf.transpose(nn_log_ratio)
    logsumexp_value = tf.reduce_logsumexp(nn_log_ratio, -1)
    log_k_scalar = tf.log(tf.cast(self.num_mc_nn_samples, tf.float32))
    log_k = log_k_scalar * tf.ones([self.batch_size])

    return tf.reduce_sum(logsumexp_value - log_k, -1)
Exemplo n.º 8
0
  def testCrfLogLikelihood(self):
    inputs = np.array(
        [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
    transition_params = np.array(
        [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
    sequence_lengths = np.array(3, dtype=np.int32)
    num_words = inputs.shape[0]
    num_tags = inputs.shape[1]
    with self.test_session() as sess:
      all_sequence_log_likelihoods = []

      # Make sure all probabilities sum to 1.
      for tag_indices in itertools.product(
          range(num_tags), repeat=sequence_lengths):
        tag_indices = list(tag_indices)
        tag_indices.extend([0] * (num_words - sequence_lengths))
        sequence_log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
            inputs=tf.expand_dims(inputs, 0),
            tag_indices=tf.expand_dims(tag_indices, 0),
            sequence_lengths=tf.expand_dims(sequence_lengths, 0),
            transition_params=tf.constant(transition_params))
        all_sequence_log_likelihoods.append(sequence_log_likelihood)
      total_log_likelihood = tf.reduce_logsumexp(all_sequence_log_likelihoods)
      tf_total_log_likelihood = sess.run(total_log_likelihood)
      self.assertAllClose(tf_total_log_likelihood, 0.0)
Exemplo n.º 9
0
 def _log_variance(self):
   # Following calculation is based on law of total variance:
   #
   # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
   #
   # where,
   #
   # Z|v ~ interpolate_affine[v](distribution)
   # V ~ mixture_distribution
   #
   # thus,
   #
   # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
   # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
   v = tf.stack(
       [
           # log(self.distribution.variance()) = log(Var[d]) = log(rate[d])
           self.distribution.log_rate,
           # log((Mean[d] - Mean)**2)
           2. * tf.log(
               tf.abs(self.distribution.mean() -
                      self._mean()[..., tf.newaxis])),
       ],
       axis=-1)
   return tf.reduce_logsumexp(
       self.mixture_distribution.logits[..., tf.newaxis] + v, axis=[-2, -1])
Exemplo n.º 10
0
 def _log_prob(self, x):
   with tf.control_dependencies(self._runtime_assertions):
     x = self._pad_sample_dims(x)
     log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
     log_mix_prob = tf.nn.log_softmax(
         self.mixture_distribution.logits, axis=-1)  # [B, k]
     return tf.reduce_logsumexp(log_prob_x + log_mix_prob, axis=-1)  # [S, B]
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       tf.assert_non_positive(x),
       tf.assert_near(
           tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])),
   ], x)
Exemplo n.º 12
0
 def log_prob(self, x):
   n1 = tf.contrib.distributions.Normal(self.mu, self.sigma1)
   n2 = tf.contrib.distributions.Normal(self.mu, self.sigma2)
   mix1 = tf.reduce_sum(n1.log_prob(x), -1) + tf.log(self.pi)
   mix2 = tf.reduce_sum(n2.log_prob(x), -1) + tf.log(np.float32(1.0 - self.pi))
   prior_mix = tf.stack([mix1, mix2])
   lse_mix = tf.reduce_logsumexp(prior_mix, [0])
   return tf.reduce_sum(lse_mix)
Exemplo n.º 13
0
 def eval_func(func):
     feval = func(mc_Xr, **Ys)
     feval = tf.reshape(feval, (S, N, -1))
     if logspace:
         log_S = tf.log(tf.cast(S, settings.float_type))
         return tf.reduce_logsumexp(feval, axis=0) - log_S  # N x D
     else:
         return tf.reduce_mean(feval, axis=0)
Exemplo n.º 14
0
 def eval_func(f):
     feval = f(*Xs, **Ys)  # f should be elementwise: return shape N x H**Din
     if logspace:
         log_gh_w = np.log(gh_w.reshape(1, -1))
         result = tf.reduce_logsumexp(feval + log_gh_w, axis=1)
     else:
         result = tf.matmul(feval, gh_w.reshape(-1, 1))
     return tf.reshape(result, shape)
Exemplo n.º 15
0
def reduce_logmeanexp(input_tensor, axis=None, keep_dims=False):
  logsumexp = tf.reduce_logsumexp(input_tensor, axis, keep_dims)
  input_tensor = tf.convert_to_tensor(input_tensor)
  n = input_tensor.shape.as_list()
  if axis is None:
    n = tf.cast(tf.reduce_prod(n), logsumexp.dtype)
  else:
    n = tf.cast(tf.reduce_prod(n[axis]), logsumexp.dtype)

  return -tf.log(n) + logsumexp
Exemplo n.º 16
0
 def neg_log_likelihood(state):
   state_ext = tf.expand_dims(state, 0)
   linear_part = tf.matmul(state_ext, x_data)
   linear_part_ex = tf.stack([tf.zeros_like(linear_part),
                              linear_part], axis=0)
   term1 = tf.squeeze(tf.matmul(
       tf.reduce_logsumexp(linear_part_ex, axis=0), y_data), -1)
   term2 = (0.5 * tf.reduce_sum(state_ext * state_ext, -1) -
            tf.reduce_sum(linear_part, -1))
   return  tf.squeeze(term1 + term2)
def get_KL_divergence_Sample(shape, mu, sigma, prior, Z):
    
    """
    Compute KL divergence between posterior and prior.
    Instead of computing the real KL distance between the Prior and Variatiational
    posterior of the weights, we will jsut sample its value of the specific values
    of the sampled weights  W. 
    
    In this case:
        - Posterior: Multivariate Independent Gaussian.
        - Prior: Mixture model
    
    The sample of the posterior is:
        KL_sample = log(q(W|theta)) - log(p(W|theta_0)) where
         p(theta) = pi*N(0,sigma1) + (1-pi)*N(0,sigma2)
    
    Input:
        - mus,sigmas: 
        - Z: Samples weights values, the hidden variables !
    shape = shape of the sample we want to compute the KL of
    mu = the mu variable used when sampling
    sigma= the sigma variable used when sampling
    prior = the prior object with parameters
    sample = the sample from the posterior
    
    """
    
    # Flatten the hidden variables (weights)
    Z = tf.reshape(Z, [-1])
    
    #Get the log probability distribution of your sampled variable
    
    # Distribution of the Variational Posterior
    VB_distribution = Normal(mu, sigma)
    # Distribution of the Gaussian Components of the prior
    prior_1_distribution = Normal(0.0, prior.sigma1)
    prior_2_distribution = Normal(0.0, prior.sigma2)
    
    # Now we compute the log likelihood of those Hidden variables for their
    # prior and posterior.
    
    #get: sum( log[ q( theta | mu, sigma ) ] )
    q_ll = tf.reduce_sum(VB_distribution.log_prob(Z))
    
    #get: sum( log[ p( theta ) ] ) for mixture prior
    mix1 = tf.reduce_sum(prior_1_distribution.log_prob(Z)) + tf.log(prior.pi_mix)
    mix2 = tf.reduce_sum(prior_2_distribution.log_prob(Z)) + tf.log(1.0 - prior.pi_mix)
    p_ll = tf.reduce_logsumexp([mix1,mix2])
    
    #Compute the sample of the KL distance as the substaction ob both
    KL = q_ll -  p_ll
    
    return KL
Exemplo n.º 18
0
 def _log_cdf(self, x):
   with tf.control_dependencies(self._assertions):
     x = tf.convert_to_tensor(x, name="x")
     distribution_log_cdfs = [d.log_cdf(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_cdfs = [
         cat_lp + d_lcdf
         for (cat_lp, d_lcdf) in zip(cat_log_probs, distribution_log_cdfs)
     ]
     concatted_log_cdfs = tf.stack(final_log_cdfs, axis=0)
     mixture_log_cdf = tf.reduce_logsumexp(concatted_log_cdfs, [0])
     return mixture_log_cdf
Exemplo n.º 19
0
 def _log_prob(self, x):
   with tf.control_dependencies(self._assertions):
     x = tf.convert_to_tensor(x, name="x")
     distribution_log_probs = [d.log_prob(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_probs = [
         cat_lp + d_lp
         for (cat_lp, d_lp) in zip(cat_log_probs, distribution_log_probs)
     ]
     concat_log_probs = tf.stack(final_log_probs, 0)
     log_sum_exp = tf.reduce_logsumexp(concat_log_probs, [0])
     return log_sum_exp
Exemplo n.º 20
0
 def _forward_log_det_jacobian(self, x):
   # This code is similar to tf.nn.log_softmax but different because we have
   # an implicit zero column to handle. I.e., instead of:
   #   reduce_sum(logits - reduce_sum(exp(logits), dim))
   # we must do:
   #   log_normalization = 1 + reduce_sum(exp(logits))
   #   -log_normalization + reduce_sum(logits - log_normalization)
   log_normalization = tf.nn.softplus(
       tf.reduce_logsumexp(x, axis=-1, keep_dims=True))
   return tf.squeeze(
       (-log_normalization + tf.reduce_sum(
           x - log_normalization, axis=-1, keepdims=True)),
       axis=-1)
Exemplo n.º 21
0
  def marginal_log_prob(self, x, **kwargs):
    'The marginal log probability of the observed variable. Sums out `cat`.'
    batch_event_rank = self.event_shape.ndims + self.batch_shape.ndims
    # expand x to broadcast log probs over num_components dimension
    expanded_x = tf.expand_dims(x, -1 - batch_event_rank)
    log_probs = self.components.log_prob(expanded_x)

    p_ndims = self.cat.probs.shape.ndims
    perm = tf.concat([[p_ndims - 1], tf.range(p_ndims - 1)], 0)
    transposed_p = tf.transpose(self.cat.probs, perm)

    return tf.reduce_logsumexp(log_probs + tf.log(transposed_p),
                               -1 - batch_event_rank)
Exemplo n.º 22
0
  def while_step(t, state, tas, log_weights_acc, log_z_hat_acc):
    """Implements one timestep of the particle filter."""
    particle_state, loop_state = state
    cur_mask = nested.read_tas(mask_ta, t)
    # Propagate the particles one step.
    log_alpha, new_particle_state, loop_args = transition(particle_state, t)
    # Update the current weights with the incremental weights.
    log_alpha *= cur_mask
    log_alpha = tf.reshape(log_alpha, [num_particles, batch_size])
    log_weights_acc += log_alpha

    should_resample = resampling_criterion(log_weights_acc, t)

    if resampling_criterion == never_resample_criterion:
      resampled = tf.to_float(should_resample)
    else:
      # Compute the states as if we did resample.
      resampled_states = resampling_fn(
          log_weights_acc,
          new_particle_state,
          num_particles,
          batch_size)
      # Decide whether or not we should resample; don't resample if we are past
      # the end of a sequence.
      should_resample = tf.logical_and(should_resample,
                                       cur_mask[:batch_size] > 0.)
      float_should_resample = tf.to_float(should_resample)
      new_particle_state = nested.where_tensors(
          tf.tile(should_resample, [num_particles]),
          resampled_states,
          new_particle_state)
      resampled = float_should_resample

    new_loop_state = loop_fn(loop_state, loop_args, new_particle_state,
                             log_weights_acc, resampled, cur_mask, t)
    # Update log Z hat.
    log_z_hat_update = tf.reduce_logsumexp(
        log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles))
    # If it is the last timestep, always add the update.
    log_z_hat_acc += tf.cond(t < max_num_steps - 1,
                             lambda: log_z_hat_update * resampled,
                             lambda: log_z_hat_update)
    # Update the TensorArrays before we reset the weights so that we capture
    # the incremental weights and not zeros.
    ta_updates = [log_weights_acc, resampled]
    new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
    # For the particle filters that resampled, reset weights to zero.
    log_weights_acc *= (1. - tf.tile(resampled[tf.newaxis, :],
                                     [num_particles, 1]))
    new_state = (new_particle_state, new_loop_state)
    return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc
  def _log_prob(self, y):
    # For caching to work, it is imperative that the bijector is the first to
    # modify the input.
    x = self.bijector.inverse(y)
    event_ndims = self._maybe_get_static_event_ndims()

    ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
    if self.bijector._is_injective:  # pylint: disable=protected-access
      return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims)

    lp_on_fibers = [
        self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
        for x_i, ildj_i in zip(x, ildj)]
    return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
 def testLogCdf(self):
   gm = tfd.MixtureSameFamily(
       mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]),
       components_distribution=tfd.Normal(
           loc=[-1., 1], scale=[0.1, 0.5]))
   x = gm.sample(10, seed=42)
   actual_log_cdf = gm.log_cdf(x)
   expected_log_cdf = tf.reduce_logsumexp(
       (gm.mixture_distribution.logits + gm.components_distribution.log_cdf(
           x[..., tf.newaxis])),
       axis=1)
   actual_log_cdf_, expected_log_cdf_ = self.evaluate(
       [actual_log_cdf, expected_log_cdf])
   self.assertAllClose(actual_log_cdf_, expected_log_cdf_, rtol=1e-6, atol=0.0)
Exemplo n.º 25
0
def iwae(model, observation, num_timesteps, num_samples=1,
         summarize=False):
  """Compute the IWAE evidence lower bound.

  Args:
    model: A callable that computes one timestep of the model.
    observation: A shape [batch_size*num_samples, state_size] Tensor
      containing z_n, the observation for each sequence in the batch.
    num_timesteps: The number of timesteps in each sequence, an integer.
    num_samples: The number of samples to use to compute the IWAE bound.
  Returns:
    log_p_hat: The IWAE estimator of the lower bound on the log marginal.
    loss: A tensor that you can perform gradient descent on to optimize the
      bound.
    maintain_ema_op: A no-op included for compatibility with FIVO.
    states: The sequence of states sampled.
  """
  # Initialization
  num_instances = tf.shape(observation)[0]
  batch_size = tf.cast(num_instances / num_samples, tf.int32)
  states = [model.zero_state(num_instances)]
  log_weights = []
  log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)

  for t in xrange(num_timesteps):
    # run the model for one timestep
    (zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
        states[-1], observation, t)
    # update accumulators
    states.append(zt)
    log_weight = log_p_zt + log_p_x_given_z - log_q_zt
    log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
    if summarize:
      weight_dist = tf.contrib.distributions.Categorical(
          logits=tf.transpose(log_weight_acc, perm=[1, 0]),
          allow_nan_stats=False)
      weight_entropy = weight_dist.entropy()
      weight_entropy = tf.reduce_mean(weight_entropy)
      tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
    log_weights.append(log_weight_acc)
  # Compute the lower bound on the log evidence.
  log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
               tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
  loss = -tf.reduce_mean(log_p_hat)
  losses = [Loss("log_p_hat", loss)]

  # we clip off the initial state before returning.
  # there are no emas for iwae, so we return a noop for that
  return log_p_hat, losses, tf.no_op(), states[1:], log_weights
Exemplo n.º 26
0
  def _log_prob(self, value):
    with tf.control_dependencies(self._runtime_assertions):
      # The argument `value` is a tensor of sequences of observations.
      # `observation_batch_shape` is the shape of that tensor with the
      # sequence part removed.
      # `observation_batch_shape` is then broadcast to the full batch shape
      # to give the `working_shape` that defines the shape of the result.

      observation_batch_shape = tf.shape(
          value)[:-1 - self._underlying_event_rank]
      # value :: observation_batch_shape num_steps observation_event_shape
      working_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())
      log_init = tf.broadcast_to(self._log_init,
                                 tf.concat([working_shape,
                                            [self._num_states]], axis=0))
      # log_init :: working_shape num_states
      log_transition = self._log_trans

      # `observation_event_shape` is the shape of each sequence of observations
      # emitted by the model.
      observation_event_shape = tf.shape(
          value)[-1 - self._underlying_event_rank:]
      working_obs = tf.broadcast_to(value,
                                    tf.concat([working_shape,
                                               observation_event_shape],
                                              axis=0))
      # working_obs :: working_shape observation_event_shape
      r = self._underlying_event_rank

      # Move index into sequence of observations to front so we can apply
      # tf.foldl
      working_obs = util.move_dimension(working_obs,
                                        -1 - r, 0)[..., tf.newaxis]
      # working_obs :: num_steps working_shape underlying_event_shape
      observation_probs = (
          self._observation_distribution.log_prob(working_obs))

      def forward_step(log_prev_step, log_observation):
        return _log_vector_matrix(log_prev_step,
                                  log_transition) + log_observation

      fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init)
      # fwd_prob :: working_shape num_states

      log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
      # log_prob :: working_shape

      return log_prob
Exemplo n.º 27
0
def logsumexp(v, reduction_indices=None, keep_dims=False):
    if float(tf.__version__[:4]) > 0.10: # reduce_logsumexp does not exist below tfv0.11
        if isinstance(reduction_indices, int): # due to a bug in tfv0.11
            reduction_indices = [reduction_indices]
        return handle_inf(
                 tf.reduce_logsumexp(v,
                  reduction_indices, # this is a bit fragile. reduction_indices got renamed to axis in tfv0.12
                  keep_dims=keep_dims)
                 )
    else:
        m = tf.reduce_max(v, reduction_indices=reduction_indices, keep_dims=keep_dims)
        # Use SMALL_NUMBER to handle v = []
        return m + tf.log(tf.reduce_sum(tf.exp(v - m), 
                        reduction_indices=reduction_indices,
                        keep_dims=keep_dims) + SMALL_NUMBER)
Exemplo n.º 28
0
def SampledSoftmaxLoss(features, sampler, num_classes, target_classes,
                       target_params, sampled_classes, sampled_params):
  """Loss for training softmax classifiers on large label vocabulary.

  This function assumes that we have already chosen the sampled classes and
  fetched the parameters for the target classes and the sampled classes.

  Args:
    features: a Tensor with shape [batch_size, hidden_size]
    sampler: a candidate sampler object
    num_classes: an integer
    target_classes: an integer Tensor with shape [batch_size]
    target_params: a Tensor with shape [batch_size, hidden_size]
      The parameters corresponding to the target classes.
    sampled_classes: an integer tensor with shape [num_sampled_classes]
    sampled_params: a Tensor with shape [num_sampled_classes, hidden_size]
      The parameters corresponding to the sampled classes.

  Returns:
    a Tensor with shape [batch_size]
  """
  sampled_logits = (tf.matmul(features, sampled_params, transpose_b=True) -
                    sampler.log_expected_count(sampled_classes))
  target_logits = (tf.reduce_sum(target_params * features, 1) -
                   sampler.log_expected_count(target_classes))
  sampled_log_denominator = tf.reduce_logsumexp(
      sampled_logits, [1], name='SampledLogDenominator')
  sampled_classes_mask = tf.unsorted_segment_sum(
      tf.fill(tf.shape(sampled_classes), float('-inf')), sampled_classes,
      num_classes)
  target_log_denominator = (
      target_logits + tf.gather(sampled_classes_mask, target_classes))
  combined_log_denominator = tf.reduce_logsumexp(
      tf.stack([sampled_log_denominator, target_log_denominator]), [0])
  loss = combined_log_denominator - target_logits
  return loss
Exemplo n.º 29
0
 def _log_prob(self, x):
   # By convention, we always put the grid points right-most.
   y = tf.stack([aff.inverse(x) for aff in self.interpolated_affine], axis=-1)
   log_prob = tf.reduce_sum(self.distribution.log_prob(y), axis=-2)
   # Because the affine transformation has a constant Jacobian, it is the case
   # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general.
   fldj = tf.stack(
       [
           aff.forward_log_det_jacobian(
               x, event_ndims=tf.rank(self.event_shape_tensor()))
           for aff in self.interpolated_affine
       ],
       axis=-1)
   return tf.reduce_logsumexp(
       self.mixture_distribution.logits - fldj + log_prob, axis=-1)
Exemplo n.º 30
0
    def body(handle, cost, correct, total, *arrays):
      """Runs the network and advances the state by a step."""

      with tf.control_dependencies([handle, cost, correct, total] +
                                   [x.flow for x in arrays]):
        # Get a copy of the network inside this while loop.
        updated_state = MasterState(handle, state.current_batch_size)
        network_tensors = self._feedforward_unit(
            updated_state, arrays, network_states, stride, during_training=True)

        # Every layer is written to a TensorArray, so that it can be backprop'd.
        next_arrays = update_tensor_arrays(network_tensors, arrays)
        with tf.control_dependencies([x.flow for x in next_arrays]):
          with tf.name_scope('compute_loss'):
            # A gold label > -1 determines that the sentence is still
            # in a valid state. Otherwise, the sentence has ended.
            #
            # We add only the valid sentences to the loss, in the following way:
            #   1. We compute 'valid_ix', the indices in gold that contain
            #      valid oracle actions.
            #   2. We compute the cost function by comparing logits and gold
            #      only for the valid indices.
            gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
            gold.set_shape([None])
            valid = tf.greater(gold, -1)
            valid_ix = tf.reshape(tf.where(valid), [-1])
            gold = tf.gather(gold, valid_ix)

            logits = self.network.get_logits(network_tensors)
            logits = tf.gather(logits, valid_ix)

            cost += tf.reduce_sum(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.cast(gold, tf.int64), logits=logits))

            if (self.eligible_for_self_norm and
                self.master.hyperparams.self_norm_alpha > 0):
              log_z = tf.reduce_logsumexp(logits, [1])
              cost += (self.master.hyperparams.self_norm_alpha *
                       tf.nn.l2_loss(log_z))

            correct += tf.reduce_sum(
                tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
            total += tf.size(gold)

        with tf.control_dependencies([cost, correct, total, gold]):
          handle = dragnn_ops.advance_from_oracle(handle, component=self.name)
        return [handle, cost, correct, total] + next_arrays
Exemplo n.º 31
0
    def _build(self, model): #, drop_one_logit=False):

        n_samples=model.n_samples_ph
        y = model.y
        shaper = tf.shape(y)
        distr = model.prediction_distr

        if self._use_alpha:
            
            if self._alpha_parameter!=0:
                y_tile = tf.tile(y, [n_samples, 1])
                loss_core = -distr.log_prob(y_tile)
                #loss_per_minibatch = tf.exp(tf.scalar_mul(self._alpha_parameter,distr.log_prob(y_tile)))
                #loss_per_minibatch_reshaped=tf.reshape(loss_per_minibatch, (alpha_samples,shaper[0]))
                #loss_per_minibatch_avg=tf.reduce_mean(loss_per_minibatch_reshaped,axis=0)
                #loss_per_sample=tf.scalar_mul(-1./self._alpha_parameter,tf.log(loss_per_minibatch_avg))
                loss_per_minibatch = tf.scalar_mul(self._alpha_parameter,distr.log_prob(y_tile))
                #import pdb; pdb.set_trace()
                loss_per_minibatch_reshaped=tf.reshape(loss_per_minibatch, (n_samples, shaper[0]))
                loss_per_minibatch_avg=tf.reduce_logsumexp(loss_per_minibatch_reshaped,axis=0)
                loss_per_sample=tf.scalar_mul(-1./self._alpha_parameter,loss_per_minibatch_avg)
            else:
                y_tile = tf.tile(y, [n_samples, 1])
                loss_core = -distr.log_prob(y_tile)
                loss_per_minibatch = -distr.log_prob(y_tile)
                loss_per_minibatch_reshaped=tf.reshape(loss_per_minibatch, (n_samples, shaper[0]))
                loss_per_sample=tf.reduce_mean(loss_per_minibatch_reshaped, axis=0)

        else:
            loss_per_sample = -distr.log_prob(y)
            loss_core = loss_per_sample

        nll = tf.reduce_mean(loss_per_sample, name="nll")
        kl_losses = model.kl_losses
        total_KL = tf.reduce_sum(kl_losses) / model.dataset.n_samples_train
        loss = nll + total_KL
        nll_core = tf.reduce_mean(loss_core, name="nll_core")

        # in case of Bayesian network I need to add kl_losses for the weights if I want to see them
        # (otherwise kl_losses will be an empty list for non bayesian predictions)
        # if kl_losses:
        #
        #     KL_i_names = ["KL_" + str(int(i+1)) for i, l in enumerate(kl_losses)]
        #
        #     nodes_to_log = [[loss],
        #                     [nll],
        #                     # [total_KL],
        #                     # kl_losses
        #                     ]
        #
        #     names_of_nodes_to_log = [["loss"],
        #                              ["NLL"],
        #                              # ["total_KL"],
        #                              # KL_i_names
        #                              ]
        #
        #     filenames_to_log_to = [{"fileName" : "loss"},
        #                             {"fileName" : "negloglikelihood"},
        #                             # {"fileName" : "total_KL"},
        #                             # {"fileName" : "all_KLs", "legend": 0}
        #                            ]
        #
        # else:


        means = model.prediction_mean
        # if self._use_alpha:
        means=tf.reshape(means, (n_samples,shaper[0],shaper[1]))
        means=tf.reduce_mean(means,axis=0)
        # else:
        #     pass

        mse_per_sample = tf.reduce_sum(tf.square(y - means), axis=1)
        mse = tf.reduce_mean(mse_per_sample)

        # First panel will be at screen during training
        list_of_vpanels_of_plots = [
            [
                    {
                        'nodes' : [loss],
                        'names': ["loss"],
                        'output': {'fileName' : "loss"}
                    },

                    {
                        'nodes': [nll],
                        'names': ["NLL"],
                        'output': {'fileName': "negloglikelihood"}
                    },

                    {
                        'nodes': [mse],
                        'names': ["mse"],
                        'output': {'fileName': "mse"}
                    }
            ]
        ]

        nodes_to_log, names_of_nodes_to_log, filenames_to_log_to = create_panels_lists(list_of_vpanels_of_plots)

        # nodes_to_log = [[loss], [nll], [mse], [loss_core]]
        #
        # names_of_nodes_to_log = [["loss"], ["NLL"], ["MSE"], ["loss_core"]]
        #
        # filenames_to_log_to = [{"fileName": "loss"},
        #                        {"fileName": "negloglikelihood"},
        #                        {"fileName": "mse"},
        #                        {"fileName": "loss_core"}
        #                        ]

        return loss, loss_per_sample, nodes_to_log, names_of_nodes_to_log, filenames_to_log_to
Exemplo n.º 32
0
    def build_loss_and_gradients(self, var_list):
        """Build loss function

    $\\text{KL}( p(z \mid x) \| q(z) )
      = \mathbb{E}_{p(z \mid x)} [ \log p(z \mid x) - \log q(z; \lambda) ]$

    and stochastic gradients based on importance sampling.

    The loss function can be estimated as

    $\\frac{1}{S} \sum_{s=1}^S [
      w_{\\text{norm}}(z^s; \lambda) (\log p(x, z^s) - \log q(z^s; \lambda) ],$

    where for $z^s \sim q(z; \lambda)$,

    $w_{\\text{norm}}(z^s; \lambda) =
          w(z^s; \lambda) / \sum_{s=1}^S w(z^s; \lambda)$

    normalizes the importance weights, $w(z^s; \lambda) = p(x,
    z^s) / q(z^s; \lambda)$.

    This provides a gradient,

    $- \\frac{1}{S} \sum_{s=1}^S [
      w_{\\text{norm}}(z^s; \lambda) \\nabla_{\lambda} \log q(z^s; \lambda) ].$
    """
        p_log_prob = [0.0] * self.n_samples
        q_log_prob = [0.0] * self.n_samples
        for s in range(self.n_samples):
            # Form dictionary in order to replace conditioning on prior or
            # observed variable with conditioning on a specific value.
            scope = 'inference_' + str(id(self)) + '/' + str(s)
            dict_swap = {}
            for x, qx in six.iteritems(self.data):
                if isinstance(x, RandomVariable):
                    if isinstance(qx, RandomVariable):
                        qx_copy = copy(qx, scope=scope)
                        dict_swap[x] = qx_copy.value()
                    else:
                        dict_swap[x] = qx

            for z, qz in six.iteritems(self.latent_vars):
                # Copy q(z) to obtain new set of posterior samples.
                qz_copy = copy(qz, scope=scope)
                dict_swap[z] = qz_copy.value()
                q_log_prob[s] += tf.reduce_sum(
                    qz_copy.log_prob(tf.stop_gradient(dict_swap[z])))

            for z in six.iterkeys(self.latent_vars):
                z_copy = copy(z, dict_swap, scope=scope)
                p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z]))

            for x in six.iterkeys(self.data):
                if isinstance(x, RandomVariable):
                    x_copy = copy(x, dict_swap, scope=scope)
                    p_log_prob[s] += tf.reduce_sum(
                        x_copy.log_prob(dict_swap[x]))

        p_log_prob = tf.stack(p_log_prob)
        q_log_prob = tf.stack(q_log_prob)

        if self.logging:
            summary_key = 'summaries_' + str(id(self))
            tf.summary.scalar("loss/p_log_prob",
                              tf.reduce_mean(p_log_prob),
                              collections=[summary_key])
            tf.summary.scalar("loss/q_log_prob",
                              tf.reduce_mean(q_log_prob),
                              collections=[summary_key])

        log_w = p_log_prob - q_log_prob
        log_w_norm = log_w - tf.reduce_logsumexp(log_w)
        w_norm = tf.exp(log_w_norm)

        loss = tf.reduce_mean(w_norm * log_w)
        grads = tf.gradients(
            -tf.reduce_mean(q_log_prob * tf.stop_gradient(w_norm)), var_list)
        grads_and_vars = list(zip(grads, var_list))
        return loss, grads_and_vars
Exemplo n.º 33
0
 def log_posterior(self, x):
     logp = self.unnor_logpdf(x)
     log_posteriors = logp - tf.reduce_logsumexp(
         logp, axis=2, keepdims=True)
     return log_posteriors
Exemplo n.º 34
0
def seq_log_probs_to_word_log_probs(get_beam_outputs,
                                    get_sequence_log_probs,
                                    Nclasses,
                                    index_sequences_elements,
                                    max_targ_length,
                                    padding_value=0):
    '''
    :param get_outputs: (Nsequences x beam_width x max_prediction_length)
    :param get_sequence_log_probs: (Nsequences x beam_width)
    :param Nclasses: scalar
    :param index_sequence_elements: (sum_i^Nsequences seq_len(i) x 2), a list
        of all the (putative) non-zero indices in the tensor of sequences
    :param max_targ_length: scalar tensor
    :return: score_as_unnorm_log_probs: (sum_i^Nsequences seq_len(i) x Nclasses),
        a tensor of log probabilities for each id, de-sequenced

    A sensible set of variables for a beam search to return is the set of the K
    most probable sequences and their probabilities, where K=beam_width. (These
    sequence_log_probs are not assumed to be normalized.)

    We want to expand the log probabilities to cover *all* tokens, not just the
    K most likely.  Conceptually, this is straightforward: For each element of
    each sequence, exponentiate the log probabilities; compute the "leftover"
    probability for all ids outside the beam, and divide it up equally among
    them; compute the logarithm elementwise.  Computationally, however, it is
    more complicated, b/c an effort must be made to avoid over- and underflows.

    Furthermore, to avoid doing any serious calculations, we have to make some
    simplifying choice for how to compute the "leftover" probabilities.  Here,
    we basically assign each non-selected id probability 1/S, S=total number of
    possible sequences.  That is, we pretend that each non-selected *sequence*
    has equal probability, 1/S, and then assume (what is certainly false) that
    each non-selected token at each time step *in each beam* can be assigned to
    exactly one of these non-selected sequences.  Hence e.g., even if token 324
    appears in at least one beam at time step t, it will still be assigned
    probability 1/S at t in all beams where it did *not* appear.  This
    facilitates summing log probabilities across the beams.

    Total number of sequences: For simplicity, ignore the end-of-sequence
    tokens.  For a vocabulary of size N and a maximum sequence length of M,
    there are N possible sequences that end at the first step; N^2 that end
    at the second step; and so forth up to N^M. Thus altogether there are
            N^1 + N^2 + N^3 + ... + N^M
        =   N^0 + N^1 + N^2 + N^3 + ... + N^M - 1
        =   (N^(M+1) - 1)/(N - 1) - 1
        ~=  N^M
    sequences, where the approximation follows from the fact that, for N or M
    of any reasonable size, the -1s don't matter.  Likewise, subtracting out
    the K in-beam sequences has no appreciable effect for any reasonable K.
    Hence the probability of each out-of-beam sequence is approximately N^-M,
    or again:
        log(out_beam_prob) = -M*log(N)

    Given the approximations, and more importantly since no attempt is made to
    decrease the in-beam probabilities by the probability assigned to out-of-
    beam ids, the result of logsumexp will be *unnormalized* log probabilities.
    These values are furthermore desequenced into shape
        (sum_i^Ncases targ_seq_len(i) x Nclasses)
    before returning.
    '''

    # one-hotify and scale by log probabilities
    #   -> (Ncases x beam_width x max_pred_length x Nclasses)
    # NB that the resulting tensor does *not* represent log probs, b/c it has
    #  *zeros* in the out-of-beam locations
    in_beam_log_probs = tf.multiply(
        tf.one_hot(get_beam_outputs, Nclasses, axis=-1),
        tf.expand_dims(tf.expand_dims(get_sequence_log_probs, axis=-1),
                       axis=-1))

    # pad out to max_targ_length
    #   -> (Ncases x beam_width x max_targ_length x Nclasses)
    max_pred_length = common_layers.shape_list(get_beam_outputs)[2]
    in_beam_log_probs = tf.pad(
        tensor=in_beam_log_probs,
        paddings=[[0, 0], [0, 0],
                  [0, tf.maximum(max_targ_length - max_pred_length, 0) + 1],
                  [0, 0]],
        constant_values=padding_value)
    ###
    # This assumes the pad token=0.  Ideally, you'd pass this in explicitly,
    #  and then set constant_values=<pad value> in tf.pad.
    ###

    # fill in zeros with (approximate) out-of-beam log probs (see above)
    out_beam_log_prob = tf.multiply(tf.cast(-max_targ_length, tf.float32),
                                    tf.math.log(tf.cast(Nclasses, tf.float32)))
    out_beam_log_probs = tf.fill(common_layers.shape_list(in_beam_log_probs),
                                 out_beam_log_prob)
    IS_OUT_OF_BEAM = tf.equal(in_beam_log_probs, 0)
    beam_log_probs = tf.compat.v1.where(IS_OUT_OF_BEAM, out_beam_log_probs,
                                        in_beam_log_probs)

    # collapse across beam -> (Ncases x max_targ_length x Nclasses)
    score_as_unnorm_log_probs = tf.reduce_logsumexp(beam_log_probs, axis=1)

    # de-sequence -> (sum_i^Ncases targ_seq_len(i) x Nclasses)
    score_as_unnorm_log_probs = tf.gather_nd(score_as_unnorm_log_probs,
                                             index_sequences_elements)

    return score_as_unnorm_log_probs
Exemplo n.º 35
0
    def __init__(self, s_size, a_size, scope, trainer):
        with tf.variable_scope(scope):
            self.kernel = adaptive_isotropic_gaussian_kernel
            self.inputs = tf.placeholder(shape=[None, s_size],
                                         dtype=tf.float32)
            self.imageIn = tf.reshape(self.inputs, shape=[-1, 84, 84, 1])
            self.conv1 = slim.conv2d(activation_fn=tf.nn.relu,
                                     inputs=self.imageIn,
                                     num_outputs=32,
                                     kernel_size=[8, 8],
                                     stride=[4, 4],
                                     padding='VALID')
            self.conv2 = slim.conv2d(activation_fn=tf.nn.relu,
                                     inputs=self.conv1,
                                     num_outputs=64,
                                     kernel_size=[4, 4],
                                     stride=[2, 2],
                                     padding='VALID')
            self.conv3 = slim.conv2d(activation_fn=tf.nn.relu,
                                     inputs=self.conv2,
                                     num_outputs=64,
                                     kernel_size=[3, 3],
                                     stride=[1, 1],
                                     padding='VALID')
            hidden = slim.fully_connected(slim.flatten(self.conv3),
                                          512,
                                          activation_fn=tf.nn.relu)

            self.policy = slim.fully_connected(
                hidden,
                a_size,
                activation_fn=tf.nn.softmax,
                weights_initializer=normalized_columns_initializer(0.01),
                biases_initializer=None)
            self.q = slim.fully_connected(
                hidden,
                a_size,
                activation_fn=None,
                weights_initializer=normalized_columns_initializer(0.01),
                biases_initializer=None)

            if scope != 'global':
                local_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope)
                global_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, 'global')
                self.actions = tf.placeholder(shape=[None], dtype=tf.int32)
                self.actions_onehot = tf.one_hot(self.actions,
                                                 a_size,
                                                 dtype=tf.float32)
                #self.rewards = tf.placeholder(shape=[None],dtype=tf.float32)

                #  td error update
                q_a = self.q / tau
                self.v_next = tf.reduce_logsumexp(q_a, axis=1)
                self.q_target = tf.placeholder(shape=[None], dtype=tf.float32)
                self.readout_action = tf.reduce_sum(tf.multiply(
                    self.q, self.actions_onehot),
                                                    axis=1)
                self.td_loss = tf.reduce_mean(
                    0.5 * tf.square(self.q_target - self.readout_action))

                #  svgd update
                self.ai = tf.placeholder(shape=[None, k], dtype=tf.int32)
                self.aj = tf.placeholder(shape=[None, k], dtype=tf.int32)
                self.ai_onehot = tf.one_hot(self.ai, a_size, dtype=tf.float32)
                self.aj_onehot = tf.one_hot(self.aj, a_size, dtype=tf.float32)
                self.readout_actionj = self.aj_onehot * tf.expand_dims(
                    self.policy, axis=1)
                self.readout_actioni = self.ai_onehot * tf.expand_dims(
                    self.policy, axis=1)

                self.Q_soft = tf.expand_dims(self.q,
                                             axis=1) * self.readout_actioni
                Q_soft_grad = tf.gradients(self.Q_soft,
                                           self.readout_actioni)[0]
                self.Q_soft_grad = tf.expand_dims(Q_soft_grad, axis=2)
                self.Q_soft_grad = tf.stop_gradient(self.Q_soft_grad)

                self.readout_actioni = tf.stop_gradient(self.readout_actioni)
                self.kernel, self.kernel_grad = self.kernel(
                    self.readout_actioni, self.readout_actionj)
                self.kernel = tf.expand_dims(self.kernel, axis=3)

                self.action_gradient = tf.reduce_mean(
                    self.kernel * self.Q_soft_grad + self.kernel_grad, axis=1)
                self.action_gradients = tf.stop_gradient(self.action_gradient)

                self.su = tf.reduce_sum(tf.reduce_sum(self.action_gradients *
                                                      self.readout_actionj,
                                                      axis=2),
                                        axis=1)
                self.surrogate_loss = -tf.reduce_mean(self.su)

                #action_gradients = tf.reduce_mean(self.kernel * self.Q_soft_grad + self.kernel_grad, axis=1)
                #self.gradient = tf.gradients(self.readout_actionj, local_vars, grad_ys=action_gradients)
                #self.surrogate_loss = -tf.reduce_sum(local_vars * tf.stop_gradient(self.gradient))

                # total loss
                self.loss = self.td_loss + self.surrogate_loss
                self.gradients = tf.gradients(self.loss, local_vars)
                self.var_norms = tf.global_norm(local_vars)
                grads, self.grad_norms = tf.clip_by_global_norm(
                    self.gradients, 40.0)
                self.apply_grads = trainer.apply_gradients(
                    zip(grads, global_vars))
Exemplo n.º 36
0
    def _build_graph(self):
        self.context_word = tf.placeholder(tf.int32,
                                           [self.document_size, None, None])
        self.context_len = tf.placeholder(tf.int32, [self.document_size, None])

        self.context_char = tf.placeholder(
            tf.int32, [self.document_size, None, None, None])
        self.context_word_len = tf.placeholder(
            tf.int32, [self.document_size, None, None])

        self.question_word = tf.placeholder(tf.int32, [None, None])
        self.question_len = tf.placeholder(tf.int32, [None])

        self.question_char = tf.placeholder(tf.int32, [None, None, None])
        self.question_word_len = tf.placeholder(tf.int32, [None, None])

        self.answer_start = tf.placeholder(tf.int32,
                                           [self.document_size, None])
        self.answer_end = tf.placeholder(tf.int32, [self.document_size, None])
        self.abstractive_answer_mask = tf.placeholder(
            tf.int32, [self.document_size, None, self.abstractive_answer_num])
        self.training = tf.placeholder(tf.bool, [])

        self.question_tokens = tf.placeholder(tf.string, [None, None])
        self.context_tokens = tf.placeholder(tf.string,
                                             [self.document_size, None, None])

        # 1. Word encoding
        word_embedding = Embedding(
            pretrained_embedding=self.pretrained_word_embedding,
            embedding_shape=(len(self.vocab.get_word_vocab()) + 1,
                             self.word_embedding_size),
            trainable=self.word_embedding_trainable)
        char_embedding = Embedding(
            embedding_shape=(len(self.vocab.get_char_vocab()) + 1,
                             self.char_embedding_size),
            trainable=True,
            init_scale=0.05)

        # 1.1 Embedding
        dropout = Dropout(self.keep_prob)
        context_word_repr = word_embedding(self.context_word)
        context_char_repr = char_embedding(self.context_char)
        question_word_repr = word_embedding(self.question_word)
        question_char_repr = char_embedding(self.question_char)
        if self.use_elmo:
            elmo_emb = ElmoEmbedding(local_path=self.elmo_local_path)
            context_elmo_repr = elmo_emb(self.context_tokens, self.context_len)
            context_elmo_repr = dropout(context_elmo_repr, self.training)
            question_elmo_repr = elmo_emb(self.question_tokens,
                                          self.question_len)
            question_elmo_repr = dropout(question_elmo_repr, self.training)

        # 1.2 Char convolution
        conv1d = Conv1DAndMaxPooling(self.char_conv_filters,
                                     self.char_conv_kernel_size)
        if self.max_pooling_mask:
            question_char_repr = conv1d(
                dropout(question_char_repr, self.training),
                self.question_word_len)
            context_char_repr = conv1d(
                dropout(context_char_repr, self.training),
                self.context_word_len)
        else:
            question_char_repr = conv1d(
                dropout(question_char_repr, self.training))
            context_char_repr = conv1d(
                dropout(context_char_repr, self.training))

        # 2. Phrase encoding
        context_embs = [context_word_repr, context_char_repr]
        question_embs = [question_word_repr, question_char_repr]
        if self.use_elmo:
            context_embs.append(context_elmo_repr)
            question_embs.append(question_elmo_repr)

        context_repr = tf.concat(context_embs, axis=-1)
        question_repr = tf.concat(question_embs, axis=-1)

        variational_dropout = VariationalDropout(self.keep_prob)
        emb_enc_gru = CudnnBiGRU(self.rnn_hidden_size)

        context_repr = variational_dropout(context_repr, self.training)
        context_repr, _ = emb_enc_gru(context_repr, self.context_len)
        context_repr = variational_dropout(context_repr, self.training)

        question_repr = variational_dropout(question_repr, self.training)
        question_repr, _ = emb_enc_gru(question_repr, self.question_len)
        question_repr = variational_dropout(question_repr, self.training)

        # 3. Bi-Attention
        bi_attention = BiAttention(
            TriLinear(bias=True, name="bi_attention_tri_linear"))
        c2q, q2c = bi_attention(context_repr, question_repr, self.context_len,
                                self.question_len)
        context_repr = tf.concat(
            [context_repr, c2q, context_repr * c2q, context_repr * q2c],
            axis=-1)

        # 4. Self-Attention layer
        dense1 = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                       use_bias=True,
                                       activation=tf.nn.relu)
        gru = CudnnBiGRU(self.rnn_hidden_size)
        dense2 = tf.keras.layers.Dense(self.rnn_hidden_size * 2,
                                       use_bias=True,
                                       activation=tf.nn.relu)
        self_attention = SelfAttention(
            TriLinear(bias=True, name="self_attention_tri_linear"))

        inputs = dense1(context_repr)
        outputs = variational_dropout(inputs, self.training)
        outputs, _ = gru(outputs, self.context_len)
        outputs = variational_dropout(outputs, self.training)
        c2c = self_attention(outputs, self.context_len)
        outputs = tf.concat([c2c, outputs, c2c * outputs],
                            axis=len(c2c.shape) - 1)
        outputs = dense2(outputs)
        context_repr = inputs + outputs
        context_repr = variational_dropout(context_repr, self.training)

        # 5. Modeling layer
        sum_max_encoding = SumMaxEncoder()
        context_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        context_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru1 = CudnnBiGRU(self.rnn_hidden_size)
        question_modeling_gru2 = CudnnBiGRU(self.rnn_hidden_size)
        self.max_context_len = tf.reduce_max(self.context_len)
        self.max_question_len = tf.reduce_max(self.question_len)

        modeled_context1, _ = context_modeling_gru1(context_repr,
                                                    self.context_len)
        modeled_context2, _ = context_modeling_gru2(
            tf.concat([context_repr, modeled_context1], axis=2),
            self.context_len)
        encoded_context = sum_max_encoding(modeled_context1, self.context_len,
                                           self.max_context_len)
        modeled_question1, _ = question_modeling_gru1(question_repr,
                                                      self.question_len)
        modeled_question2, _ = question_modeling_gru2(
            tf.concat([question_repr, modeled_question1], axis=2),
            self.question_len)
        encoded_question = sum_max_encoding(modeled_question2,
                                            self.question_len,
                                            self.max_question_len)

        # 6. Predictions
        start_dense = tf.keras.layers.Dense(1, activation=None)
        start_logits = tf.squeeze(start_dense(modeled_context1),
                                  squeeze_dims=[2])
        start_logits = mask_logits(start_logits, self.context_len)

        end_dense = tf.keras.layers.Dense(1, activation=None)
        end_logits = tf.squeeze(end_dense(modeled_context2), squeeze_dims=[2])
        end_logits = mask_logits(end_logits, self.context_len)

        abstractive_answer_logits = None
        if self.abstractive_answer_num != 0:
            abstractive_answer_logits = []
            for i in range(self.abstractive_answer_num):
                tri_linear = TriLinear(name="cls" + str(i))
                abstractive_answer_logits.append(
                    tf.squeeze(tri_linear(encoded_context, encoded_question),
                               squeeze_dims=[2]))
            abstractive_answer_logits = tf.concat(abstractive_answer_logits,
                                                  axis=-1)

        # 7. Loss and input/output dict
        seq_length = tf.shape(start_logits)[1]
        start_mask = tf.one_hot(self.answer_start,
                                depth=seq_length,
                                dtype=tf.float32)
        end_mask = tf.one_hot(self.answer_end,
                              depth=seq_length,
                              dtype=tf.float32)
        if self.abstractive_answer_num != 0:
            abstractive_answer_mask = tf.cast(self.abstractive_answer_mask,
                                              dtype=tf.float32)
            extractive_mask = 1. - tf.reduce_max(
                abstractive_answer_mask, axis=-1, keepdims=True)
            start_mask = extractive_mask * start_mask
            end_mask = extractive_mask * end_mask

            concated_start_masks = tf.concat(
                [start_mask, abstractive_answer_mask], axis=1)
            concated_end_masks = tf.concat([end_mask, abstractive_answer_mask],
                                           axis=1)

            concated_start_logits = tf.concat(
                [start_logits, abstractive_answer_logits], axis=1)
            concated_end_logits = tf.concat(
                [end_logits, abstractive_answer_logits], axis=1)
        else:
            concated_start_masks = start_mask
            concated_end_masks = end_mask

            concated_start_logits = start_logits
            concated_end_logits = end_logits

        start_log_norm = tf.reduce_logsumexp(concated_start_logits, axis=1)
        start_log_score = tf.reduce_logsumexp(
            concated_start_logits + VERY_NEGATIVE_NUMBER *
            (1 - tf.cast(concated_start_masks, tf.float32)),
            axis=1)
        self.start_loss = tf.reduce_mean(-(start_log_score - start_log_norm))

        end_log_norm = tf.reduce_logsumexp(concated_end_logits, axis=1)
        end_log_score = tf.reduce_logsumexp(
            concated_end_logits + VERY_NEGATIVE_NUMBER *
            (1 - tf.cast(concated_end_masks, tf.float32)),
            axis=1)
        self.end_loss = tf.reduce_mean(-(end_log_score - end_log_norm))

        self.loss = self.start_loss + self.end_loss
        global_step = tf.train.get_or_create_global_step()

        self.input_placeholder_dict = OrderedDict({
            "context_word": self.context_word,
            "question_word": self.question_word,
            "context_char": self.context_char,
            "question_char": self.question_char,
            "context_len": self.context_len,
            "question_len": self.question_len,
            "answer_start": self.answer_start,
            "answer_end": self.answer_end,
            "training": self.training
        })
        if self.max_pooling_mask:
            self.input_placeholder_dict[
                'context_word_len'] = self.context_word_len
            self.input_placeholder_dict[
                'question_word_len'] = self.question_word_len
        if self.use_elmo:
            self.input_placeholder_dict['context_tokens'] = self.context_tokens
            self.input_placeholder_dict[
                'question_tokens'] = self.question_tokens
        if self.abstractive_answer_num != 0:
            self.input_placeholder_dict[
                "abstractive_answer_mask"] = self.abstractive_answer_mask

        self.output_variable_dict = OrderedDict({
            "start_logits": start_logits,
            "end_logits": end_logits,
        })
        if self.abstractive_answer_num != 0:
            self.output_variable_dict[
                "abstractive_answer_logits"] = abstractive_answer_logits

        # 8. Metrics and summary
        with tf.variable_scope("train_metrics"):
            self.train_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.train_update_metrics = tf.group(
            *[op for _, op in self.train_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="train_metrics")
        self.train_metric_init_op = tf.variables_initializer(metric_variables)

        with tf.variable_scope("eval_metrics"):
            self.eval_metrics = {'loss': tf.metrics.mean(self.loss)}

        self.eval_update_metrics = tf.group(
            *[op for _, op in self.eval_metrics.values()])
        metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="eval_metrics")
        self.eval_metric_init_op = tf.variables_initializer(metric_variables)

        tf.summary.scalar('loss', self.loss)
        self.summary_op = tf.summary.merge_all()
 def logsumexp(x, axis=None):
     '''Returns `log(sum(exp(x), axis=axis))` with improved numerical stability.
     '''
     return tf.reduce_logsumexp(x, axis=[axis])
Exemplo n.º 38
0
def iwae(cell,
         inputs,
         seq_lengths,
         num_samples=1,
         parallel_iterations=30,
         swap_memory=True):
    """Computes the IWAE lower bound on the log marginal probability.

  This method accepts a stochastic latent variable model and some observations
  and computes a stochastic lower bound on the log marginal probability of the
  observations. The IWAE estimator is defined by averaging multiple importance
  weights. For more details see "Importance Weighted Autoencoders" by Burda
  et al. https://arxiv.org/abs/1509.00519.

  When num_samples = 1, this bound becomes the evidence lower bound (ELBO).

  Args:
    cell: A callable that implements one timestep of the model. See
      models/vrnn.py for an example.
    inputs: The inputs to the model. A potentially nested list or tuple of
      Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
      have a rank at least two and have matching shapes in the first two
      dimensions, which represent time and the batch respectively. At each
      timestep 'cell' will be called with a slice of the Tensors in inputs.
    seq_lengths: A [batch_size] Tensor of ints encoding the length of each
      sequence in the batch (sequences can be padded to a common length).
    num_samples: The number of samples to use.
    parallel_iterations: The number of parallel iterations to use for the
      internal while loop.
    swap_memory: Whether GPU-CPU memory swapping should be enabled for the
      internal while loop.

  Returns:
    log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
      log marginal probability of the observations.
    kl: A Tensor of shape [batch_size] containing the kl divergence
      from q(z|x) to p(z), averaged over samples.
    log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
      containing the log weights at each timestep. Will not be valid for
      timesteps past the end of a sequence.
    log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
      effective sample size at each timestep. Will not be valid for timesteps
      past the end of a sequence.
  """
    batch_size = tf.shape(seq_lengths)[0]
    max_seq_len = tf.reduce_max(seq_lengths)
    seq_mask = tf.transpose(tf.sequence_mask(seq_lengths,
                                             maxlen=max_seq_len,
                                             dtype=tf.float32),
                            perm=[1, 0])
    if num_samples > 1:
        inputs, seq_mask = nested.tile_tensors([inputs, seq_mask],
                                               [1, num_samples])
    inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask],
                                                max_seq_len)

    t0 = tf.constant(0, tf.int32)
    init_states = cell.zero_state(batch_size * num_samples, tf.float32)
    ta_names = ['log_weights', 'log_ess']
    tas = [
        tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
        for n in ta_names
    ]
    log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
    kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
    accs = (log_weights_acc, kl_acc)

    def while_predicate(t, *unused_args):
        return t < max_seq_len

    ########################################################################
    def while_step(t, rnn_state, tas, accs):
        """Implements one timestep of IWAE computation."""
        log_weights_acc, kl_acc = accs
        cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
        # Run the cell for one step.
        log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
            cur_inputs,
            rnn_state,
            cur_mask,
        )
        # Compute the incremental weight and use it to update the current
        # accumulated weight.
        kl_acc += kl * cur_mask
        log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
        log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
        log_weights_acc += log_alpha
        # Calculate the effective sample size.
        ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
        ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
        log_ess = ess_num - ess_denom
        # Update the  Tensorarrays and accumulators.
        ta_updates = [log_weights_acc, log_ess]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
        new_accs = (log_weights_acc, kl_acc)
        return t + 1, new_state, new_tas, new_accs
        #########################################################################

    _, _, tas, accs = tf.while_loop(while_predicate,
                                    while_step,
                                    loop_vars=(t0, init_states, tas, accs),
                                    parallel_iterations=parallel_iterations,
                                    swap_memory=swap_memory)

    log_weights, log_ess = [x.stack() for x in tas]
    final_log_weights, kl = accs
    log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
                 tf.log(tf.to_float(num_samples)))
    kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
    log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
    return log_p_hat, kl, log_weights, log_ess
Exemplo n.º 39
0
def _log_sum_sq(x, axis=None):
    """Computes log(sum(x**2))."""
    return tf.reduce_logsumexp(2. * tf.log(tf.abs(x)), axis)
Exemplo n.º 40
0
def fivo(cell,
         inputs,
         seq_lengths,
         num_samples=1,
         resampling_criterion=ess_criterion,
         parallel_iterations=30,
         swap_memory=True,
         random_seed=None):
    """Computes the FIVO lower bound on the log marginal probability.

  This method accepts a stochastic latent variable model and some observations
  and computes a stochastic lower bound on the log marginal probability of the
  observations. The lower bound is defined by a particle filter's unbiased
  estimate of the marginal probability of the observations. For more details see
  "Filtering Variational Objectives" by Maddison et al.
  https://arxiv.org/abs/1705.09279.

  When the resampling criterion is "never resample", this bound becomes IWAE.

  Args:
    cell: A callable that implements one timestep of the model. See
      models/vrnn.py for an example.
    inputs: The inputs to the model. A potentially nested list or tuple of
      Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
      have a rank at least two and have matching shapes in the first two
      dimensions, which represent time and the batch respectively. At each
      timestep 'cell' will be called with a slice of the Tensors in inputs.
    seq_lengths: A [batch_size] Tensor of ints encoding the length of each
      sequence in the batch (sequences can be padded to a common length).
    num_samples: The number of particles to use in each particle filter.
    resampling_criterion: The resampling criterion to use for this particle
      filter. Must accept the number of samples, the effective sample size,
      and the current timestep and return a boolean Tensor of shape [batch_size]
      indicating whether each particle filter should resample. See
      ess_criterion and related functions defined in this file for examples.
    parallel_iterations: The number of parallel iterations to use for the
      internal while loop. Note that values greater than 1 can introduce
      non-determinism even when random_seed is provided.
    swap_memory: Whether GPU-CPU memory swapping should be enabled for the
      internal while loop.
    random_seed: The random seed to pass to the resampling operations in
      the particle filter. Mainly useful for testing.

  Returns:
    log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
      log marginal probability of the observations.
    kl: A Tensor of shape [batch_size] containing the sum over time of the kl
      divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
      this includes kl terms from trajectories that are culled during resampling
      steps.
    log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
      containing the log weights at each timestep of the particle filter. Note
      that on timesteps when a resampling operation is performed the log weights
      are reset to 0. Will not be valid for timesteps past the end of a
      sequence.
    log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
      effective sample size of each particle filter at each timestep. Will not
      be valid for timesteps past the end of a sequence.
    resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
      particle filters resampled. Will be 1.0 on timesteps when resampling
      occurred and 0.0 on timesteps when it did not.
  """
    # batch_size represents the number of particle filters running in parallel.
    batch_size = tf.shape(seq_lengths)[0]
    max_seq_len = tf.reduce_max(seq_lengths)
    seq_mask = tf.transpose(tf.sequence_mask(seq_lengths,
                                             maxlen=max_seq_len,
                                             dtype=tf.float32),
                            perm=[1, 0])

    # Each sequence in the batch will be the input data for a different
    # particle filter. The batch will be laid out as:
    #   particle 1 of particle filter 1
    #   particle 1 of particle filter 2
    #   ...
    #   particle 1 of particle filter batch_size
    #   particle 2 of particle filter 1
    #   ...
    #   particle num_samples of particle filter batch_size
    if num_samples > 1:
        inputs, seq_mask = nested.tile_tensors([inputs, seq_mask],
                                               [1, num_samples])
    inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask],
                                                max_seq_len)

    t0 = tf.constant(0, tf.int32)
    init_states = cell.zero_state(batch_size * num_samples, tf.float32)
    ta_names = ['log_weights', 'log_ess', 'resampled']
    tas = [
        tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
        for n in ta_names
    ]
    log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
    log_p_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
    kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
    accs = (log_weights_acc, log_p_hat_acc, kl_acc)

    def while_predicate(t, *unused_args):
        return t < max_seq_len

    def while_step(t, rnn_state, tas, accs):
        """Implements one timestep of FIVO computation."""
        log_weights_acc, log_p_hat_acc, kl_acc = accs
        cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
        # Run the cell for one step.
        log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
            cur_inputs,
            rnn_state,
            cur_mask,
        )
        # Compute the incremental weight and use it to update the current
        # accumulated weight.
        kl_acc += kl * cur_mask
        log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
        log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
        log_weights_acc += log_alpha

        # Calculate the effective sample size(ESS for unnormalized weights).
        ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
        ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
        log_ess = ess_num - ess_denom

        # Calculate the ancestor indices via resampling. Because we maintain the
        # log unnormalized weights, we pass the weights in as logits, allowing
        # the distribution object to apply a softmax and normalize them.
        resampling_dist = tf.contrib.distributions.Categorical(
            logits=tf.transpose(log_weights_acc, perm=[1, 0]))
        ancestor_inds = tf.stop_gradient(
            resampling_dist.sample(sample_shape=num_samples, seed=random_seed))

        # Because the batch is flattened and laid out as discussed
        # above, we must modify ancestor_inds to index the proper samples.
        # The particles in the ith filter are distributed every batch_size rows
        # in the batch, and offset i rows from the top. So, to correct the indices
        # we multiply by the batch_size and add the proper offset. Crucially,
        # when ancestor_inds is flattened the layout of the batch is maintained.
        offset = tf.expand_dims(tf.range(batch_size), 0)
        ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
        noresample_inds = tf.range(num_samples * batch_size)

        # Decide whether or not we should resample; don't resample if we are past
        # the end of a sequence.
        # should_resample = resampling_criterion(log_weights_acc, log_ess, t)
        should_resample = -tf.reduce_sum(
            log_weights_acc, axis=0) >= tf.reduce_sum(num_samples / 2.0)
        ## GIVEN
        # should_resample = resampling_criterion(num_samples, log_ess, t)

        should_resample = tf.logical_and(should_resample,
                                         cur_mask[:batch_size] > 0.)
        float_should_resample = tf.to_float(should_resample)
        ancestor_inds = tf.where(tf.tile(should_resample, [num_samples]),
                                 ancestor_inds, noresample_inds)
        new_state = nested.gather_tensors(new_state, ancestor_inds)

        # Update the TensorArrays before we reset the weights so that we capture
        # the incremental weights and not zeros.
        ta_updates = [log_weights_acc, log_ess, float_should_resample]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]

        # For the particle filters that resampled, update log_p_hat and
        # reset weights to zero.
        log_p_hat_update = tf.reduce_logsumexp(
            log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
        log_p_hat_acc += log_p_hat_update * float_should_resample
        log_weights_acc *= (
            1. -
            tf.tile(float_should_resample[tf.newaxis, :], [num_samples, 1]))
        new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)

        return t + 1, new_state, new_tas, new_accs

    _, _, tas, accs = tf.while_loop(while_predicate,
                                    while_step,
                                    loop_vars=(t0, init_states, tas, accs),
                                    parallel_iterations=parallel_iterations,
                                    swap_memory=swap_memory)

    log_weights, log_ess, resampled = [x.stack() for x in tas]
    final_log_weights, log_p_hat, kl = accs

    # Add in the final weight update to log_p_hat( objective loss).
    log_p_hat += (tf.reduce_logsumexp(final_log_weights, axis=0) -
                  tf.log(tf.to_float(num_samples)))

    kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
    log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
    return log_p_hat, kl, log_weights, log_ess, resampled
Exemplo n.º 41
0
    def call(self, X):
        """ Calculates the log-likelihood of datapoint(s) with M-dimensions
            In matrix-multiplication format like z . (A (x) B)^T . (C (x) D)^T ...
        
        Input
            X   :    Datapoint(s) in M-dimensions.
        
        Return
            log likelihoods of data
        """

        # Ensure dimension of data is the same as model
        if X.shape[1] != self.M:
            raise Exception('Dataset has wrong dimensions')
        X = tf.cast(tf.reshape(X, (-1, self.M)), tf.float32)

        # Go from logits -> weights
        wk0 = tf.nn.softmax(self.wk0_logits, axis=1)  # axis 1 as it is (1, K0)
        W = [tf.nn.softmax(self.W_logits[i], axis=0) for i in range(self.M)]

        # Modify params
        params = self.fix_params()

        if self.M < 7:
            # ######### Multiply in exp_domain
            product = tf.eye(wk0.shape[1])  # start out with identity matrix
            for i in range(self.M):
                result = tfm.exp(
                    tfm.log(W[i]) + self.dists[i](*params[i]).log_prob(
                        # Make data broadcastable into (n, km, kn)
                        X[:, tf.newaxis, tf.newaxis, i])
                )  # intermediary calculation in log-domain -> exp after.
                # Keep batch dimension in place, transpose matrices
                product = product @ tf.transpose(result, perm=[0, 2, 1])

            # In order: Squeeze (n, 1, k_last) -> (n, k_last).
            # Reduce sum over k_last into (n, )
            # Squeeze result to (n, ) if n > 1 or () if n == 1
            likelihoods = tf.squeeze(
                tf.reduce_sum(tf.squeeze(wk0 @ product, axis=1), axis=1))
            # add small number to avoid nan
            log_likelihoods = tfm.log(likelihoods + np.finfo(np.float64).eps)
        else:
            ######### Multiply in log_domain
            # Inner product
            product = tfm.log(W[0]) + self.dists[0](*params[0]).log_prob(
                X[:, tf.newaxis, tf.newaxis, 0])
            product = tf.transpose(product, perm=[0, 2, 1])
            for i in range(1, self.M):
                result = tfm.log(W[i]) + self.dists[i](*params[i]).log_prob(
                    X[:, tf.newaxis, tf.newaxis, i])
                product = tf.reduce_logsumexp(
                    product[:, :, :, tf.newaxis] +
                    tf.transpose(result, perm=[0, 2, 1])[:, tf.newaxis, :, :],
                    axis=2)

            # Multiply with wk0
            prod = tf.squeeze(tfm.reduce_logsumexp(
                tfm.log(wk0[:, :, tf.newaxis]) + product[:, tf.newaxis, :, :],
                axis=2),
                              axis=1)
            log_likelihoods = tf.reduce_logsumexp(prod, axis=1)
        return log_likelihoods
Exemplo n.º 42
0
    def __init__(
            self,
            *,
            env_spec,  # No good default, but we do need to have it
            expert_trajs=None,
            reward_arch=cnn_net,
            reward_arch_args={},
            value_fn_arch=cnn_net,
            score_discrim=False,
            discount=1.0,
            state_only=False,
            max_itrs=100,
            fusion=False,
            name='airl',
            drop_framestack=False,
            only_show_scores=False,
            rescore_expert_trajs=True,
            encoder_loc=None):
        super(AIRL, self).__init__()

        # Write down everything that we're going to need in order to restore
        # this. All of these arguments are serializable, so it's pretty easy
        self.init_args = dict(model=AtariAIRL,
                              env_spec=env_spec,
                              expert_trajs=expert_trajs,
                              reward_arch=reward_arch,
                              reward_arch_args=reward_arch_args,
                              value_fn_arch=value_fn_arch,
                              score_discrim=score_discrim,
                              discount=discount,
                              state_only=state_only,
                              max_itrs=max_itrs,
                              fusion=fusion,
                              name=name,
                              rescore_expert_trajs=rescore_expert_trajs,
                              drop_framestack=drop_framestack,
                              only_show_scores=only_show_scores,
                              encoder_loc=encoder_loc)

        self.encoder = None if not encoder_loc else encoding.VariationalAutoEncoder.load(
            encoder_loc)
        self.encode_fn = None
        if self.encoder:
            if state_only:
                self.encode_fn = self.encoder.base_vector
            else:
                self.encode_fn = self.encoder.encode

        if fusion:
            self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
        else:
            self.fusion = None

        if self.encoder:
            self.dO = self.encoder.encoding_shape
            self.dOshape = self.encoder.encoding_shape
        else:
            self.dO = env_spec.observation_space.flat_dim
            self.dOshape = env_spec.observation_space.shape

        if drop_framestack:
            assert len(self.dOshape) == 3
            self.dOshape = (*self.dOshape[:-1], 1)

        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env_spec.action_space, Box)
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        #self.set_demos(expert_trajs)
        self.expert_trajs = expert_trajs
        self.state_only = state_only
        self.max_itrs = max_itrs
        self.drop_framestack = drop_framestack
        self.only_show_scores = only_show_scores

        self.expert_cache = None
        self.rescore_expert_trajs = rescore_expert_trajs
        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be batch_size x T x dO/dU
            obs_dtype = tf.int8 if reward_arch == cnn_net else tf.float32
            self.obs_t = tf.placeholder(obs_dtype,
                                        list((None, ) + self.dOshape),
                                        name='obs')
            self.nobs_t = tf.placeholder(obs_dtype,
                                         list((None, ) + self.dOshape),
                                         name='nobs')
            self.act_t = tf.placeholder(tf.float32, [None, self.dU],
                                        name='act')
            self.nact_t = tf.placeholder(tf.float32, [None, self.dU],
                                         name='nact')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lprobs = tf.placeholder(tf.float32, [None, 1],
                                         name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            with tf.variable_scope('discrim') as dvs:
                rew_input = self.obs_t
                with tf.variable_scope('reward'):
                    if self.state_only:
                        self.reward = reward_arch(rew_input,
                                                  dout=1,
                                                  **reward_arch_args)
                    else:
                        print("Not state only", self.act_t)
                        self.reward = reward_arch(rew_input,
                                                  actions=self.act_t,
                                                  dout=1,
                                                  **reward_arch_args)
                # value function shaping
                with tf.variable_scope('vfn'):
                    fitted_value_fn_n = value_fn_arch(self.nobs_t, dout=1)
                with tf.variable_scope('vfn', reuse=True):
                    self.value_fn = fitted_value_fn = value_fn_arch(self.obs_t,
                                                                    dout=1)

                # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                self.qfn = self.reward + self.gamma * fitted_value_fn_n
                log_p_tau = self.reward + self.gamma * fitted_value_fn_n - fitted_value_fn

            log_q_tau = self.lprobs

            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.discrim_output = tf.exp(log_p_tau - log_pq)
            self.accuracy, self.update_accuracy = tf.metrics.accuracy(
                labels=self.labels, predictions=self.discrim_output > 0.5)
            self.loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                        (1 - self.labels) *
                                        (log_q_tau - log_pq))
            self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(
                self.loss)
            self._make_param_ops(_vs)

            self.grad_reward = tf.gradients(self.reward,
                                            [self.obs_t, self.act_t])

            self.modify_obs = self.get_ablation_modifiers()

            self.score_mean = 0
            self.score_std = 1
Exemplo n.º 43
0
def _log_sum_sq(x, axis=None):
  """Computes log(sum(x**2))."""
  return tf.reduce_logsumexp(
      input_tensor=2. * tf.math.log(tf.abs(x)), axis=axis)
Exemplo n.º 44
0
        def loop_step(batch_index, ts, stop_decoder, states, alphas, cand_seqs,
                      cand_scores, completed_scores, completed_scores_scaled,
                      completed_seqs, completed_lens):
            """
            Args:
              ts (int): time step
              stop_decoder (bool): stop decoding
              ys (?): [beam_size]
              states (float): [beam_size, state_size]
              alphas (float): [beam_size, alpha_size]
              cand_scores: [beam_size], sequence score
              cand_seqs: [beam_size, ts], ts increases over time

            Returns:
              logits shape: [beam_size, output_dim]
              state: [beam_size, state_size]
              alpha: [beam_size, alpha_size]

            """
            # 1. get score from one step decoder
            # logits = tf.one_hot(ts, depth=num_symbols, off_value=0.0, dtype=tf.float32)
            if DEBUG: ts = tf.Print(ts, [ts], message='ts: ')
            ys = cand_seqs[:, ts]
            if DEBUG: ys = tf.Print(ys, [ys], message='Y(t-1): ')
            logits, states, alphas = self.step(ys, states, alphas, batch_index)
            if DEBUG: logits = tf.Print(logits, [logits], message='logits: ')
            Z = tf.reduce_logsumexp(logits, 1, keep_dims=True)
            if DEBUG: Z = tf.Print(Z, [Z], message='Z: ')
            logprobs = tf.subtract(logits, Z)  # [beam_size, num_symbols]
            new_scores = tf.add(logprobs,
                                tf.expand_dims(cand_scores,
                                               1))  # [beam_size, num_symbols]
            if DEBUG:
                new_scores = tf.Print(new_scores, [new_scores],
                                      message='new_scores: ')
            num_unstop_symbols = tf.shape(new_scores)[1] - 1
            new_uncompleted_scores, new_completed_scores = tf.split(
                new_scores, [num_unstop_symbols, 1], 1)
            if DEBUG:
                new_uncompleted_scores = tf.Print(
                    new_uncompleted_scores, [new_uncompleted_scores],
                    message='new_uncompleted_scores: ')

            # 2. Update completed seqs  --------------------------------------
            # 2.1 update scores
            new_completed_scores = tf.squeeze(new_completed_scores,
                                              -1)  # [beam_size]
            all_completed_scores = tf.concat(
                [completed_scores, new_completed_scores], 0)  # [2*beam_size]

            # 2.2 choose top K from scaled_scores
            new_completed_scores_scaled = tf.div(new_completed_scores,
                                                 tf.to_float(ts + 1))
            all_scores_scaled = tf.concat(
                [completed_scores_scaled, new_completed_scores_scaled], 0)
            completed_scores_scaled, indices = tf.nn.top_k(all_scores_scaled,
                                                           k=beam_size,
                                                           sorted=False)
            if DEBUG:
                indices = tf.Print(indices, [indices],
                                   message='top K completed indices: ')

            # 2.2 update len
            new_completed_lens = tf.fill([beam_size], tf.add(ts,
                                                             1))  # [beam_size]
            all_lens = tf.concat([completed_lens, new_completed_lens],
                                 0)  # [2*beam_size]
            completed_lens = tf.gather(all_lens,
                                       indices,
                                       validate_indices=True,
                                       axis=0)  # [beam_size]
            if DEBUG:
                completed_lens = tf.Print(completed_lens, [completed_lens],
                                          message='completed lens',
                                          summarize=5)

            # 2.3 update seqs
            all_completed = tf.concat([completed_seqs, cand_seqs], 0)
            completed_seqs = tf.gather(all_completed,
                                       indices,
                                       validate_indices=True,
                                       axis=0)  # [beam_size, ts]
            if DEBUG:
                completed_seqs = tf.Print(completed_seqs, [completed_seqs],
                                          message='completed seqs: ',
                                          summarize=MAX_STEPS + 2)

            # 2.4 stop decoding loop
            max_uncompleted = tf.reduce_max(new_uncompleted_scores)
            completed_scores = tf.gather(all_completed_scores,
                                         indices,
                                         validate_indices=True,
                                         axis=0)
            min_completed = tf.reduce_min(completed_scores)
            stop_decoder = tf.greater(min_completed, max_uncompleted)

            # 2. Update completed seqs  --------------------------------------

            # 3. Update uncompleted sequences --------------------------------
            # new_uncompleted_scores: [beam_size, num_symbols-1]
            # top_k: [beam_size]. indices of top k scores
            def f0():
                return new_uncompleted_scores[0, :]

            def f1():
                return new_uncompleted_scores

            un_scores = tf.cond(tf.equal(ts, 0), f0, f1)
            new_flat = tf.squeeze(tf.reshape(
                un_scores, [-1, 1]))  # [beam_size*num_unstop_symbols]

            # get top K symbols
            cand_scores, flat_indices = tf.nn.top_k(new_flat,
                                                    k=beam_size,
                                                    sorted=False)
            cand_parents = tf.div(flat_indices, num_unstop_symbols)
            _ys = tf.mod(flat_indices,
                         num_unstop_symbols)  # [beam_size], y(t) for next step
            A = tf.gather(cand_seqs[:, 0:ts + 1],
                          cand_parents)  #[beam_size, ts+1]
            B = tf.expand_dims(_ys, -1)  # [beam_size, 1]
            C = tf.fill([beam_size, MAX_STEPS + 2 - ts - 2], stop_symbol)
            cand_seqs = tf.concat([A, B, C], 1)  # [beam_size, MAX_STEPS]
            if DEBUG:
                cand_seqs = tf.Print(cand_seqs, [cand_seqs],
                                     message='cand seqs: ',
                                     summarize=MAX_STEPS + 2)
            # cand_seqs.set_shape([beam_size, MAX_STEPS+2])
            cand_seqs = tf.reshape(cand_seqs, [beam_size, MAX_STEPS + 2])
            cand_scores.set_shape([beam_size])
            # completed_seqs.set_shape([beam_size, MAX_STEPS+2])
            completed_seqs = tf.reshape(completed_seqs,
                                        [beam_size, MAX_STEPS + 2])

            s1_shape = [beam_size, self.attention_cell.state_size]
            s2_shape = [beam_size, self.decoder_cell.state_size]
            s3_shape = [beam_size, self.attn_context.context_size]

            # prepare data for next step
            # states = tf.gather(states, cand_parents, axis=0)
            # states = self.select_states(states, cand_parents)
            states = tuple(tf.gather(el, cand_parents) for el in states)
            states[0].set_shape(s1_shape)
            states[1].set_shape(s2_shape)
            states[2].set_shape(s3_shape)
            alphas = tf.gather(alphas, cand_parents, axis=1)
            alphas_shape = [self.attn_context.num_encoder_states, beam_size]
            alphas = tf.reshape(alphas, alphas_shape)
            # alphas.set_shape(alphas_shape)
            # 3. Update uncompleted sequences --------------------------------

            ts = tf.add(ts, 1)
            return ts, stop_decoder, states, alphas, cand_seqs, \
                cand_scores, completed_scores, completed_scores_scaled, \
                completed_seqs, completed_lens
Exemplo n.º 45
0
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       name=None):
  """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
  with tf.compat.v1.name_scope(
      name, 'compute_log_acceptance_correction',
      [independent_chain_ndims, current_momentums, proposed_momentums]):
    log_current_kinetic, log_proposed_kinetic = [], []
    for current_momentum, proposed_momentum in zip(
        current_momentums, proposed_momentums):
      axis = tf.range(independent_chain_ndims, tf.rank(current_momentum))
      log_current_kinetic.append(_log_sum_sq(current_momentum, axis))
      log_proposed_kinetic.append(_log_sum_sq(proposed_momentum, axis))
    current_kinetic = 0.5 * tf.exp(
        tf.reduce_logsumexp(
            input_tensor=tf.stack(log_current_kinetic, axis=-1), axis=-1))
    proposed_kinetic = 0.5 * tf.exp(
        tf.reduce_logsumexp(
            input_tensor=tf.stack(log_proposed_kinetic, axis=-1), axis=-1))
    return mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
Exemplo n.º 46
0
def main(unused_argv=()):

    # load default data
    data = utils.get_data_retina()

    # verify utils
    utils.verify_data(data)

    #########################################################################

    ## Try some architecture.

    embedx = 20
    embedy = 10
    stim_history = 30
    batch_size = 1000
    batch_neg_resp = 100
    beta = 10
    is_training = True
    with tf.Session() as sess:
        ei_embedding, ei_tf = em.embed_ei(embedx,
                                          embedy,
                                          data['eix'],
                                          data['eiy'],
                                          data['n_elec'],
                                          data['ei_embedding_matrix'],
                                          is_training=is_training)

        responses_embedding, responses_tf = em.embed_responses(
            embedx, embedy, ei_embedding, is_training=is_training)

        stimulus_embedding, stim_tf = em.embed_stimulus(
            embedx,
            embedy,
            data['stimx'],
            data['stimy'],
            stim_history=stim_history,
            is_training=is_training)

        responses_embedding_pos = tf.gather(
            responses_embedding,
            np.arange(batch_size).astype(np.int))

        responses_embedding_neg = tf.gather(
            responses_embedding,
            np.arange(batch_size, batch_neg_resp + batch_size).astype(np.int))
        d_pos = tf.reduce_sum(
            (stimulus_embedding - responses_embedding_pos)**2, [1, 2])
        d_neg_pairs = tf.reduce_sum(
            (tf.expand_dims(responses_embedding_pos, 1) -
             tf.expand_dims(responses_embedding_neg, 0))**2, [2, 3])
        d_neg = -tf.reduce_logsumexp(-d_neg_pairs / beta, 1)
        loss = tf.reduce_sum(tf.nn.relu(d_pos - d_neg + 1))

        train_op = tf.train.AdamOptimizer(0.01).minimize(loss)
        sess.run(tf.global_variables_initializer())

        from IPython import embed
        embed()

        for _ in range(10000):
            stim_batch, resp_batch, ei_batch, resp_batch_neg = get_train_batch(
                data,
                batch_size=batch_size,
                batch_neg_resp=batch_neg_resp,
                stim_history=stim_history)
            feed_dict = {
                ei_tf: ei_batch,
                responses_tf: np.append(resp_batch, resp_batch_neg, 0),
                stim_tf: stim_batch
            }
            loss_np, _ = sess.run([loss, train_op], feed_dict=feed_dict)
            print(loss_np)
Exemplo n.º 47
0
 def neglogp(self, x):
     p = tf.stack([self.log_mixing_probs[:,i] + self.gaussians[i].logp(x) for i in range(self.n)])
     if self.hard_grad:
         return -tf.reduce_max(p, axis=[0])
     else:
         return -tf.reduce_logsumexp(p, axis=[0])
Exemplo n.º 48
0
 def get_mdn_coef(output):
     logmix, mean, logstd = tf.split(output, 3, 1)
     logmix = logmix - tf.reduce_logsumexp(logmix, 1, keepdims=True)
     return logmix, mean, logstd
Exemplo n.º 49
0
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']

            logits_list = []
            stddev_list = []

            for i in range(FLAGS.ensemble_size):
                logits = model(images, training=False)
                if isinstance(logits, (list, tuple)):
                    # If model returns a tuple of (logits, covmat), extract both
                    logits, covmat = logits
                else:
                    covmat = tf.eye(logits.shape[0])
                logits = mean_field_logits(
                    logits,
                    covmat,
                    mean_field_factor=FLAGS.gp_mean_field_factor)
                stddev = tf.sqrt(tf.linalg.diag_part(covmat))

                stddev_list.append(stddev)
                logits_list.append(logits)

                member_probs = tf.nn.softmax(logits)
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics[f'{dataset_split}/nll_member_{i}'].update_state(
                    member_loss)
                metrics[f'{dataset_split}/accuracy_member_{i}'].update_state(
                    labels, member_probs)
            # Logits dimension is (num_samples, batch_size, num_classes).
            logits_list = tf.stack(logits_list, axis=0)
            stddev_list = tf.stack(stddev_list, axis=0)

            stddev = tf.reduce_mean(stddev_list, axis=0)
            probs_list = tf.nn.softmax(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)

            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.ensemble_size, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.ensemble_size)))

            if dataset_name == 'clean':
                metrics[
                    f'{dataset_split}/negative_log_likelihood'].update_state(
                        negative_log_likelihood)
                metrics[f'{dataset_split}/accuracy'].update_state(
                    labels, probs)
                metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
                metrics[f'{dataset_split}/stddev'].update_state(stddev)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)
                corrupt_metrics['test/stddev_{}'.format(
                    dataset_name)].update_state(stddev)
Exemplo n.º 50
0
 def test_partition_2_step_incomplete_does_not_add_up_to_one(self):
     log_values = self.dynamic_spn(
         [self.data_2_steps[:-1], [2] * (self.data_2_steps.shape[0] - 1)]
     )
     self.assertNotEqual(tf.reduce_logsumexp(log_values), 0.0)
Exemplo n.º 51
0
    def __init__(
            self,
            train_envs,
            n_act_dim=envs.N_ACT_DIM,
            n_obs_dim=envs.N_OBS_DIM,
            true_dynamics=None,
    ):
        """

        Args:
            train_envs: List of environments. Uses known rewards.
            true_dynamics: If None, use dynamics from first training task.

        """
        n_train_tasks = len(train_envs)
        demo_obs_t_ph = tf.placeholder(tf.int32, [None])
        demo_act_t_ph = tf.placeholder(tf.int32, [None])
        demo_task_t_ph = tf.placeholder(tf.int32, [None])
        demo_batch_size_ph = tf.placeholder(tf.int32)

        demo_batch_idxes = tf.reshape(
            tf.range(0, demo_batch_size_ph, 1), [demo_batch_size_ph, 1])

        demo_q_t = tf.stack([
            self._build_mlp(
                self._featurize_obs(demo_obs_t_ph, n_obs_dim),
                n_act_dim,
                q_scope+'-'+str(train_task_idx),
                n_layers=q_n_layers,
                size=q_layer_size,
                activation=q_activation,
                output_activation=q_output_activation,
            ) for train_task_idx in range(n_train_tasks)
        ], axis=0)
        demo_q_t = tf.gather_nd(demo_q_t, tf.concat(
            [tf.expand_dims(demo_task_t_ph, 1), demo_batch_idxes], axis=1))

        demo_act_idxes = tf.concat([demo_batch_idxes, tf.reshape(
            demo_act_t_ph, [demo_batch_size_ph, 1])], axis=1)
        demo_act_val_t = tf.gather_nd(demo_q_t, demo_act_idxes)
        state_val_t = tf.reduce_logsumexp(demo_q_t, axis=1)
        act_log_likelihoods = demo_act_val_t - state_val_t

        neg_avg_log_likelihood = -tf.reduce_mean(act_log_likelihoods)

        obs_for_obs_tp1_probs = tf.cast(tf.floor(
            tf.range(0, n_obs_dim*n_act_dim, 1) / n_act_dim), dtype=tf.int32)

        act_for_obs_tp1_probs = tf.floormod(tf.range(
            0, n_obs_dim*n_act_dim, 1), n_act_dim)

        obs_tp1_probs_in = tf.one_hot(
            obs_for_obs_tp1_probs*n_act_dim+act_for_obs_tp1_probs, n_obs_dim*n_act_dim)

        obs_tp1_probs = self._build_mlp(
            obs_tp1_probs_in,
            n_obs_dim, im_scope, n_layers=n_layers, size=layer_size,
            activation=activation, output_activation=output_activation
        )
        obs_tp1_probs = tf.reshape(
            obs_tp1_probs, [n_obs_dim, n_act_dim, n_obs_dim])

        q_tp1 = tf.stack([
            self._build_mlp(
                self._featurize_obs(tf.range(0, n_obs_dim, 1), n_obs_dim),
                n_act_dim,
                q_scope+'-'+str(train_task_idx),
                n_layers=q_n_layers,
                size=q_layer_size,
                activation=q_activation,
                output_activation=q_output_activation,
                reuse=True,
            ) for train_task_idx in range(n_train_tasks)
        ], axis=0)

        v_tp1 = tf.reduce_logsumexp(q_tp1, axis=2)

        all_rew = tf.convert_to_tensor(np.stack(
            [env.unwrapped.R for env in train_envs], axis=0), dtype=tf.float32)

        v_tp1_broad = tf.reshape(v_tp1, [n_train_tasks, 1, 1, n_obs_dim])
        obs_tp1_probs_broad = tf.expand_dims(obs_tp1_probs, 0)

        exp_v_tp1 = tf.reduce_sum(obs_tp1_probs_broad * v_tp1_broad, axis=3)
        exp_rew_t = tf.reduce_sum(obs_tp1_probs_broad * all_rew, axis=3)
        target_t = exp_rew_t + gamma * exp_v_tp1

        q_t = tf.stack([
            self._build_mlp(
                self._featurize_obs(tf.range(0, n_obs_dim, 1), n_obs_dim),
                n_act_dim,
                q_scope+'-'+str(train_task_idx),
                n_layers=q_n_layers,
                size=q_layer_size,
                activation=q_activation,
                output_activation=q_output_activation,
                reuse=True,
            )
            for train_task_idx in range(n_train_tasks)
        ], axis=0)

        td_err = q_t - target_t
        sq_td_err = tf.reduce_mean(td_err**2)
        loss = neg_avg_log_likelihood + sq_td_err_penalty * sq_td_err

        update_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

        self.n_act_dim = n_act_dim
        self.n_obs_dim = n_obs_dim

        self.demo_obs_t_ph = demo_obs_t_ph
        self.demo_act_t_ph = demo_act_t_ph
        self.demo_task_t_ph = demo_task_t_ph
        self.demo_batch_size_ph = demo_batch_size_ph

        self.q_t = q_t
        self.loss = loss
        self.neg_avg_log_likelihood = neg_avg_log_likelihood
        self.sq_td_err = sq_td_err
        self.update_op = update_op
        self.obs_tp1_probs = obs_tp1_probs
        if true_dynamics is None:
            self.true_dynamics = np.argmax(train_envs[0].unwrapped.T, axis=2)
        else:
            self.true_dynamics = true_dynamics
Exemplo n.º 52
0
 def get_lossfunc(logmix, mean, logstd, y):
     v = logmix + tf_lognormal(y, mean, logstd)
     v = tf.reduce_logsumexp(v, 1, keepdims=True)
     return -tf.reduce_mean(v)
 def add_relevant_loss(loss, relevant_Y, final_loss):
     relevant_loss = tf.gather(params=loss, indices=relevant_Y)  # get relevant rows
     relevant_loss = tf.reduce_logsumexp(relevant_loss)  # calculate the any-sample loss of these rows
     final_loss = final_loss + relevant_loss  # add to final loss
     return final_loss
Exemplo n.º 54
0
    def resample(particle_states, particle_weights, alpha):
        """
        Implements (soft)-resampling of particles.
        :param particle_states: tf op (batch, K, 3), particle states
        :param particle_weights: tf op (batch, K), unnormalized particle weights in log space
        :param alpha: float, trade-off parameter for soft-resampling. alpha == 1 corresponds to standard,
        hard-resampling. alpha == 0 corresponds to sampling particles uniformly, ignoring their weights.
        :return: particle_states, particle_weights
        """
        with tf.name_scope('resample'):
            assert 0.0 < alpha <= 1.0
            batch_size, num_particles = particle_states.get_shape().as_list(
            )[:2]

            # normalize
            particle_weights = particle_weights - tf.reduce_logsumexp(
                particle_weights, axis=-1, keep_dims=True)

            uniform_weights = tf.constant(-np.log(num_particles),
                                          shape=(batch_size, num_particles),
                                          dtype=tf.float32)

            # build sampling distribution, q(s), and update particle weights
            if alpha < 1.0:
                # soft resampling
                q_weights = tf.stack([
                    particle_weights + np.log(alpha),
                    uniform_weights + np.log(1.0 - alpha)
                ],
                                     axis=-1)
                q_weights = tf.reduce_logsumexp(q_weights,
                                                axis=-1,
                                                keep_dims=False)
                q_weights = q_weights - tf.reduce_logsumexp(
                    q_weights, axis=-1, keep_dims=True)  # normalized

                particle_weights = particle_weights - q_weights  # this is unnormalized
            else:
                # hard resampling. this will produce zero gradients
                q_weights = particle_weights
                particle_weights = uniform_weights

            # sample particle indices according to q(s)
            indices = tf.cast(tf.multinomial(q_weights, num_particles),
                              tf.int32)  # shape: (batch_size, num_particles)

            # index into particles
            helper = tf.range(0,
                              batch_size * num_particles,
                              delta=num_particles,
                              dtype=tf.int32)  # (batch, )
            indices = indices + tf.expand_dims(helper, axis=1)

            particle_states = tf.reshape(particle_states,
                                         (batch_size * num_particles, 3))
            particle_states = tf.gather(
                particle_states, indices=indices,
                axis=0)  # (batch_size, num_particles, 3)

            particle_weights = tf.reshape(particle_weights,
                                          (batch_size * num_particles, ))
            particle_weights = tf.gather(
                particle_weights, indices=indices,
                axis=0)  # (batch_size, num_particles,)

            return particle_states, particle_weights
Exemplo n.º 55
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading Data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(
                num_channels=channel_num,
                num_filters=128,
                train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(
                num_channels=channel_num,
                num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(
                num_channels=channel_num,
                num_filters=192)
        else:
            model = ResNet32(
                num_channels=channel_num,
                num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(
            num_channels=channel_num,
            num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(
            num_channels=channel_num,
            num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(
            num_channels=channel_num,
            num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(
            cond_shape=FLAGS.cond_shape,
            cond_size=FLAGS.cond_size,
            cond_pos=FLAGS.cond_pos,
            cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(
            num_channels=channel_num,
            num_filters=FLAGS.num_filters,
            cond_size=FLAGS.cond_size,
            cond_shape=FLAGS.cond_shape,
            cond_pos=FLAGS.cond_pos,
            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(
                tf.convert_to_tensor(
                    np.reshape(
                        np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                        (FLAGS.batch_size * 10, 10)),
                    dtype=tf.float32),
                trainable=False,
                dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(
                    X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(
                x_split,
                weights[0],
                label=label_tensor,
                stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(
                energy_pos_full, axis=1, keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(
                    X_SPLIT[j],
                    weights[0],
                    label=LABEL_POS_SPLIT[j],
                    stop_at_grad=False)]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([model.forward(tf.stop_gradient(
            x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                             mean=0.0,
                                             stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat(
                [model.forward(
                        x_mod,
                        weights[0],
                        label=LABEL_SPLIT[j],
                        reuse=True,
                        stop_at_grad=False,
                        stop_batch=True)],
                axis=0)

            x_grad, label_grad = tf.gradients(
                FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(
                        x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
                                    stop_at_grad=False, reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(
                tf.stop_gradient(x_mod),
                weights[0],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(
            x_mod,
            weights[0],
            reuse=True,
            label=LABEL,
            stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(label_prob *
                                       tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(
        max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
Exemplo n.º 56
0
    def _create_td_update(self):
        """Create a minimization operation for Q-function update."""

        next_observations = tf.tile(
            self._next_observations_ph[:, tf.newaxis, :],
            (1, self._value_n_particles, 1))
        next_observations = tf.reshape(
            next_observations, (-1, *self._observation_shape))

        target_actions = tf.random_uniform(
            (1, self._value_n_particles, *self._action_shape), -1, 1)
        target_actions = tf.tile(
            target_actions, (tf.shape(self._next_observations_ph)[0], 1, 1))
        target_actions = tf.reshape(target_actions, (-1, *self._action_shape))

        Q_next_targets = tuple(
            Q([next_observations, target_actions])
            for Q in self._Q_targets)

        min_Q_next_targets = tf.reduce_min(Q_next_targets, axis=0)

        assert_shape(min_Q_next_targets, (None, 1))

        min_Q_next_target = tf.reshape(
            min_Q_next_targets, (-1, self._value_n_particles))

        assert_shape(min_Q_next_target, (None, self._value_n_particles))

        # Equation 10:
        next_value = tf.reduce_logsumexp(
            min_Q_next_target, keepdims=True, axis=1)
        assert_shape(next_value, [None, 1])

        # Importance weights add just a constant to the value.
        next_value -= tf.log(tf.to_float(self._value_n_particles))
        next_value += np.prod(self._action_shape) * np.log(2)

        # \hat Q in Equation 11:
        Q_target = tf.stop_gradient(
            self._reward_scale
            * self._rewards_ph
            + (1 - self._terminals_ph)
            * self._discount
            * next_value)
        assert_shape(Q_target, [None, 1])

        Q_values = self._Q_values = tuple(
            Q([self._observations_ph, self._actions_ph])
            for Q in self._Qs)

        for Q_value in self._Q_values:
            assert_shape(Q_value, [None, 1])

        # Equation 11:
        Q_losses = self._Q_losses = tuple(
            tf.losses.mean_squared_error(
                labels=Q_target, predictions=Q_value, weights=0.5)
            for Q_value in Q_values)

        if self._train_Q:
            self._Q_optimizers = tuple(
                tf.train.AdamOptimizer(
                    learning_rate=self._Q_lr,
                    name='{}_{}_optimizer'.format(Q._name, i)
                ) for i, Q in enumerate(self._Qs))
            Q_training_ops = tuple(
                tf.contrib.layers.optimize_loss(
                    Q_loss,
                    None,
                    learning_rate=self._Q_lr,
                    optimizer=Q_optimizer,
                    variables=Q.trainable_variables,
                    increment_global_step=False,
                    summaries=())
                for i, (Q, Q_loss, Q_optimizer)
                in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))

            self._training_ops.append(tf.group(Q_training_ops))
def _log_prob_from_logits(logits):
    return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True)
Exemplo n.º 58
0
def get_coref_softmax_loss(antecedent_scores, antecedent_labels):
  gold_scores = antecedent_scores + tf.log(tf.to_float(antecedent_labels))  # [k, max_ant + 1]
  marginalized_gold_scores = tf.reduce_logsumexp(gold_scores, [1])  # [k]
  log_norm = tf.reduce_logsumexp(antecedent_scores, [1])  # [k]
  return log_norm - marginalized_gold_scores  # [k]
Exemplo n.º 59
0
    def while_step(t, rnn_state, tas, accs):
        """Implements one timestep of FIVO computation."""
        log_weights_acc, log_p_hat_acc, kl_acc = accs
        cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
        # Run the cell for one step.
        log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
            cur_inputs,
            rnn_state,
            cur_mask,
        )
        # Compute the incremental weight and use it to update the current
        # accumulated weight.
        kl_acc += kl * cur_mask
        log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
        log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
        log_weights_acc += log_alpha

        # Calculate the effective sample size(ESS for unnormalized weights).
        ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
        ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
        log_ess = ess_num - ess_denom

        # Calculate the ancestor indices via resampling. Because we maintain the
        # log unnormalized weights, we pass the weights in as logits, allowing
        # the distribution object to apply a softmax and normalize them.
        resampling_dist = tf.contrib.distributions.Categorical(
            logits=tf.transpose(log_weights_acc, perm=[1, 0]))
        ancestor_inds = tf.stop_gradient(
            resampling_dist.sample(sample_shape=num_samples, seed=random_seed))

        # Because the batch is flattened and laid out as discussed
        # above, we must modify ancestor_inds to index the proper samples.
        # The particles in the ith filter are distributed every batch_size rows
        # in the batch, and offset i rows from the top. So, to correct the indices
        # we multiply by the batch_size and add the proper offset. Crucially,
        # when ancestor_inds is flattened the layout of the batch is maintained.
        offset = tf.expand_dims(tf.range(batch_size), 0)
        ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
        noresample_inds = tf.range(num_samples * batch_size)

        # Decide whether or not we should resample; don't resample if we are past
        # the end of a sequence.
        # should_resample = resampling_criterion(log_weights_acc, log_ess, t)
        should_resample = -tf.reduce_sum(
            log_weights_acc, axis=0) >= tf.reduce_sum(num_samples / 2.0)
        ## GIVEN
        # should_resample = resampling_criterion(num_samples, log_ess, t)

        should_resample = tf.logical_and(should_resample,
                                         cur_mask[:batch_size] > 0.)
        float_should_resample = tf.to_float(should_resample)
        ancestor_inds = tf.where(tf.tile(should_resample, [num_samples]),
                                 ancestor_inds, noresample_inds)
        new_state = nested.gather_tensors(new_state, ancestor_inds)

        # Update the TensorArrays before we reset the weights so that we capture
        # the incremental weights and not zeros.
        ta_updates = [log_weights_acc, log_ess, float_should_resample]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]

        # For the particle filters that resampled, update log_p_hat and
        # reset weights to zero.
        log_p_hat_update = tf.reduce_logsumexp(
            log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
        log_p_hat_acc += log_p_hat_update * float_should_resample
        log_weights_acc *= (
            1. -
            tf.tile(float_should_resample[tf.newaxis, :], [num_samples, 1]))
        new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)

        return t + 1, new_state, new_tas, new_accs
Exemplo n.º 60
0
 def test_partition_2_step_adds_up_to_one(self):
     log_values = self.dynamic_spn(
         [self.data_2_steps, [2] * self.data_2_steps.shape[0]]
     )
     self.assertEqual(tf.reduce_logsumexp(log_values), 0.0)