Exemple #1
0
 def while_step(t, rnn_state, tas, accs):
     """Implements one timestep of IWAE computation."""
     log_weights_acc, kl_acc = accs
     cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
     # Run the cell for one step.
     log_q_z, log_p_z, log_p_x_given_z, kl, new_state, new_rnn_out\
                                                  = cell(cur_inputs,
                                                         rnn_state,
                                                         cur_mask,
                                                         )
     # Compute the incremental weight and use it to update the current
     # accumulated weight.
     kl_acc += kl * cur_mask
     log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
     log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
     log_weights_acc += log_alpha 
     # Calculate the effective sample size.
     ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
     ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
     log_ess = ess_num - ess_denom
     # Update the  Tensorarrays and accumulators.
     ta_updates = [log_weights_acc, log_ess]
     new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
     new_accs = (log_weights_acc, kl_acc)
     return t + 1, new_state, new_tas, new_accs
Exemple #2
0
 def while_step(t, rnn_state, tas, accs):
   """Implements one timestep of IWAE computation."""
   log_weights_acc, kl_acc = accs
   cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
   # Run the cell for one step.
   log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
       cur_inputs,
       rnn_state,
       cur_mask,
   )
   # Compute the incremental weight and use it to update the current
   # accumulated weight.
   kl_acc += kl * cur_mask
   log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
   log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
   log_weights_acc += log_alpha
   # Calculate the effective sample size.
   ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
   ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
   log_ess = ess_num - ess_denom
   # Update the  Tensorarrays and accumulators.
   ta_updates = [log_weights_acc, log_ess]
   new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
   new_accs = (log_weights_acc, kl_acc)
   return t + 1, new_state, new_tas, new_accs
Exemple #3
0
 def while_step(t, rnn_state, tas, accs):
   """Implements one timestep of FIVO computation."""
   log_weights_acc, log_p_hat_acc, kl_acc = accs
   cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
   # Run the cell for one step.
   log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
       cur_inputs,
       rnn_state,
       cur_mask,
   )
   # Compute the incremental weight and use it to update the current
   # accumulated weight.
   kl_acc += kl * cur_mask
   log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
   log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
   log_weights_acc += log_alpha
   # Calculate the effective sample size.
   ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
   ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
   log_ess = ess_num - ess_denom
   # Calculate the ancestor indices via resampling. Because we maintain the
   # log unnormalized weights, we pass the weights in as logits, allowing
   # the distribution object to apply a softmax and normalize them.
   resampling_dist = tf.contrib.distributions.Categorical(
       logits=tf.transpose(log_weights_acc, perm=[1, 0]))
   ancestor_inds = tf.stop_gradient(
       resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
   # Because the batch is flattened and laid out as discussed
   # above, we must modify ancestor_inds to index the proper samples.
   # The particles in the ith filter are distributed every batch_size rows
   # in the batch, and offset i rows from the top. So, to correct the indices
   # we multiply by the batch_size and add the proper offset. Crucially,
   # when ancestor_inds is flattened the layout of the batch is maintained.
   offset = tf.expand_dims(tf.range(batch_size), 0)
   ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
   noresample_inds = tf.range(num_samples * batch_size)
   # Decide whether or not we should resample; don't resample if we are past
   # the end of a sequence.
   should_resample = resampling_criterion(num_samples, log_ess, t)
   should_resample = tf.logical_and(should_resample,
                                    cur_mask[:batch_size] > 0.)
   float_should_resample = tf.to_float(should_resample)
   ancestor_inds = tf.where(
       tf.tile(should_resample, [num_samples]),
       ancestor_inds,
       noresample_inds)
   new_state = nested.gather_tensors(new_state, ancestor_inds)
   # Update the TensorArrays before we reset the weights so that we capture
   # the incremental weights and not zeros.
   ta_updates = [log_weights_acc, log_ess, float_should_resample]
   new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
   # For the particle filters that resampled, update log_p_hat and
   # reset weights to zero.
   log_p_hat_update = tf.reduce_logsumexp(
       log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
   log_p_hat_acc += log_p_hat_update * float_should_resample
   log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :],
                                    [num_samples, 1]))
   new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
   return t + 1, new_state, new_tas, new_accs
