예제 #1
0
def generate_lqr_demos(xs, x_goal, lqr):
    lqr = tree_util.tree_map(lambda x: onp.array(x).astype(np.float64), lqr)
    kmat = solve_discrete_lqr(lqr)

    policy = vectorize.vectorize("(i),()->(j)")(util.policy(kmat, x_goal))

    return Demos(xs=xs, us=policy(xs, np.zeros((), dtype=np.int32)))
예제 #2
0
    def batch_loss(params, data):
        pmat, lqr = get_lqr(params)

        kmat = discrete.gain_matrix(pmat, lqr)
        policy = vectorize.vectorize("(i),()->(j)")(util.policy(kmat, x_goal))
        us = policy(data.xs, np.zeros((), dtype=np.int32))

        return loss(data.us, us)
예제 #3
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Vectorize library."""

from absl.testing import absltest
from absl.testing import parameterized

from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental.vectorize import vectorize

from jax.config import config
config.parse_flags_with_absl()

matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)
matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)
vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)
vecvec = vectorize('(m),(m)->()')(jnp.dot)


@vectorize('(n)->()')
def magnitude(x):
    return jnp.dot(x, x)


mean = vectorize('(n)->()')(jnp.mean)


@vectorize('()->(n)')
def stack_plus_minus(x):