Exemple #1
0
def make_networks(
    action_spec: specs.Array,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
) -> Dict[str, snt.Module]:
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions),
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = snt.Sequential([
      networks.CriticMultiplexer(
          critic_network=networks.LayerNormMLP(critic_layer_sizes)),
      networks.DiscreteValuedHead(0., 1., 10),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Exemple #2
0
def make_networks(
        action_spec: types.NestedSpec,
        policy_layer_sizes: Sequence[int] = (10, 10),
        critic_layer_sizes: Sequence[int] = (10, 10),
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions]
    critic_layer_sizes = list(critic_layer_sizes) + [1]

    policy_network = snt.Sequential(
        [networks.LayerNormMLP(policy_layer_sizes), tf.tanh])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes))

    return {
        'policy': policy_network,
        'critic': critic_network,
    }
Exemple #3
0
def make_networks(
        action_spec,
        policy_layer_sizes=(10, 10),
        critic_layer_sizes=(10, 10),
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    critic_layer_sizes = list(critic_layer_sizes) + [1]

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes),
        networks.MultivariateNormalDiagHead(num_dimensions)
    ])
    critic_network = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes))

    return {
        'policy': policy_network,
        'critic': critic_network,
    }
Exemple #4
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, types.TensorTransformation]:
    """Creates networks used by the agent."""

    # Get total number of action dimensions from action spec.
    num_dimensions = np.prod(action_spec.shape, dtype=int)

    # Create the shared observation network; here simply a state-less operation.
    observation_network = tf2_utils.batch_concat

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes),
        networks.MultivariateNormalDiagHead(num_dimensions)
    ])

    # The multiplexer transforms concatenates the observations/actions.
    multiplexer = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes),
        action_network=networks.ClipToSpec(action_spec))

    # Create the critic network.
    critic_network = snt.Sequential([
        multiplexer,
        networks.DiscreteValuedHead(vmin, vmax, num_atoms),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Exemple #5
0
# Get total number of action dimensions from action spec

num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

# Create shared observation network:
observation_network = tf2_utils.batch_concat

policy_network = snt.Sequential([
    networks.LayerNormMLP((256, 256, 256), activate_final=True),
    networks.NearZeroInitializedLinear(num_dimensions),
    networks.TanhToSpec(environment_spec.actions)
])

# Create the distributional critic network:
critic_network = snt.Sequential([
    networks.CriticMultiplexer(),
    networks.LayerNormMLP((512, 512, 256), activate_final=True),
    networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51)
])

# Create logger for agent diagnostics:
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10)

# Create D4PG Agent:
agent = d4pg.D4PG(environment_spec=environment_spec,
                  policy_network=policy_network,
                  critic_network=critic_network,
                  observation_network=observation_network,
                  logger=agent_logger,
                  checkpoint=False)