def test_example(self): class MLP(elegy.Module): def __apply__(self, input): mlp = hk.Sequential([ hk.Linear(10), ]) return mlp(input) callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3) # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = elegy.Model( module=MLP.defer(), loss=elegy.losses.MeanSquaredError(), optimizer=optix.rmsprop(0.01), ) history = model.fit( np.arange(100).reshape(5, 20).astype(np.float32), np.zeros(5), epochs=10, batch_size=1, callbacks=[callback], verbose=0, ) assert len(history.history["loss"]) == 7 # Only 7 epochs are run.
def setup_learner(): """Setup learner for distributed setting""" # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) return learner
def main(_): # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) # Construct the actors on different threads. # stop_signal in a list so the reference is shared. actor_threads = [] stop_signal = [False] for i in range(NUM_ACTORS): actor = actor_lib.Actor( agent, build_env(), UNROLL_LENGTH, learner, rng_seed=i, logger=util.AbslLogger(), # Provide your own logger here. ) args = (actor, stop_signal) actor_threads.append(threading.Thread(target=run_actor, args=args)) # Start the actors and learner. for t in actor_threads: t.start() learner.run(int(max_updates)) # Stop. stop_signal[0] = True for t in actor_threads: t.join()
def main(_): # Construct the agent network. We need a sample environment for its spec. env = catch.Catch() num_actions = env.action_spec().num_values net = hk.transform(lambda ts: SimpleNet(num_actions)(ts)) # pylint: disable=unnecessary-lambda # Construct the agent and learner. agent = Agent(net.apply) opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1) learner = Learner(agent, opt.update) # Initialize the optimizer state. sample_ts = env.reset() sample_ts = preprocess_step(sample_ts) ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts) params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch) opt_state = opt.init(params) # Create accessor and queueing functions. current_params = lambda: params batch_size = 2 q = queue.Queue(maxsize=batch_size) def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch) # Start the actors. num_actors = 2 trajectories_per_actor = 500 unroll_len = 20 for i in range(num_actors): key = jax.random.PRNGKey(i) args = (agent, key, current_params, q.put, unroll_len, trajectories_per_actor) threading.Thread(target=run_actor, args=args).start() # Run the learner. num_steps = num_actors * trajectories_per_actor // batch_size for i in range(num_steps): traj = dequeue() params, opt_state = learner.update(params, opt_state, traj)
def test_rmsprop(self): decay, eps = .9, 0.1 # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.rmsprop(LR, decay, eps) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params rmsprop = optix.rmsprop(LR, decay, eps) state = rmsprop.init(optix_params) for _ in range(STEPS): updates, state = rmsprop.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-5)
def main(debug: bool = False, eager: bool = False): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() X_train, y_train, X_test, y_test = dataget.image.mnist( global_cache=True).get() print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) class MLP(elegy.Module): """Standard LeNet-300-100 MLP network.""" def __init__(self, n1: int = 300, n2: int = 100, **kwargs): super().__init__(**kwargs) self.n1 = n1 self.n2 = n2 def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 mlp = hk.Sequential([ hk.Flatten(), hk.Linear(self.n1), jax.nn.relu, hk.Linear(self.n2), jax.nn.relu, hk.Linear(10), ]) return dict(outputs=mlp(image)) model = elegy.Model( module=MLP.defer(n1=300, n2=100), loss=[ elegy.losses.SparseCategoricalCrossentropy(from_logits=True, on="outputs"), elegy.regularizers.GlobalL2(l=1e-4), ], metrics=elegy.metrics.SparseCategoricalAccuracy.defer(on="outputs"), optimizer=optix.rmsprop(1e-3), run_eagerly=eager, ) history = model.fit( x=X_train, y=dict(outputs=y_train), epochs=100, steps_per_epoch=200, batch_size=64, validation_data=(X_test, dict(outputs=y_test)), shuffle=True, ) plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(9, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred['outputs'][k])}") plt.imshow(x_sample[k], cmap="gray") plt.show()
def __init__( self, num_agent_steps, state_space, action_space, seed, max_grad_norm=None, gamma=0.99, nstep=1, buffer_size=10 ** 6, use_per=False, batch_size=32, start_steps=50000, update_interval=4, update_interval_target=8000, eps=0.01, eps_eval=0.001, eps_decay_steps=250000, loss_type="huber", dueling_net=False, double_q=False, setup_net=True, fn=None, lr=5e-5, lr_cum_p=2.5e-9, units=(512,), num_quantiles=32, num_cosines=64, ): super(FQF, self).__init__( num_agent_steps=num_agent_steps, state_space=state_space, action_space=action_space, seed=seed, max_grad_norm=max_grad_norm, gamma=gamma, nstep=nstep, buffer_size=buffer_size, batch_size=batch_size, use_per=use_per, start_steps=start_steps, update_interval=update_interval, update_interval_target=update_interval_target, eps=eps, eps_eval=eps_eval, eps_decay_steps=eps_decay_steps, loss_type=loss_type, dueling_net=dueling_net, double_q=double_q, setup_net=False, num_quantiles=num_quantiles, ) if setup_net: if fn is None: def fn(s, cum_p): return DiscreteImplicitQuantileFunction( num_cosines=num_cosines, action_space=action_space, hidden_units=units, dueling_net=dueling_net, )(s, cum_p) self.net, self.params, fake_feature = make_quantile_nerwork(self.rng, state_space, action_space, fn, num_quantiles) self.params_target = self.params opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size) self.opt_state = opt_init(self.params) # Fraction proposal network. self.cum_p_net = hk.without_apply_rng(hk.transform(lambda s: CumProbNetwork(num_quantiles=num_quantiles)(s))) self.params_cum_p = self.cum_p_net.init(next(self.rng), fake_feature) opt_init, self.opt_cum_p = optix.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True) self.opt_state_cum_p = opt_init(self.params_cum_p)
def main(debug: bool = False, eager: bool = False): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) class MLP(elegy.Module): """Standard LeNet-300-100 MLP network.""" def __init__(self, n1: int = 300, n2: int = 100, **kwargs): super().__init__(**kwargs) self.n1 = n1 self.n2 = n2 def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 x = hk.Flatten()(image) x = hk.Sequential([ hk.Linear(self.n1), jax.nn.relu, hk.Linear(self.n2), jax.nn.relu, hk.Linear(self.n1), jax.nn.relu, hk.Linear(x.shape[-1]), jax.nn.sigmoid, ])(x) return x.reshape(image.shape) * 255 class MeanSquaredError(elegy.Loss): # we request `x` instead of `y_true` since we are don't require labels in autoencoders def call(self, x, y_pred): return jnp.mean(jnp.square(x - y_pred), axis=-1) model = elegy.Model( module=MLP.defer(n1=256, n2=64), loss=MeanSquaredError(), optimizer=optix.rmsprop(0.001), run_eagerly=eager, ) # Notice we are not passing `y` history = model.fit( x=X_train, epochs=20, batch_size=64, validation_data=(X_test, ), shuffle=True, callbacks=[elegy.callbacks.TensorBoard()], ) plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") plt.show()