def generate_lqr_demos(xs, x_goal, lqr): lqr = tree_util.tree_map(lambda x: onp.array(x).astype(np.float64), lqr) kmat = solve_discrete_lqr(lqr) policy = vectorize.vectorize("(i),()->(j)")(util.policy(kmat, x_goal)) return Demos(xs=xs, us=policy(xs, np.zeros((), dtype=np.int32)))
def batch_loss(params, data): pmat, lqr = get_lqr(params) kmat = discrete.gain_matrix(pmat, lqr) policy = vectorize.vectorize("(i),()->(j)")(util.policy(kmat, x_goal)) us = policy(data.xs, np.zeros((), dtype=np.int32)) return loss(data.us, us)
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for Vectorize library.""" from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp from jax import test_util as jtu from jax.experimental.vectorize import vectorize from jax.config import config config.parse_flags_with_absl() matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot) matvec = vectorize('(n,m),(m)->(n)')(jnp.dot) vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot) vecvec = vectorize('(m),(m)->()')(jnp.dot) @vectorize('(n)->()') def magnitude(x): return jnp.dot(x, x) mean = vectorize('(n)->()')(jnp.mean) @vectorize('()->(n)') def stack_plus_minus(x):