Exemple #4
0
 def while_step(t, rnn_state, tas, accs):
     """Implements one timestep of FIVO computation."""
     log_weights_acc, log_p_hat_acc, kl_acc = accs
     cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
     # Run the cell for one step.
     log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
         cur_inputs,
         rnn_state,
         cur_mask,
     )
     # Compute the incremental weight and use it to update the current
     # accumulated weight.
     kl_acc += kl * cur_mask
     log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
     log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
     log_weights_acc += log_alpha
     # Calculate the effective sample size.
     ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
     ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
     log_ess = ess_num - ess_denom
     # Calculate the ancestor indices via resampling. Because we maintain the
     # log unnormalized weights, we pass the weights in as logits, allowing
     # the distribution object to apply a softmax and normalize them.
     resampling_dist = tf.contrib.distributions.Categorical(
         logits=tf.transpose(log_weights_acc, perm=[1, 0]))
     ancestor_inds = tf.stop_gradient(
         resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
     # Because the batch is flattened and laid out as discussed
     # above, we must modify ancestor_inds to index the proper samples.
     # The particles in the ith filter are distributed every batch_size rows
     # in the batch, and offset i rows from the top. So, to correct the indices
     # we multiply by the batch_size and add the proper offset. Crucially,
     # when ancestor_inds is flattened the layout of the batch is maintained.
     offset = tf.expand_dims(tf.range(batch_size), 0)
     ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
     noresample_inds = tf.range(num_samples * batch_size)
     # Decide whether or not we should resample; don't resample if we are past
     # the end of a sequence.
     should_resample = resampling_criterion(num_samples, log_ess, t)
     should_resample = tf.logical_and(should_resample,
                                      cur_mask[:batch_size] > 0.)
     float_should_resample = tf.to_float(should_resample)
     ancestor_inds = tf.where(tf.tile(should_resample, [num_samples]),
                              ancestor_inds, noresample_inds)
     new_state = nested.gather_tensors(new_state, ancestor_inds)
     # Update the TensorArrays before we reset the weights so that we capture
     # the incremental weights and not zeros.
     ta_updates = [log_weights_acc, log_ess, float_should_resample]
     new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
     # For the particle filters that resampled, update log_p_hat and
     # reset weights to zero.
     log_p_hat_update = tf.reduce_logsumexp(
         log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
     log_p_hat_acc += log_p_hat_update * float_should_resample
     log_weights_acc *= (
         1. -
         tf.tile(float_should_resample[tf.newaxis, :], [num_samples, 1]))
     new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
     return t + 1, new_state, new_tas, new_accs
Exemple #5
0
def elbo(cell,
         inputs,
         seq_lengths,
         num_samples=1,
         parallel_iterations=30,
         swap_memory=True):
    batch_size = tf.shape(seq_lengths)[0]
    max_seq_len = tf.reduce_max(seq_lengths)
    data_dim = tf.shape(inputs[0])[-1]
    seq_mask = tf.transpose(
            tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
            perm=[1, 0])
    if num_samples > 1:
        inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
    inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)

    t0 = tf.constant(0, tf.int32)
    init_states = cell.zero_state(batch_size * num_samples, tf.float32)
    init_inputs, init_mask = nested.read_tas([inputs_ta, mask_ta], t0)
    ta_names = ['log_weights', 'log_ess']
    tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
             for n in ta_names]
    log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
    kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
    accs = (log_weights_acc, kl_acc)

    def while_predicate(t, *unused_args):
        return t < max_seq_len

    def while_step(t, rnn_state, tas, accs):
        """Implements one timestep of IWAE computation."""
        log_weights_acc, kl_acc = accs
        cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
        # Run the cell for one step.
        log_q_z, log_p_z, log_p_x_given_z, kl, new_state, new_rnn_out\
                                                     = cell(cur_inputs,
                                                            rnn_state,
                                                            cur_mask,
                                                            )
        # Compute the incremental weight and use it to update the current
        # accumulated weight.
        kl_acc += kl * cur_mask
        log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
        log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
        log_weights_acc += log_alpha 
        # Calculate the effective sample size.
        ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
        ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
        log_ess = ess_num - ess_denom
        # Update the  Tensorarrays and accumulators.
        ta_updates = [log_weights_acc, log_ess]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
        new_accs = (log_weights_acc, kl_acc)
        return t + 1, new_state, new_tas, new_accs
    
    _, _, tas, accs = tf.while_loop(while_predicate,
                                    while_step,
                                    loop_vars=(t0, init_states, tas, accs),
                                    parallel_iterations=parallel_iterations,
                                    swap_memory=swap_memory)

    log_weights, log_ess = [x.stack() for x in tas]
    final_log_weights, kl = accs
    log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
                                 tf.log(tf.to_float(num_samples)))
    kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
    log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
    return log_p_hat, kl, log_weights, log_ess
