def while_step(t, state, tas, log_weights_acc, log_z_hat_acc): """Implements one timestep of the particle filter.""" particle_state, loop_state = state cur_mask = nested.read_tas(mask_ta, t) # Propagate the particles one step. log_alpha, new_particle_state, loop_args = transition( particle_state, t) # Update the current weights with the incremental weights. log_alpha *= cur_mask # log_alpha_op = tf.print('log alpha: ', log_alpha) # with tf.control_dependencies([log_alpha_op]): log_alpha = tf.reshape(log_alpha, [num_particles, batch_size]) log_weights_acc += log_alpha should_resample = resampling_criterion(log_weights_acc, t) if resampling_criterion == never_resample_criterion: resampled = tf.to_float(should_resample) else: # Compute the states as if we did resample. normalized_log_weights_acc = log_weights_acc - tf.reduce_logsumexp( log_weights_acc, axis=0, keepdims=True) resampled_states = resampling_fn(normalized_log_weights_acc, new_particle_state, num_particles, batch_size) # Decide whether or not we should resample; don't resample if we are past # the end of a sequence. should_resample = tf.logical_and(should_resample, cur_mask[:batch_size] > 0.) float_should_resample = tf.to_float(should_resample) new_particle_state = nested.where_tensors( tf.tile(should_resample, [num_particles]), resampled_states, new_particle_state) resampled = float_should_resample new_loop_state = loop_fn(loop_state, loop_args, new_particle_state, log_weights_acc, resampled, cur_mask, t) # Update log Z hat. log_z_hat_update = tf.reduce_logsumexp( log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles)) # If it is the last timestep, always add the update. log_z_hat_acc += tf.cond(t < max_num_steps - 1, lambda: log_z_hat_update * resampled, lambda: log_z_hat_update) # Update the TensorArrays before we reset the weights so that we capture # the incremental weights and not zeros. ta_updates = [log_weights_acc, resampled] new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)] # For the particle filters that resampled, reset weights to zero. log_weights_acc *= ( 1. - tf.tile(resampled[tf.newaxis, :], [num_particles, 1])) log_weights_acc -= tf.reduce_logsumexp(log_weights_acc, axis=0, keepdims=True) new_state = (new_particle_state, new_loop_state) return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc
def while_step(t, state, tas, log_weights_acc, log_z_hat_acc): """Implements one timestep of the particle filter.""" particle_state, loop_state = state cur_mask = nested.read_tas(mask_ta, t) # Propagate the particles one step. log_alpha, new_particle_state, loop_args = transition(particle_state, t) # Update the current weights with the incremental weights. log_alpha *= cur_mask log_alpha = tf.reshape(log_alpha, [num_particles, batch_size]) log_weights_acc += log_alpha should_resample = resampling_criterion(log_weights_acc, t) if resampling_criterion == never_resample_criterion: resampled = tf.to_float(should_resample) else: # Compute the states as if we did resample. resampled_states = resampling_fn( log_weights_acc, new_particle_state, num_particles, batch_size) # Decide whether or not we should resample; don't resample if we are past # the end of a sequence. should_resample = tf.logical_and(should_resample, cur_mask[:batch_size] > 0.) float_should_resample = tf.to_float(should_resample) new_particle_state = nested.where_tensors( tf.tile(should_resample, [num_particles]), resampled_states, new_particle_state) resampled = float_should_resample new_loop_state = loop_fn(loop_state, loop_args, new_particle_state, log_weights_acc, resampled, cur_mask, t) # Update log Z hat. log_z_hat_update = tf.reduce_logsumexp( log_weights_acc, axis=0) - tf.log(tf.to_float(num_particles)) # If it is the last timestep, always add the update. log_z_hat_acc += tf.cond(t < max_num_steps - 1, lambda: log_z_hat_update * resampled, lambda: log_z_hat_update) # Update the TensorArrays before we reset the weights so that we capture # the incremental weights and not zeros. ta_updates = [log_weights_acc, resampled] new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)] # For the particle filters that resampled, reset weights to zero. log_weights_acc *= (1. - tf.tile(resampled[tf.newaxis, :], [num_particles, 1])) new_state = (new_particle_state, new_loop_state) return t + 1, new_state, new_tas, log_weights_acc, log_z_hat_acc