def testTrajectoryCyclicIntegerCounter(self): num_states = 3 def dynamics(t, x, u): return (x + u) % num_states T = 10 U = jnp.ones((T, 1)) X = control.trajectory(dynamics, U, jnp.zeros(1)) expected = jnp.arange(T + 1) % num_states expected = jnp.reshape(expected, (T + 1, 1)) np.testing.assert_allclose(X, expected) U = 2 * jnp.ones((T, 1)) X = control.trajectory(dynamics, U, jnp.zeros(1)) expected = jnp.cumsum(2 * jnp.ones(T)) % num_states expected = jnp.concatenate((jnp.zeros(1), expected)) expected = jnp.reshape(expected, (T + 1, 1)) np.testing.assert_allclose(X, expected)
def testTrajectoryCyclicIntegerCounter(self): num_states = 3 def dynamics(t, x, u): return (x + u) % num_states T = 10 U = np.ones((T, 1)) X = control.trajectory(dynamics, U, np.zeros(1)) expected = np.arange(T + 1) % num_states expected = np.reshape(expected, (T + 1, 1)) self.assertAllClose(X, expected, check_dtypes=False) U = 2 * np.ones((T, 1)) X = control.trajectory(dynamics, U, np.zeros(1)) expected = np.cumsum(2 * np.ones(T)) % num_states expected = np.concatenate((np.zeros(1), expected)) expected = np.reshape(expected, (T + 1, 1)) self.assertAllClose(X, expected, check_dtypes=False)
def testTrajectoryTimeVarying(self): T = 6 def clip(x, lo, hi): return jnp.minimum(hi, jnp.maximum(lo, x)) def dynamics(t, x, u): # computes `(x + u) if t > T else 0` return (x + u) * clip(t - T, 0, 1) U = jnp.ones((2 * T, 1)) X = control.trajectory(dynamics, U, jnp.zeros(1)) expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T))) expected = jnp.reshape(expected, (2 * T + 1, 1)) np.testing.assert_allclose(X, expected)
def testTrajectoryTimeVarying(self): T = 6 def clip(x, lo, hi): return np.minimum(hi, np.maximum(lo, x)) def dynamics(t, x, u): # computes `(x + u) if t > T else 0` return (x + u) * clip(t - T, 0, 1) U = np.ones((2 * T, 1)) X = control.trajectory(dynamics, U, np.zeros(1)) expected = np.concatenate((np.zeros(T + 1), np.arange(T))) expected = np.reshape(expected, (2 * T + 1, 1)) self.assertAllClose(X, expected, check_dtypes=True)
def testTrajectoryCyclicIndicator(self): num_states = 3 def position(x): '''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1''' x = jnp.cumsum(x) x = 1 - x return jnp.sum(x, dtype=jnp.int32) def dynamics(t, x, u): '''moves the next standard basis vector''' idx = (position(x) + u[0]) % num_states return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0] T = 8 U = jnp.ones((T, 1), dtype=jnp.int32) X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0]) expected = jnp.vstack((jnp.eye(num_states),) * 3) self.assertAllClose(X, expected, check_dtypes=True)