Exemple #1
0
    def testTrajectoryCyclicIntegerCounter(self):
        num_states = 3

        def dynamics(t, x, u):
            return (x + u) % num_states

        T = 10

        U = jnp.ones((T, 1))
        X = control.trajectory(dynamics, U, jnp.zeros(1))
        expected = jnp.arange(T + 1) % num_states
        expected = jnp.reshape(expected, (T + 1, 1))
        np.testing.assert_allclose(X, expected)

        U = 2 * jnp.ones((T, 1))
        X = control.trajectory(dynamics, U, jnp.zeros(1))
        expected = jnp.cumsum(2 * jnp.ones(T)) % num_states
        expected = jnp.concatenate((jnp.zeros(1), expected))
        expected = jnp.reshape(expected, (T + 1, 1))
        np.testing.assert_allclose(X, expected)
Exemple #2
0
    def testTrajectoryCyclicIntegerCounter(self):
        num_states = 3

        def dynamics(t, x, u):
            return (x + u) % num_states

        T = 10

        U = np.ones((T, 1))
        X = control.trajectory(dynamics, U, np.zeros(1))
        expected = np.arange(T + 1) % num_states
        expected = np.reshape(expected, (T + 1, 1))
        self.assertAllClose(X, expected, check_dtypes=False)

        U = 2 * np.ones((T, 1))
        X = control.trajectory(dynamics, U, np.zeros(1))
        expected = np.cumsum(2 * np.ones(T)) % num_states
        expected = np.concatenate((np.zeros(1), expected))
        expected = np.reshape(expected, (T + 1, 1))
        self.assertAllClose(X, expected, check_dtypes=False)
Exemple #3
0
    def testTrajectoryTimeVarying(self):
        T = 6

        def clip(x, lo, hi):
            return jnp.minimum(hi, jnp.maximum(lo, x))

        def dynamics(t, x, u):
            # computes `(x + u) if t > T else 0`
            return (x + u) * clip(t - T, 0, 1)

        U = jnp.ones((2 * T, 1))
        X = control.trajectory(dynamics, U, jnp.zeros(1))
        expected = jnp.concatenate((jnp.zeros(T + 1), jnp.arange(T)))
        expected = jnp.reshape(expected, (2 * T + 1, 1))
        np.testing.assert_allclose(X, expected)
Exemple #4
0
    def testTrajectoryTimeVarying(self):
        T = 6

        def clip(x, lo, hi):
            return np.minimum(hi, np.maximum(lo, x))

        def dynamics(t, x, u):
            # computes `(x + u) if t > T else 0`
            return (x + u) * clip(t - T, 0, 1)

        U = np.ones((2 * T, 1))
        X = control.trajectory(dynamics, U, np.zeros(1))
        expected = np.concatenate((np.zeros(T + 1), np.arange(T)))
        expected = np.reshape(expected, (2 * T + 1, 1))
        self.assertAllClose(X, expected, check_dtypes=True)
Exemple #5
0
  def testTrajectoryCyclicIndicator(self):
    num_states = 3

    def position(x):
      '''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
      x = jnp.cumsum(x)
      x = 1 - x
      return jnp.sum(x, dtype=jnp.int32)

    def dynamics(t, x, u):
      '''moves  the next standard basis vector'''
      idx = (position(x) + u[0]) % num_states
      return lax.dynamic_slice_in_dim(jnp.eye(num_states), idx, 1)[0]

    T = 8

    U = jnp.ones((T, 1), dtype=jnp.int32)
    X = control.trajectory(dynamics, U, jnp.eye(num_states, dtype=jnp.int32)[0])
    expected = jnp.vstack((jnp.eye(num_states),) * 3)
    self.assertAllClose(X, expected, check_dtypes=True)