Ejemplo n.º 1
0
def test_double_dqn_loss(setup):
    algo, env, buff, _, batch_size = setup

    algo._double_q = True
    trainer = Trainer(snapshot_config)
    trainer.setup(algo, env)

    paths = trainer.obtain_episodes(0, batch_size=batch_size)
    buff.add_episode_batch(paths)
    timesteps = buff.sample_timesteps(algo._buffer_batch_size)
    timesteps_copy = copy.deepcopy(timesteps)

    observations = as_torch(timesteps.observations)
    rewards = as_torch(timesteps.rewards).reshape(-1, 1)
    actions = as_torch(timesteps.actions)
    next_observations = as_torch(timesteps.next_observations)
    terminals = as_torch(timesteps.terminals).reshape(-1, 1)

    next_inputs = next_observations
    inputs = observations
    with torch.no_grad():
        # double Q loss
        selected_actions = torch.argmax(algo._qf(next_inputs), axis=1)
        # use target qf to get Q values for those actions
        selected_actions = selected_actions.long().unsqueeze(1)
        best_qvals = torch.gather(algo._target_qf(next_inputs),
                                  dim=1,
                                  index=selected_actions)

    rewards_clipped = rewards
    y_target = (rewards_clipped +
                (1.0 - terminals) * algo._discount * best_qvals)
    y_target = y_target.squeeze(1)

    # optimize qf
    qvals = algo._qf(inputs)
    selected_qs = torch.sum(qvals * actions, axis=1)
    qval_loss = F.smooth_l1_loss(selected_qs, y_target)

    algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf(
        timesteps_copy)
    env.close()

    assert (qval_loss.detach() == algo_loss).all()
    assert (y_target == algo_targets).all()
    assert (selected_qs == algo_selected_qs).all()
Ejemplo n.º 2
0
def test_dqn_loss(setup):
    algo, env, buff, _, batch_size = setup

    trainer = Trainer(snapshot_config)
    trainer.setup(algo, env, sampler_cls=LocalSampler)

    paths = trainer.obtain_episodes(0, batch_size=batch_size)
    buff.add_episode_batch(paths)
    timesteps = buff.sample_timesteps(algo._buffer_batch_size)
    timesteps_copy = copy.deepcopy(timesteps)

    observations = np_to_torch(timesteps.observations)
    rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
    actions = np_to_torch(timesteps.actions)
    next_observations = np_to_torch(timesteps.next_observations)
    terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)

    next_inputs = next_observations
    inputs = observations
    with torch.no_grad():
        target_qvals = algo._target_qf(next_inputs)
        best_qvals, _ = torch.max(target_qvals, 1)
        best_qvals = best_qvals.unsqueeze(1)

    rewards_clipped = rewards
    y_target = (rewards_clipped +
                (1.0 - terminals) * algo._discount * best_qvals)
    y_target = y_target.squeeze(1)

    # optimize qf
    qvals = algo._qf(inputs)
    selected_qs = torch.sum(qvals * actions, axis=1)
    qval_loss = F.smooth_l1_loss(selected_qs, y_target)

    algo_loss, algo_targets, algo_selected_qs = algo._optimize_qf(
        timesteps_copy)
    env.close()

    assert (qval_loss.detach() == algo_loss).all()
    assert (y_target == algo_targets).all()
    assert (selected_qs == algo_selected_qs).all()