def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() act_out = actor.unroll_and_push(frame_count=frame_count, params=params) learner.enqueue_traj(act_out) learner.run(max_iterations=1)
def test_encode_weights(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() proto_weight = util.proto3_weight_encoder(frame_count, params) decoded_frame_count, decoded_params = \ util.proto3_weight_decoder(proto_weight) self.assertEqual(frame_count, decoded_frame_count) np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"], params["catch_net/linear"]["w"]) act_out = actor.unroll_and_push(frame_count, params)
def test_optimize(lr, w, x): net = hk.without_apply_rng( hk.transform(lambda x: hk.Linear( 1, with_bias=False, w_init=hk.initializers.Constant(w))(x))) params = net.init(next(hk.PRNGSequence(0)), jnp.zeros((1, 1))) opt_init, opt = optix.sgd(lr) opt_state = opt_init(params) def _loss(params, x): return net.apply(params, x).mean(), None opt_state, params, loss, _ = optimize(_loss, opt, opt_state, params, None, x=jnp.ones((1, 1)) * x) assert np.isclose(loss, w * x) assert np.isclose(params["linear"]["w"], w - lr * x)
def test_sgd(self): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.sgd(LR) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params sgd = optix.sgd(LR, 0.0) state = sgd.init(optix_params) for _ in range(STEPS): updates, state = sgd.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-5)
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # experimental/optix.py sgd optix_sgd_params = self.init_params sgd = optix.sgd(LR, 0.0) state_sgd = sgd.init(optix_sgd_params) # experimental/optix.py sgd apply every optix_sgd_apply_every_params = self.init_params sgd_apply_every = optix.chain(optix.apply_every(k=k), optix.trace(decay=0, nesterov=False), optix.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optix_sgd_apply_every_params) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optix_sgd_params = optix.apply_updates(optix_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update( self.per_step_updates, state_sgd_apply_every) optix_sgd_apply_every_params = optix.apply_updates( optix_sgd_apply_every_params, updates_sgd_apply_every) if i % k == k - 1: # Check equivalence. for x, y in zip(tree_leaves(optix_sgd_apply_every_params), tree_leaves(optix_sgd_params)): np.testing.assert_allclose(x, y, atol=1e-6, rtol=100) else: # Check updaue is zero. for x, y in zip(tree_leaves(updates_sgd_apply_every), tree_leaves(zero_update)): np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
# images=train_images, # labels=train_labels, tasks=tasks, num_tasks=num_tasks_per_step, num_samples=num_inner_samples, shuffle=True, ) outer_loop_sampler = partial( random_samples, # images=flatten(train_images, 1), # labels=flatten(train_labels, 1), num_samples=num_outer_samples, ) inner_opt = optix.chain(optix.sgd(args.inner_lr)) inner_loop_fn = make_inner_loop_fn(loss_acc_fn, inner_opt.update) outer_loop_loss_fn = make_outer_loop_loss_fn(loss_acc_fn, inner_opt.init, inner_loop_fn) rng, rng_net = split(rng) (out_shape), params = net_init(rng_net, (-1, size, size, 1)) rln_params, pln_params = ( params[:args.num_rln_layers], params[args.num_rln_layers:], ) outer_opt_init, outer_opt_update, outer_get_params = optimizers.adam( step_size=args.outer_lr) outer_opt_state = outer_opt_init((rln_params, pln_params))