def test_always_resampling(self): """Test always_resample_criterion makes smc always resample. Past a sequence end the filter should not resample, however. Also check that weights and log_z_hat estimate are correct. """ tf.set_random_seed(1234) with self.test_session() as sess: outs = smc.smc( _simple_transition_fn, num_steps=tf.convert_to_tensor([5, 3]), num_particles=2, resampling_criterion=smc.always_resample_criterion) log_z_hat, weights, resampled = sess.run(outs[0:3]) gt_weights = np.array( [[[5, 1], [4, .5]], [[5, 1], [4, .5]], [[5, 1], [4, .5]], [[5, 1], [0., 0.]], [[5, 1], [0., 0.]]], dtype=np.float32) gt_log_z_hat = np.array( [5*lse([5, 1]) - 5*np.log(2), 3*lse([4, .5]) - 3*np.log(2)], dtype=np.float32) gt_resampled = np.array( [[1, 1], [1, 1], [1, 1], [1, 0], [1, 0]], dtype=np.float32) self.assertAllClose(gt_log_z_hat, log_z_hat) self.assertAllClose(gt_weights, weights) self.assertAllEqual(gt_resampled, resampled)
def test_resampling_on_max_num_steps(self): """Test that everything is correct when resampling on step max_num_steps. When resampling on step max_num_steps (i.e. the last step of the longest sequence), ensure that there are no off-by-one errors preventing resampling and also that the weights are not updated. """ tf.set_random_seed(1234) with self.test_session() as sess: outs = smc.smc( _simple_transition_fn, num_steps=tf.convert_to_tensor([4, 2]), num_particles=2, resampling_criterion=_resample_at_step_criterion(3)) log_z_hat, weights, resampled = sess.run(outs[0:3]) gt_log_z_hat = np.array( [lse([20, 4]) - np.log(2), lse([8, 1]) - np.log(2)], dtype=np.float32) # Ensure that we only resample on the 3rd timestep and that the second # filter doesn't resample at all because it is only run for 2 steps. gt_resampled = np.array( [[0, 0], [0, 0], [0, 0], [1, 0]], dtype=np.float32) gt_weights = np.array( [[[5, 1], [4, .5]], [[10, 2], [8, 1]], [[15, 3], [8, 1]], [[20, 4], [8, 1]]], dtype=np.float32) self.assertAllClose(gt_log_z_hat, log_z_hat) self.assertAllEqual(gt_resampled, resampled) self.assertAllEqual(gt_weights, weights)
def test_never_resampling(self): """Test that never_resample_criterion makes smc not resample. Also test that the weights and log_z_hat are computed correctly when never resampling. """ tf.set_random_seed(1234) with self.test_session() as sess: outs = smc.smc( _simple_transition_fn, num_steps=tf.convert_to_tensor([5, 3]), num_particles=2, resampling_criterion=smc.never_resample_criterion) log_z_hat, weights, resampled = sess.run(outs[0:3]) gt_weights = np.array( [[[5, 1], [4, .5]], [[10, 2], [8, 1]], [[15, 3], [12, 1.5]], [[20, 4], [12, 1.5]], [[25, 5], [12, 1.5]]], dtype=np.float32) gt_log_z_hat = np.array( [lse([25, 5]) - np.log(2), lse([12, 1.5]) - np.log(2)], dtype=np.float32) self.assertAllClose(gt_log_z_hat, log_z_hat) self.assertAllClose(gt_weights, weights) self.assertAllEqual(np.zeros_like(resampled), resampled)
def test_weights_reset_when_resampling_at_sequence_end(self): """Test that the weights are reset when resampling at the sequence end. When resampling happens on the last timestep of a sequence the weights should be set to zero on the next timestep and remain zero afterwards. """ tf.set_random_seed(1234) with self.test_session() as sess: outs = smc.smc( _simple_transition_fn, num_steps=tf.convert_to_tensor([5, 3]), num_particles=2, resampling_criterion=_resample_at_step_criterion(2)) log_z_hat, weights, resampled = sess.run(outs[0:3]) gt_log_z = np.array( [lse([15, 3]) + lse([10, 2]) - 2*np.log(2), lse([12, 1.5]) - np.log(2)], dtype=np.float32) gt_resampled = np.array( [[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]], dtype=np.float32) gt_weights = np.array( [[[5, 1], [4, .5]], [[10, 2], [8, 1]], [[15, 3], [12, 1.5]], [[5, 1], [0, 0]], [[10, 2], [0, 0]]], dtype=np.float32) self.assertAllClose(gt_log_z, log_z_hat) self.assertAllEqual(gt_resampled, resampled) self.assertAllEqual(gt_weights, weights)
def test_weights_not_updated_past_sequence_end(self): """Test that non-zero weights are not updated past the end of a sequence.""" tf.set_random_seed(1234) with self.test_session() as sess: outs = smc.smc( _simple_transition_fn, num_steps=tf.convert_to_tensor([6, 4]), num_particles=2, resampling_criterion=_resample_at_step_criterion(1)) log_z_hat, weights, resampled = sess.run(outs[0:3]) gt_log_z_hat = np.array( [lse([10, 2]) + lse([20, 4]) - 2*np.log(2), lse([8, 1]) + lse([8, 1]) - 2*np.log(2)], dtype=np.float32) # Ensure that we only resample on the 2nd timestep. gt_resampled = np.array( [[0, 0], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0]], dtype=np.float32) # Ensure that the weights after the end of the sequence don't change. # Ensure that the weights after resampling before the end of the sequence # do change. gt_weights = np.array( [[[5, 1], [4, .5]], [[10, 2], [8, 1]], [[5, 1], [4, .5]], [[10, 2], [8, 1]], [[15, 3], [8, 1]], [[20, 4], [8, 1]]], dtype=np.float32) self.assertAllClose(gt_log_z_hat, log_z_hat) self.assertAllEqual(gt_resampled, resampled) self.assertAllEqual(gt_weights, weights)
def fivo(model, observations, seq_lengths, num_samples=1, resampling_criterion=smc.ess_criterion, resampling_type='multinomial', sinkhorn_regularization=0.01, relaxed_resampling_temperature=0.5, parallel_iterations=30, swap_memory=True, random_seed=None): """Computes the FIVO lower bound on the log marginal probability. This method accepts a stochastic latent variable model and some observations and computes a stochastic lower bound on the log marginal probability of the observations. The lower bound is defined by a particle filter's unbiased estimate of the marginal probability of the observations. For more details see "Filtering Variational Objectives" by Maddison et al. https://arxiv.org/abs/1705.09279. When the resampling criterion is "never resample", this bound becomes IWAE. Args: model: A subclass of ELBOTrainableSequenceModel that implements one timestep of the model. See models/vrnn.py for an example. observations: The inputs to the model. A potentially nested list or tuple of Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must have a rank at least two and have matching shapes in the first two dimensions, which represent time and the batch respectively. The model will be provided with the observations before computing the bound. seq_lengths: A [batch_size] Tensor of ints encoding the length of each sequence in the batch (sequences can be padded to a common length). num_samples: The number of particles to use in each particle filter. resampling_criterion: The resampling criterion to use for this particle filter. Must accept the number of samples, the current log weights, and the current timestep and return a boolean Tensor of shape [batch_size] indicating whether each particle filter should resample. See ess_criterion and related functions for examples. When resampling_criterion is never_resample_criterion, resampling_fn is ignored and never called. resampling_type: The type of resampling, one of "multinomial" or "relaxed". relaxed_resampling_temperature: A positive temperature only used for relaxed resampling. parallel_iterations: The number of parallel iterations to use for the internal while loop. Note that values greater than 1 can introduce non-determinism even when random_seed is provided. swap_memory: Whether GPU-CPU memory swapping should be enabled for the internal while loop. random_seed: The random seed to pass to the resampling operations in the particle filter. Mainly useful for testing. Returns: log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the log marginal probability of the observations. log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples] containing the log weights at each timestep of the particle filter. Note that on timesteps when a resampling operation is performed the log weights are reset to 0. Will not be valid for timesteps past the end of a sequence. resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the particle filters resampled. Will be 1.0 on timesteps when resampling occurred and 0.0 on timesteps when it did not. """ # batch_size is the number of particle filters running in parallel. batch_size = tf.shape(seq_lengths)[0] # Each sequence in the batch will be the input data for a different # particle filter. The batch will be laid out as: # particle 1 of particle filter 1 # particle 1 of particle filter 2 # ... # particle 1 of particle filter batch_size # particle 2 of particle filter 1 # ... # particle num_samples of particle filter batch_size observations = nested.tile_tensors( observations, [1, num_samples]) # tuple of inputs, targets > shape (S,B*N,D). tiled_seq_lengths = tf.tile(seq_lengths, [num_samples]) # shape (B*N° model.set_observations(observations, tiled_seq_lengths) if resampling_type == 'multinomial': resampling_fn = smc.multinomial_resampling elif resampling_type == 'relaxed': resampling_fn = functools.partial( smc.relaxed_resampling, temperature=relaxed_resampling_temperature) elif resampling_type == 'differentiable': resampling_fn = get_transport_fun(sinkhorn_regularization, 1e-2, 100) else: raise NotImplementedError resampling_fn = functools.partial(resampling_fn, random_seed=random_seed) def transition_fn(prev_state, t): if prev_state is None: return model.zero_state(batch_size * num_samples, tf.float32) return model.propose_and_weight(prev_state, t) log_p_hat, log_weights, resampled, final_state, _ = smc.smc( transition_fn, seq_lengths, num_particles=num_samples, resampling_criterion=resampling_criterion, resampling_fn=resampling_fn, parallel_iterations=parallel_iterations, swap_memory=swap_memory) return log_p_hat, log_weights, resampled, final_state
def fivo_aux_td(model, observations, seq_lengths, num_samples=1, resampling_criterion=smc.ess_criterion, resampling_type='multinomial', relaxed_resampling_temperature=0.5, parallel_iterations=30, swap_memory=True, random_seed=None): """Experimental.""" # batch_size is the number of particle filters running in parallel. batch_size = tf.shape(seq_lengths)[0] max_seq_len = tf.reduce_max(seq_lengths) # Each sequence in the batch will be the input data for a different # particle filter. The batch will be laid out as: # particle 1 of particle filter 1 # particle 1 of particle filter 2 # ... # particle 1 of particle filter batch_size # particle 2 of particle filter 1 # ... # particle num_samples of particle filter batch_size observations = nested.tile_tensors(observations, [1, num_samples]) tiled_seq_lengths = tf.tile(seq_lengths, [num_samples]) model.set_observations(observations, tiled_seq_lengths) if resampling_type == 'multinomial': resampling_fn = smc.multinomial_resampling elif resampling_type == 'relaxed': resampling_fn = functools.partial( smc.relaxed_resampling, temperature=relaxed_resampling_temperature) resampling_fn = functools.partial(resampling_fn, random_seed=random_seed) def transition_fn(prev_state, t): if prev_state is None: model_init_state = model.zero_state(batch_size * num_samples, tf.float32) return (tf.zeros([num_samples * batch_size], dtype=tf.float32), (tf.zeros([num_samples * batch_size, model.latent_size], dtype=tf.float32), tf.zeros([num_samples * batch_size, model.latent_size], dtype=tf.float32)), model_init_state) prev_log_r, prev_log_r_tilde, prev_model_state = prev_state (new_model_state, zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t) r_tilde_mu, r_tilde_sigma_sq = log_r_tilde # Compute the weight without r. log_weight = log_p_zt + log_p_x_given_z - log_q_zt # Compute log_r and log_r_tilde. p_mu = tf.stop_gradient(p_ztplus1.mean()) p_sigma_sq = tf.stop_gradient(p_ztplus1.variance()) log_r = (tf.log(r_tilde_sigma_sq) - tf.log(r_tilde_sigma_sq + p_sigma_sq) - tf.square(r_tilde_mu - p_mu) / (r_tilde_sigma_sq + p_sigma_sq)) # log_r is [num_samples*batch_size, latent_size]. We sum it along the last # dimension to compute log r. log_r = 0.5 * tf.reduce_sum(log_r, axis=-1) # Compute prev log r tilde prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde prev_log_r_tilde = -0.5 * tf.reduce_sum( tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu) / prev_r_tilde_sigma_sq, axis=-1) # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros. last_timestep = t >= (tiled_seq_lengths - 1) log_r = tf.where(last_timestep, tf.zeros_like(log_r), log_r) prev_log_r_tilde = tf.where(last_timestep, tf.zeros_like(prev_log_r_tilde), prev_log_r_tilde) log_weight += tf.stop_gradient(log_r - prev_log_r) new_state = (log_r, log_r_tilde, new_model_state) loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r) return log_weight, new_state, loop_fn_args def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t): if loop_state is None: return (tf.zeros([batch_size], dtype=tf.float32), tf.zeros([batch_size], dtype=tf.float32), tf.zeros([num_samples, batch_size], dtype=tf.float32)) log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args # Compute the log_p_hat update log_p_hat_update = tf.reduce_logsumexp(log_weights, axis=0) - tf.log( tf.to_float(num_samples)) # If it is the last timestep, we always add the update. log_p_hat_acc += tf.cond(t >= max_seq_len - 1, lambda: log_p_hat_update, lambda: log_p_hat_update * resampled) # Compute the Bellman update. log_r = tf.reshape(log_r, [num_samples, batch_size]) prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size]) log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size]) mask = tf.reshape(mask, [num_samples, batch_size]) # On the first timestep there is no bellman error because there is no # prev_log_r_tilde. mask = tf.cond(tf.equal(t, 0), lambda: tf.zeros_like(mask), lambda: mask) # On the first timestep also fix up prev_log_r_tilde, which will be -inf. prev_log_r_tilde = tf.where(tf.is_inf(prev_log_r_tilde), tf.zeros_like(prev_log_r_tilde), prev_log_r_tilde) # log_lambda is [num_samples, batch_size] log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r, axis=0, keepdims=True) bellman_error = mask * tf.square(prev_log_r_tilde - tf.stop_gradient( log_lambda + log_p_x_given_z + log_r)) bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0) # Compute the log_r_diff update log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size]) return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc) log_weights, resampled, accs = smc.smc( transition_fn, seq_lengths, num_particles=num_samples, resampling_criterion=resampling_criterion, resampling_fn=resampling_fn, loop_fn=loop_fn, parallel_iterations=parallel_iterations, swap_memory=swap_memory) log_p_hat, bellman_loss, log_r_diff = accs loss_per_seq = [-log_p_hat, bellman_loss] tf.summary.scalar("bellman_loss", tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths))) tf.summary.scalar( "log_r_diff", tf.reduce_mean( tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths))) return loss_per_seq, log_p_hat, log_weights, resampled
def fivo(model, observations, seq_lengths, num_samples=1, resampling_criterion=smc.ess_criterion, resampling_type='multinomial', relaxed_resampling_temperature=0.5, parallel_iterations=30, swap_memory=True, random_seed=None): """Computes the FIVO lower bound on the log marginal probability. This method accepts a stochastic latent variable model and some observations and computes a stochastic lower bound on the log marginal probability of the observations. The lower bound is defined by a particle filter's unbiased estimate of the marginal probability of the observations. For more details see "Filtering Variational Objectives" by Maddison et al. https://arxiv.org/abs/1705.09279. When the resampling criterion is "never resample", this bound becomes IWAE. Args: model: A subclass of ELBOTrainableSequenceModel that implements one timestep of the model. See models/vrnn.py for an example. observations: The inputs to the model. A potentially nested list or tuple of Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must have a rank at least two and have matching shapes in the first two dimensions, which represent time and the batch respectively. The model will be provided with the observations before computing the bound. seq_lengths: A [batch_size] Tensor of ints encoding the length of each sequence in the batch (sequences can be padded to a common length). num_samples: The number of particles to use in each particle filter. resampling_criterion: The resampling criterion to use for this particle filter. Must accept the number of samples, the current log weights, and the current timestep and return a boolean Tensor of shape [batch_size] indicating whether each particle filter should resample. See ess_criterion and related functions for examples. When resampling_criterion is never_resample_criterion, resampling_fn is ignored and never called. resampling_type: The type of resampling, one of "multinomial" or "relaxed". relaxed_resampling_temperature: A positive temperature only used for relaxed resampling. parallel_iterations: The number of parallel iterations to use for the internal while loop. Note that values greater than 1 can introduce non-determinism even when random_seed is provided. swap_memory: Whether GPU-CPU memory swapping should be enabled for the internal while loop. random_seed: The random seed to pass to the resampling operations in the particle filter. Mainly useful for testing. Returns: log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the log marginal probability of the observations. log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples] containing the log weights at each timestep of the particle filter. Note that on timesteps when a resampling operation is performed the log weights are reset to 0. Will not be valid for timesteps past the end of a sequence. resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the particle filters resampled. Will be 1.0 on timesteps when resampling occurred and 0.0 on timesteps when it did not. """ # batch_size is the number of particle filters running in parallel. batch_size = tf.shape(seq_lengths)[0] # Each sequence in the batch will be the input data for a different # particle filter. The batch will be laid out as: # particle 1 of particle filter 1 # particle 1 of particle filter 2 # ... # particle 1 of particle filter batch_size # particle 2 of particle filter 1 # ... # particle num_samples of particle filter batch_size observations = nested.tile_tensors(observations, [1, num_samples]) tiled_seq_lengths = tf.tile(seq_lengths, [num_samples]) model.set_observations(observations, tiled_seq_lengths) if resampling_type == 'multinomial': resampling_fn = smc.multinomial_resampling elif resampling_type == 'relaxed': resampling_fn = functools.partial( smc.relaxed_resampling, temperature=relaxed_resampling_temperature) resampling_fn = functools.partial(resampling_fn, random_seed=random_seed) def transition_fn(prev_state, t): if prev_state is None: return model.zero_state(batch_size * num_samples, tf.float32) return model.propose_and_weight(prev_state, t) log_p_hat, log_weights, resampled, final_state, _ = smc.smc( transition_fn, seq_lengths, num_particles=num_samples, resampling_criterion=resampling_criterion, resampling_fn=resampling_fn, parallel_iterations=parallel_iterations, swap_memory=swap_memory) return log_p_hat, log_weights, resampled, final_state
def fivo_aux_td( model, observations, seq_lengths, num_samples=1, resampling_criterion=smc.ess_criterion, resampling_type='multinomial', relaxed_resampling_temperature=0.5, parallel_iterations=30, swap_memory=True, random_seed=None): """Experimental.""" # batch_size is the number of particle filters running in parallel. batch_size = tf.shape(seq_lengths)[0] max_seq_len = tf.reduce_max(seq_lengths) # Each sequence in the batch will be the input data for a different # particle filter. The batch will be laid out as: # particle 1 of particle filter 1 # particle 1 of particle filter 2 # ... # particle 1 of particle filter batch_size # particle 2 of particle filter 1 # ... # particle num_samples of particle filter batch_size observations = nested.tile_tensors(observations, [1, num_samples]) tiled_seq_lengths = tf.tile(seq_lengths, [num_samples]) model.set_observations(observations, tiled_seq_lengths) if resampling_type == 'multinomial': resampling_fn = smc.multinomial_resampling elif resampling_type == 'relaxed': resampling_fn = functools.partial( smc.relaxed_resampling, temperature=relaxed_resampling_temperature) resampling_fn = functools.partial(resampling_fn, random_seed=random_seed) def transition_fn(prev_state, t): if prev_state is None: model_init_state = model.zero_state(batch_size * num_samples, tf.float32) return (tf.zeros([num_samples*batch_size], dtype=tf.float32), (tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32), tf.zeros([num_samples*batch_size, model.latent_size], dtype=tf.float32)), model_init_state) prev_log_r, prev_log_r_tilde, prev_model_state = prev_state (new_model_state, zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_tilde, p_ztplus1) = model(prev_model_state, t) r_tilde_mu, r_tilde_sigma_sq = log_r_tilde # Compute the weight without r. log_weight = log_p_zt + log_p_x_given_z - log_q_zt # Compute log_r and log_r_tilde. p_mu = tf.stop_gradient(p_ztplus1.mean()) p_sigma_sq = tf.stop_gradient(p_ztplus1.variance()) log_r = (tf.log(r_tilde_sigma_sq) - tf.log(r_tilde_sigma_sq + p_sigma_sq) - tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq)) # log_r is [num_samples*batch_size, latent_size]. We sum it along the last # dimension to compute log r. log_r = 0.5*tf.reduce_sum(log_r, axis=-1) # Compute prev log r tilde prev_r_tilde_mu, prev_r_tilde_sigma_sq = prev_log_r_tilde prev_log_r_tilde = -0.5*tf.reduce_sum( tf.square(tf.stop_gradient(zt) - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1) # If the sequence is on the last timestep, log_r and log_r_tilde are just zeros. last_timestep = t >= (tiled_seq_lengths - 1) log_r = tf.where(last_timestep, tf.zeros_like(log_r), log_r) prev_log_r_tilde = tf.where(last_timestep, tf.zeros_like(prev_log_r_tilde), prev_log_r_tilde) log_weight += tf.stop_gradient(log_r - prev_log_r) new_state = (log_r, log_r_tilde, new_model_state) loop_fn_args = (log_r, prev_log_r_tilde, log_p_x_given_z, log_r - prev_log_r) return log_weight, new_state, loop_fn_args def loop_fn(loop_state, loop_args, unused_model_state, log_weights, resampled, mask, t): if loop_state is None: return (tf.zeros([batch_size], dtype=tf.float32), tf.zeros([batch_size], dtype=tf.float32), tf.zeros([num_samples, batch_size], dtype=tf.float32)) log_p_hat_acc, bellman_loss_acc, log_r_diff_acc = loop_state log_r, prev_log_r_tilde, log_p_x_given_z, log_r_diff = loop_args # Compute the log_p_hat update log_p_hat_update = tf.reduce_logsumexp( log_weights, axis=0) - tf.log(tf.to_float(num_samples)) # If it is the last timestep, we always add the update. log_p_hat_acc += tf.cond(t >= max_seq_len-1, lambda: log_p_hat_update, lambda: log_p_hat_update * resampled) # Compute the Bellman update. log_r = tf.reshape(log_r, [num_samples, batch_size]) prev_log_r_tilde = tf.reshape(prev_log_r_tilde, [num_samples, batch_size]) log_p_x_given_z = tf.reshape(log_p_x_given_z, [num_samples, batch_size]) mask = tf.reshape(mask, [num_samples, batch_size]) # On the first timestep there is no bellman error because there is no # prev_log_r_tilde. mask = tf.cond(tf.equal(t, 0), lambda: tf.zeros_like(mask), lambda: mask) # On the first timestep also fix up prev_log_r_tilde, which will be -inf. prev_log_r_tilde = tf.where( tf.is_inf(prev_log_r_tilde), tf.zeros_like(prev_log_r_tilde), prev_log_r_tilde) # log_lambda is [num_samples, batch_size] log_lambda = tf.reduce_mean(prev_log_r_tilde - log_p_x_given_z - log_r, axis=0, keepdims=True) bellman_error = mask * tf.square( prev_log_r_tilde - tf.stop_gradient(log_lambda + log_p_x_given_z + log_r) ) bellman_loss_acc += tf.reduce_mean(bellman_error, axis=0) # Compute the log_r_diff update log_r_diff_acc += mask * tf.reshape(log_r_diff, [num_samples, batch_size]) return (log_p_hat_acc, bellman_loss_acc, log_r_diff_acc) log_weights, resampled, accs = smc.smc( transition_fn, seq_lengths, num_particles=num_samples, resampling_criterion=resampling_criterion, resampling_fn=resampling_fn, loop_fn=loop_fn, parallel_iterations=parallel_iterations, swap_memory=swap_memory) log_p_hat, bellman_loss, log_r_diff = accs loss_per_seq = [- log_p_hat, bellman_loss] tf.summary.scalar("bellman_loss", tf.reduce_mean(bellman_loss / tf.to_float(seq_lengths))) tf.summary.scalar("log_r_diff", tf.reduce_mean(tf.reduce_mean(log_r_diff, axis=0) / tf.to_float(seq_lengths))) return loss_per_seq, log_p_hat, log_weights, resampled