def test_discrete_qr_q_function(feature_size, action_size, n_quantiles, batch_size, gamma): encoder = DummyEncoder(feature_size) q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) # check output shape x = torch.rand(batch_size, feature_size) y = q_func(x) assert y.shape == (batch_size, action_size) # check taus taus = q_func._make_taus(encoder(x)) step = 1 / n_quantiles for i in range(n_quantiles): assert np.allclose(taus[0][i].numpy(), i * step + step / 2.0) # check compute_target action = torch.randint(high=action_size, size=(batch_size, )) target = q_func.compute_target(x, action) assert target.shape == (batch_size, n_quantiles) # check compute_target with action=None targets = q_func.compute_target(x) assert targets.shape == (batch_size, action_size, n_quantiles) # check quantile huber loss obs_t = torch.rand(batch_size, feature_size) act_t = torch.randint(action_size, size=(batch_size, )) rew_tp1 = torch.rand(batch_size, 1) q_tp1 = torch.rand(batch_size, n_quantiles) ter_tp1 = torch.randint(2, size=(batch_size, 1)) # shape check loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, reduction="none") assert loss.shape == (batch_size, 1) # mean loss loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1) target = rew_tp1.numpy() + gamma * q_tp1.numpy() * (1 - ter_tp1.numpy()) y = _pick_value_by_action(q_func._compute_quantiles(encoder(obs_t), taus), act_t) reshaped_target = np.reshape(target, (batch_size, -1, 1)) reshaped_y = np.reshape(y.detach().numpy(), (batch_size, 1, -1)) reshaped_taus = np.reshape(taus, (1, 1, -1)) ref_loss = ref_quantile_huber_loss(reshaped_y, reshaped_target, reshaped_taus, n_quantiles) assert np.allclose(loss.cpu().detach(), ref_loss.mean()) # check layer connection check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1))
def test_discrete_qr_q_function(feature_size, action_size, n_quantiles, batch_size, gamma): encoder = DummyEncoder(feature_size) q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) # check output shape x = torch.rand(batch_size, feature_size) y = q_func(x) assert y.shape == (batch_size, action_size) action = torch.randint(high=action_size, size=(batch_size, )) target = q_func.compute_target(x, action) quantiles = q_func(x, as_quantiles=True) assert target.shape == (batch_size, n_quantiles) assert (quantiles[torch.arange(batch_size), action] == target).all() # check quantile huber loss obs_t = torch.rand(batch_size, feature_size) act_t = torch.randint(action_size, size=(batch_size, )) rew_tp1 = torch.rand(batch_size, 1) q_tp1 = torch.rand(batch_size, n_quantiles) # shape check loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, reduction='none') assert loss.shape == (batch_size, 1) # mean loss loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1) target = (rew_tp1.numpy() + gamma * q_tp1.numpy()) y = _pick_value_by_action(q_func(obs_t, as_quantiles=True), act_t) taus = _make_taus_prime(n_quantiles, 'cpu:0').numpy() reshaped_target = np.reshape(target, (batch_size, -1, 1)) reshaped_y = np.reshape(y.detach().numpy(), (batch_size, 1, -1)) reshaped_taus = np.reshape(taus, (1, 1, -1)) ref_loss = ref_quantile_huber_loss(reshaped_y, reshaped_target, reshaped_taus, n_quantiles) assert np.allclose(loss.cpu().detach(), ref_loss.mean()) # check layer connection check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1))
def test_ensemble_discrete_q_function(feature_size, action_size, batch_size, gamma, ensemble_size, q_func_type, n_quantiles, embed_size, bootstrap): q_funcs = [] for _ in range(ensemble_size): encoder = DummyEncoder(feature_size) if q_func_type == 'mean': q_func = DiscreteQFunction(encoder, action_size) elif q_func_type == 'qr': q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) elif q_func_type == 'iqn': q_func = DiscreteIQNQFunction(encoder, action_size, n_quantiles, embed_size) elif q_func_type == 'fqf': q_func = DiscreteFQFQFunction(encoder, action_size, n_quantiles, embed_size) q_funcs.append(q_func) q_func = EnsembleDiscreteQFunction(q_funcs, bootstrap) # check output shape x = torch.rand(batch_size, feature_size) values = q_func(x, 'none') assert values.shape == (ensemble_size, batch_size, action_size) # check compute_target action = torch.randint(high=action_size, size=(batch_size, )) target = q_func.compute_target(x, action) if q_func_type == 'mean': assert target.shape == (batch_size, 1) min_values = values.min(dim=0).values assert torch.allclose(min_values[torch.arange(batch_size), action], target.view(-1)) else: assert target.shape == (batch_size, n_quantiles) # check reductions if q_func_type != 'iqn': assert torch.allclose(values.min(dim=0).values, q_func(x, 'min')) assert torch.allclose(values.max(dim=0).values, q_func(x, 'max')) assert torch.allclose(values.mean(dim=0), q_func(x, 'mean')) # check td computation obs_t = torch.rand(batch_size, feature_size) act_t = torch.randint(0, action_size, size=(batch_size, 1), dtype=torch.int64) rew_tp1 = torch.rand(batch_size, 1) if q_func_type == 'mean': q_tp1 = torch.rand(batch_size, 1) else: q_tp1 = torch.rand(batch_size, n_quantiles) ref_td_sum = 0.0 for i in range(ensemble_size): f = q_func.q_funcs[i] ref_td_sum += f.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma) loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, gamma) if bootstrap: assert not torch.allclose(ref_td_sum, loss) elif q_func_type != 'iqn': assert torch.allclose(ref_td_sum, loss) # check layer connection check_parameter_updates(q_func, (obs_t, act_t, rew_tp1, q_tp1))
def test_ensemble_discrete_q_function( feature_size, action_size, batch_size, gamma, ensemble_size, q_func_factory, n_quantiles, embed_size, bootstrap, use_independent_target, ): q_funcs = [] for _ in range(ensemble_size): encoder = DummyEncoder(feature_size) if q_func_factory == "mean": q_func = DiscreteMeanQFunction(encoder, action_size) elif q_func_factory == "qr": q_func = DiscreteQRQFunction(encoder, action_size, n_quantiles) elif q_func_factory == "iqn": q_func = DiscreteIQNQFunction(encoder, action_size, n_quantiles, n_quantiles, embed_size) elif q_func_factory == "fqf": q_func = DiscreteFQFQFunction(encoder, action_size, n_quantiles, embed_size) q_funcs.append(q_func) q_func = EnsembleDiscreteQFunction(q_funcs, bootstrap) # check output shape x = torch.rand(batch_size, feature_size) values = q_func(x, "none") assert values.shape == (ensemble_size, batch_size, action_size) # check compute_target action = torch.randint(high=action_size, size=(batch_size, )) target = q_func.compute_target(x, action) if q_func_factory == "mean": assert target.shape == (batch_size, 1) min_values = values.min(dim=0).values assert torch.allclose(min_values[torch.arange(batch_size), action], target.view(-1)) else: assert target.shape == (batch_size, n_quantiles) # check compute_target with action=None targets = q_func.compute_target(x) if q_func_factory == "mean": assert targets.shape == (batch_size, action_size) else: assert targets.shape == (batch_size, action_size, n_quantiles) # check reductions if q_func_factory != "iqn": assert torch.allclose(values.min(dim=0).values, q_func(x, "min")) assert torch.allclose(values.max(dim=0).values, q_func(x, "max")) assert torch.allclose(values.mean(dim=0), q_func(x, "mean")) # check td computation obs_t = torch.rand(batch_size, feature_size) act_t = torch.randint(0, action_size, size=(batch_size, 1), dtype=torch.int64) rew_tp1 = torch.rand(batch_size, 1) ter_tp1 = torch.randint(2, size=(batch_size, 1)) if q_func_factory == "mean": if use_independent_target: q_tp1 = torch.rand(ensemble_size, batch_size, 1) else: q_tp1 = torch.rand(batch_size, 1) else: if use_independent_target: q_tp1 = torch.rand(ensemble_size, batch_size, n_quantiles) else: q_tp1 = torch.rand(batch_size, n_quantiles) ref_td_sum = 0.0 for i in range(ensemble_size): f = q_func.q_funcs[i] if use_independent_target: target = q_tp1[i] else: target = q_tp1 ref_td_sum += f.compute_error(obs_t, act_t, rew_tp1, target, ter_tp1, gamma) loss = q_func.compute_error(obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma, use_independent_target) if bootstrap: assert not torch.allclose(ref_td_sum, loss) elif q_func_factory != "iqn": assert torch.allclose(ref_td_sum, loss) # with mask mask = torch.rand(ensemble_size, batch_size, 1) loss = q_func.compute_error( obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma, use_independent_target, mask, ) # check layer connection check_parameter_updates( q_func, (obs_t, act_t, rew_tp1, q_tp1, ter_tp1, gamma, use_independent_target), )