Esempio n. 1
0
    def test_max_bounded(self):
        n_batch = 20
        ndim_action = 3
        mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
        mat = np.broadcast_to(
            np.eye(ndim_action, dtype=np.float32)[None],
            (n_batch, ndim_action, ndim_action))
        v = np.random.randn(n_batch).astype(np.float32)
        min_action, max_action = -1.3, 1.3
        q_out = action_value.QuadraticActionValue(chainer.Variable(mu),
                                                  chainer.Variable(mat),
                                                  chainer.Variable(v),
                                                  min_action, max_action)

        v_out = q_out.max
        self.assertIsInstance(v_out, chainer.Variable)
        v_out = v_out.array

        # If mu[i] is an valid action, v_out[i] should be v[i]
        mu_is_allowed = np.all((min_action < mu) * (mu < max_action), axis=1)
        np.testing.assert_almost_equal(v_out[mu_is_allowed], v[mu_is_allowed])

        # Otherwise, v_out[i] should be less than v[i]
        mu_is_not_allowed = ~np.all(
            (min_action - 1e-2 < mu) * (mu < max_action + 1e-2), axis=1)
        np.testing.assert_array_less(v_out[mu_is_not_allowed],
                                     v[mu_is_not_allowed])
Esempio n. 2
0
    def test_max_unbounded(self):
        n_batch = 7
        ndim_action = 3
        mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
        mat = np.broadcast_to(
            np.eye(ndim_action, dtype=np.float32)[None],
            (n_batch, ndim_action, ndim_action))
        v = np.random.randn(n_batch).astype(np.float32)
        q_out = action_value.QuadraticActionValue(chainer.Variable(mu),
                                                  chainer.Variable(mat),
                                                  chainer.Variable(v))

        v_out = q_out.max
        self.assertIsInstance(v_out, chainer.Variable)
        v_out = v_out.array

        np.testing.assert_almost_equal(v_out, v)
Esempio n. 3
0
 def test_getitem(self):
     n_batch = 7
     ndim_action = 3
     mu = np.random.randn(n_batch, ndim_action).astype(np.float32)
     mat = np.broadcast_to(
         np.eye(ndim_action, dtype=np.float32)[None],
         (n_batch, ndim_action, ndim_action))
     v = np.random.randn(n_batch).astype(np.float32)
     min_action, max_action = -1, 1
     qout = action_value.QuadraticActionValue(
         chainer.Variable(mu),
         chainer.Variable(mat),
         chainer.Variable(v),
         min_action,
         max_action,
     )
     sliced = qout[:3]
     np.testing.assert_equal(sliced.mu.array, mu[:3])
     np.testing.assert_equal(sliced.mat.array, mat[:3])
     np.testing.assert_equal(sliced.v.array, v[:3])
     np.testing.assert_equal(sliced.min_action, min_action)
     np.testing.assert_equal(sliced.max_action, max_action)