def _test_inexact_delta_z(self, device): v_min, v_max = (-1, 1) n_atoms = 4 # delta_z=2/3=0.66666... is not exact z = np.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) y = np.asarray([ [-1, -1, 1, 1], [-1, 0, 1, 1], ], dtype=np.float32) y_probs = np.asarray([ [0.5, 0.1, 0.1, 0.3], [0.5, 0.2, 0.0, 0.3], ], dtype=np.float32) proj_gt = np.asarray([ [0.6, 0.0, 0.0, 0.4], [0.5, 0.1, 0.1, 0.3], ], dtype=np.float32) proj = (categorical_dqn._apply_categorical_projection( torch.as_tensor(y, device=device), torch.as_tensor(y_probs, device=device), torch.as_tensor(z, device=device), ).detach().cpu().numpy()) np.testing.assert_allclose(proj, proj_gt, atol=1e-5)
def _test(self, device): v_min, v_max = self.v_range z = np.linspace(v_min, v_max, num=self.n_atoms, dtype=np.float32) y = np.random.normal(size=(self.batch_size, self.n_atoms)).astype(np.float32) y_probs = np.asarray( np.random.dirichlet(alpha=np.ones(self.n_atoms), size=self.batch_size).astype(np.float32)) # Naive implementation as ground truths proj_gt = _apply_categorical_projection_naive(y, y_probs, z) # Projected probabilities should sum to one np.testing.assert_allclose(proj_gt.sum(axis=1), np.ones(self.batch_size, dtype=np.float32), atol=1e-5) # Batch implementation to test proj = (categorical_dqn._apply_categorical_projection( torch.as_tensor(y, device=device), torch.as_tensor(y_probs, device=device), torch.as_tensor(z, device=device), ).detach().cpu().numpy()) # Projected probabilities should sum to one np.testing.assert_allclose(proj.sum(axis=1), np.ones(self.batch_size, dtype=np.float32), atol=1e-5) # Both should be equal np.testing.assert_allclose(proj, proj_gt, atol=1e-5)
def _test(self, device): v_min, v_max = (-1, 1) n_atoms = 3 z = np.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) y = np.asarray( [ [-1, 0, 1], [1, -1, 0], [1, 1, 1], [-1, -1, -1], [0, 0, 0], [-0.5, 0, 1], [-0.5, 0, 0.5], ], dtype=np.float32, ) y_probs = np.asarray( [ [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], [0.5, 0.2, 0.3], ], dtype=np.float32, ) proj_gt = np.asarray( [ [0.5, 0.2, 0.3], [0.2, 0.3, 0.5], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.25, 0.45, 0.3], [0.25, 0.6, 0.15], ], dtype=np.float32, ) proj = (categorical_dqn._apply_categorical_projection( torch.as_tensor(y, device=device), torch.as_tensor(y_probs, device=device), torch.as_tensor(z, device=device), ).detach().cpu().numpy()) np.testing.assert_allclose(proj, proj_gt, atol=1e-5)
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch["next_state"] batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] with pfrl.utils.evaluating(self.target_model), pfrl.utils.evaluating( self.model): if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) next_qout, _ = pack_and_forward( self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_qout = self.model(batch_next_state) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.numel() # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.evaluate_actions_as_distribution( next_qout.greedy_actions.detach()) assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * exp_batch["discount"][..., None] * z_values[None]) # Tz = ( # batch_rewards.squeeze(dim=-1) # + (1.0 - batch_terminal.unsqueeze(dim=-1)) # * exp_batch["discount"].unsqueeze(dim=-1) # * z_values.unsqueeze(dim=0) # ) return _apply_categorical_projection(Tz, next_q_max, z_values)