def make_networks( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (256, 256) ) -> SACNetworks: """Creates networks used by the agent.""" num_dimensions = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs): network = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes), w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu, activate_final=True), networks_lib.NormalTanhDistribution(num_dimensions), ]) return network(obs) def _critic_fn(obs, action): network1 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) network2 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) input_ = jnp.concatenate([obs, action], axis=-1) value1 = network1(input_) value2 = network2(input_) return jnp.concatenate([value1, value2], axis=-1) policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return SACNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply), q_network=networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: SimpleActorCoreRecurrentState[RecurrentState]): # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. rng = state.rng rng, policy_rng = jax.random.split(rng) observation = utils.add_batch_dim(observation) recurrent_state = utils.add_batch_dim(state.recurrent_state) action, new_recurrent_state = utils.squeeze_batch_dim( recurrent_policy(params, policy_rng, observation, recurrent_state)) return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)
def critic_mean( critic_params: networks_lib.Params, observation: types.NestedArray, action: types.NestedArray, ) -> jnp.ndarray: # We add batch dimension to make sure batch concat in critic_network # works correctly. observation = utils.add_batch_dim(observation) action = utils.add_batch_dim(action) # Computes the mean action-value estimate. logits, atoms = critic_network.apply(critic_params, observation, action) logits = utils.squeeze_batch_dim(logits) probabilities = jax.nn.softmax(logits) return jnp.sum(probabilities * atoms, axis=-1)
def make_discriminator( environment_spec: specs.EnvironmentSpec, discriminator_transformed: hk.TransformedWithState, logpi_fn: Optional[Callable[ [networks_lib.Params, networks_lib.Observation, networks_lib.Action], jnp.ndarray]] = None ) -> networks_lib.FeedForwardNetwork: """Creates the discriminator network. Args: environment_spec: Environment spec discriminator_transformed: Haiku transformed of the discriminator. logpi_fn: If the policy logpi function is provided, its output will be removed from the discriminator logit. Returns: The network. """ def apply_fn(params: hk.Params, policy_params: networks_lib.Params, state: hk.State, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: output, state = discriminator_transformed.apply( params, state, transitions.observation, transitions.action, transitions.next_observation, is_training, rng) if logpi_fn is not None: logpi = logpi_fn(policy_params, transitions.observation, transitions.action) # Quick Maths: # D = exp(output)/(exp(output) + pi(a|s)) # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) # logit(D) = output - logpi return output - logpi, state return output, state dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(environment_spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) return networks_lib.FeedForwardNetwork( # pylint: disable=g-long-lambda init=lambda rng: discriminator_transformed.init( rng, dummy_obs, dummy_actions, dummy_obs, False, rng), apply=apply_fn)
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> D4PGNetworks: """Creates networks used by the agent.""" action_spec = spec.actions num_dimensions = np.prod(action_spec.shape, dtype=int) critic_atoms = jnp.linspace(vmin, vmax, num_atoms) def _actor_fn(obs): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), networks_lib.NearZeroInitializedLinear(num_dimensions), networks_lib.TanhToSpec(action_spec), ]) return network(obs) def _critic_fn(obs, action): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP( layer_sizes=[*critic_layer_sizes, num_atoms]), ]) value = network([obs, action]) return value, critic_atoms policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return D4PGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda rng: policy.init(rng, dummy_obs), policy.apply), critic_network=networks_lib.FeedForwardNetwork( lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply))
def test_step(self): simple_spec = specs.Array(shape=(), dtype=float) spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, simple_spec) discriminator = _make_discriminator(spec) ail_network = ail_networks.AILNetworks(discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None) loss = losses.gail_loss() optimizer = optax.adam(.01) step = jax.jit( functools.partial(ail_learning.ail_update_step, optimizer=optimizer, ail_network=ail_network, loss_fn=loss)) zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0., 0., np.array([0.])) zero_transition = utils.add_batch_dim(zero_transition) one_transition = types.Transition(np.array([1.]), np.array([0.]), 0., 0., np.array([0.])) one_transition = utils.add_batch_dim(one_transition) key = jax.random.PRNGKey(0) discriminator_params, discriminator_state = discriminator.init(key) state = ail_learning.DiscriminatorTrainingState( optimizer_state=optimizer.init(discriminator_params), discriminator_params=discriminator_params, discriminator_state=discriminator_state, policy_params=None, key=key, steps=0, ) expected_loss = [1.062, 1.057, 1.052] for i in range(3): state, loss = step(state, (one_transition, zero_transition)) self.assertAlmostEqual(loss['total_loss'], expected_loss[i], places=3)
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Tuple[int, ...] = (256, 256), critic_layer_sizes: Tuple[int, ...] = (256, 256), activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, ) -> CRRNetworks: """Creates networks used by the agent.""" num_actions = np.prod(spec.actions.shape, dtype=int) # Create dummy observations and actions to create network parameters. dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: network = hk.Sequential([ hk.nets.MLP( list(policy_layer_sizes), w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation, activate_final=True), networks_lib.NormalTanhDistribution(num_actions), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_policy_fn)) policy_network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) def _critic_fn(obs, action): network = hk.Sequential([ hk.nets.MLP( list(critic_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation), ]) data = jnp.concatenate([obs, action], axis=-1) return network(data) critic = hk.without_apply_rng(hk.transform(_critic_fn)) critic_network = networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply) return CRRNetworks( policy_network=policy_network, critic_network=critic_network, log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def make_networks( spec, build_actor_fn=build_standard_actor_fn, img_encoder_fn=None, ): """Creates networks used by the agent.""" # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) if isinstance(spec.actions, specs.DiscreteArray): num_dimensions = spec.actions.num_values # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions) else: num_dimensions = np.prod(spec.actions.shape, dtype=int) _actor_fn = build_actor_fn(num_dimensions) if img_encoder_fn is not None: img_encoder = hk.without_apply_rng( hk.transform(img_encoder_fn, apply_rng=True)) key = jax.random.PRNGKey(seed=42) temp_encoder_params = img_encoder.init(key, dummy_obs['state_image']) dummy_hidden = img_encoder.apply(temp_encoder_params, dummy_obs['state_image']) img_encoder_network = networks_lib.FeedForwardNetwork( lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply) dummy_policy_input = dict( state_image=dummy_hidden, state_dense=dummy_obs['state_dense'], ) else: img_encoder_fn = None dummy_policy_input = dummy_obs img_encoder_network = None policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) return BCNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_policy_input), policy.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode(), img_encoder=img_encoder_network, )
def make_networks( spec: specs.EnvironmentSpec, discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" if discrete_actions: final_layer_size = spec.actions.num_values else: final_layer_size = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs, is_training=False, key=None): # is_training and key allows to defined train/test dependant modules # like dropout. del is_training del key if discrete_actions: network = hk.nets.MLP([64, 64, final_layer_size]) else: network = hk.Sequential([ networks_lib.LayerNormMLP([64, 64], activate_final=True), networks_lib.NormalTanhDistribution(final_layer_size), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) # Create dummy observations and actions to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def make_discrete_networks( environment_spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (512, ), use_conv: bool = True, ) -> PPONetworks: """Creates networks used by the agent for discrete action environments. Args: environment_spec: Environment spec used to define number of actions. hidden_layer_sizes: Network definition. use_conv: Whether to use a conv or MLP feature extractor. Returns: PPONetworks """ num_actions = environment_spec.actions.num_values def forward_fn(inputs): layers = [] if use_conv: layers.extend([networks_lib.AtariTorso()]) layers.extend([ hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu), networks_lib.CategoricalValueHead(num_values=num_actions) ]) policy_value_network = hk.Sequential(layers) return policy_value_network(inputs) forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. network = networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) # Create PPONetworks to add functionality required by the agent. return make_ppo_networks(network)
def make_haiku_networks( env_spec: specs.EnvironmentSpec, forward_fn: Any, initial_state_fn: Any, unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]: """Builds functional impala network from recurrent model definitions.""" # Make networks purely functional. forward_hk = hk.without_apply_rng(hk.transform(forward_fn)) initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn)) unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn)) # Define networks init functions. def initial_state_init_fn(rng: networks_lib.PRNGKey) -> hk.Params: return initial_state_hk.init(rng) # Note: batch axis is not needed for the actors. dummy_obs = utils.zeros_like(env_spec.observations) dummy_obs_sequence = utils.add_batch_dim(dummy_obs) def unroll_init_fn(rng: networks_lib.PRNGKey, initial_state: types.RecurrentState) -> hk.Params: return unroll_hk.init(rng, dummy_obs_sequence, initial_state) return IMPALANetworks(forward_fn=forward_hk.apply, unroll_init_fn=unroll_init_fn, unroll_fn=unroll_hk.apply, initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_hk.apply)
def make_networks( env_spec: specs.EnvironmentSpec, forward_fn: Any, initial_state_fn: Any, unroll_fn: Any, batch_size) -> R2D2Networks: """Builds functional r2d2 network from recurrent model definitions.""" # Make networks purely functional. forward_hk = hk.transform(forward_fn) initial_state_hk = hk.transform(initial_state_fn) unroll_hk = hk.transform(unroll_fn) # Define networks init functions. def initial_state_init_fn(rng, batch_size): return initial_state_hk.init(rng, batch_size) dummy_obs_batch = utils.tile_nested( utils.zeros_like(env_spec.observations), batch_size) dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch) def unroll_init_fn(rng, initial_state): return unroll_hk.init(rng, dummy_obs_sequence, initial_state) # Make FeedForwardNetworks. forward = networks_lib.FeedForwardNetwork( init=forward_hk.init, apply=forward_hk.apply) unroll = networks_lib.FeedForwardNetwork( init=unroll_init_fn, apply=unroll_hk.apply) initial_state = networks_lib.FeedForwardNetwork( init=initial_state_init_fn, apply=initial_state_hk.apply) return R2D2Networks( forward=forward, unroll=unroll, initial_state=initial_state)
def make_ensemble_policy_prior( policy_prior_network: mbop_networks.PolicyPriorNetwork, spec: specs.EnvironmentSpec, use_round_robin: bool = True) -> PolicyPrior: """Creates an ensemble policy prior from its network. Args: policy_prior_network: The policy prior network. spec: Environment specification. use_round_robin: Whether to use round robin or mean to calculate the policy prior over the ensemble members. Returns: A policy prior. """ def _policy_prior(params: networks.Params, key: networks.PRNGKey, observation_t: networks.Observation, action_tm1: networks.Action) -> networks.Action: # Regressor policies are deterministic. del key apply_fn = ( ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean) return apply_fn( policy_prior_network.apply, params, observation_t=observation_t, action_tm1=action_tm1) dummy_action = utils.zeros_like(spec.actions) dummy_action = utils.add_batch_dim(dummy_action) return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action)
def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def select_action(self, observation: types.NestedArray) -> types.NestedArray: key = next(self._rng) # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) action = self._policy(self._client.params, key, observation) return utils.to_numpy_squeeze(action)
def batched_policy( params: network_types.Params, key: RNGKey, observation: Observation ) -> Union[Action, Tuple[Action, types.NestedArray]]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) output = policy(params, key, observation) return utils.squeeze_batch_dim(output)
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: PRNGKey): rng = state rng1, rng2 = jax.random.split(rng) observation = utils.add_batch_dim(observation) action = utils.squeeze_batch_dim(policy(params, rng1, observation)) return action, rng2
def make_haiku_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates Haiku networks to be used by the agent.""" num_actions = spec.actions.num_values def forward_fn(inputs): policy_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), networks_lib.CategoricalHead(num_actions) ]) value_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1) ]) action_distribution = policy_network(inputs) value = value_network(inputs) return (action_distribution, value) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Make network purely functional network_hk = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) network = networks_lib.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state)
def unvectorized_select_action( params: networks_lib.Params, observations: networks_lib.Observation, state: State, ) -> Tuple[networks_lib.Action, State]: observations, state = utils.add_batch_dim((observations, state)) actions, state = actor_core.select_action(params, observations, state) return utils.squeeze_batch_dim((actions, state))
def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey, observation: networks_lib.Observation, epsilon: Epsilon ) -> networks_lib.Action: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) action_values = network.apply(params, observation) action_values = utils.squeeze_batch_dim(action_values) return rlax.epsilon_greedy(epsilon).sample(key, action_values)
def select_action(params: networks_lib.Params, observation: networks_lib.Observation, state: SimpleActorCoreStateWithExtras): rng = state.rng rng1, rng2 = jax.random.split(rng) observation = utils.add_batch_dim(observation) action, extras = utils.squeeze_batch_dim( policy(params, rng1, observation)) return action, SimpleActorCoreStateWithExtras(rng2, extras)
def make_network_from_module( module: hk.Transformed, spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork: """Creates a network with dummy init arguments using the specified module. Args: module: Module that expects one batch axis and one features axis for its inputs. spec: EnvironmentSpec shapes to derive dummy inputs. Returns: FeedForwardNetwork whose `init` method only takes a random key, and `apply` takes an observation and action and produces an output. """ dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) return networks.FeedForwardNetwork( lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
def batched_recurrent_policy( params: network_types.Params, key: RNGKey, observation: Observation, core_state: RecurrentState ) -> Tuple[Union[Action, Tuple[Action, types.NestedArray]], RecurrentState]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) output, new_state = recurrent_policy(params, key, observation, core_state) return output, new_state
def batched_policy( params: network_lib.Params, key: network_lib.PRNGKey, observation: network_lib.Observation ) -> Tuple[Union[network_lib.Action, Tuple[ network_lib.Action, types.NestedArray]], network_lib.PRNGKey]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. key, key2 = jax.random.split(key) observation = utils.add_batch_dim(observation) output = policy(params, key2, observation) return utils.squeeze_batch_dim(output), key
def select_action(self, observation: types.NestedArray) -> types.NestedArray: action, new_state = self._recurrent_policy( self._client.params, key=next(self._rng), observation=utils.add_batch_dim(observation), core_state=self._state) self._prev_state = self._state # Keep previous state to save in replay. self._state = new_state # Keep new state for next policy call. return utils.to_numpy_squeeze(action)
def batched_policy( params, observation, discrete_action, ): observation = utils.add_batch_dim(observation) action = utils.squeeze_batch_dim( policy(params, observation, discrete_action)) return action
def __init__(self, network: hk.Transformed, obs_spec: specs.Array, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, dataset: tf.data.Dataset, loss_fn: LossFn = _sparse_categorical_cross_entropy, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: # Pull out the data needed for updates. o_tm1, a_tm1, r_t, d_t, o_t = sample.data del r_t, d_t, o_t logits = network.apply(params, o_tm1) return jnp.mean(loss_fn(a_tm1, logits)) def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]: """Do a step of SGD.""" grad_fn = jax.value_and_grad(loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 new_state = TrainingState( params=new_params, opt_state=new_opt_state, steps=steps) # Compute the global norm of the gradients for logging. global_gradient_norm = optax.global_norm(gradients) fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm} return new_state, fetches self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState( params=initial_params, opt_state=initial_opt_state, steps=0) self._sgd_step = jax.jit(sgd_step)
def batched_recurrent_policy( params: network_lib.Params, key: network_lib.PRNGKey, observation: network_lib.Observation, core_state: RecurrentState ) -> Tuple[Union[network_lib.Action, Tuple[ network_lib.Action, types.NestedArray]], RecurrentState, network_lib.PRNGKey]: # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) key, key2 = jax.random.split(key) output, new_state = recurrent_policy(params, key2, observation, core_state) return output, new_state, key