def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch['next_state'] batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] with chainer.using_config('train', False): if self.recurrent: target_next_qout, _ = self.target_model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') next_qout, _ = self.model.n_step_forward( batch_next_state, exp_batch['next_recurrent_state'], output_mode='concat') else: target_next_qout = self.target_model(batch_next_state) next_qout = self.model(batch_next_state) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.size # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.evaluate_actions_as_distribution( next_qout.greedy_actions.array).array assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * self.xp.expand_dims(exp_batch['discount'], 1) * z_values[None]) return _apply_categorical_projection(Tz, next_q_max, z_values)
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch['next_state'] batch_rewards = exp_batch['reward'] batch_terminal = exp_batch['is_state_terminal'] with chainer.using_config('train', False), state_kept(self.q_function): next_qout = self.q_function(batch_next_state) target_next_qout = self.target_q_function(batch_next_state) next_q_max = target_next_qout.evaluate_actions( next_qout.greedy_actions) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.size # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.max_as_distribution.array assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * self.xp.expand_dims(exp_batch['discount'], 1) * z_values[None]) return _apply_categorical_projection(Tz, next_q_max, z_values)
def _test(self, xp): v_min, v_max = (-1, 1) n_atoms = 3 z = xp.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) y = xp.asarray([ [-1, 0, 1], [1, -1, 0], [1, 1, 1], [-1, -1, -1], [0, 0, 0], [-0.5, 0, 1], [-0.5, 0, 0.5], ], dtype=np.float32) y_probs = xp.asarray([ [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], ], dtype=np.float32) proj_gt = xp.asarray([ [0.5, 0.2, 0.3], [0.2, 0.3, 0.5], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.25, 0.45, 0.3], [0.25, 0.6, 0.15], ], dtype=np.float32) proj = categorical_dqn._apply_categorical_projection(y, y_probs, z) xp.testing.assert_allclose(proj, proj_gt, atol=1e-5)
def _test(self, xp): v_min, v_max = self.v_range z = xp.linspace(v_min, v_max, num=self.n_atoms, dtype=np.float32) y = xp.random.normal(size=(self.batch_size, self.n_atoms)).astype(np.float32) y_probs = xp.asarray( np.random.dirichlet(alpha=np.ones(self.n_atoms), size=self.batch_size).astype(np.float32)) # Naive implementation as ground truths proj_gt = _apply_categorical_projection_naive(y, y_probs, z) # Projected probabilities should sum to one xp.testing.assert_allclose(proj_gt.sum(axis=1), xp.ones(self.batch_size, dtype=np.float32), atol=1e-5) # Batch implementation to test proj = categorical_dqn._apply_categorical_projection(y, y_probs, z) # Projected probabilities should sum to one xp.testing.assert_allclose(proj.sum(axis=1), xp.ones(self.batch_size, dtype=np.float32), atol=1e-5) # Both should be equal xp.testing.assert_allclose(proj, proj_gt, atol=1e-5)
def _test_inexact_delta_z(self, xp): v_min, v_max = (-1, 1) n_atoms = 4 # delta_z=2/3=0.66666... is not exact z = xp.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) y = xp.asarray([ [-1, -1, 1, 1], [-1, 0, 1, 1], ], dtype=np.float32) y_probs = xp.asarray([ [0.5, 0.1, 0.1, 0.3], [0.5, 0.2, 0.0, 0.3], ], dtype=np.float32) proj_gt = xp.asarray([ [0.6, 0.0, 0.0, 0.4], [0.5, 0.1, 0.1, 0.3], ], dtype=np.float32) proj = categorical_dqn._apply_categorical_projection(y, y_probs, z) xp.testing.assert_allclose(proj, proj_gt, atol=1e-5)