def get_dummy_batched_obs_and_actions( environment_spec): """Generates dummy batched (batch_size=1) obs and actions.""" dummy_observation = utils.tile_nested( utils.zeros_like(environment_spec.observations), 1) dummy_action = utils.tile_nested( utils.zeros_like(environment_spec.actions), 1) return dummy_observation, dummy_action
def make_networks( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (256, 256) ) -> SACNetworks: """Creates networks used by the agent.""" num_dimensions = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs): network = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes), w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu, activate_final=True), networks_lib.NormalTanhDistribution(num_dimensions), ]) return network(obs) def _critic_fn(obs, action): network1 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) network2 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) input_ = jnp.concatenate([obs, action], axis=-1) value1 = network1(input_) value2 = network2(input_) return jnp.concatenate([value1, value2], axis=-1) policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return SACNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply), q_network=networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def init_state(nest: types.Nest) -> RunningStatisticsState: """Initializes the running statistics for the given nested structure.""" dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32 return RunningStatisticsState( count=0., mean=utils.zeros_like(nest, dtype=dtype), summed_variance=utils.zeros_like(nest, dtype=dtype), # Initialize with ones to make sure normalization works correctly # in the initial state. std=utils.ones_like(nest, dtype=dtype))
def make_discriminator( environment_spec: specs.EnvironmentSpec, discriminator_transformed: hk.TransformedWithState, logpi_fn: Optional[Callable[ [networks_lib.Params, networks_lib.Observation, networks_lib.Action], jnp.ndarray]] = None ) -> networks_lib.FeedForwardNetwork: """Creates the discriminator network. Args: environment_spec: Environment spec discriminator_transformed: Haiku transformed of the discriminator. logpi_fn: If the policy logpi function is provided, its output will be removed from the discriminator logit. Returns: The network. """ def apply_fn(params: hk.Params, policy_params: networks_lib.Params, state: hk.State, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: output, state = discriminator_transformed.apply( params, state, transitions.observation, transitions.action, transitions.next_observation, is_training, rng) if logpi_fn is not None: logpi = logpi_fn(policy_params, transitions.observation, transitions.action) # Quick Maths: # D = exp(output)/(exp(output) + pi(a|s)) # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) # logit(D) = output - logpi return output - logpi, state return output, state dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(environment_spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) return networks_lib.FeedForwardNetwork( # pylint: disable=g-long-lambda init=lambda rng: discriminator_transformed.init( rng, dummy_obs, dummy_actions, dummy_obs, False, rng), apply=apply_fn)
def make_networks( spec, build_actor_fn=build_standard_actor_fn, img_encoder_fn=None, ): """Creates networks used by the agent.""" # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) if isinstance(spec.actions, specs.DiscreteArray): num_dimensions = spec.actions.num_values # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions) else: num_dimensions = np.prod(spec.actions.shape, dtype=int) _actor_fn = build_actor_fn(num_dimensions) if img_encoder_fn is not None: img_encoder = hk.without_apply_rng( hk.transform(img_encoder_fn, apply_rng=True)) key = jax.random.PRNGKey(seed=42) temp_encoder_params = img_encoder.init(key, dummy_obs['state_image']) dummy_hidden = img_encoder.apply(temp_encoder_params, dummy_obs['state_image']) img_encoder_network = networks_lib.FeedForwardNetwork( lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply) dummy_policy_input = dict( state_image=dummy_hidden, state_dense=dummy_obs['state_dense'], ) else: img_encoder_fn = None dummy_policy_input = dummy_obs img_encoder_network = None policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) return BCNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_policy_input), policy.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode(), img_encoder=img_encoder_network, )
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> D4PGNetworks: """Creates networks used by the agent.""" action_spec = spec.actions num_dimensions = np.prod(action_spec.shape, dtype=int) critic_atoms = jnp.linspace(vmin, vmax, num_atoms) def _actor_fn(obs): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), networks_lib.NearZeroInitializedLinear(num_dimensions), networks_lib.TanhToSpec(action_spec), ]) return network(obs) def _critic_fn(obs, action): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP( layer_sizes=[*critic_layer_sizes, num_atoms]), ]) value = network([obs, action]) return value, critic_atoms policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return D4PGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda rng: policy.init(rng, dummy_obs), policy.apply), critic_network=networks_lib.FeedForwardNetwork( lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply))
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Tuple[int, ...] = (256, 256), critic_layer_sizes: Tuple[int, ...] = (256, 256), activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, ) -> CRRNetworks: """Creates networks used by the agent.""" num_actions = np.prod(spec.actions.shape, dtype=int) # Create dummy observations and actions to create network parameters. dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: network = hk.Sequential([ hk.nets.MLP( list(policy_layer_sizes), w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation, activate_final=True), networks_lib.NormalTanhDistribution(num_actions), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_policy_fn)) policy_network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) def _critic_fn(obs, action): network = hk.Sequential([ hk.nets.MLP( list(critic_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation), ]) data = jnp.concatenate([obs, action], axis=-1) return network(data) critic = hk.without_apply_rng(hk.transform(_critic_fn)) critic_network = networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply) return CRRNetworks( policy_network=policy_network, critic_network=critic_network, log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def make_networks( env_spec: specs.EnvironmentSpec, forward_fn: Any, initial_state_fn: Any, unroll_fn: Any, batch_size) -> R2D2Networks: """Builds functional r2d2 network from recurrent model definitions.""" # Make networks purely functional. forward_hk = hk.transform(forward_fn) initial_state_hk = hk.transform(initial_state_fn) unroll_hk = hk.transform(unroll_fn) # Define networks init functions. def initial_state_init_fn(rng, batch_size): return initial_state_hk.init(rng, batch_size) dummy_obs_batch = utils.tile_nested( utils.zeros_like(env_spec.observations), batch_size) dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch) def unroll_init_fn(rng, initial_state): return unroll_hk.init(rng, dummy_obs_sequence, initial_state) # Make FeedForwardNetworks. forward = networks_lib.FeedForwardNetwork( init=forward_hk.init, apply=forward_hk.apply) unroll = networks_lib.FeedForwardNetwork( init=unroll_init_fn, apply=unroll_hk.apply) initial_state = networks_lib.FeedForwardNetwork( init=initial_state_init_fn, apply=initial_state_hk.apply) return R2D2Networks( forward=forward, unroll=unroll, initial_state=initial_state)
def make_discrete_networks( environment_spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (512, ), use_conv: bool = True, ) -> PPONetworks: """Creates networks used by the agent for discrete action environments. Args: environment_spec: Environment spec used to define number of actions. hidden_layer_sizes: Network definition. use_conv: Whether to use a conv or MLP feature extractor. Returns: PPONetworks """ num_actions = environment_spec.actions.num_values def forward_fn(inputs): layers = [] if use_conv: layers.extend([networks_lib.AtariTorso()]) layers.extend([ hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu), networks_lib.CategoricalValueHead(num_values=num_actions) ]) policy_value_network = hk.Sequential(layers) return policy_value_network(inputs) forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. network = networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) # Create PPONetworks to add functionality required by the agent. return make_ppo_networks(network)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Make network purely functional network_hk = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) network = networks_lib.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def make_haiku_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates Haiku networks to be used by the agent.""" num_actions = spec.actions.num_values def forward_fn(inputs): policy_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), networks_lib.CategoricalHead(num_actions) ]) value_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1) ]) action_distribution = policy_network(inputs) value = value_network(inputs) return (action_distribution, value) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
def make_ensemble_policy_prior( policy_prior_network: mbop_networks.PolicyPriorNetwork, spec: specs.EnvironmentSpec, use_round_robin: bool = True) -> PolicyPrior: """Creates an ensemble policy prior from its network. Args: policy_prior_network: The policy prior network. spec: Environment specification. use_round_robin: Whether to use round robin or mean to calculate the policy prior over the ensemble members. Returns: A policy prior. """ def _policy_prior(params: networks.Params, key: networks.PRNGKey, observation_t: networks.Observation, action_tm1: networks.Action) -> networks.Action: # Regressor policies are deterministic. del key apply_fn = ( ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean) return apply_fn( policy_prior_network.apply, params, observation_t=observation_t, action_tm1=action_tm1) dummy_action = utils.zeros_like(spec.actions) dummy_action = utils.add_batch_dim(dummy_action) return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action)
def make_haiku_networks( env_spec: specs.EnvironmentSpec, forward_fn: Any, initial_state_fn: Any, unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]: """Builds functional impala network from recurrent model definitions.""" # Make networks purely functional. forward_hk = hk.without_apply_rng(hk.transform(forward_fn)) initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn)) unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn)) # Define networks init functions. def initial_state_init_fn(rng: networks_lib.PRNGKey) -> hk.Params: return initial_state_hk.init(rng) # Note: batch axis is not needed for the actors. dummy_obs = utils.zeros_like(env_spec.observations) dummy_obs_sequence = utils.add_batch_dim(dummy_obs) def unroll_init_fn(rng: networks_lib.PRNGKey, initial_state: types.RecurrentState) -> hk.Params: return unroll_hk.init(rng, dummy_obs_sequence, initial_state) return IMPALANetworks(forward_fn=forward_hk.apply, unroll_init_fn=unroll_init_fn, unroll_fn=unroll_hk.apply, initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_hk.apply)
def make_networks( spec: specs.EnvironmentSpec, discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" if discrete_actions: final_layer_size = spec.actions.num_values else: final_layer_size = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs, is_training=False, key=None): # is_training and key allows to defined train/test dependant modules # like dropout. del is_training del key if discrete_actions: network = hk.nets.MLP([64, 64, final_layer_size]) else: network = hk.Sequential([ networks_lib.LayerNormMLP([64, 64], activate_final=True), networks_lib.NormalTanhDistribution(final_layer_size), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) # Create dummy observations and actions to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def _update_spec(self, base_spec): dummy_obs = utils.zeros_like(base_spec) emb, _ = self._distance_fn(dummy_obs['state'], dummy_obs['goal']) full_spec = dict(base_spec) full_spec['embeddings'] = (dm_env_specs.Array(shape=emb.shape, dtype=emb.dtype)) return full_spec
def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state)
def make_network_from_module( module: hk.Transformed, spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork: """Creates a network with dummy init arguments using the specified module. Args: module: Module that expects one batch axis and one features axis for its inputs. spec: EnvironmentSpec shapes to derive dummy inputs. Returns: FeedForwardNetwork whose `init` method only takes a random key, and `apply` takes an observation and action and produces an output. """ dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) return networks.FeedForwardNetwork( lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
def __init__(self, network: hk.Transformed, obs_spec: specs.Array, optimizer: optax.GradientTransformation, rng: hk.PRNGSequence, dataset: tf.data.Dataset, loss_fn: LossFn = _sparse_categorical_cross_entropy, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: # Pull out the data needed for updates. o_tm1, a_tm1, r_t, d_t, o_t = sample.data del r_t, d_t, o_t logits = network.apply(params, o_tm1) return jnp.mean(loss_fn(a_tm1, logits)) def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]: """Do a step of SGD.""" grad_fn = jax.value_and_grad(loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 new_state = TrainingState( params=new_params, opt_state=new_opt_state, steps=steps) # Compute the global norm of the gradients for logging. global_gradient_norm = optax.global_norm(gradients) fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm} return new_state, fetches self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState( params=initial_params, opt_state=initial_opt_state, steps=0) self._sgd_step = jax.jit(sgd_step)
def make_networks( spec, actor_fn_build_fn = build_mlp_actor_fn, actor_hidden_layer_sizes = (256, 256), critic_fn_build_fn = build_hk_batch_ensemble_mlp_critic_fn, # critic_fn_build_fn: Callable = build_mlp_critic_fn, critic_hidden_layer_sizes = (256, 256), use_double_q = False, ): """Creates networks used by the agent.""" num_dimensions = np.prod(spec.actions.shape, dtype=int) _actor_fn = actor_fn_build_fn(num_dimensions, actor_hidden_layer_sizes) policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) _critic_fn = critic_fn_build_fn(critic_hidden_layer_sizes, use_double_q) critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True)) critic_ensemble_init = ensemble_utils.transform_init_for_ensemble(critic.init, init_same=False) critic_ensemble_member_apply = ensemble_utils.transform_apply_for_ensemble_member(critic.apply) critic_same_batch_ensemble_apply = ensemble_utils.build_same_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2) critic_diff_batch_ensemble_apply = ensemble_utils.build_different_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return BatchEnsembleMSGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply), q_ensemble_init=lambda ensemble_size, key: critic_ensemble_init(ensemble_size, key, dummy_obs, dummy_action), q_ensemble_member_apply=critic_ensemble_member_apply, q_ensemble_same_batch_apply=critic_same_batch_ensemble_apply, q_ensemble_different_batch_apply=critic_diff_batch_ensemble_apply, log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def default_models_to_snapshot(networks: SACNetworks, spec: specs.EnvironmentSpec): """Defines default models to be snapshotted.""" dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.zeros_like(spec.actions) dummy_key = jax.random.PRNGKey(0) def q_network(source: core.VariableSource) -> types.ModelToSnapshot: params = source.get_variables(['critic'])[0] return types.ModelToSnapshot(networks.q_network.apply, params, { 'obs': dummy_obs, 'action': dummy_action }) def default_training_actor( source: core.VariableSource) -> types.ModelToSnapshot: params = source.get_variables(['policy'])[0] return types.ModelToSnapshot(apply_policy_and_sample(networks, False), params, { 'key': dummy_key, 'obs': dummy_obs }) def default_eval_actor( source: core.VariableSource) -> types.ModelToSnapshot: params = source.get_variables(['policy'])[0] return types.ModelToSnapshot(apply_policy_and_sample(networks, True), params, { 'key': dummy_key, 'obs': dummy_obs }) return { 'q_network': q_network, 'default_training_actor': default_training_actor, 'default_eval_actor': default_eval_actor, }
def make_ppo_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates Haiku networks to be used by the agent.""" num_actions = spec.actions.num_values forward_fn = functools.partial(ppo_forward_fn, num_actions=num_actions) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
def make_dqn_atari_network( environment_spec: specs.EnvironmentSpec ) -> networks.FeedForwardNetwork: """Creates networks for training DQN on Atari.""" def network(inputs): model = hk.Sequential([ networks.AtariTorso(), hk.nets.MLP([512, environment_spec.actions.num_values]), ]) return model(inputs) network_hk = hk.without_apply_rng(hk.transform(network)) obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) return networks.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply)
def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. key, key_initial_state = jax.random.split(key) params = initial_state_init_fn(key_initial_state) # TODO(jferret): as it stands, we do not yet support # training the initial state params. initial_state = initial_state_fn(params) initial_params = unroll_init_fn(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state)
def test_recurrent(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) if has_extras: return (actions, (action_values, )), core_state else: return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor = actors.RecurrentActor(policy, jax.random.PRNGKey(1), initial_state, variable_client, has_extras=has_extras) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def make_networks( spec: specs.EnvironmentSpec, direct_rl_networks: DirectRLNetworks, layer_sizes: Tuple[int, ...] = (256, 256), intrinsic_reward_coefficient: float = 1.0, extrinsic_reward_coefficient: float = 0.0, ) -> RNDNetworks[DirectRLNetworks]: """Creates networks used by the agent and returns RNDNetworks. Args: spec: Environment spec. direct_rl_networks: Networks used by a direct rl algorithm. layer_sizes: Layer sizes. intrinsic_reward_coefficient: Multiplier on intrinsic reward. extrinsic_reward_coefficient: Multiplier on extrinsic reward. Returns: The RND networks. """ def _rnd_fn(obs, act): # RND does not use the action but other variants like RED do. del act network = networks_lib.LayerNormMLP(list(layer_sizes)) return network(obs) target = hk.without_apply_rng(hk.transform(_rnd_fn)) predictor = hk.without_apply_rng(hk.transform(_rnd_fn)) # Create dummy observations and actions to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) return RNDNetworks( target=networks_lib.FeedForwardNetwork( lambda key: target.init(key, dummy_obs, ()), target.apply), predictor=networks_lib.FeedForwardNetwork( lambda key: predictor.init(key, dummy_obs, ()), predictor.apply), direct_rl_networks=direct_rl_networks, get_reward=functools.partial( rnd_reward_fn, intrinsic_reward_coefficient=intrinsic_reward_coefficient, extrinsic_reward_coefficient=extrinsic_reward_coefficient))
def make_flax_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates FLAX networks to be used by the agent.""" num_actions = spec.actions.num_values class MLP(flax.deprecated.nn.Module): """MLP module.""" def apply(self, data: jnp.ndarray, layer_sizes: Tuple[int], activation: Callable[[jnp.ndarray], jnp.ndarray] = flax.deprecated.nn.relu, kernel_init: object = jax.nn.initializers.lecun_uniform(), activate_final: bool = False, bias: bool = True): hidden = data for i, hidden_size in enumerate(layer_sizes): hidden = flax.deprecated.nn.Dense(hidden, hidden_size, name=f'hidden_{i}', kernel_init=kernel_init, bias=bias) if i != len(layer_sizes) - 1 or activate_final: hidden = activation(hidden) return hidden class PolicyValueModule(flax.deprecated.nn.Module): """MLP module.""" def apply(self, inputs: jnp.ndarray): inputs = utils.batch_concat(inputs) logits = MLP(inputs, [64, 64, num_actions]) value = MLP(inputs, [64, 64, 1]) value = jnp.squeeze(value, axis=-1) return tfd.Categorical(logits=logits), value dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: PolicyValueModule.init(rng, dummy_obs)[1], PolicyValueModule.call)
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def make_network( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" num_actions = spec.actions.num_values def actor_fn(obs, is_training=True, key=None): # is_training and key allows to utilize train/test dependant modules # like dropout. del is_training del key mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])]) return mlp(obs) policy = hk.without_apply_rng(hk.transform(actor_fn)) # Create dummy observations to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @hk.transform def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state) @hk.transform def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState ) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) return jnp.argmax(action_values, axis=-1), core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.RecurrentActor( policy, hk.PRNGSequence(1), initial_state, variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)