Ejemplo n.º 1
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     act_out = actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.enqueue_traj(act_out)
     learner.run(max_iterations=1)
Ejemplo n.º 2
0
 def test_encode_weights(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     proto_weight = util.proto3_weight_encoder(frame_count, params)
     decoded_frame_count, decoded_params = \
       util.proto3_weight_decoder(proto_weight)
     self.assertEqual(frame_count, decoded_frame_count)
     np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"],
                                    params["catch_net/linear"]["w"])
     act_out = actor.unroll_and_push(frame_count, params)
Ejemplo n.º 3
0
def test_optimize(lr, w, x):
    net = hk.without_apply_rng(
        hk.transform(lambda x: hk.Linear(
            1, with_bias=False, w_init=hk.initializers.Constant(w))(x)))
    params = net.init(next(hk.PRNGSequence(0)), jnp.zeros((1, 1)))
    opt_init, opt = optix.sgd(lr)
    opt_state = opt_init(params)

    def _loss(params, x):
        return net.apply(params, x).mean(), None

    opt_state, params, loss, _ = optimize(_loss,
                                          opt,
                                          opt_state,
                                          params,
                                          None,
                                          x=jnp.ones((1, 1)) * x)
    assert np.isclose(loss, w * x)
    assert np.isclose(params["linear"]["w"], w - lr * x)
Ejemplo n.º 4
0
    def test_sgd(self):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.sgd(LR)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state = sgd.init(optix_params)
        for _ in range(STEPS):
            updates, state = sgd.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-5)
Ejemplo n.º 5
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # experimental/optix.py sgd
        optix_sgd_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state_sgd = sgd.init(optix_sgd_params)

        # experimental/optix.py sgd apply every
        optix_sgd_apply_every_params = self.init_params
        sgd_apply_every = optix.chain(optix.apply_every(k=k),
                                      optix.trace(decay=0, nesterov=False),
                                      optix.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optix_sgd_apply_every_params)
        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optix_sgd_params = optix.apply_updates(optix_sgd_params,
                                                   updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
                self.per_step_updates, state_sgd_apply_every)
            optix_sgd_apply_every_params = optix.apply_updates(
                optix_sgd_apply_every_params, updates_sgd_apply_every)
            if i % k == k - 1:
                # Check equivalence.
                for x, y in zip(tree_leaves(optix_sgd_apply_every_params),
                                tree_leaves(optix_sgd_params)):
                    np.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
            else:
                # Check updaue is zero.
                for x, y in zip(tree_leaves(updates_sgd_apply_every),
                                tree_leaves(zero_update)):
                    np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
Ejemplo n.º 6
0
        # images=train_images,
        # labels=train_labels,
        tasks=tasks,
        num_tasks=num_tasks_per_step,
        num_samples=num_inner_samples,
        shuffle=True,
    )

    outer_loop_sampler = partial(
        random_samples,
        # images=flatten(train_images, 1),
        # labels=flatten(train_labels, 1),
        num_samples=num_outer_samples,
    )

    inner_opt = optix.chain(optix.sgd(args.inner_lr))
    inner_loop_fn = make_inner_loop_fn(loss_acc_fn, inner_opt.update)
    outer_loop_loss_fn = make_outer_loop_loss_fn(loss_acc_fn, inner_opt.init,
                                                 inner_loop_fn)

    rng, rng_net = split(rng)
    (out_shape), params = net_init(rng_net, (-1, size, size, 1))

    rln_params, pln_params = (
        params[:args.num_rln_layers],
        params[args.num_rln_layers:],
    )

    outer_opt_init, outer_opt_update, outer_get_params = optimizers.adam(
        step_size=args.outer_lr)
    outer_opt_state = outer_opt_init((rln_params, pln_params))