Exemple #6
0
    def while_step(t, rnn_state, tas, accs, while_samples):
        """Implements one timestep of IWAE computation."""
        if config.bound == "elbo":
            log_weights_acc, kl_acc = accs
        elif config.bound == "fivo":
            log_weights_acc, log_p_hat_acc, kl_acc = accs
        cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)

        if config.missing_data:
            cur_inputs = tf.cond(
                tf.logical_and(t < max_seq_len - 6, t >= max_seq_len - 18),
                lambda: while_samples, lambda: cur_inputs)

        # Run the cell for one step.
        log_q_z, log_p_z, log_p_x_given_z, kl, new_state, new_rnn_out, dists_return\
                                                        = model(cur_inputs,
                                                                rnn_state,
                                                                cur_mask,
                                                                return_value = "probs"
                                                                )

        new_sample0 = dists.sample_from_probs(dists_return, config.lat_bins,
                                              config.lon_bins, config.sog_bins,
                                              config.cog_bins)
        new_sample0 = tf.cast(new_sample0, tf.float32)
        new_sample_ = (new_sample0, tf.zeros_like(new_sample0,
                                                  dtype=tf.float32))

        # Compute the incremental weight and use it to update the current
        # accumulated weight
        kl_acc += kl * cur_mask
        log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
        log_alpha = tf.reshape(log_alpha, [config.num_samples, batch_size])
        log_weights_acc += log_alpha
        # Calculate the effective sample size.
        ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
        ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
        log_ess = ess_num - ess_denom
        if config.bound == "fivo":
            # Calculate the ancestor indices via resampling. Because we maintain the
            # log unnormalized weights, we pass the weights in as logits, allowing
            # the distribution object to apply a softmax and normalize them.
            resampling_dist = tf.contrib.distributions.Categorical(
                logits=tf.transpose(log_weights_acc, perm=[1, 0]))
            ancestor_inds = tf.stop_gradient(
                resampling_dist.sample(sample_shape=num_samples,
                                       seed=config.random_seed))
            # Because the batch is flattened and laid out as discussed
            # above, we must modify ancestor_inds to index the proper samples.
            # The particles in the ith filter are distributed every batch_size rows
            # in the batch, and offset i rows from the top. So, to correct the indices
            # we multiply by the batch_size and add the proper offset. Crucially,
            # when ancestor_inds is flattened the layout of the batch is maintained.
            offset = tf.expand_dims(tf.range(batch_size), 0)
            ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset,
                                       [-1])
            noresample_inds = tf.range(num_samples * batch_size)
            # Decide whether or not we should resample; don't resample if we are past
            # the end of a sequence.
            should_resample = resampling_criterion(num_samples, log_ess, t)
            should_resample = tf.logical_and(should_resample,
                                             cur_mask[:batch_size] > 0.)
            float_should_resample = tf.to_float(should_resample)
            ancestor_inds = tf.where(tf.tile(should_resample, [num_samples]),
                                     ancestor_inds, noresample_inds)
            new_state = nested.gather_tensors(new_state, ancestor_inds)
            new_sample_ = nested.gather_tensors(new_sample_, ancestor_inds)

        # Update the  Tensorarrays and accumulators.
        ta_updates = [
            log_alpha, new_sample_[0], new_sample_[1], new_state[0],
            new_state[1], new_rnn_out
        ]
        #    ta_updates = [log_weights_acc, log_ess]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]

        if config.bound == "fivo":
            # For the particle filters that resampled, update log_p_hat and
            # reset weights to zero.
            log_p_hat_update = tf.reduce_logsumexp(
                log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
            log_p_hat_acc += log_p_hat_update * float_should_resample
            log_weights_acc *= (1. - tf.tile(
                float_should_resample[tf.newaxis, :], [num_samples, 1]))
            new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
        elif config.bound == "elbo":
            new_accs = (log_weights_acc, kl_acc)
        return t + 1, new_state, new_tas, new_accs, new_sample_