def step(self, obs): if obs.ndim < 2: obs = obs[np.newaxis, :] action_mean = self.actor(obs) dist = MultivariateNormalTriL(loc=action_mean, scale_tril=self.cov_mat) action = dist.sample() return action.numpy()[0], dist.log_prob(action)
def infer(self, obs, act): if obs.ndim < 2: obs = obs[np.newaxis, :] action_mean = self.actor(obs) dist = MultivariateNormalTriL(loc=action_mean, scale_tril=self.cov_mat) action_logprobs = dist.log_prob(act) dist_entropy = dist.entropy() q_value = self.critic(obs) return action_logprobs, tf.squeeze(q_value), dist_entropy
def update(m, P, ell): S = H @ mm(P, H, transpose_b=True) + R yp = mv(H, m) chol = tf.linalg.cholesky(S) predicted_dist = MultivariateNormalTriL(yp, chol) ell_t = predicted_dist.log_prob(y) Kt = tf.linalg.cholesky_solve(chol, H @ P) m = m + mv(Kt, y - yp, transpose_a=True) P = P - mm(Kt, S, transpose_a=True) @ Kt ell = ell + ell_t return ell, m, P
def pkf(lgssm, observations, return_loglikelihood=False, max_parallel=10000): with tf.name_scope("parallel_filter"): P0, Fs, Qs, H, R = lgssm dtype = P0.dtype m0 = tf.zeros(tf.shape(P0)[0], dtype=dtype) max_num_levels = math.ceil(math.log2(max_parallel)) initial_elements = make_associative_filtering_elements( m0, P0, Fs, Qs, H, R, observations) final_elements = scan_associative(filtering_operator, initial_elements, max_num_levels=max_num_levels) if return_loglikelihood: filtered_means = tf.concat( [tf.expand_dims(m0, 0), final_elements[1][:-1]], axis=0) filtered_cov = tf.concat( [tf.expand_dims(P0, 0), final_elements[2][:-1]], axis=0) predicted_means = mv(Fs, filtered_means) predicted_covs = mm(Fs, mm(filtered_cov, Fs, transpose_b=True)) + Qs obs_means = mv(H, predicted_means) obs_covs = mm(H, mm(predicted_covs, H, transpose_b=True)) + tf.expand_dims(R, 0) dists = MultivariateNormalTriL(obs_means, tf.linalg.cholesky(obs_covs)) # TODO: some logic could be added here to avoid handling the covariance of non-nan models, but no impact for GPs logprobs = dists.log_prob(observations) logprobs_without_nans = tf.where(tf.math.is_nan(logprobs), tf.zeros_like(logprobs), logprobs) total_log_prob = tf.reduce_sum(logprobs_without_nans) return final_elements[1], final_elements[2], total_log_prob return final_elements[1], final_elements[2]