Esempio n. 1
0
def test_ensemble_continuous_q_function(feature_size, action_size, batch_size,
                                        gamma, ensemble_size, q_func_type,
                                        n_quantiles, embed_size, bootstrap):
    q_funcs = []
    for _ in range(ensemble_size):
        encoder = DummyEncoder(feature_size, action_size, concat=True)
        if q_func_type == 'mean':
            q_func = ContinuousQFunction(encoder)
        elif q_func_type == 'qr':
            q_func = ContinuousQRQFunction(encoder, n_quantiles)
        elif q_func_type == 'iqn':
            q_func = ContinuousIQNQFunction(encoder, n_quantiles, embed_size)
        elif q_func_type == 'fqf':
            q_func = ContinuousFQFQFunction(encoder, n_quantiles, embed_size)
        q_funcs.append(q_func)

    q_func = EnsembleContinuousQFunction(q_funcs, bootstrap)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    action = torch.rand(batch_size, action_size)
    values = q_func(x, action, 'none')
    assert values.shape == (ensemble_size, batch_size, 1)

    # check compute_target
    target = q_func.compute_target(x, action)
    if q_func_type == 'mean':
        assert target.shape == (batch_size, 1)
        min_values = values.min(dim=0).values
        assert (target == min_values).all()
    else:
        assert target.shape == (batch_size, n_quantiles)

    # check reductions
    if q_func_type != 'iqn':
        assert torch.allclose(values.min(dim=0)[0], q_func(x, action, 'min'))
        assert torch.allclose(values.max(dim=0)[0], q_func(x, action, 'max'))
        assert torch.allclose(values.mean(dim=0), q_func(x, action, 'mean'))

    # check td computation
    obs_t = torch.rand(batch_size, feature_size)
    act_t = torch.rand(batch_size, action_size)
    rew_tp1 = torch.rand(batch_size, 1)
    if q_func_type == 'mean':
        q_tp1 = torch.rand(batch_size, 1)
    else:
        q_tp1 = torch.rand(batch_size, n_quantiles)
    ref_td_sum = 0.0
    for i in range(ensemble_size):
        f = q_func.q_funcs[i]
        ref_td_sum += f.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma)
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma)
    if bootstrap:
        assert not torch.allclose(ref_td_sum, loss)
    elif q_func_type != 'iqn':
        assert torch.allclose(ref_td_sum, loss)

    # check layer connection
    check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1))
Esempio n. 2
0
def test_continuous_qr_q_function(feature_size, action_size, n_quantiles,
                                  batch_size, gamma):
    encoder = DummyEncoder(feature_size, action_size, concat=True)
    q_func = ContinuousQRQFunction(encoder, n_quantiles)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    action = torch.rand(batch_size, action_size)
    y = q_func(x, action)
    assert y.shape == (batch_size, 1)

    target = q_func.compute_target(x, action)
    quantiles = q_func(x, action, as_quantiles=True)
    assert target.shape == (batch_size, n_quantiles)
    assert (target == quantiles).all()

    # check quantile huber loss
    obs_t = torch.rand(batch_size, feature_size)
    act_t = torch.rand(batch_size, action_size)
    rew_tp1 = torch.rand(batch_size, 1)
    q_tp1 = torch.rand(batch_size, n_quantiles)
    # check shape
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, reduction='none')
    assert loss.shape == (batch_size, 1)
    # mean loss
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1)

    target = rew_tp1.numpy() + gamma * q_tp1.numpy()
    y = q_func(obs_t, act_t, as_quantiles=True).detach().numpy()
    taus = _make_taus_prime(n_quantiles, 'cpu:0').numpy()

    reshaped_target = target.reshape((batch_size, -1, 1))
    reshaped_y = y.reshape((batch_size, 1, -1))
    reshaped_taus = taus.reshape((1, 1, -1))

    ref_loss = ref_quantile_huber_loss(reshaped_y, reshaped_target,
                                       reshaped_taus, n_quantiles)
    assert np.allclose(loss.cpu().detach(), ref_loss.mean())

    # check layer connection
    check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1))
