def make_networks( action_spec: specs.Array, policy_layer_sizes: Sequence[int] = (300, 200), critic_layer_sizes: Sequence[int] = (400, 300), ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) critic_layer_sizes = list(critic_layer_sizes) policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions), ]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = snt.Sequential([ networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes)), networks.DiscreteValuedHead(0., 1., 10), ]) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( action_spec: types.NestedSpec, policy_layer_sizes: Sequence[int] = (10, 10), critic_layer_sizes: Sequence[int] = (10, 10), ) -> Dict[str, snt.Module]: """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions] critic_layer_sizes = list(critic_layer_sizes) + [1] policy_network = snt.Sequential( [networks.LayerNormMLP(policy_layer_sizes), tf.tanh]) # The multiplexer concatenates the (maybe transformed) observations/actions. critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes)) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( action_spec, policy_layer_sizes=(10, 10), critic_layer_sizes=(10, 10), ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) critic_layer_sizes = list(critic_layer_sizes) + [1] policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) critic_network = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes)) return { 'policy': policy_network, 'critic': critic_network, }
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (256, 256, 256), critic_layer_sizes: Sequence[int] = (512, 512, 256), vmin: float = -150., vmax: float = 150., num_atoms: int = 51, ) -> Dict[str, types.TensorTransformation]: """Creates networks used by the agent.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(action_spec.shape, dtype=int) # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.batch_concat # Create the policy network. policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes), networks.MultivariateNormalDiagHead(num_dimensions) ]) # The multiplexer transforms concatenates the observations/actions. multiplexer = networks.CriticMultiplexer( critic_network=networks.LayerNormMLP(critic_layer_sizes), action_network=networks.ClipToSpec(action_spec)) # Create the critic network. critic_network = snt.Sequential([ multiplexer, networks.DiscreteValuedHead(vmin, vmax, num_atoms), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, }
# Get total number of action dimensions from action spec num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) # Create shared observation network: observation_network = tf2_utils.batch_concat policy_network = snt.Sequential([ networks.LayerNormMLP((256, 256, 256), activate_final=True), networks.NearZeroInitializedLinear(num_dimensions), networks.TanhToSpec(environment_spec.actions) ]) # Create the distributional critic network: critic_network = snt.Sequential([ networks.CriticMultiplexer(), networks.LayerNormMLP((512, 512, 256), activate_final=True), networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51) ]) # Create logger for agent diagnostics: agent_logger = loggers.TerminalLogger(label='agent', time_delta=10) # Create D4PG Agent: agent = d4pg.D4PG(environment_spec=environment_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, logger=agent_logger, checkpoint=False)