def test_max_bounded(self): n_batch = 20 ndim_action = 3 mu = np.random.randn(n_batch, ndim_action).astype(np.float32) mat = np.broadcast_to( np.eye(ndim_action, dtype=np.float32)[None], (n_batch, ndim_action, ndim_action)) v = np.random.randn(n_batch).astype(np.float32) min_action, max_action = -1.3, 1.3 q_out = action_value.QuadraticActionValue(chainer.Variable(mu), chainer.Variable(mat), chainer.Variable(v), min_action, max_action) v_out = q_out.max self.assertIsInstance(v_out, chainer.Variable) v_out = v_out.array # If mu[i] is an valid action, v_out[i] should be v[i] mu_is_allowed = np.all((min_action < mu) * (mu < max_action), axis=1) np.testing.assert_almost_equal(v_out[mu_is_allowed], v[mu_is_allowed]) # Otherwise, v_out[i] should be less than v[i] mu_is_not_allowed = ~np.all( (min_action - 1e-2 < mu) * (mu < max_action + 1e-2), axis=1) np.testing.assert_array_less(v_out[mu_is_not_allowed], v[mu_is_not_allowed])
def test_max_unbounded(self): n_batch = 7 ndim_action = 3 mu = np.random.randn(n_batch, ndim_action).astype(np.float32) mat = np.broadcast_to( np.eye(ndim_action, dtype=np.float32)[None], (n_batch, ndim_action, ndim_action)) v = np.random.randn(n_batch).astype(np.float32) q_out = action_value.QuadraticActionValue(chainer.Variable(mu), chainer.Variable(mat), chainer.Variable(v)) v_out = q_out.max self.assertIsInstance(v_out, chainer.Variable) v_out = v_out.array np.testing.assert_almost_equal(v_out, v)
def test_getitem(self): n_batch = 7 ndim_action = 3 mu = np.random.randn(n_batch, ndim_action).astype(np.float32) mat = np.broadcast_to( np.eye(ndim_action, dtype=np.float32)[None], (n_batch, ndim_action, ndim_action)) v = np.random.randn(n_batch).astype(np.float32) min_action, max_action = -1, 1 qout = action_value.QuadraticActionValue( chainer.Variable(mu), chainer.Variable(mat), chainer.Variable(v), min_action, max_action, ) sliced = qout[:3] np.testing.assert_equal(sliced.mu.array, mu[:3]) np.testing.assert_equal(sliced.mat.array, mat[:3]) np.testing.assert_equal(sliced.v.array, v[:3]) np.testing.assert_equal(sliced.min_action, min_action) np.testing.assert_equal(sliced.max_action, max_action)