def train_with_bc(make_demonstrations: Callable[[int], Iterator[types.Transition]], networks: networks_lib.FeedForwardNetwork, loss: losses.Loss, num_steps: int = 100000) -> networks_lib.Params: """Trains the given network with BC and returns the params. Args: make_demonstrations: A function (batch_size) -> iterator with demonstrations to be imitated. networks: Network taking (params, obs, is_training, key) as input loss: BC loss to use. num_steps: number of training steps Returns: The trained network params. """ demonstration_iterator = make_demonstrations(256) prefetching_iterator = utils.sharded_prefetch( demonstration_iterator, buffer_size=2, num_threads=jax.local_device_count()) learner = learning.BCLearner(network=networks, random_key=jax.random.PRNGKey(0), loss_fn=loss, prefetching_iterator=prefetching_iterator, optimizer=optax.adam(1e-4), num_sgd_steps_per_step=1) # Train the agent for _ in range(num_steps): learner.step() return learner.get_variables(['policy'])[0]
def __init__(self, networks: mbop_networks.MBOPNetworks, losses: mbop_losses.MBOPLosses, iterator: Iterator[types.Transition], rng_key: jax_types.PRNGKey, logger_fn: LoggerFn, make_world_model_learner: MakeWorldModelLearner, make_policy_prior_learner: MakePolicyPriorLearner, make_n_step_return_learner: MakeNStepReturnLearner, counter: Optional[counting.Counter] = None): """Creates an MBOP learner. Args: networks: One network per model. losses: One loss per model. iterator: An iterator of time-batched transitions used to train the networks. rng_key: Random key. logger_fn: Constructs a logger for a label. make_world_model_learner: Function to create the world model learner. make_policy_prior_learner: Function to create the policy prior learner. make_n_step_return_learner: Function to create the n-step return learner. counter: Parent counter object. """ # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger_fn('', 'steps') # Prepare iterators for the learners, to not split the data (preserve sample # efficiency). sharded_prefetching_dataset = utils.sharded_prefetch(iterator) world_model_iterator, policy_prior_iterator, n_step_return_iterator = ( itertools.tee(sharded_prefetching_dataset, 3)) world_model_key, policy_prior_key, n_step_return_key = jax.random.split( rng_key, 3) self._world_model = make_world_model_learner(logger_fn, self._counter, world_model_key, world_model_iterator, networks.world_model_network, losses.world_model_loss) self._policy_prior = make_policy_prior_learner( logger_fn, self._counter, policy_prior_key, policy_prior_iterator, networks.policy_prior_network, losses.policy_prior_loss) self._n_step_return = make_n_step_return_learner( logger_fn, self._counter, n_step_return_key, n_step_return_iterator, networks.n_step_return_network, losses.n_step_return_loss) # Start recording timestamps after the first learning step to not report # "warmup" time. self._timestamp = None self._learners = { 'world_model': self._world_model, 'policy_prior': self._policy_prior, 'n_step_return': self._n_step_return }
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum( jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner( network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), prefetching_iterator=utils.sharded_prefetch(dataset), num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient( learner, 'policy', device='cpu') evaluator = actors.GenericActor( actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop( environment=environment, actor=evaluator, logger=loggers.TerminalLogger('evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: networks_lib.FeedForwardNetwork, dataset: Iterator[types.Transition], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, *, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec return learning.BCLearner( network=networks, random_key=random_key, loss_fn=self._loss_fn, optimizer=optax.adam(learning_rate=self._config.learning_rate), prefetching_iterator=utils.sharded_prefetch(dataset), num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, loss_has_aux=self._loss_has_aux, logger=logger_fn('learner'), counter=counter)
def __init__( self, obs_spec: specs.Array, unroll_fn: networks_lib.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: counting.Counter = None, logger: loggers.Logger = None, devices: Optional[Sequence[jax.xla.Device]] = None, prefetch_size: int = 2, num_prefetch_threads: Optional[int] = None, ): self._devices = devices or jax.local_devices() # Transform into pure functions. unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True)) initial_state_fn = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)) loss_fn = losses.impala_loss( unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). state = make_initial_state(random_key) self._state = utils.replicate_in_all_devices(state, self._devices) if num_prefetch_threads is None: num_prefetch_threads = len(self._devices) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, devices=devices, num_threads=num_prefetch_threads, ) self._sgd_step = jax.pmap( sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')
def __init__(self, unroll: networks_lib.FeedForwardNetwork, initial_state: networks_lib.FeedForwardNetwork, batch_size: int, random_key: networks_lib.PRNGKey, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, bootstrap_n: int = 5, tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR, clip_rewards: bool = False, max_abs_reward: float = 1., use_core_state: bool = True, prefetch_size: int = 2, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initializes the learner.""" random_key, key_initial_1, key_initial_2 = jax.random.split( random_key, 3) initial_state_params = initial_state.init(key_initial_1, batch_size) initial_state = initial_state.apply(initial_state_params, key_initial_2, batch_size) def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """Computes mean transformed N-step loss for a batch of sequences.""" # Convert sample data to sequence-major format [T, B, ...]. data = utils.batch_to_sequence(sample.data) # Get core state & warm it up on observations for a burn-in period. if use_core_state: # Replay core state. online_state = jax.tree_map(lambda x: x[0], data.extras['core_state']) else: online_state = initial_state target_state = online_state # Maybe burn the core state in. if burn_in_length: burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = unroll.apply(params, key1, burn_obs, online_state) _, target_state = unroll.apply(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. data = jax.tree_map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) online_q, _ = unroll.apply(params, key1, data.observation, online_state) target_q, _ = unroll.apply(target_params, key2, data.observation, target_state) # Get value-selector actions from online Q-values for double Q-learning. selector_actions = jnp.argmax(online_q, axis=-1) # Preprocess discounts & rewards. discounts = (data.discount * discount).astype(online_q.dtype) rewards = data.reward if clip_rewards: rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) rewards = rewards.astype(online_q.dtype) # Get N-step transformed TD error and loss. batch_td_error_fn = jax.vmap(functools.partial( rlax.transformed_n_step_q_learning, n=bootstrap_n, tx_pair=tx_pair), in_axes=1, out_axes=1) # TODO(b/183945808): when this bug is fixed, truncations of actions, # rewards, and discounts will no longer be necessary. batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1], target_q[1:], selector_actions[1:], rewards[:-1], discounts[:-1]) batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0) # Importance weighting. probs = sample.info.probability importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) mean_loss = jnp.mean(importance_weights * batch_loss) # Calculate priorities as a mixture of max and mean sequence errors. abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype) max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0) mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error, axis=0) priorities = (max_priority + mean_priority) return mean_loss, priorities def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value} def update_priorities(keys_and_priorities: Tuple[jnp.ndarray, jnp.ndarray]): keys, priorities = keys_and_priorities keys, priorities = tree.map_structure( # Fetch array and combine device and batch dimensions. lambda x: utils.fetch_devicearray(x).reshape( (-1, ) + x.shape[2:]), (keys, priorities)) replay_client.mutate_priorities( # pytype: disable=attribute-error table=adders.DEFAULT_PRIORITY_TABLE, updates=dict(zip(keys, priorities))) # Internalise components, hyperparameters, logger, counter, and methods. self._replay_client = replay_client self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, time_delta=1.) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities) # Initialise and internalise training state (parameters/optimiser state). random_key, key_init = jax.random.split(random_key) initial_params = unroll.init(key_init, initial_state) opt_state = optimizer.init(initial_params) state = TrainingState(params=initial_params, target_params=initial_params, opt_state=opt_state, steps=jnp.array(0), random_key=random_key) # Replicate parameters. self._state = utils.replicate_in_all_devices(state) # Shard multiple inputs with on-device prefetching. # We split samples in two outputs, the keys which need to be kept on-host # since int64 arrays are not supported in TPUs, and the entire sample # separately so it can be sent to the sgd_step method. def split_sample( sample: reverb.ReplaySample) -> utils.PrefetchingSplit: return utils.PrefetchingSplit(host=sample.info.key, device=sample) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, num_threads=jax.local_device_count(), split_fn=split_sample)
def __init__( self, networks: impala_networks.IMPALANetworks, iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, devices: Optional[Sequence[jax.xla.Device]] = None, prefetch_size: int = 2, num_prefetch_threads: Optional[int] = None, ): local_devices = jax.local_devices() process_id = jax.process_index() logging.info('Learner process id: %s. Devices passed: %s', process_id, devices) logging.info('Learner process id: %s. Local devices from JAX API: %s', process_id, local_devices) self._devices = devices or local_devices self._local_devices = [d for d in self._devices if d in local_devices] loss_fn = losses.impala_loss(networks.unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, 'param_norm': optax.global_norm(new_params), 'param_updates_norm': optax.global_norm(updates), } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" key, key_initial_state = jax.random.split(key) # Note: parameters do not depend on the batch size, so initial_state below # does not need a batch dimension. params = networks.initial_state_init_fn(key_initial_state) # TODO(jferret): as it stands, we do not yet support # training the initial state params. initial_state = networks.initial_state_fn(params) initial_params = networks.unroll_init_fn(key, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). state = make_initial_state(random_key) self._state = utils.replicate_in_all_devices(state, self._local_devices) if num_prefetch_threads is None: num_prefetch_threads = len(self._local_devices) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, devices=self._local_devices, num_threads=num_prefetch_threads, ) self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')