예제 #1
0
    def while_step(t, state, tas, log_weights_acc, log_z_hat_acc):
        """Implements one timestep of the particle filter."""
        particle_state, loop_state = state
        cur_mask = nested.read_tas(mask_ta, t)
        # Propagate the particles one step.
        log_alpha, new_particle_state, loop_args = transition(
            particle_state, t)
        # Update the current weights with the incremental weights.
        log_alpha *= cur_mask
        # log_alpha_op = tf.print('log alpha: ', log_alpha)
        # with tf.control_dependencies([log_alpha_op]):
        log_alpha = tf.reshape(log_alpha, [num_particles, batch_size])
        log_weights_acc += log_alpha

        should_resample = resampling_criterion(log_weights_acc, t)

        if resampling_criterion == never_resample_criterion:
            resampled = tf.to_float(should_resample)
        else:
            # Compute the states as if we did resample.
            normalized_log_weights_acc = log_weights_acc - tf.reduce_logsumexp(
                log_weights_acc, axis=0, keepdims=True)
            resampled_states = resampling_fn(normalized_log_weights_acc,
                                             new_particle_state, num_particles,
                                             batch_size)
            # Decide whether or not we should resample; don't resample if we are past
            # the end of a sequence.
            should_resample = tf.logical_and(should_resample,
                                             cur_mask[:batch_size] > 0.)
            float_should_resample = tf.to_float(should_resample)
            new_particle_state = nested.where_tensors(
                tf.tile(should_resample, [num_particles]), resampled_states,
                new_particle_state)
            resampled = float_should_resample

        new_loop_state = loop_fn(loop_state, loop_args, new_particle_state,
                                 log_weights_acc, resampled, cur_mask, t)
        # Update log Z hat.
        log_z_hat_update = tf.reduce_logsumexp(
            log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles))
        # If it is the last timestep, always add the update.
        log_z_hat_acc += tf.cond(t < max_num_steps - 1,
                                 lambda: log_z_hat_update * resampled,
                                 lambda: log_z_hat_update)
        # Update the TensorArrays before we reset the weights so that we capture
        # the incremental weights and not zeros.
        ta_updates = [log_weights_acc, resampled]
        new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
        # For the particle filters that resampled, reset weights to zero.
        log_weights_acc *= (
            1. - tf.tile(resampled[tf.newaxis, :], [num_particles, 1]))
        log_weights_acc -= tf.reduce_logsumexp(log_weights_acc,
                                               axis=0,
                                               keepdims=True)

        new_state = (new_particle_state, new_loop_state)
        return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc
예제 #2
0
파일: smc.py 프로젝트: 812864539/models
  def while_step(t, state, tas, log_weights_acc, log_z_hat_acc):
    """Implements one timestep of the particle filter."""
    particle_state, loop_state = state
    cur_mask = nested.read_tas(mask_ta, t)
    # Propagate the particles one step.
    log_alpha, new_particle_state, loop_args = transition(particle_state, t)
    # Update the current weights with the incremental weights.
    log_alpha *= cur_mask
    log_alpha = tf.reshape(log_alpha, [num_particles, batch_size])
    log_weights_acc += log_alpha

    should_resample = resampling_criterion(log_weights_acc, t)

    if resampling_criterion == never_resample_criterion:
      resampled = tf.to_float(should_resample)
    else:
      # Compute the states as if we did resample.
      resampled_states = resampling_fn(
          log_weights_acc,
          new_particle_state,
          num_particles,
          batch_size)
      # Decide whether or not we should resample; don't resample if we are past
      # the end of a sequence.
      should_resample = tf.logical_and(should_resample,
                                       cur_mask[:batch_size] > 0.)
      float_should_resample = tf.to_float(should_resample)
      new_particle_state = nested.where_tensors(
          tf.tile(should_resample, [num_particles]),
          resampled_states,
          new_particle_state)
      resampled = float_should_resample

    new_loop_state = loop_fn(loop_state, loop_args, new_particle_state,
                             log_weights_acc, resampled, cur_mask, t)
    # Update log Z hat.
    log_z_hat_update = tf.reduce_logsumexp(
        log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles))
    # If it is the last timestep, always add the update.
    log_z_hat_acc += tf.cond(t < max_num_steps - 1,
                             lambda: log_z_hat_update * resampled,
                             lambda: log_z_hat_update)
    # Update the TensorArrays before we reset the weights so that we capture
    # the incremental weights and not zeros.
    ta_updates = [log_weights_acc, resampled]
    new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
    # For the particle filters that resampled, reset weights to zero.
    log_weights_acc *= (1. - tf.tile(resampled[tf.newaxis, :],
                                     [num_particles, 1]))
    new_state = (new_particle_state, new_loop_state)
    return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc