예제 #1
0
 def test_tile_tensors(self):
     """Checks that tile_tensors correctly tiles tensors of different ranks."""
     a = tf.range(20)
     b = tf.reshape(a, [2, 10])
     c = tf.reshape(a, [2, 2, 5])
     a_tiled = tf.tile(a, [3])
     b_tiled = tf.tile(b, [3, 1])
     c_tiled = tf.tile(c, [3, 1, 1])
     tensors = [a, (b, ExampleTuple(c, c))]
     expected_tensors = [a_tiled, (b_tiled, ExampleTuple(c_tiled, c_tiled))]
     tiled = nested_utils.tile_tensors(tensors, [3])
     nest.assert_same_structure(expected_tensors, tiled)
     with self.test_session() as sess:
         expected, out = sess.run([expected_tensors, tiled])
         expected = nest.flatten(expected)
         out = nest.flatten(out)
         # Check that the tiling is correct.
         for x, y in zip(expected, out):
             self.assertAllClose(x, y)
예제 #2
0
 def test_tile_tensors(self):
   """Checks that tile_tensors correctly tiles tensors of different ranks."""
   a = tf.range(20)
   b = tf.reshape(a, [2, 10])
   c = tf.reshape(a, [2, 2, 5])
   a_tiled = tf.tile(a, [3])
   b_tiled = tf.tile(b, [3, 1])
   c_tiled = tf.tile(c, [3, 1, 1])
   tensors = [a, (b, ExampleTuple(c, c))]
   expected_tensors = [a_tiled, (b_tiled, ExampleTuple(c_tiled, c_tiled))]
   tiled = nested_utils.tile_tensors(tensors, [3])
   nest.assert_same_structure(expected_tensors, tiled)
   with self.test_session() as sess:
     expected, out = sess.run([expected_tensors, tiled])
     expected = nest.flatten(expected)
     out = nest.flatten(out)
     # Check that the tiling is correct.
     for x, y in zip(expected, out):
       self.assertAllClose(x, y)
예제 #3
0
def fivo(model,
         observations,
         seq_lengths,
         num_samples=1,
         resampling_criterion=smc.ess_criterion,
         resampling_type='multinomial',
         sinkhorn_regularization=0.01,
         relaxed_resampling_temperature=0.5,
         parallel_iterations=30,
         swap_memory=True,
         random_seed=None):
    """Computes the FIVO lower bound on the log marginal probability.

  This method accepts a stochastic latent variable model and some observations
  and computes a stochastic lower bound on the log marginal probability of the
  observations. The lower bound is defined by a particle filter's unbiased
  estimate of the marginal probability of the observations. For more details see
  "Filtering Variational Objectives" by Maddison et al.
  https://arxiv.org/abs/1705.09279.

  When the resampling criterion is "never resample", this bound becomes IWAE.

  Args:
    model: A subclass of ELBOTrainableSequenceModel that implements one
      timestep of the model. See models/vrnn.py for an example.
    observations: The inputs to the model. A potentially nested list or tuple of
      Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
      have a rank at least two and have matching shapes in the first two
      dimensions, which represent time and the batch respectively. The model
      will be provided with the observations before computing the bound.
    seq_lengths: A [batch_size] Tensor of ints encoding the length of each
      sequence in the batch (sequences can be padded to a common length).
    num_samples: The number of particles to use in each particle filter.
    resampling_criterion: The resampling criterion to use for this particle
      filter. Must accept the number of samples, the current log weights,
      and the current timestep and return a boolean Tensor of shape [batch_size]
      indicating whether each particle filter should resample. See
      ess_criterion and related functions for examples. When
      resampling_criterion is never_resample_criterion, resampling_fn is ignored
      and never called.
    resampling_type: The type of resampling, one of "multinomial" or "relaxed".
    relaxed_resampling_temperature: A positive temperature only used for relaxed
      resampling.
    parallel_iterations: The number of parallel iterations to use for the
      internal while loop. Note that values greater than 1 can introduce
      non-determinism even when random_seed is provided.
    swap_memory: Whether GPU-CPU memory swapping should be enabled for the
      internal while loop.
    random_seed: The random seed to pass to the resampling operations in
      the particle filter. Mainly useful for testing.

  Returns:
    log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
      log marginal probability of the observations.
    log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
      containing the log weights at each timestep of the particle filter. Note
      that on timesteps when a resampling operation is performed the log weights
      are reset to 0. Will not be valid for timesteps past the end of a
      sequence.
    resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
      particle filters resampled. Will be 1.0 on timesteps when resampling
      occurred and 0.0 on timesteps when it did not.
  """
    # batch_size is the number of particle filters running in parallel.
    batch_size = tf.shape(seq_lengths)[0]

    # Each sequence in the batch will be the input data for a different
    # particle filter. The batch will be laid out as:
    #   particle 1 of particle filter 1
    #   particle 1 of particle filter 2
    #   ...
    #   particle 1 of particle filter batch_size
    #   particle 2 of particle filter 1
    #   ...
    #   particle num_samples of particle filter batch_size
    observations = nested.tile_tensors(
        observations,
        [1, num_samples])  # tuple of inputs, targets > shape (S,B*N,D).
    tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])  # shape (B*N°
    model.set_observations(observations, tiled_seq_lengths)

    if resampling_type == 'multinomial':
        resampling_fn = smc.multinomial_resampling
    elif resampling_type == 'relaxed':
        resampling_fn = functools.partial(
            smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
    elif resampling_type == 'differentiable':
        resampling_fn = get_transport_fun(sinkhorn_regularization, 1e-2, 100)
    else:
        raise NotImplementedError
    resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)

    def transition_fn(prev_state, t):
        if prev_state is None:
            return model.zero_state(batch_size * num_samples, tf.float32)
        return model.propose_and_weight(prev_state, t)

    log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
        transition_fn,
        seq_lengths,
        num_particles=num_samples,
        resampling_criterion=resampling_criterion,
        resampling_fn=resampling_fn,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    return log_p_hat, log_weights, resampled, final_state
