Exemplo n.º 1
0
    def test_max_bounded(self):
        n_batch = 20
        ndim_action = 3
        mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
        mat = np.broadcast_to(
            np.eye(ndim_action, dtype=np.float32)[None],
            (n_batch, ndim_action, ndim_action),
        )
        v = np.random.randn(n_batch).astype(np.float32)
        min_action, max_action = -1.3, 1.3
        q_out = action_value.QuadraticActionValue(torch.tensor(mu),
                                                  torch.tensor(mat),
                                                  torch.tensor(v), min_action,
                                                  max_action)

        v_out = q_out.max
        self.assertIsInstance(v_out, torch.Tensor)
        v_out = v_out.detach().numpy()

        # If mu[i] is an valid action, v_out[i] should be v[i]
        mu_is_allowed = np.all((min_action < mu) * (mu < max_action), axis=1)
        np.testing.assert_almost_equal(v_out[mu_is_allowed], v[mu_is_allowed])

        # Otherwise, v_out[i] should be less than v[i]
        mu_is_not_allowed = ~np.all(
            (min_action - 1e-2 < mu) * (mu < max_action + 1e-2), axis=1)
        np.testing.assert_array_less(v_out[mu_is_not_allowed],
                                     v[mu_is_not_allowed])
Exemplo n.º 2
0
    def test_max_unbounded(self):
        n_batch = 7
        ndim_action = 3
        mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
        mat = np.broadcast_to(
            np.eye(ndim_action, dtype=np.float32)[None],
            (n_batch, ndim_action, ndim_action),
        )
        v = np.random.randn(n_batch).astype(np.float32)
        q_out = action_value.QuadraticActionValue(torch.tensor(mu),
                                                  torch.tensor(mat),
                                                  torch.tensor(v))

        v_out = q_out.max
        self.assertIsInstance(v_out, torch.Tensor)
        v_out = v_out.detach().numpy()

        np.testing.assert_almost_equal(v_out, v)
Exemplo n.º 3
0
 def test_getitem(self):
     n_batch = 7
     ndim_action = 3
     mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
     mat = np.broadcast_to(
         np.eye(ndim_action, dtype=np.float32)[None],
         (n_batch, ndim_action, ndim_action),
     )
     v = np.random.randn(n_batch).astype(np.float32)
     min_action, max_action = -1, 1
     qout = action_value.QuadraticActionValue(
         torch.tensor(mu),
         torch.tensor(mat),
         torch.tensor(v),
         min_action,
         max_action,
     )
     sliced = qout[:3]
     torch_assert_allclose(sliced.mu, mu[:3])
     torch_assert_allclose(sliced.mat, mat[:3])
     torch_assert_allclose(sliced.v, v[:3])
     torch_assert_allclose(sliced.min_action[0], min_action)
     torch_assert_allclose(sliced.max_action[0], max_action)