def make_networks( env_spec: specs.EnvironmentSpec, forward_fn: Any, initial_state_fn: Any, unroll_fn: Any, batch_size) -> R2D2Networks: """Builds functional r2d2 network from recurrent model definitions.""" # Make networks purely functional. forward_hk = hk.transform(forward_fn) initial_state_hk = hk.transform(initial_state_fn) unroll_hk = hk.transform(unroll_fn) # Define networks init functions. def initial_state_init_fn(rng, batch_size): return initial_state_hk.init(rng, batch_size) dummy_obs_batch = utils.tile_nested( utils.zeros_like(env_spec.observations), batch_size) dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch) def unroll_init_fn(rng, initial_state): return unroll_hk.init(rng, dummy_obs_sequence, initial_state) # Make FeedForwardNetworks. forward = networks_lib.FeedForwardNetwork( init=forward_hk.init, apply=forward_hk.apply) unroll = networks_lib.FeedForwardNetwork( init=unroll_init_fn, apply=unroll_hk.apply) initial_state = networks_lib.FeedForwardNetwork( init=initial_state_init_fn, apply=initial_state_hk.apply) return R2D2Networks( forward=forward, unroll=unroll, initial_state=initial_state)
def make_networks( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (256, 256) ) -> SACNetworks: """Creates networks used by the agent.""" num_dimensions = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs): network = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes), w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu, activate_final=True), networks_lib.NormalTanhDistribution(num_dimensions), ]) return network(obs) def _critic_fn(obs, action): network1 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) network2 = hk.Sequential([ hk.nets.MLP(list(hidden_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling( 1.0, 'fan_in', 'uniform'), activation=jax.nn.relu), ]) input_ = jnp.concatenate([obs, action], axis=-1) value1 = network1(input_) value2 = network2(input_) return jnp.concatenate([value1, value2], axis=-1) policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return SACNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply), q_network=networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Tuple[int, ...] = (256, 256), critic_layer_sizes: Tuple[int, ...] = (256, 256), activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, ) -> CRRNetworks: """Creates networks used by the agent.""" num_actions = np.prod(spec.actions.shape, dtype=int) # Create dummy observations and actions to create network parameters. dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray: network = hk.Sequential([ hk.nets.MLP( list(policy_layer_sizes), w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation, activate_final=True), networks_lib.NormalTanhDistribution(num_actions), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_policy_fn)) policy_network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) def _critic_fn(obs, action): network = hk.Sequential([ hk.nets.MLP( list(critic_layer_sizes) + [1], w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), activation=activation), ]) data = jnp.concatenate([obs, action], axis=-1) return network(data) critic = hk.without_apply_rng(hk.transform(_critic_fn)) critic_network = networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply) return CRRNetworks( policy_network=policy_network, critic_network=critic_network, log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def make_networks( spec, build_actor_fn=build_standard_actor_fn, img_encoder_fn=None, ): """Creates networks used by the agent.""" # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) if isinstance(spec.actions, specs.DiscreteArray): num_dimensions = spec.actions.num_values # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions) else: num_dimensions = np.prod(spec.actions.shape, dtype=int) _actor_fn = build_actor_fn(num_dimensions) if img_encoder_fn is not None: img_encoder = hk.without_apply_rng( hk.transform(img_encoder_fn, apply_rng=True)) key = jax.random.PRNGKey(seed=42) temp_encoder_params = img_encoder.init(key, dummy_obs['state_image']) dummy_hidden = img_encoder.apply(temp_encoder_params, dummy_obs['state_image']) img_encoder_network = networks_lib.FeedForwardNetwork( lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply) dummy_policy_input = dict( state_image=dummy_hidden, state_dense=dummy_obs['state_dense'], ) else: img_encoder_fn = None dummy_policy_input = dummy_obs img_encoder_network = None policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) return BCNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_policy_input), policy.apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode(), img_encoder=img_encoder_network, )
def make_networks( spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> D4PGNetworks: """Creates networks used by the agent.""" action_spec = spec.actions num_dimensions = np.prod(action_spec.shape, dtype=int) critic_atoms = jnp.linspace(vmin, vmax, num_atoms) def _actor_fn(obs): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True), networks_lib.NearZeroInitializedLinear(num_dimensions), networks_lib.TanhToSpec(action_spec), ]) return network(obs) def _critic_fn(obs, action): network = hk.Sequential([ utils.batch_concat, networks_lib.LayerNormMLP( layer_sizes=[*critic_layer_sizes, num_atoms]), ]) value = network([obs, action]) return value, critic_atoms policy = hk.without_apply_rng(hk.transform(_actor_fn)) critic = hk.without_apply_rng(hk.transform(_critic_fn)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return D4PGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda rng: policy.init(rng, dummy_obs), policy.apply), critic_network=networks_lib.FeedForwardNetwork( lambda rng: critic.init(rng, dummy_obs, dummy_action), critic.apply))
def make_networks( spec: specs.EnvironmentSpec, discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" if discrete_actions: final_layer_size = spec.actions.num_values else: final_layer_size = np.prod(spec.actions.shape, dtype=int) def _actor_fn(obs, is_training=False, key=None): # is_training and key allows to defined train/test dependant modules # like dropout. del is_training del key if discrete_actions: network = hk.nets.MLP([64, 64, final_layer_size]) else: network = hk.Sequential([ networks_lib.LayerNormMLP([64, 64], activate_final=True), networks_lib.NormalTanhDistribution(final_layer_size), ]) return network(obs) policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) # Create dummy observations and actions to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Make network purely functional network_hk = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) network = networks_lib.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def make_haiku_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates Haiku networks to be used by the agent.""" num_actions = spec.actions.num_values def forward_fn(inputs): policy_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), networks_lib.CategoricalHead(num_actions) ]) value_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP([64, 64]), hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1) ]) action_distribution = policy_network(inputs) value = value_network(inputs) return (action_distribution, value) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
def test_mean_random(self): x = jnp.ones(10) bx = jnp.ones((9, 10)) ffn = RandomFFN() wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial( ffn.init, x=x), apply=ffn.apply) mean_ensemble = ensemble.make_ensemble(wrapped_ffn, ensemble.apply_mean, num_networks=3) key = jax.random.PRNGKey(0) params = mean_ensemble.init(key) single_output = mean_ensemble.apply(params, x) self.assertEqual(single_output.shape, (15, )) batch_output = mean_ensemble.apply(params, bx) # Make sure all rows are equal: np.testing.assert_allclose(jnp.broadcast_to(batch_output[0], batch_output.shape), batch_output, atol=1E-5, rtol=1E-5) # Check results explicitly: all_members = jnp.concatenate([ jnp.expand_dims(ffn.apply( jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0) for i in range(3) ]) batch_means = jnp.mean(all_members, axis=0) np.testing.assert_allclose(batch_output, batch_means, atol=1E-5, rtol=1E-5)
def test_round_robin_random(self): x = jnp.ones(10) # Base input bx = jnp.ones((9, 10)) # Batched input ffn = RandomFFN() wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial( ffn.init, x=x), apply=ffn.apply) rr_ensemble = ensemble.make_ensemble(wrapped_ffn, ensemble.apply_round_robin, num_networks=3) key = jax.random.PRNGKey(0) params = rr_ensemble.init(key) out = rr_ensemble.apply(params, bx) # The output should be the same every 3 rows: blocks = jnp.split(out, 3, axis=0) np.testing.assert_array_equal(blocks[0], blocks[1]) np.testing.assert_array_equal(blocks[0], blocks[2]) self.assertTrue((out[0] != out[1]).any()) for i in range(9): np.testing.assert_allclose( out[i], ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]), atol=1E-5, rtol=1E-5)
def make_discrete_networks( environment_spec: specs.EnvironmentSpec, hidden_layer_sizes: Sequence[int] = (512, ), use_conv: bool = True, ) -> PPONetworks: """Creates networks used by the agent for discrete action environments. Args: environment_spec: Environment spec used to define number of actions. hidden_layer_sizes: Network definition. use_conv: Whether to use a conv or MLP feature extractor. Returns: PPONetworks """ num_actions = environment_spec.actions.num_values def forward_fn(inputs): layers = [] if use_conv: layers.extend([networks_lib.AtariTorso()]) layers.extend([ hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu), networks_lib.CategoricalValueHead(num_values=num_actions) ]) policy_value_network = hk.Sequential(layers) return policy_value_network(inputs) forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. network = networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) # Create PPONetworks to add functionality required by the agent. return make_ppo_networks(network)
def get_fake_world_model() -> networks_lib.FeedForwardNetwork: def apply(params: Any, observation_t: jnp.ndarray, action_t: jnp.ndarray): del params return observation_t, jnp.ones(( action_t.shape[0], 1, )) return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply)
def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork: """Like params_adding_ffn, but with pytree inputs, preserves structure.""" def init_fn(key, sx=sx): return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx) def apply_fn(params, x): return jax.tree_map(lambda p, v: p + v, params, x) return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn)
def make_networks( spec: specs.EnvironmentSpec, direct_rl_networks: DirectRLNetworks, layer_sizes: Tuple[int, ...] = (256, 256), intrinsic_reward_coefficient: float = 1.0, extrinsic_reward_coefficient: float = 0.0, ) -> RNDNetworks[DirectRLNetworks]: """Creates networks used by the agent and returns RNDNetworks. Args: spec: Environment spec. direct_rl_networks: Networks used by a direct rl algorithm. layer_sizes: Layer sizes. intrinsic_reward_coefficient: Multiplier on intrinsic reward. extrinsic_reward_coefficient: Multiplier on extrinsic reward. Returns: The RND networks. """ def _rnd_fn(obs, act): # RND does not use the action but other variants like RED do. del act network = networks_lib.LayerNormMLP(list(layer_sizes)) return network(obs) target = hk.without_apply_rng(hk.transform(_rnd_fn)) predictor = hk.without_apply_rng(hk.transform(_rnd_fn)) # Create dummy observations and actions to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) return RNDNetworks( target=networks_lib.FeedForwardNetwork( lambda key: target.init(key, dummy_obs, ()), target.apply), predictor=networks_lib.FeedForwardNetwork( lambda key: predictor.init(key, dummy_obs, ()), predictor.apply), direct_rl_networks=direct_rl_networks, get_reward=functools.partial( rnd_reward_fn, intrinsic_reward_coefficient=intrinsic_reward_coefficient, extrinsic_reward_coefficient=extrinsic_reward_coefficient))
def make_discriminator( environment_spec: specs.EnvironmentSpec, discriminator_transformed: hk.TransformedWithState, logpi_fn: Optional[Callable[ [networks_lib.Params, networks_lib.Observation, networks_lib.Action], jnp.ndarray]] = None ) -> networks_lib.FeedForwardNetwork: """Creates the discriminator network. Args: environment_spec: Environment spec discriminator_transformed: Haiku transformed of the discriminator. logpi_fn: If the policy logpi function is provided, its output will be removed from the discriminator logit. Returns: The network. """ def apply_fn(params: hk.Params, policy_params: networks_lib.Params, state: hk.State, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: output, state = discriminator_transformed.apply( params, state, transitions.observation, transitions.action, transitions.next_observation, is_training, rng) if logpi_fn is not None: logpi = logpi_fn(policy_params, transitions.observation, transitions.action) # Quick Maths: # D = exp(output)/(exp(output) + pi(a|s)) # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s)) # logit(D) = output - logpi return output - logpi, state return output, state dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(environment_spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) return networks_lib.FeedForwardNetwork( # pylint: disable=g-long-lambda init=lambda rng: discriminator_transformed.init( rng, dummy_obs, dummy_actions, dummy_obs, False, rng), apply=apply_fn)
def make_dqn_atari_network( environment_spec: specs.EnvironmentSpec ) -> networks.FeedForwardNetwork: """Creates networks for training DQN on Atari.""" def network(inputs): model = hk.Sequential([ networks.AtariTorso(), hk.nets.MLP([512, environment_spec.actions.num_values]), ]) return model(inputs) network_hk = hk.without_apply_rng(hk.transform(network)) obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations)) return networks.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply)
def make_ppo_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates Haiku networks to be used by the agent.""" num_actions = spec.actions.num_values forward_fn = functools.partial(ppo_forward_fn, num_actions=num_actions) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
def add_bc_pretraining(sac_networks: sac.SACNetworks) -> sac.SACNetworks: """Augments `sac_networks` to run BC pretraining in policy_network.init.""" make_demonstrations = functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name) bc_network = bc.pretraining.convert_to_bc_network( sac_networks.policy_network) loss = bc.logp(sac_networks.log_prob) def bc_init(*unused_args): return bc.pretraining.train_with_bc(make_demonstrations, bc_network, loss) return dataclasses.replace(sac_networks, policy_network=networks_lib.FeedForwardNetwork( bc_init, sac_networks.policy_network.apply))
def convert_policy_value_to_bc_network( policy_value_network: networks_lib.FeedForwardNetwork ) -> networks_lib.FeedForwardNetwork: """Converts a network from e.g. PPO into a BC policy network. Args: policy_value_network: FeedForwardNetwork taking the observation as input. Returns: The BC policy network taking observation, is_training, key as input. """ def apply(params, obs, is_training=False, key=None): del is_training, key actions, _ = policy_value_network.apply(params, obs) return actions return networks_lib.FeedForwardNetwork(policy_value_network.init, apply)
def convert_to_bc_network( policy_network: networks_lib.FeedForwardNetwork ) -> networks_lib.FeedForwardNetwork: """Converts a policy_network from SAC/TD3/D4PG/.. into a BC policy network. Args: policy_network: FeedForwardNetwork taking the observation as input and returning action representation compatible with one of the BC losses. Returns: The BC policy network taking observation, is_training, key as input. """ def apply(params, obs, is_training=False, key=None): del is_training, key return policy_network.apply(params, obs) return networks_lib.FeedForwardNetwork(policy_network.init, apply)
def add_bc_pretraining(td3_networks: td3.TD3Networks) -> td3.TD3Networks: """Augments `td3_networks` to run BC pretraining in policy_network.init.""" make_demonstrations = functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name) bc_network = bc.pretraining.convert_to_bc_network( td3_networks.policy_network) # TODO(lukstafi): consider passing noised policy. loss = bc.mse(lambda x, key: x) def bc_init(*unused_args): return bc.pretraining.train_with_bc(make_demonstrations, bc_network, loss) return dataclasses.replace(td3_networks, policy_network=networks_lib.FeedForwardNetwork( bc_init, td3_networks.policy_network.apply))
def add_bc_pretraining(ppo_networks: ppo.PPONetworks) -> ppo.PPONetworks: """Augments `ppo_networks` to run BC pretraining in policy_network.init.""" make_demonstrations = functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name) bc_network = bc.pretraining.convert_policy_value_to_bc_network( ppo_networks.network) loss = bc.logp(ppo_networks.log_prob) # Note: despite only training the policy network, this will also include the # initial value network params. def bc_init(*unused_args): return bc.pretraining.train_with_bc(make_demonstrations, bc_network, loss) return dataclasses.replace(ppo_networks, network=networks_lib.FeedForwardNetwork( bc_init, ppo_networks.network.apply))
def make_network_from_module( module: hk.Transformed, spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork: """Creates a network with dummy init arguments using the specified module. Args: module: Module that expects one batch axis and one features axis for its inputs. spec: EnvironmentSpec shapes to derive dummy inputs. Returns: FeedForwardNetwork whose `init` method only takes a random key, and `apply` takes an observation and action and produces an output. """ dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions)) return networks.FeedForwardNetwork( lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
def make_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent. The model used by the ARS paper is a simple clipped linear model. Args: spec: an environment spec Returns: A FeedForwardNetwork network. """ obs_size = spec.observations.shape[0] act_size = spec.actions.shape[0] return networks_lib.FeedForwardNetwork( init=lambda _: jnp.zeros((obs_size, act_size)), apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1))
def make_flax_networks( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates FLAX networks to be used by the agent.""" num_actions = spec.actions.num_values class MLP(flax.deprecated.nn.Module): """MLP module.""" def apply(self, data: jnp.ndarray, layer_sizes: Tuple[int], activation: Callable[[jnp.ndarray], jnp.ndarray] = flax.deprecated.nn.relu, kernel_init: object = jax.nn.initializers.lecun_uniform(), activate_final: bool = False, bias: bool = True): hidden = data for i, hidden_size in enumerate(layer_sizes): hidden = flax.deprecated.nn.Dense(hidden, hidden_size, name=f'hidden_{i}', kernel_init=kernel_init, bias=bias) if i != len(layer_sizes) - 1 or activate_final: hidden = activation(hidden) return hidden class PolicyValueModule(flax.deprecated.nn.Module): """MLP module.""" def apply(self, inputs: jnp.ndarray): inputs = utils.batch_concat(inputs) logits = MLP(inputs, [64, 64, num_actions]) value = MLP(inputs, [64, 64, 1]) value = jnp.squeeze(value, axis=-1) return tfd.Categorical(logits=logits), value dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. return networks_lib.FeedForwardNetwork( lambda rng: PolicyValueModule.init(rng, dummy_obs)[1], PolicyValueModule.call)
def make_networks( spec, actor_fn_build_fn = build_mlp_actor_fn, actor_hidden_layer_sizes = (256, 256), critic_fn_build_fn = build_hk_batch_ensemble_mlp_critic_fn, # critic_fn_build_fn: Callable = build_mlp_critic_fn, critic_hidden_layer_sizes = (256, 256), use_double_q = False, ): """Creates networks used by the agent.""" num_dimensions = np.prod(spec.actions.shape, dtype=int) _actor_fn = actor_fn_build_fn(num_dimensions, actor_hidden_layer_sizes) policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) _critic_fn = critic_fn_build_fn(critic_hidden_layer_sizes, use_double_q) critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True)) critic_ensemble_init = ensemble_utils.transform_init_for_ensemble(critic.init, init_same=False) critic_ensemble_member_apply = ensemble_utils.transform_apply_for_ensemble_member(critic.apply) critic_same_batch_ensemble_apply = ensemble_utils.build_same_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2) critic_diff_batch_ensemble_apply = ensemble_utils.build_different_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) return BatchEnsembleMSGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply), q_ensemble_init=lambda ensemble_size, key: critic_ensemble_init(ensemble_size, key, dummy_obs, dummy_action), q_ensemble_member_apply=critic_ensemble_member_apply, q_ensemble_same_batch_apply=critic_same_batch_ensemble_apply, q_ensemble_different_batch_apply=critic_diff_batch_ensemble_apply, log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), sample_eval=lambda params, key: params.mode())
def make_network( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" num_actions = spec.actions.num_values def actor_fn(obs, is_training=True, key=None): # is_training and key allows to utilize train/test dependant modules # like dropout. del is_training del key mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])]) return mlp(obs) policy = hk.without_apply_rng(hk.transform(actor_fn)) # Create dummy observations to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def make_action_candidates_network( spec, num_actions, discrete_rl_networks, torso_layer_sizes = (256,), head_layer_sizes = (256,), input_dropout_rate = 0.1, hidden_dropout_rate = 0.1): """Creates networks used by the agent and wraps it into Flax Model. Args: spec: Environment spec. num_actions: the number of actions proposed by the multi-modal model. discrete_rl_networks: Direct RL algorithm networks. torso_layer_sizes: Layer sizes of the torso. head_layer_sizes: Layer sizes of the heads. input_dropout_rate: Dropout rate input. hidden_dropout_rate: Dropout rate hidden. Returns: The Flax model. """ dummy_obs, _ = get_dummy_batched_obs_and_actions(spec) encoder_module = Encoder( action_dim=np.prod(spec.actions.shape, dtype=int), num_actions=num_actions, torso_layer_sizes=torso_layer_sizes, head_layer_sizes=head_layer_sizes, input_dropout_rate=input_dropout_rate, hidden_dropout_rate=hidden_dropout_rate,) encoder = networks_lib.FeedForwardNetwork( lambda key: encoder_module.init(key, dummy_obs, is_training=False), encoder_module.apply) return AquademNetworks( encoder=encoder, discrete_rl_networks=discrete_rl_networks)
def make_continuous_networks( environment_spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (64, 64), value_layer_sizes: Sequence[int] = (64, 64), ) -> PPONetworks: """Creates PPONetworks to be used for continuous action environments.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) def forward_fn(inputs): policy_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP(policy_layer_sizes, activation=jnp.tanh), # Note: we don't respect bounded action specs here and instead # rely on CanonicalSpecWrapper to clip actions accordingly. networks_lib.MultivariateNormalDiagHead(num_dimensions) ]) value_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP(value_layer_sizes, activation=jnp.tanh), hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1) ]) action_distribution = policy_network(inputs) value = value_network(inputs) return (action_distribution, value) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. network = networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) # Create PPONetworks to add functionality required by the agent. return make_ppo_networks(network)
def make_q_network(spec, hidden_layer_sizes=(512, 512, 256), architecture='LayerNorm'): """DQN network for Aquadem algo.""" def _q_fn(obs): if architecture == 'MLP': # AQuaOff architecture network_fn = hk.nets.MLP elif architecture == 'LayerNorm': # Original AQuaDem architecture network_fn = networks_lib.LayerNormMLP else: return ValueError('Architecture not recognized') network = network_fn(list(hidden_layer_sizes) + [spec.actions.num_values]) value = network(obs) return value critic = hk.without_apply_rng(hk.transform(_q_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) critic_network = networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs), critic.apply) return critic_network