def __init__(self, policy: str, action_dim: int, max_action: float, lr: float, discount: float, noise_clip: float, policy_noise: float, policy_freq: int, actor_rng: jnp.ndarray, critic_rng: jnp.ndarray, sample_state: np.ndarray): self.discount = discount self.noise_clip = noise_clip self.policy_noise = policy_noise self.policy_freq = policy_freq self.max_action = max_action self.td3_update = policy == 'TD3' self.actor = hk.transform(lambda x: Actor(action_dim, max_action)(x)) actor_opt_init, self.actor_opt_update = optix.adam(lr) self.critic = hk.transform(lambda x: Critic()(x)) critic_opt_init, self.critic_opt_update = optix.adam(lr) self.actor_params = self.target_actor_params = self.actor.init( actor_rng, sample_state) self.actor_opt_state = actor_opt_init(self.actor_params) action = self.actor.apply(self.actor_params, sample_state) self.critic_params = self.target_critic_params = self.critic.init( critic_rng, jnp.concatenate((sample_state, action), 0)) self.critic_opt_state = critic_opt_init(self.critic_params) self.updates = 0
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> BootstrappedDqn: """Initialize a Bootstrapped DQN agent with default parameters.""" # Define network. prior_scale = 3. hidden_sizes = [50, 50] def network(inputs: jnp.ndarray) -> jnp.ndarray: """Simple Q-network with randomized prior function.""" net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) x = hk.Flatten()(inputs) return net(x) + prior_scale * lax.stop_gradient(prior_net(x)) optimizer = optix.adam(learning_rate=1e-3) return BootstrappedDqn( obs_spec=obs_spec, action_spec=action_spec, network=hk.transform(network), batch_size=128, discount=.99, num_ensemble=20, replay_capacity=10000, min_replay_size=128, sgd_period=1, target_update_period=4, optimizer=optimizer, mask_prob=0.5, noise_scale=0., epsilon_fn=lambda _: 0., seed=seed, )
def __init__( self, num_critics=2, fn_error=None, lr_error=3e-4, units_error=(256, 256, 256), d2rl=False, init_error=10.0, ): if fn_error is None: def fn_error(s, a): return ContinuousQFunction( num_critics=num_critics, hidden_units=units_error, d2rl=d2rl, )(s, a) # Error model. self.error = hk.without_apply_rng(hk.transform(fn_error)) self.params_error = self.params_error_target = self.error.init( next(self.rng), *self.fake_args_critic) opt_init, self.opt_error = optix.adam(lr_error) self.opt_state_error = opt_init(self.params_error) # Running mean of error. self.rm_error_list = [ jnp.array(init_error, dtype=jnp.float32) for _ in range(num_critics) ]
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Creates an actor-critic agent with default hyperparameters.""" hidden_size = 256 initial_rnn_state = hk.LSTMState( hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32), cell=jnp.zeros((1, hidden_size), dtype=jnp.float32)) def network(inputs: jnp.ndarray, state) -> Tuple[Tuple[Logits, Value], LSTMState]: flat_inputs = hk.Flatten()(inputs) torso = hk.nets.MLP([hidden_size, hidden_size]) lstm = hk.LSTM(hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso(flat_inputs) embedding, state = lstm(embedding, state) logits = policy_head(embedding) value = value_head(embedding) return (logits, jnp.squeeze(value, axis=-1)), state return ActorCriticRNN( obs_spec=obs_spec, action_spec=action_spec, network=network, initial_rnn_state=initial_rnn_state, optimizer=optix.adam(3e-3), rng=hk.PRNGSequence(seed), sequence_length=32, discount=0.99, td_lambda=0.9, )
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Creates an actor-critic agent with default hyperparameters.""" def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]: flat_inputs = hk.Flatten()(inputs) torso = hk.nets.MLP([64, 64]) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso(flat_inputs) logits = policy_head(embedding) value = value_head(embedding) return logits, jnp.squeeze(value, axis=-1) return ActorCritic( obs_spec=obs_spec, action_spec=action_spec, network=network, optimizer=optix.adam(3e-3), rng=hk.PRNGSequence(seed), sequence_length=32, discount=0.99, td_lambda=0.9, )
def main(_): # Create the dataset. train_dataset, vocab_size = dataset.load(FLAGS.batch_size, FLAGS.sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads, FLAGS.num_layers, FLAGS.dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optix.chain(optix.clip_by_global_norm(FLAGS.grad_clip_value), optix.adam(FLAGS.learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % LOG_EVERY == 0: steps_per_sec = LOG_EVERY / (time.time() - prev_time) prev_time = time.time() metrics.update({'steps_per_sec': steps_per_sec}) logging.info({k: float(v) for k, v in metrics.items()})
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState: """Does a step of SGD given inputs & targets.""" _, optimizer = optix.adam(FLAGS.learning_rate) _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss)) gradients = jax.grad(loss_fn)(state.params, batch) updates, new_opt_state = optimizer(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state)
def __init__(self, obs_spec, action_spec, epsilon, learning_rate): self._obs_spec = obs_spec self._action_spec = action_spec self._epsilon = epsilon # Neural net and optimiser. self._network = build_network(action_spec.num_values) self._optimizer = optix.adam(learning_rate) # Jitting for speed. self.actor_step = jax.jit(self.actor_step) self.learner_step = jax.jit(self.learner_step)
def __init__( self, observation_spec, action_spec, params: RlaxRainbowParams = RlaxRainbowParams()): if not callable(params.epsilon): eps = params.epsilon params = params._replace(epsilon=lambda ts: eps) if not callable(params.beta_is): beta = params.beta_is params = params._replace(beta_is=lambda ts: beta) self.params = params self.rng = hk.PRNGSequence(jax.random.PRNGKey(params.seed)) # Build and initialize Q-network. def build_network( layers: List[int], output_shape: List[int]) -> hk.Transformed: def q_net(obs): layers_ = tuple(layers) + (onp.prod(output_shape), ) network = NoisyMLP(layers_) return hk.Reshape(output_shape=output_shape)(network(obs)) return hk.transform(q_net) self.network = build_network(params.layers, (action_spec.num_values, params.n_atoms)) self.trg_params = self.network.init( next(self.rng), observation_spec.generate_value().astype(onp.float16)) self.online_params = self.trg_params self.atoms = jnp.tile(jnp.linspace(-params.atom_vmax, params.atom_vmax, params.n_atoms), (action_spec.num_values, 1)) # Build and initialize optimizer. self.optimizer = optix.adam(params.learning_rate, eps=3.125e-5) self.opt_state = self.optimizer.init(self.online_params) self.train_step = 0 self.update_q = DQNLearning.update_q if params.use_priority: self.experience = PriorityBuffer( observation_spec.shape[1], action_spec.num_values, 1, params.experience_buffer_size) else: self.experience = ExperienceBuffer( observation_spec.shape[1], action_spec.num_values, 1, params.experience_buffer_size) self.last_obs = onp.empty(observation_spec.shape) self.requires_vectorized_observation = lambda: True
def __init__(self, config: DDPGConfig): """Initialize the algorithm.""" pi_net = partial(mlp_policy_net, output_sizes=config.pi_net_size + (config.action_dim, )) pi_net = hk.transform(pi_net) q_net = partial(mlp_value_net, output_sizes=config.q_net_size + (1, )) q_net = hk.transform(q_net) pi_optimizer = optix.adam(learning_rate=config.learning_rate) q_optimizer = optix.adam(learning_rate=config.learning_rate) self.func = DDPGFunc(pi_net, q_net, pi_optimizer, q_optimizer, config.gamma, config.state_dim, config.action_dim) rng = jax.random.PRNGKey(config.seed) # rundom number generator state = jnp.zeros(config.state_dim) state_action = jnp.zeros(config.state_dim + config.action_dim) pi_params = pi_net.init(rng, state) q_params = q_net.init(rng, state_action) pi_opt_state = pi_optimizer.init(pi_params) q_opt_state = q_optimizer.init(q_params) self.state = DDPGState(pi_params, q_params, pi_opt_state, q_opt_state) return
def main(debug: bool = False, eager: bool = False): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = elegy.Model( module=VariationalAutoEncoder.defer(), loss=[KLDivergence(), BinaryCrossEntropy(on="logits")], optimizer=optix.adam(1e-3), run_eagerly=eager, ) epochs = 10 # Fit with datasets in memory history = model.fit( x=X_train, epochs=epochs, batch_size=64, steps_per_epoch=100, validation_data=(X_test, ), shuffle=True, ) plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred["image"][i], cmap="gray") plt.show()
def __init__(self, obs_spec, action_spec, epsilon_cfg, target_period, learning_rate): self._obs_spec = obs_spec self._action_spec = action_spec self._target_period = target_period # Neural net and optimiser. self._network = build_network(action_spec.num_values) self._optimizer = optix.adam(learning_rate) self._epsilon_by_frame = rlax.polynomial_schedule(**epsilon_cfg) # Jitting for speed. self.actor_step = jax.jit(self.actor_step) self.learner_step = jax.jit(self.learner_step)
def run(bsuite_id: str) -> str: """Runs a DQN agent on a given bsuite environment, logging to CSV.""" env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) action_spec = env.action_spec() # Define network. prior_scale = 5. hidden_sizes = [50, 50] def network(inputs: jnp.ndarray) -> jnp.ndarray: """Simple Q-network with randomized prior function.""" net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) x = hk.Flatten()(inputs) return net(x) + prior_scale * lax.stop_gradient(prior_net(x)) optimizer = optix.adam(learning_rate=1e-3) agent = boot_dqn.BootstrappedDqn( obs_spec=env.observation_spec(), action_spec=action_spec, network=network, optimizer=optimizer, num_ensemble=FLAGS.num_ensemble, batch_size=128, discount=.99, replay_capacity=10000, min_replay_size=128, sgd_period=1, target_update_period=4, mask_prob=1.0, noise_scale=0., ) num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes') experiment.run(agent=agent, environment=env, num_episodes=num_episodes, verbose=FLAGS.verbose) return bsuite_id
def main(): net = hk.transform(behavioral_cloning) opt = optix.adam(0.001) @jax.jit def loss(params, batch): """ The loss criterion for our model """ logits = net.apply(params, None, batch) return mse_loss(logits, batch[2]) @jax.jit def update(opt_state, params, batch): grads = jax.grad(loss)(params, batch) updates, opt_state = opt.update(grads, opt_state) params = optix.apply_updates(params, updates) return params, opt_state @jax.jit def accuracy(params, batch): """ Simply report the loss for the current batch """ logits = net.apply(params, None, batch) return mse_loss(logits, batch[2]) train_dataset, val_dataset = load_data(MINERL_ENV, batch_size=32, epochs=100) rng = jax.random.PRNGKey(2020) batch = next(train_dataset) params = net.init(rng, batch) opt_state = opt.init(params) for i, batch in enumerate(train_dataset): params, opt_state = update(opt_state, params, batch) if i % 1000 == 0: print(accuracy(params, val_dataset)) if i % 10000 == 0: with open(PARAMS_FILENAME, 'wb') as fh: dill.dump(params, fh) with open(PARAMS_FILENAME, 'wb') as fh: dill.dump(params, fh)
def main(_): model = hk.transform(lambda x: VariationalAutoEncoder()(x)) # pylint: disable=unnecessary-lambda optimizer = optix.adam(FLAGS.learning_rate) @jax.jit def loss_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: """ELBO loss: E_p[log(x)] - KL(d||q), where p ~ Be(0.5) and q ~ N(0,1).""" outputs: VAEOutput = model.apply(params, rng_key, batch["image"]) log_likelihood = -binary_cross_entropy(batch["image"], outputs.logits) kl = kl_gaussian(outputs.mean, outputs.stddev**2) elbo = log_likelihood - kl return -jnp.mean(elbo) @jax.jit def update( params: hk.Params, rng_key: PRNGKey, opt_state: OptState, batch: Batch, ) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, new_opt_state rng_seq = hk.PRNGSequence(42) params = model.init(next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))) opt_state = optimizer.init(params) train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size) valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size) for step in range(FLAGS.training_steps): params, opt_state = update(params, next(rng_seq), opt_state, next(train_ds)) if step % FLAGS.eval_frequency == 0: val_loss = loss_fn(params, next(rng_seq), next(valid_ds)) logging.info("STEP: %5d; Validation ELBO: %.3f", step, -val_loss)
def __init__( self, observation_spec, action_spec, target_update_period: int = None, discount: float = None, epsilon=lambda x: 0.1, learning_rate: float = 0.001, layers: List[int] = None, use_double_q=True, use_priority=True, seed: int = 1234, importance_beta: float = 0.4): self.rng = hk.PRNGSequence(jax.random.PRNGKey(seed)) # Build and initialize Q-network. self.layers = layers or (512,) self.network = build_network(self.layers, action_spec.num_values) # sample_input = env.observation_spec()["observation"].generate_value() # sample_input = jnp.zeros(observation_le) self.trg_params = self.network.init(next(self.rng), observation_spec.generate_value().astype(onp.float16)) self.online_params = self.trg_params#self.network.init(next(self.rng), sample_input) # Build and initialize optimizer. self.optimizer = optix.adam(learning_rate) self.opt_state = self.optimizer.init(self.online_params) self.epsilon = epsilon self.train_steps = 0 self.target_update_period = target_update_period or 500 self.discount = discount or 0.99 # if use_double_q: # self.update_q = DQNLearning.update_q_double # else: # self.update_q = DQNLearning.update_q self.update_q = DQNLearning.update_q if use_priority: self.experience = PriorityBuffer(observation_spec.shape[1], action_spec.num_values, 1, 2**19) else: self.experience = ExperienceBuffer(observation_spec.shape[1], action_spec.num_values, 1, 2**19) self.importance_beta = importance_beta self.last_obs = onp.empty(observation_spec.shape)
def test_adam(self): b1, b2, eps = 0.9, 0.999, 1e-8 # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.adam(LR, b1, b2, eps) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params adam = optix.adam(LR, b1, b2, eps) state = adam.init(optix_params) for _ in range(STEPS): updates, state = adam.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-4)
def test_graph_network_learning(self, spatial_dimension, dtype): key = random.PRNGKey(0) R_key, dr0_key, params_key = random.split(key, 3) d, _ = space.free() R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype) dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype) E_gt = vmap( lambda R, dr0: \ np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2)) cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(params_key, R[0]) @jit def loss(params, R): return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0))**2) opt = optix.chain(optix.clip_by_global_norm(1.0), optix.adam(1e-4)) @jit def update(params, opt_state, R): updates, opt_state = opt.update(grad(loss)(params, R), opt_state) return optix.apply_updates(params, updates), opt_state opt_state = opt.init(params) l0 = loss(params, R) for i in range(4): params, opt_state = update(params, opt_state, R) assert loss(params, R) < l0 * 0.95
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Initialize a DQN agent with default parameters.""" def network(inputs: jnp.ndarray) -> jnp.ndarray: flat_inputs = hk.Flatten()(inputs) mlp = hk.nets.MLP([64, 64, action_spec.num_values]) action_values = mlp(flat_inputs) return action_values return DQN( obs_spec=obs_spec, action_spec=action_spec, network=hk.transform(network), optimizer=optix.adam(1e-3), batch_size=32, discount=0.99, replay_capacity=10000, min_replay_size=100, sgd_period=1, target_update_period=4, epsilon=0.05, rng=hk.PRNGSequence(seed), )
def make_optimizer() -> optix.InitUpdate: """Defines the optimizer.""" return optix.adam(FLAGS.learning_rate)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.hidden_layer_sizes is None: # Cannot pass default arguments as lists due to style requirements, so we # override it here if they are not set. FLAGS.hidden_layer_sizes = DEFAULT_LAYER_SIZES # Make the network. net = hk.transform(net_fn) # Make the optimiser. opt = optix.adam(FLAGS.step_size) @jax.jit def loss( params: Params, inputs: np.ndarray, targets: np.ndarray, ) -> jnp.DeviceArray: """Cross-entropy loss.""" assert targets.dtype == np.int32 log_probs = net.apply(params, inputs) return -jnp.mean(one_hot(targets, NUM_ACTIONS) * log_probs) @jax.jit def accuracy( params: Params, inputs: np.ndarray, targets: np.ndarray, ) -> jnp.DeviceArray: """Classification accuracy.""" predictions = net.apply(params, inputs) return jnp.mean(jnp.argmax(predictions, axis=-1) == targets) @jax.jit def update( params: Params, opt_state: OptState, inputs: np.ndarray, targets: np.ndarray, ) -> Tuple[Params, OptState]: """Learning rule (stochastic gradient descent).""" _, gradient = jax.value_and_grad(loss)(params, inputs, targets) updates, opt_state = opt.update(gradient, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state def output_samples(params: Params, max_samples: int): """Output some cases where the policy disagrees with the dataset action.""" if max_samples == 0: return count = 0 with open(os.path.join(FLAGS.data_path, 'test.txt')) as f: lines = list(f) np.random.shuffle(lines) for line in lines: state = GAME.new_initial_state() actions = _trajectory(line) for action in actions: if not state.is_chance_node(): observation = np.array(state.information_state_tensor(), np.float32) policy = np.exp(net.apply(params, observation)) probs_actions = [(p, a) for a, p in enumerate(policy)] pred = max(probs_actions)[1] if pred != action: print(state) for p, a in reversed( sorted(probs_actions)[-TOP_K_ACTIONS:]): print('{:7} {:.2f}'.format( state.action_to_string(a), p)) print('Ground truth {}\n'.format( state.action_to_string(action))) count += 1 break state.apply_action(action) if count >= max_samples: return # Store what we need to rebuild the Haiku net. if FLAGS.save_path: filename = os.path.join(FLAGS.save_path, 'layers.txt') with open(filename, 'w') as layer_def_file: for s in FLAGS.hidden_layer_sizes: layer_def_file.write(f'{s} ') layer_def_file.write('\n') # Make datasets. if FLAGS.data_path is None: raise app.UsageError( 'Please generate your own supervised training data and supply the local' 'location as --data_path') train = batch(make_dataset(os.path.join(FLAGS.data_path, 'train.txt')), FLAGS.train_batch) test = batch(make_dataset(os.path.join(FLAGS.data_path, 'test.txt')), FLAGS.eval_batch) # Initialize network and optimiser. if FLAGS.checkpoint_file: with open(FLAGS.checkpoint_file, 'rb') as pkl_file: params, opt_state = pickle.load(pkl_file) else: rng = jax.random.PRNGKey( FLAGS.rng_seed) # seed used for network weights inputs, unused_targets = next(train) params = net.init(rng, inputs) opt_state = opt.init(params) # Train/eval loop. for step in range(FLAGS.iterations): # Do SGD on a batch of training examples. inputs, targets = next(train) params, opt_state = update(params, opt_state, inputs, targets) # Periodically evaluate classification accuracy on the test set. if (1 + step) % FLAGS.eval_every == 0: inputs, targets = next(test) test_accuracy = accuracy(params, inputs, targets) print(f'After {1+step} steps, test accuracy: {test_accuracy}.') if FLAGS.save_path: filename = os.path.join(FLAGS.save_path, f'checkpoint-{1 + step}.pkl') with open(filename, 'wb') as pkl_file: pickle.dump((params, opt_state), pkl_file) output_samples(params, FLAGS.num_examples)
def __init__( self, num_agent_steps, state_space, action_space, seed, max_grad_norm=None, gamma=0.99, nstep=1, num_critics=1, buffer_size=10**6, use_per=False, batch_size=256, start_steps=10000, update_interval=1, tau=5e-3, fn_actor=None, fn_critic=None, lr_actor=1e-3, lr_critic=1e-3, units_actor=(256, 256), units_critic=(256, 256), d2rl=False, std=0.1, update_interval_policy=2, ): super(DDPG, self).__init__( num_agent_steps=num_agent_steps, state_space=state_space, action_space=action_space, seed=seed, max_grad_norm=max_grad_norm, gamma=gamma, nstep=nstep, num_critics=num_critics, buffer_size=buffer_size, use_per=use_per, batch_size=batch_size, start_steps=start_steps, update_interval=update_interval, tau=tau, ) if d2rl: self.name += "-D2RL" if fn_critic is None: def fn_critic(s, a): return ContinuousQFunction( num_critics=num_critics, hidden_units=units_critic, d2rl=d2rl, )(s, a) if fn_actor is None: def fn_actor(s): return DeterministicPolicy( action_space=action_space, hidden_units=units_actor, d2rl=d2rl, )(s) # Critic. self.critic = hk.without_apply_rng(hk.transform(fn_critic)) self.params_critic = self.params_critic_target = self.critic.init( next(self.rng), *self.fake_args_critic) opt_init, self.opt_critic = optix.adam(lr_critic) self.opt_state_critic = opt_init(self.params_critic) # Actor. self.actor = hk.without_apply_rng(hk.transform(fn_actor)) self.params_actor = self.params_actor_target = self.actor.init( next(self.rng), *self.fake_args_actor) opt_init, self.opt_actor = optix.adam(lr_actor) self.opt_state_actor = opt_init(self.params_actor) # Other parameters. self.std = std self.update_interval_policy = update_interval_policy
def train(batch_size, max_iter, num_test_samples): model = make_network(4, 8, output_dim=2) optimizer = optix.adam(0.005) data_generator = collect_batch( dataset.FamilyTreeDataset( task="grandparents", epoch_size=10, nmin=20), n=batch_size, ) @jax.jit def evaluate_batch(params, data): logits = model.apply(params, data.predicates) correct = jnp.all(data.targets == jnp.argmax(logits, axis=-1), axis=-1) ce_loss = jnp.sum(cross_entropy(data.targets, logits), axis=-1) return correct, ce_loss def evaluate(params): eval_data_gen = collect_batch( dataset.FamilyTreeDataset( task="grandparents", epoch_size=10, nmin=50), n=batch_size, ) correct, ce_loss = zip(*[ evaluate_batch(params, preprocess_batch(data)) for data in itertools.islice(eval_data_gen, num_test_samples) ]) correct = np.concatenate(correct, axis=0) ce_loss = np.concatenate(ce_loss, axis=0) return np.mean(correct), np.mean(ce_loss) def loss(params, data): logits = model.apply(params, data.predicates) return jnp.mean(cross_entropy(data.targets, logits)) @jax.jit def update(data, params, opt_state): batch_loss, dparam = jax.value_and_grad(loss)(params, data) updates, opt_state = optimizer.update(dparam, opt_state) params = optix.apply_updates(params, updates) return batch_loss, OptimizationState(params, opt_state) rng = jax.random.PRNGKey(0) init_params = model.init( rng, preprocess_batch(next(data_generator)).predicates, ) state = OptimizationState(init_params, optimizer.init(init_params)) for i, batch in enumerate(data_generator): if i >= max_iter: break batch = preprocess_batch(batch) batch_loss, state = update(batch, *state) if (i + 1) % 100 == 0: print(i + 1, evaluate(state.params))
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def main(_): # Make the network and optimiser. net = hk.transform(net_fn) opt = optix.adam(1e-3) # Training loss (cross-entropy). def loss(params: hk.Params, batch: Batch) -> jnp.ndarray: """Compute the loss of the network, including L2.""" logits = net.apply(params, batch) labels = jax.nn.one_hot(batch["label"], 10) l2_loss = 0.5 * sum( jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) softmax_xent /= labels.shape[0] return softmax_xent + 1e-4 * l2_loss # Evaluation metric (classification accuracy). @jax.jit def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray: predictions = net.apply(params, batch) return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"]) @jax.jit def update(params: hk.Params, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Learning rule (stochastic gradient descent).""" grads = jax.grad(loss)(params, batch) updates, opt_state = opt.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state # We maintain avg_params, the exponential moving average of the "live" params. # avg_params is used only for evaluation. # For more, see: https://doi.org/10.1137/0330046 @jax.jit def ema_update(avg_params: hk.Params, new_params: hk.Params, epsilon: float = 0.001) -> hk.Params: return jax.tree_multimap( lambda p1, p2: (1 - epsilon) * p1 + epsilon * p2, avg_params, new_params) # Make datasets. train = load_dataset("train", is_training=True, batch_size=1000) train_eval = load_dataset("train", is_training=False, batch_size=10000) test_eval = load_dataset("test", is_training=False, batch_size=10000) # Initialize network and optimiser; note we draw an input to get shapes. params = avg_params = net.init(jax.random.PRNGKey(42), next(train)) opt_state = opt.init(params) # Train/eval loop. for step in range(10001): if step % 1000 == 0: # Periodically evaluate classification accuracy on train & test sets. train_accuracy = accuracy(avg_params, next(train_eval)) test_accuracy = accuracy(avg_params, next(test_eval)) train_accuracy, test_accuracy = jax.device_get( (train_accuracy, test_accuracy)) print(f"[Step {step}] Train / Test accuracy: " f"{train_accuracy:.3f} / {test_accuracy:.3f}.") # Do SGD on a batch of training examples. params, opt_state = update(params, opt_state, next(train)) avg_params = ema_update(avg_params, params)
def __init__( self, num_agent_steps, state_space, action_space, seed, max_grad_norm=None, gamma=0.99, nstep=1, buffer_size=10 ** 6, use_per=False, batch_size=32, start_steps=50000, update_interval=4, update_interval_target=8000, eps=0.01, eps_eval=0.001, eps_decay_steps=250000, loss_type="huber", dueling_net=False, double_q=False, setup_net=True, fn=None, lr=5e-5, lr_cum_p=2.5e-9, units=(512,), num_quantiles=32, num_cosines=64, ): super(FQF, self).__init__( num_agent_steps=num_agent_steps, state_space=state_space, action_space=action_space, seed=seed, max_grad_norm=max_grad_norm, gamma=gamma, nstep=nstep, buffer_size=buffer_size, batch_size=batch_size, use_per=use_per, start_steps=start_steps, update_interval=update_interval, update_interval_target=update_interval_target, eps=eps, eps_eval=eps_eval, eps_decay_steps=eps_decay_steps, loss_type=loss_type, dueling_net=dueling_net, double_q=double_q, setup_net=False, num_quantiles=num_quantiles, ) if setup_net: if fn is None: def fn(s, cum_p): return DiscreteImplicitQuantileFunction( num_cosines=num_cosines, action_space=action_space, hidden_units=units, dueling_net=dueling_net, )(s, cum_p) self.net, self.params, fake_feature = make_quantile_nerwork(self.rng, state_space, action_space, fn, num_quantiles) self.params_target = self.params opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size) self.opt_state = opt_init(self.params) # Fraction proposal network. self.cum_p_net = hk.without_apply_rng(hk.transform(lambda s: CumProbNetwork(num_quantiles=num_quantiles)(s))) self.params_cum_p = self.cum_p_net.init(next(self.rng), fake_feature) opt_init, self.opt_cum_p = optix.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True) self.opt_state_cum_p = opt_init(self.params_cum_p)
def __init__( self, num_agent_steps, state_space, action_space, seed, max_grad_norm=None, gamma=0.99, nstep=1, buffer_size=10 ** 6, use_per=False, batch_size=32, start_steps=50000, update_interval=4, update_interval_target=8000, eps=0.01, eps_eval=0.001, eps_decay_steps=250000, loss_type="huber", dueling_net=False, double_q=False, setup_net=True, fn=None, lr=2.5e-4, units=(512,), ): super(DQN, self).__init__( num_agent_steps=num_agent_steps, state_space=state_space, action_space=action_space, seed=seed, max_grad_norm=max_grad_norm, gamma=gamma, nstep=nstep, buffer_size=buffer_size, batch_size=batch_size, use_per=use_per, start_steps=start_steps, update_interval=update_interval, update_interval_target=update_interval_target, eps=eps, eps_eval=eps_eval, eps_decay_steps=eps_decay_steps, loss_type=loss_type, dueling_net=dueling_net, double_q=double_q, ) if setup_net: if fn is None: def fn(s): return DiscreteQFunction( action_space=action_space, hidden_units=units, dueling_net=dueling_net, )(s) self.net = hk.without_apply_rng(hk.transform(fn)) self.params = self.params_target = self.net.init(next(self.rng), *self.fake_args) opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size) self.opt_state = opt_init(self.params)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.QNetwork, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = hk.transform(network).apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(1), optimizer=optix.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, 'foo') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(1), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def main_loop(unused_arg): env = catch.Catch(seed=FLAGS.seed) rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed)) # Build and initialize Q-network. num_actions = env.action_spec().num_values network = build_network(num_actions) sample_input = env.observation_spec().generate_value() net_params = network.init(next(rng), sample_input) # Build and initialize optimizer. optimizer = optix.adam(FLAGS.learning_rate) opt_state = optimizer.init(net_params) @jax.jit def policy(net_params, key, obs): """Sample action from epsilon-greedy policy.""" q = network.apply(net_params, obs) a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q) return q, a @jax.jit def eval_policy(net_params, key, obs): """Sample action from greedy policy.""" q = network.apply(net_params, obs) return rlax.greedy().sample(key, q) @jax.jit def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t): """Update network weights wrt Q-learning loss.""" def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t): q_tm1 = network.apply(net_params, obs_tm1) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error) dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t) updates, opt_state = optimizer.update(dloss_dtheta, opt_state) net_params = optix.apply_updates(net_params, updates) return net_params, opt_state print(f"Training agent for {FLAGS.train_episodes} episodes") print("Returns range [-1.0, 1.0]") for episode in range(FLAGS.train_episodes): timestep = env.reset() obs_tm1 = timestep.observation _, a_tm1 = policy(net_params, next(rng), obs_tm1) while not timestep.last(): new_timestep = env.step(int(a_tm1)) obs_t = new_timestep.observation # Sample action from stochastic policy. q_t, a_t = policy(net_params, next(rng), obs_t) # Update Q-values. r_t = new_timestep.reward discount_t = FLAGS.discount_factor * new_timestep.discount net_params, opt_state = update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t) timestep = new_timestep obs_tm1 = obs_t a_tm1 = a_t if not episode % FLAGS.evaluate_every: # Evaluate agent with deterministic policy. returns = 0. for _ in range(FLAGS.eval_episodes): timestep = env.reset() obs = timestep.observation while not timestep.last(): action = eval_policy(net_params, next(rng), obs) timestep = env.step(int(action)) obs = timestep.observation returns += timestep.reward avg_returns = returns / FLAGS.eval_episodes print(f"Episode {episode:4d}: Average returns: {avg_returns:.2f}")
def main(_): # Make the network and optimiser. net = hk.transform(net_fn) opt = optix.adam(1e-3) # Define layerwise sparsities def module_matching(s): def match_func(m, n, k): return m.endswith(s) and not sparsity_ignore(m, n, k) return match_func module_sparsity = ((module_matching("linear"), 0.98), (module_matching("linear_1"), 0.9)) # Training loss (cross-entropy). @jax.jit def loss(params: hk.Params, batch: Batch) -> jnp.ndarray: """Compute the loss of the network, including L2.""" logits = net.apply(params, batch) labels = jax.nn.one_hot(batch["label"], 10) l2_loss = 0.5 * sum( jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) softmax_xent /= labels.shape[0] return softmax_xent + 1e-4 * l2_loss # Evaluation metric (classification accuracy). @jax.jit def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray: predictions = net.apply(params, batch) return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"]) @jax.jit def get_updates( params: hk.Params, opt_state: OptState, batch: Batch, ) -> Tuple[hk.Params, OptState]: """Learning rule (stochastic gradient descent).""" grads = jax.grad(loss)(params, batch) updates, opt_state = opt.update(grads, opt_state) return updates, opt_state # We maintain avg_params, the exponential moving average of the "live" params. # avg_params is used only for evaluation. # For more, see: https://doi.org/10.1137/0330046 @jax.jit def ema_update( avg_params: hk.Params, new_params: hk.Params, epsilon: float = 0.001, ) -> hk.Params: return jax.tree_multimap( lambda p1, p2: (1 - epsilon) * p1 + epsilon * p2, avg_params, new_params) # Make datasets. train = load_dataset("train", is_training=True, batch_size=1000) train_eval = load_dataset("train", is_training=False, batch_size=10000) test_eval = load_dataset("test", is_training=False, batch_size=10000) # Implemenation note: It is possible to avoid pruned_params and just use # a single params which progressively gets pruned. The updates also don't # need to masked in such an implementation. The current implementation # attempts to mimic the way the current TF implementation which allows for # previously inactivated connections to become active again if active values # drop below their value. # Initialize network and optimiser; note we draw an input to get shapes. pruned_params = params = avg_params = net.init(jax.random.PRNGKey(42), next(train)) masks = update_mask(params, 0., module_sparsity) opt_state = opt.init(params) # Train/eval loop. for step in range(10001): if step % 1000 == 0: # Periodically evaluate classification accuracy on train & test sets. avg_params = apply_mask(avg_params, masks, module_sparsity) train_accuracy = accuracy(avg_params, next(train_eval)) test_accuracy = accuracy(avg_params, next(test_eval)) total_params, total_nnz, per_layer_sparsities = get_sparsity( avg_params) train_accuracy, test_accuracy, total_nnz, per_layer_sparsities = ( jax.device_get((train_accuracy, test_accuracy, total_nnz, per_layer_sparsities))) print(f"[Step {step}] Train / Test accuracy: " f"{train_accuracy:.3f} / {test_accuracy:.3f}.") print(f"Non-zero params / Total: {total_nnz} / {total_params}; " f"Total Sparsity: {1. - total_nnz / total_params:.3f}") # Do SGD on a batch of training examples. pruned_params = apply_mask(params, masks, module_sparsity) updates, opt_state = get_updates(pruned_params, opt_state, next(train)) # applying a straight-through estimator here (that is not masking # the updates) leads to much worse performance. updates = apply_mask(updates, masks, module_sparsity) params = optix.apply_updates(params, updates) # we start pruning at iteration 1000 and end at iteration 8000 progress = min(max((step - 1000.) / 8000., 0.), 1.) if step % 200 == 0: sparsity_fraction = zhugupta_func(progress) masks = update_mask(params, sparsity_fraction, module_sparsity) avg_params = ema_update(avg_params, params) print(per_layer_sparsities)