Esempio n. 3
0
def test_continuous_qr_q_function(feature_size, action_size, n_quantiles,
                                  batch_size, gamma):
    encoder = DummyEncoder(feature_size, action_size, concat=True)
    q_func = ContinuousQRQFunction(encoder, n_quantiles)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    action = torch.rand(batch_size, action_size)
    y = q_func(x, action)
    assert y.shape == (batch_size, 1)

    # check taus
    taus = q_func._make_taus(encoder(x, action))
    step = 1 / n_quantiles
    for i in range(n_quantiles):
        assert np.allclose(taus[0][i].numpy(), i * step + step / 2.0)

    target = q_func.compute_target(x, action)
    assert target.shape == (batch_size, n_quantiles)

    # check quantile huber loss
    obs_t = torch.rand(batch_size, feature_size)
    act_t = torch.rand(batch_size, action_size)
    rew_tp1 = torch.rand(batch_size, 1)
    q_tp1 = torch.rand(batch_size, n_quantiles)
    ter_tp1 = torch.randint(2, size=(batch_size, 1))
    # check shape
    loss = q_func.compute_error(obs_t,
                                act_t,
                                rew_tp1,
                                q_tp1,
                                ter_tp1,
                                reduction="none")
    assert loss.shape == (batch_size, 1)
    # mean loss
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1)

    target = rew_tp1.numpy() + gamma * q_tp1.numpy() * (1 - ter_tp1.numpy())
    y = q_func._compute_quantiles(encoder(obs_t, act_t), taus).detach().numpy()

    reshaped_target = target.reshape((batch_size, -1, 1))
    reshaped_y = y.reshape((batch_size, 1, -1))
    reshaped_taus = taus.reshape((1, 1, -1))

    ref_loss = ref_quantile_huber_loss(reshaped_y, reshaped_target,
                                       reshaped_taus, n_quantiles)
    assert np.allclose(loss.cpu().detach(), ref_loss.mean())

    # check layer connection
    check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1))
Esempio n. 4
0
def test_ensemble_continuous_q_function(
    feature_size,
    action_size,
    batch_size,
    gamma,
    ensemble_size,
    q_func_factory,
    n_quantiles,
    embed_size,
    bootstrap,
    use_independent_target,
):
    q_funcs = []
    for _ in range(ensemble_size):
        encoder = DummyEncoder(feature_size, action_size, concat=True)
        if q_func_factory == "mean":
            q_func = ContinuousMeanQFunction(encoder)
        elif q_func_factory == "qr":
            q_func = ContinuousQRQFunction(encoder, n_quantiles)
        elif q_func_factory == "iqn":
            q_func = ContinuousIQNQFunction(encoder, n_quantiles, n_quantiles,
                                            embed_size)
        elif q_func_factory == "fqf":
            q_func = ContinuousFQFQFunction(encoder, n_quantiles, embed_size)
        q_funcs.append(q_func)

    q_func = EnsembleContinuousQFunction(q_funcs, bootstrap)

    # check output shape
    x = torch.rand(batch_size, feature_size)
    action = torch.rand(batch_size, action_size)
    values = q_func(x, action, "none")
    assert values.shape == (ensemble_size, batch_size, 1)

    # check compute_target
    target = q_func.compute_target(x, action)
    if q_func_factory == "mean":
        assert target.shape == (batch_size, 1)
        min_values = values.min(dim=0).values
        assert (target == min_values).all()
    else:
        assert target.shape == (batch_size, n_quantiles)

    # check reductions
    if q_func_factory != "iqn":
        assert torch.allclose(values.min(dim=0)[0], q_func(x, action, "min"))
        assert torch.allclose(values.max(dim=0)[0], q_func(x, action, "max"))
        assert torch.allclose(values.mean(dim=0), q_func(x, action, "mean"))

    # check td computation
    obs_t = torch.rand(batch_size, feature_size)
    act_t = torch.rand(batch_size, action_size)
    rew_tp1 = torch.rand(batch_size, 1)
    ter_tp1 = torch.randint(2, size=(batch_size, 1))
    if q_func_factory == "mean":
        if use_independent_target:
            q_tp1 = torch.rand(ensemble_size, batch_size, 1)
        else:
            q_tp1 = torch.rand(batch_size, 1)
    else:
        if use_independent_target:
            q_tp1 = torch.rand(ensemble_size, batch_size, n_quantiles)
        else:
            q_tp1 = torch.rand(batch_size, n_quantiles)
    ref_td_sum = 0.0
    for i in range(ensemble_size):
        if use_independent_target:
            target = q_tp1[i]
        else:
            target = q_tp1
        f = q_func.q_funcs[i]
        ref_td_sum += f.compute_error(obs_t, act_t, rew_tp1, target, ter_tp1,
                                      gamma)
    loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma,
                                use_independent_target)
    if bootstrap:
        assert not torch.allclose(ref_td_sum, loss)
    elif q_func_factory != "iqn":
        assert torch.allclose(ref_td_sum, loss)

    # with mask
    mask = torch.rand(ensemble_size, batch_size, 1)
    loss = q_func.compute_error(
        obs_t,
        act_t,
        rew_tp1,
        q_tp1,
        ter_tp1,
        gamma,
        use_independent_target,
        mask,
    )

    # check layer connection
    check_parameter_updates(
        q_func,
        (obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma, use_independent_target),
    )