def make_networks(action_spec: specs.BoundedArray): """Creates simple networks for testing..""" num_dimensions = np.prod(action_spec.shape, dtype=int) # Create the observation network shared between the policy and critic. observation_network = tf2_utils.batch_concat # Create the policy network (head) and the evaluation network. policy_network = snt.Sequential([ networks.LayerNormMLP([50], activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(action_spec) ]) evaluator_network = snt.Sequential([observation_network, policy_network]) # Create the critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP([50], activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, 'evaluator': evaluator_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Dict[str, types.TensorTransformation]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(action_spec) ]) critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> Dict[str, Union[snt.Module, Callable[[tf.Tensor], tf.Tensor]]]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(action_spec) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer(), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.DiscreteValuedHead(vmin, vmax, num_atoms), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_networks( action_spec: types.NestedSpec, policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ tf2_utils.batch_concat, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.3, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer() critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( environment_spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (256, 256), critic_layer_sizes: Sequence[int] = (256, 256), ) -> Mapping[str, types.TensorTransformation]: """Creates the networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf.identity # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), sac.model.SquashedGaussianValueHead(num_dimensions), ]) # Create the critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
def __init__(self, hidden_layer_sizes: Sequence[int], actions_dim: int): """ Policy network. Args: hidden_layer_sizes: a sequence of ints specifying the size of each layer. action dim: actions number of dimensions. """ super().__init__(name='layer_input_norm_mlp') layers = [] # Hidden layers. for layer_size in hidden_layer_sizes: layers.append( snt.Linear(layer_size, w_init=tf.initializers.VarianceScaling( distribution='uniform', mode='fan_out', scale=0.333))) # layers.append(snt.LayerNorm(axis=slice(1, None), create_scale=True, create_offset=True)) layers.append(tf.nn.relu) # Last layer. layers.append(networks.NearZeroInitializedLinear(actions_dim)) layers.append(tf.nn.softmax) self._network = snt.Sequential(layers)
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Dict[str, types.TensorTransformation]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, init_scale=0.7, use_tfd_independent=True) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_default_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ tf2_utils.batch_concat, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.3, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { "policy": policy_network, "critic": critic_network, }
def make_network_with_prior( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (200, 100), critic_layer_sizes: Sequence[int] = (400, 300), prior_layer_sizes: Sequence[int] = (200, 100), policy_keys: Optional[Sequence[str]] = None, prior_keys: Optional[Sequence[str]] = None, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) flatten_concat_policy = functools.partial( svg0_utils.batch_concat_selection, concat_keys=policy_keys) flatten_concat_prior = functools.partial( svg0_utils.batch_concat_selection, concat_keys=prior_keys) policy_network = snt.Sequential([ flatten_concat_policy, networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.1, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( observation_network=flatten_concat_policy, action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) prior_network = snt.Sequential([ flatten_concat_prior, networks.LayerNormMLP(prior_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions, tanh_mean=True, min_scale=0.1, init_scale=0.7, fixed_scale=False, use_tfd_independent=False) ]) return { "policy": policy_network, "critic": critic_network, "prior": prior_network, }
def custom_recurrent_network( environment_spec: mava_specs.MAEnvironmentSpec, q_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = [128, 128], shared_weights: bool = True, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agents.""" specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(q_networks_layer_sizes, Sequence): q_networks_layer_sizes = {key: q_networks_layer_sizes for key in specs.keys()} def action_selector_fn( q_values: types.NestedTensor, legal_actions: types.NestedTensor, epsilon: Optional[tf.Variable] = None, ) -> types.NestedTensor: return epsilon_greedy_action_selector( action_values=q_values, legal_actions_mask=legal_actions, epsilon=epsilon ) q_networks = {} action_selectors = {} for key in specs.keys(): # Get total number of action dimensions from action spec. num_dimensions = specs[key].actions.num_values # Create the policy network. q_network = snt.DeepRNN( [ snt.Linear(q_networks_layer_sizes[key][0]), tf.nn.relu, snt.GRU(q_networks_layer_sizes[key][1]), networks.NearZeroInitializedLinear(num_dimensions), ] ) # epsilon greedy action selector action_selector = action_selector_fn q_networks[key] = q_network action_selectors[key] = action_selector return { "q_networks": q_networks, "action_selectors": action_selectors, }
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, q_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256), shared_weights: bool = True, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agents.""" specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(q_networks_layer_sizes, Sequence): q_networks_layer_sizes = { key: q_networks_layer_sizes for key in specs.keys() } def action_selector_fn(q_values: types.NestedTensor, legal_actions: types.NestedTensor, epsilon: float) -> types.NestedTensor: return epsilon_greedy_action_selector(action_values=q_values, legal_actions_mask=legal_actions, epsilon=epsilon) q_networks = {} action_selectors = {} for key in specs.keys(): # Get total number of action dimensions from action spec. num_dimensions = specs[key].actions.num_values # Create the policy network. q_network = snt.Sequential([ networks.LayerNormMLP(q_networks_layer_sizes[key], activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), ]) # epsilon greedy action selector action_selector = action_selector_fn q_networks[key] = q_network action_selectors[key] = action_selector return { "q_networks": q_networks, "action_selectors": action_selectors, }
def make_d4pg_agent(env_spec: specs.EnvironmentSpec, logger: Logger, checkpoint_path: str, hyperparams: Dict): params = DEFAULT_PARAMS.copy() params.update(hyperparams) action_size = np.prod(env_spec.actions.shape, dtype=int).item() policy_network = snt.Sequential([ networks.LayerNormMLP(layer_sizes=[*params.pop('policy_layers'), action_size]), networks.NearZeroInitializedLinear(output_size=action_size), networks.TanhToSpec(env_spec.actions), ]) critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(layer_sizes=[*params.pop('critic_layers'), 1]) ), networks.DiscreteValuedHead(vmin=-100.0, vmax=100.0, num_atoms=params.pop('atoms')) ]) observation_network = tf.identity # Make sure observation network is a Sonnet Module. observation_network = tf2_utils.to_sonnet_module(observation_network) actor = FeedForwardActor(policy_network=snt.Sequential([ observation_network, policy_network ])) # Create optimizers. policy_optimizer = Adam(params.pop('policy_lr')) critic_optimizer = Adam(params.pop('critic_lr')) # The learner updates the parameters (and initializes them). agent = D4PG( environment_spec=env_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, logger=logger, checkpoint_path=checkpoint_path, **params ) agent.__setattr__('eval_actor', actor) return agent
def make_networks(action_spec: specs.BoundedArray): """Simple networks for testing..""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP([50], activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(action_spec) ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=networks.LayerNormMLP([50], activate_final=True)), networks.DiscreteValuedHead(-1., 1., 10) ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': tf2_utils.batch_concat, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, 50), critic_layer_sizes: Sequence[int] = (50, 50), ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) observation_network = tf2_utils.batch_concat policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, tanh_mean=True, init_scale=0.3, fixed_scale=True, use_tfd_independent=False) ]) evaluator_network = snt.Sequential([ observation_network, policy_network, networks.StochasticMeanHead(), ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, 'evaluator': evaluator_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.batch_concat # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(action_spec), ]) # Create the critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.DiscreteValuedHead(vmin, vmax, num_atoms), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
# Grab the spec of the environment. environment_spec = specs.make_environment_spec(environment) #@title Build agent networks # BUILDING A D4PG AGENT # Get total number of action dimensions from action spec. num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.batch_concat # Create the deterministic policy network. policy_network = snt.Sequential([ networks.LayerNormMLP((256, 256, 256), activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(environment_spec.actions), ]) # Create the distributional critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP((512, 512, 256), activate_final=True), networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51), ]) # Create a logger for the agent and environment loop. agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.) env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)
def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (256, 256, 256), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, sigma: float = 0.3, archecture_type: ArchitectureType = ArchitectureType.feedforward, ) -> Mapping[str, types.TensorTransformation]: """Default networks for maddpg. Args: environment_spec (mava_specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of policy networks. Defaults to (256, 256, 256). critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of critic networks. Defaults to (512, 512, 256). shared_weights (bool, optional): whether agents should share weights or not. Defaults to True. sigma (float, optional): hyperparameters used to add Gaussian noise for simple exploration. Defaults to 0.3. archecture_type (ArchitectureType, optional): archecture used for agent networks. Can be feedforward or recurrent. Defaults to ArchitectureType.feedforward. Returns: Mapping[str, types.TensorTransformation]: returned agent networks. """ # Set Policy function and layer size if archecture_type == ArchitectureType.feedforward: policy_network_func = snt.Sequential elif archecture_type == ArchitectureType.recurrent: policy_networks_layer_sizes = (128, 128) policy_network_func = snt.DeepRNN specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # TODO (dries): Make specs[key].actions # return a list of specs for hybrid action space # Get total number of action dimensions from action spec. agent_act_spec = specs[key].actions if type(specs[key].actions) == DiscreteArray: num_actions = agent_act_spec.num_values minimum = [-1.0] * num_actions maximum = [1.0] * num_actions agent_act_spec = BoundedArray( shape=(num_actions, ), minimum=minimum, maximum=maximum, dtype="float32", name="actions", ) # Get total number of action dimensions from action spec. num_dimensions = np.prod(agent_act_spec.shape, dtype=int) # An optional network to process observations observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. if archecture_type == ArchitectureType.feedforward: policy_network = [ networks.LayerNormMLP(policy_networks_layer_sizes[key], activate_final=True), ] elif archecture_type == ArchitectureType.recurrent: policy_network = [ networks.LayerNormMLP(policy_networks_layer_sizes[key][:-1], activate_final=True), snt.LSTM(policy_networks_layer_sizes[key][-1]), ] policy_network += [ networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(agent_act_spec), ] # Add Gaussian noise for simple exploration. if sigma and sigma > 0.0: policy_network += [ networks.ClippedGaussian(sigma), networks.ClipToSpec(agent_act_spec), ] policy_network = policy_network_func(policy_network) # Create the critic network. critic_network = snt.Sequential([ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP(list(critic_networks_layer_sizes[key]) + [1], activate_final=False), ]) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, ) -> Dict[str, snt.Module]: """Creates networks used by the agents.""" # Create agent_type specs. specs = environment_spec.get_agent_specs() if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values policy_network = snt.Sequential([ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions, ), activate_final=False, ), tf.keras.layers.Lambda(lambda logits: tfp.distributions. Categorical(logits=logits)), ]) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_networks_layer_sizes[key], activate_final=True), networks.MultivariateNormalDiagHead( num_dimensions=num_actions), networks.TanhToSpec(specs[key].actions), ]) else: raise ValueError( f"Unknown action_spec type, got {specs[key].actions}.") critic_network = snt.Sequential([ networks.LayerNormMLP(critic_networks_layer_sizes[key], activate_final=True), networks.NearZeroInitializedLinear(1), ]) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }
def make_networks( environment_spec: mava_specs.MAEnvironmentSpec, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), shared_weights: bool = True, sigma: float = 0.3, ) -> Mapping[str, types.TensorTransformation]: """Creates networks used by the agents.""" specs = environment_spec.get_agent_specs() # Create agent_type specs if shared_weights: type_specs = {key.split("_")[0]: specs[key] for key in specs.keys()} specs = type_specs if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() } if isinstance(critic_networks_layer_sizes, Sequence): critic_networks_layer_sizes = { key: critic_networks_layer_sizes for key in specs.keys() } observation_networks = {} policy_networks = {} critic_networks = {} for key in specs.keys(): # Get total number of action dimensions from action spec. num_dimensions = np.prod(specs[key].actions.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. policy_network = snt.Sequential( [ networks.LayerNormMLP( policy_networks_layer_sizes[key], activate_final=True ), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(specs[key].actions), networks.ClippedGaussian(sigma), networks.ClipToSpec(specs[key].actions), ] ) # Create the critic network. critic_network = snt.Sequential( [ # The multiplexer concatenates the observations/actions. networks.CriticMultiplexer(), networks.LayerNormMLP( critic_networks_layer_sizes[key], activate_final=False ), snt.Linear(1), ] ) observation_networks[key] = observation_network policy_networks[key] = policy_network critic_networks[key] = critic_network return { "policies": policy_networks, "critics": critic_networks, "observations": observation_networks, }