示例#1
0
def test_discrete_mean_q_function(feature_size, action_size, batch_size,
                                  gamma):
    encoder = DummyEncoder(feature_size)
    q_func = DiscreteMeanQFunction(encoder, action_size)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    y = q_func(x)
    assert y.shape == (batch_size, action_size)

    # check compute_target
    action = torch.randint(high=action_size, size=(batch_size, ))
    target = q_func.compute_target(x, action)
    assert target.shape == (batch_size, 1)
    assert torch.allclose(y[torch.arange(batch_size), action], target.view(-1))

    # check compute_target with action=None
    targets = q_func.compute_target(x)
    assert targets.shape == (batch_size, action_size)

    # check td calculation
    q_tp1 = np.random.random((batch_size, 1))
    rew_tp1 = np.random.random((batch_size, 1))
    ter_tp1 = np.random.randint(2, size=(batch_size, 1))
    target = rew_tp1 + gamma * q_tp1 * (1 - ter_tp1)

    obs_t = torch.rand(batch_size, feature_size)
    act_t = np.random.randint(action_size, size=(batch_size, 1))
    q_t = filter_by_action(q_func(obs_t).detach().numpy(), act_t, action_size)
    ref_loss = ref_huber_loss(q_t.reshape((-1, 1)), target)

    act_t = torch.tensor(act_t, dtype=torch.int64)
    rew_tp1 = torch.tensor(rew_tp1, dtype=torch.float32)
    q_tp1 = torch.tensor(q_tp1, dtype=torch.float32)
    ter_tp1 = torch.tensor(ter_tp1, dtype=torch.float32)
    loss = q_func.compute_error(obs_t,
                                act_t,
                                rew_tp1,
                                q_tp1,
                                ter_tp1,
                                gamma=gamma)

    assert np.allclose(loss.detach().numpy(), ref_loss)

    # check layer connection
    check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1))
示例#2
0
def test_ensemble_discrete_q_function(
    feature_size,
    action_size,
    batch_size,
    gamma,
    ensemble_size,
    q_func_type,
    n_quantiles,
    embed_size,
    bootstrap,
):
    q_funcs = []
    for _ in range(ensemble_size):
        encoder = DummyEncoder(feature_size)
        if q_func_type == "mean":
            q_func = DiscreteMeanQFunction(encoder, action_size)
        elif q_func_type == "qr":
            q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles)
        elif q_func_type == "iqn":
            q_func = DiscreteIQNQFunction(encoder, action_size, n_quantiles,
                                          n_quantiles, embed_size)
        elif q_func_type == "fqf":
            q_func = DiscreteFQFQFunction(encoder, action_size, n_quantiles,
                                          embed_size)
        q_funcs.append(q_func)
    q_func = EnsembleDiscreteQFunction(q_funcs, bootstrap)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    values = q_func(x, "none")
    assert values.shape == (ensemble_size, batch_size, action_size)

    # check compute_target
    action = torch.randint(high=action_size, size=(batch_size, ))
    target = q_func.compute_target(x, action)
    if q_func_type == "mean":
        assert target.shape == (batch_size, 1)
        min_values = values.min(dim=0).values
        assert torch.allclose(min_values[torch.arange(batch_size), action],
                              target.view(-1))
    else:
        assert target.shape == (batch_size, n_quantiles)

    # check compute_target with action=None
    targets = q_func.compute_target(x)
    if q_func_type == "mean":
        assert targets.shape == (batch_size, action_size)
    else:
        assert targets.shape == (batch_size, action_size, n_quantiles)

    # check reductions
    if q_func_type != "iqn":
        assert torch.allclose(values.min(dim=0).values, q_func(x, "min"))
        assert torch.allclose(values.max(dim=0).values, q_func(x, "max"))
        assert torch.allclose(values.mean(dim=0), q_func(x, "mean"))

    # check td computation
    obs_t = torch.rand(batch_size, feature_size)
    act_t = torch.randint(0,
                          action_size,
                          size=(batch_size, 1),
                          dtype=torch.int64)
    rew_tp1 = torch.rand(batch_size, 1)
    if q_func_type == "mean":
        q_tp1 = torch.rand(batch_size, 1)
    else:
        q_tp1 = torch.rand(batch_size, n_quantiles)
    ref_td_sum = 0.0
    for i in range(ensemble_size):
        f = q_func.q_funcs[i]
        ref_td_sum += f.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma)
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma)
    if bootstrap:
        assert not torch.allclose(ref_td_sum, loss)
    elif q_func_type != "iqn":
        assert torch.allclose(ref_td_sum, loss)

    # check layer connection
    check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1))