예제 #4
0
def fivo_aux_td(model,
                observations,
                seq_lengths,
                num_samples=1,
                resampling_criterion=smc.ess_criterion,
                resampling_type='multinomial',
                relaxed_resampling_temperature=0.5,
                parallel_iterations=30,
                swap_memory=True,
                random_seed=None):
    """Experimental."""
    # batch_size is the number of particle filters running in parallel.
    batch_size = tf.shape(seq_lengths)[0]
    max_seq_len = tf.reduce_max(seq_lengths)

    # Each sequence in the batch will be the input data for a different
    # particle filter. The batch will be laid out as:
    #   particle 1 of particle filter 1
    #   particle 1 of particle filter 2
    #   ...
    #   particle 1 of particle filter batch_size
    #   particle 2 of particle filter 1
    #   ...
    #   particle num_samples of particle filter batch_size
    observations = nested.tile_tensors(observations, [1, num_samples])
    tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
    model.set_observations(observations, tiled_seq_lengths)

    if resampling_type == 'multinomial':
        resampling_fn = smc.multinomial_resampling
    elif resampling_type == 'relaxed':
        resampling_fn = functools.partial(
            smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
    resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)

    def transition_fn(prev_state, t):
        if prev_state is None:
            model_init_state = model.zero_state(batch_size * num_samples,
                                                tf.float32)
            return (tf.zeros([num_samples * batch_size], dtype=tf.float32),
                    (tf.zeros([num_samples * batch_size, model.latent_size],
                              dtype=tf.float32),
                     tf.zeros([num_samples * batch_size, model.latent_size],
                              dtype=tf.float32)), model_init_state)

        prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
        (new_model_state, zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_tilde,
         p_ztplus1) = model(prev_model_state, t)
        r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
        # Compute the weight without r.
        log_weight = log_p_zt + log_p_x_given_z - log_q_zt
        # Compute log_r and log_r_tilde.
        p_mu = tf.stop_gradient(p_ztplus1.mean())
        p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
        log_r = (tf.log(r_tilde_sigma_sq) -
                 tf.log(r_tilde_sigma_sq + p_sigma_sq) -
                 tf.square(r_tilde_mu - p_mu) /
                 (r_tilde_sigma_sq + p_sigma_sq))
        # log_r is [num_samples*batch_size, latent_size]. We sum it along the last
        # dimension to compute log r.
        log_r = 0.5 * tf.reduce_sum(log_r, axis=-1)
        # Compute prev log r tilde
        prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
        prev_log_r_tilde = -0.5 * tf.reduce_sum(
            tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu) /
            prev_r_tilde_sigma_sq,
            axis=-1)
        # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
        last_timestep = t >= (tiled_seq_lengths - 1)
        log_r = tf.where(last_timestep, tf.zeros_like(log_r), log_r)
        prev_log_r_tilde = tf.where(last_timestep,
                                    tf.zeros_like(prev_log_r_tilde),
                                    prev_log_r_tilde)
        log_weight += tf.stop_gradient(log_r - prev_log_r)
        new_state = (log_r, log_r_tilde, new_model_state)
        loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z,
                        log_r - prev_log_r)
        return log_weight, new_state, loop_fn_args

    def loop_fn(loop_state, loop_args, unused_model_state, log_weights,
                resampled, mask, t):
        if loop_state is None:
            return (tf.zeros([batch_size], dtype=tf.float32),
                    tf.zeros([batch_size], dtype=tf.float32),
                    tf.zeros([num_samples, batch_size], dtype=tf.float32))
        log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
        log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
        # Compute the log_p_hat update
        log_p_hat_update = tf.reduce_logsumexp(log_weights, axis=0) - tf.log(
            tf.to_float(num_samples))
        # If it is the last timestep, we always add the update.
        log_p_hat_acc += tf.cond(t >= max_seq_len - 1,
                                 lambda: log_p_hat_update,
                                 lambda: log_p_hat_update * resampled)
        # Compute the Bellman update.
        log_r = tf.reshape(log_r, [num_samples, batch_size])
        prev_log_r_tilde = tf.reshape(prev_log_r_tilde,
                                      [num_samples, batch_size])
        log_p_x_given_z = tf.reshape(log_p_x_given_z,
                                     [num_samples, batch_size])
        mask = tf.reshape(mask, [num_samples, batch_size])
        # On the first timestep there is no bellman error because there is no
        # prev_log_r_tilde.
        mask = tf.cond(tf.equal(t, 0), lambda: tf.zeros_like(mask),
                       lambda: mask)
        # On the first timestep also fix up prev_log_r_tilde, which will be -inf.
        prev_log_r_tilde = tf.where(tf.is_inf(prev_log_r_tilde),
                                    tf.zeros_like(prev_log_r_tilde),
                                    prev_log_r_tilde)
        # log_lambda is [num_samples, batch_size]
        log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
                                    axis=0,
                                    keepdims=True)
        bellman_error = mask * tf.square(prev_log_r_tilde - tf.stop_gradient(
            log_lambda + log_p_x_given_z + log_r))
        bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
        # Compute the log_r_diff update
        log_r_diff_acc += mask * tf.reshape(log_r_diff,
                                            [num_samples, batch_size])
        return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)

    log_weights, resampled, accs = smc.smc(
        transition_fn,
        seq_lengths,
        num_particles=num_samples,
        resampling_criterion=resampling_criterion,
        resampling_fn=resampling_fn,
        loop_fn=loop_fn,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    log_p_hat, bellman_loss, log_r_diff = accs
    loss_per_seq = [-log_p_hat, bellman_loss]
    tf.summary.scalar("bellman_loss",
                      tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
    tf.summary.scalar(
        "log_r_diff",
        tf.reduce_mean(
            tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
    return loss_per_seq, log_p_hat, log_weights, resampled
예제 #5
0
파일: bounds.py 프로젝트: 812864539/models
def fivo(model,
         observations,
         seq_lengths,
         num_samples=1,
         resampling_criterion=smc.ess_criterion,
         resampling_type='multinomial',
         relaxed_resampling_temperature=0.5,
         parallel_iterations=30,
         swap_memory=True,
         random_seed=None):
  """Computes the FIVO lower bound on the log marginal probability.

  This method accepts a stochastic latent variable model and some observations
  and computes a stochastic lower bound on the log marginal probability of the
  observations. The lower bound is defined by a particle filter's unbiased
  estimate of the marginal probability of the observations. For more details see
  "Filtering Variational Objectives" by Maddison et al.
  https://arxiv.org/abs/1705.09279.

  When the resampling criterion is "never resample", this bound becomes IWAE.

  Args:
    model: A subclass of ELBOTrainableSequenceModel that implements one
      timestep of the model. See models/vrnn.py for an example.
    observations: The inputs to the model. A potentially nested list or tuple of
      Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
      have a rank at least two and have matching shapes in the first two
      dimensions, which represent time and the batch respectively. The model
      will be provided with the observations before computing the bound.
    seq_lengths: A [batch_size] Tensor of ints encoding the length of each
      sequence in the batch (sequences can be padded to a common length).
    num_samples: The number of particles to use in each particle filter.
    resampling_criterion: The resampling criterion to use for this particle
      filter. Must accept the number of samples, the current log weights,
      and the current timestep and return a boolean Tensor of shape [batch_size]
      indicating whether each particle filter should resample. See
      ess_criterion and related functions for examples. When
      resampling_criterion is never_resample_criterion, resampling_fn is ignored
      and never called.
    resampling_type: The type of resampling, one of "multinomial" or "relaxed".
    relaxed_resampling_temperature: A positive temperature only used for relaxed
      resampling.
    parallel_iterations: The number of parallel iterations to use for the
      internal while loop. Note that values greater than 1 can introduce
      non-determinism even when random_seed is provided.
    swap_memory: Whether GPU-CPU memory swapping should be enabled for the
      internal while loop.
    random_seed: The random seed to pass to the resampling operations in
      the particle filter. Mainly useful for testing.

  Returns:
    log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
      log marginal probability of the observations.
    log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
      containing the log weights at each timestep of the particle filter. Note
      that on timesteps when a resampling operation is performed the log weights
      are reset to 0. Will not be valid for timesteps past the end of a
      sequence.
    resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
      particle filters resampled. Will be 1.0 on timesteps when resampling
      occurred and 0.0 on timesteps when it did not.
  """
  # batch_size is the number of particle filters running in parallel.
  batch_size = tf.shape(seq_lengths)[0]

  # Each sequence in the batch will be the input data for a different
  # particle filter. The batch will be laid out as:
  #   particle 1 of particle filter 1
  #   particle 1 of particle filter 2
  #   ...
  #   particle 1 of particle filter batch_size
  #   particle 2 of particle filter 1
  #   ...
  #   particle num_samples of particle filter batch_size
  observations = nested.tile_tensors(observations, [1, num_samples])
  tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
  model.set_observations(observations, tiled_seq_lengths)

  if resampling_type == 'multinomial':
    resampling_fn = smc.multinomial_resampling
  elif resampling_type == 'relaxed':
    resampling_fn = functools.partial(
        smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
  resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)

  def transition_fn(prev_state, t):
    if prev_state is None:
      return model.zero_state(batch_size * num_samples, tf.float32)
    return model.propose_and_weight(prev_state, t)

  log_p_hat, log_weights, resampled, final_state, _ = smc.smc(
      transition_fn,
      seq_lengths,
      num_particles=num_samples,
      resampling_criterion=resampling_criterion,
      resampling_fn=resampling_fn,
      parallel_iterations=parallel_iterations,
      swap_memory=swap_memory)

  return log_p_hat, log_weights, resampled, final_state
예제 #6
0
파일: bounds.py 프로젝트: 812864539/models
def fivo_aux_td(
    model,
    observations,
    seq_lengths,
    num_samples=1,
    resampling_criterion=smc.ess_criterion,
    resampling_type='multinomial',
    relaxed_resampling_temperature=0.5,
    parallel_iterations=30,
    swap_memory=True,
    random_seed=None):
  """Experimental."""
  # batch_size is the number of particle filters running in parallel.
  batch_size = tf.shape(seq_lengths)[0]
  max_seq_len = tf.reduce_max(seq_lengths)

  # Each sequence in the batch will be the input data for a different
  # particle filter. The batch will be laid out as:
  #   particle 1 of particle filter 1
  #   particle 1 of particle filter 2
  #   ...
  #   particle 1 of particle filter batch_size
  #   particle 2 of particle filter 1
  #   ...
  #   particle num_samples of particle filter batch_size
  observations = nested.tile_tensors(observations, [1, num_samples])
  tiled_seq_lengths = tf.tile(seq_lengths, [num_samples])
  model.set_observations(observations, tiled_seq_lengths)

  if resampling_type == 'multinomial':
    resampling_fn = smc.multinomial_resampling
  elif resampling_type == 'relaxed':
    resampling_fn = functools.partial(
        smc.relaxed_resampling, temperature=relaxed_resampling_temperature)
  resampling_fn = functools.partial(resampling_fn, random_seed=random_seed)

  def transition_fn(prev_state, t):
    if prev_state is None:
      model_init_state = model.zero_state(batch_size * num_samples, tf.float32)
      return (tf.zeros([num_samples*batch_size], dtype=tf.float32),
              (tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32),
               tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)),
              model_init_state)

    prev_log_r, prev_log_r_tilde, prev_model_state = prev_state
    (new_model_state, zt, log_q_zt, log_p_zt,
     log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t)
    r_tilde_mu, r_tilde_sigma_sq = log_r_tilde
    # Compute the weight without r.
    log_weight = log_p_zt + log_p_x_given_z - log_q_zt
    # Compute log_r and log_r_tilde.
    p_mu = tf.stop_gradient(p_ztplus1.mean())
    p_sigma_sq = tf.stop_gradient(p_ztplus1.variance())
    log_r = (tf.log(r_tilde_sigma_sq) -
             tf.log(r_tilde_sigma_sq + p_sigma_sq) -
             tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
    # log_r is [num_samples*batch_size, latent_size]. We sum it along the last
    # dimension to compute log r.
    log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
    # Compute prev log r tilde
    prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde
    prev_log_r_tilde = -0.5*tf.reduce_sum(
        tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
    # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros.
    last_timestep = t >= (tiled_seq_lengths - 1)
    log_r = tf.where(last_timestep,
                     tf.zeros_like(log_r),
                     log_r)
    prev_log_r_tilde = tf.where(last_timestep,
                                tf.zeros_like(prev_log_r_tilde),
                                prev_log_r_tilde)
    log_weight += tf.stop_gradient(log_r - prev_log_r)
    new_state = (log_r, log_r_tilde, new_model_state)
    loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r)
    return log_weight, new_state, loop_fn_args

  def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t):
    if loop_state is None:
      return (tf.zeros([batch_size], dtype=tf.float32),
              tf.zeros([batch_size], dtype=tf.float32),
              tf.zeros([num_samples, batch_size], dtype=tf.float32))
    log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state
    log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args
    # Compute the log_p_hat update
    log_p_hat_update = tf.reduce_logsumexp(
        log_weights, axis=0) - tf.log(tf.to_float(num_samples))
    # If it is the last timestep, we always add the update.
    log_p_hat_acc += tf.cond(t >= max_seq_len-1,
                             lambda: log_p_hat_update,
                             lambda: log_p_hat_update * resampled)
    # Compute the Bellman update.
    log_r = tf.reshape(log_r, [num_samples, batch_size])
    prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size])
    log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size])
    mask = tf.reshape(mask, [num_samples, batch_size])
    # On the first timestep there is no bellman error because there is no
    # prev_log_r_tilde.
    mask = tf.cond(tf.equal(t, 0),
                   lambda: tf.zeros_like(mask),
                   lambda: mask)
    # On the first timestep also fix up prev_log_r_tilde, which will be -inf.
    prev_log_r_tilde = tf.where(
        tf.is_inf(prev_log_r_tilde),
        tf.zeros_like(prev_log_r_tilde),
        prev_log_r_tilde)
    # log_lambda is [num_samples, batch_size]
    log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r,
                                axis=0, keepdims=True)
    bellman_error = mask * tf.square(
        prev_log_r_tilde -
        tf.stop_gradient(log_lambda + log_p_x_given_z + log_r)
    )
    bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0)
    # Compute the log_r_diff update
    log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size])
    return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc)

  log_weights, resampled, accs = smc.smc(
      transition_fn,
      seq_lengths,
      num_particles=num_samples,
      resampling_criterion=resampling_criterion,
      resampling_fn=resampling_fn,
      loop_fn=loop_fn,
      parallel_iterations=parallel_iterations,
      swap_memory=swap_memory)

  log_p_hat, bellman_loss, log_r_diff = accs
  loss_per_seq = [- log_p_hat, bellman_loss]
  tf.summary.scalar("bellman_loss",
                    tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths)))
  tf.summary.scalar("log_r_diff",
                    tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths)))
  return loss_per_seq, log_p_hat, log_weights, resampled