def test_double_dqn_loss(setup): algo, env, buff, _, batch_size = setup algo._double_q = True trainer = Trainer(snapshot_config) trainer.setup(algo, env) paths = trainer.obtain_episodes(0, batch_size=batch_size) buff.add_episode_batch(paths) timesteps = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) observations = as_torch(timesteps.observations) rewards = as_torch(timesteps.rewards).reshape(-1, 1) actions = as_torch(timesteps.actions) next_observations = as_torch(timesteps.next_observations) terminals = as_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations with torch.no_grad(): # double Q loss selected_actions = torch.argmax(algo._qf(next_inputs), axis=1) # use target qf to get Q values for those actions selected_actions = selected_actions.long().unsqueeze(1) best_qvals = torch.gather(algo._target_qf(next_inputs), dim=1, index=selected_actions) rewards_clipped = rewards y_target = (rewards_clipped + (1.0 - terminals) * algo._discount * best_qvals) y_target = y_target.squeeze(1) # optimize qf qvals = algo._qf(inputs) selected_qs = torch.sum(qvals * actions, axis=1) qval_loss = F.smooth_l1_loss(selected_qs, y_target) algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf( timesteps_copy) env.close() assert (qval_loss.detach() == algo_loss).all() assert (y_target == algo_targets).all() assert (selected_qs == algo_selected_qs).all()
def test_dqn_loss(setup): algo, env, buff, _, batch_size = setup trainer = Trainer(snapshot_config) trainer.setup(algo, env, sampler_cls=LocalSampler) paths = trainer.obtain_episodes(0, batch_size=batch_size) buff.add_episode_batch(paths) timesteps = buff.sample_timesteps(algo._buffer_batch_size) timesteps_copy = copy.deepcopy(timesteps) observations = np_to_torch(timesteps.observations) rewards = np_to_torch(timesteps.rewards).reshape(-1, 1) actions = np_to_torch(timesteps.actions) next_observations = np_to_torch(timesteps.next_observations) terminals = np_to_torch(timesteps.terminals).reshape(-1, 1) next_inputs = next_observations inputs = observations with torch.no_grad(): target_qvals = algo._target_qf(next_inputs) best_qvals, _ = torch.max(target_qvals, 1) best_qvals = best_qvals.unsqueeze(1) rewards_clipped = rewards y_target = (rewards_clipped + (1.0 - terminals) * algo._discount * best_qvals) y_target = y_target.squeeze(1) # optimize qf qvals = algo._qf(inputs) selected_qs = torch.sum(qvals * actions, axis=1) qval_loss = F.smooth_l1_loss(selected_qs, y_target) algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf( timesteps_copy) env.close() assert (qval_loss.detach() == algo_loss).all() assert (y_target == algo_targets).all() assert (selected_qs == algo_selected_qs).all()