Ejemplo n.º 1
0
def make_mpo_networks(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    critic_layer_sizes=(512, 512, 256),
    policy_init_std=1e-9,
    obs_network=None):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes) + [1]

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          init_scale=policy_init_std,
          min_scale=1e-10)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
      action_network=networks.ClipToSpec(action_spec))
  if obs_network is None:
    obs_network = tf_utils.batch_concat

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': obs_network,
  }
Ejemplo n.º 2
0
def make_d4pg_networks(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    critic_layer_sizes=(512, 512, 256),
    vmin=-150.,
    vmax=150.,
    num_atoms=201):
  """Creates networks used by the d4pg agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  policy_layer_sizes = list(policy_layer_sizes) + [int(num_dimensions)]

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.TanhToSpec(action_spec)
  ])

  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = snt.Sequential([
      networks.CriticMultiplexer(
          critic_network=networks.LayerNormMLP(
              critic_layer_sizes, activate_final=True)),
      networks.DiscreteValuedHead(vmin, vmax, num_atoms)
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf2_utils.batch_concat,
  }
Ejemplo n.º 3
0
def make_dmpo_networks(
    action_spec,
    policy_layer_sizes = (300, 200),
    critic_layer_sizes = (400, 300),
    vmin = -150.,
    vmax = 150.,
    num_atoms = 51,
):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential(
      [critic_network,
       networks.DiscreteValuedHead(vmin, vmax, num_atoms)])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf_utils.batch_concat,
  }
Ejemplo n.º 4
0
def make_networks(
    action_spec: types.NestedSpec,
    policy_layer_sizes: Sequence[int] = (10, 10),
    critic_layer_sizes: Sequence[int] = (10, 10),
) -> Dict[str, snt.Module]:
  """Creates networks used by the agent."""
  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      tf2_utils.batch_concat,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.3,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer()
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Ejemplo n.º 5
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, Union[snt.Module, Callable[[tf.Tensor], tf.Tensor]]]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(num_dimensions),
        networks.TanhToSpec(action_spec)
    ])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.DiscreteValuedHead(vmin, vmax, num_atoms),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Ejemplo n.º 6
0
def make_default_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Mapping[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      tf2_utils.batch_concat,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.3,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer(
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      "policy": policy_network,
      "critic": critic_network,
  }
Ejemplo n.º 7
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (50, ),
    critic_layer_sizes: Sequence[int] = (50, ),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            tanh_mean=True,
                                            init_scale=0.3,
                                            fixed_scale=True,
                                            use_tfd_independent=False)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes,
                                             activate_final=True),
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential(
        [critic_network,
         networks.DiscreteValuedHead(vmin, vmax, num_atoms)])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Ejemplo n.º 8
0
def make_networks(action_spec: specs.BoundedArray):
    """Creates simple networks for testing.."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    # Create the observation network shared between the policy and critic.
    observation_network = tf2_utils.batch_concat

    # Create the policy network (head) and the evaluation network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP([50], activate_final=True),
        networks.NearZeroInitializedLinear(num_dimensions),
        networks.TanhToSpec(action_spec)
    ])
    evaluator_network = snt.Sequential([observation_network, policy_network])

    # Create the critic network.
    critic_network = snt.Sequential([
        # The multiplexer concatenates the observations/actions.
        networks.CriticMultiplexer(),
        networks.LayerNormMLP([50], activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
        'evaluator': evaluator_network,
    }
Ejemplo n.º 9
0
def make_networks(
    environment_spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (256, 256),
    critic_layer_sizes: Sequence[int] = (256, 256),
) -> Mapping[str, types.TensorTransformation]:
    """Creates the networks used by the agent."""

    # Get total number of action dimensions from action spec.
    num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

    # Create the shared observation network; here simply a state-less operation.
    observation_network = tf.identity

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        sac.model.SquashedGaussianValueHead(num_dimensions),
    ])

    # Create the critic network.
    critic_network = snt.Sequential([
        # The multiplexer concatenates the observations/actions.
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Ejemplo n.º 10
0
def make_networks(
    action_spec: specs.Array,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
) -> Dict[str, snt.Module]:
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions),
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = snt.Sequential([
      networks.CriticMultiplexer(
          critic_network=networks.LayerNormMLP(critic_layer_sizes)),
      networks.DiscreteValuedHead(0., 1., 10),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Ejemplo n.º 11
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Dict[str, types.TensorTransformation]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            init_scale=0.7,
                                            use_tfd_independent=True)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    multiplexer = networks.CriticMultiplexer(
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential([
        multiplexer,
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Ejemplo n.º 12
0
def make_networks(
    action_spec: types.NestedSpec,
    policy_layer_sizes: Sequence[int] = (10, 10),
    critic_layer_sizes: Sequence[int] = (10, 10),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions]

    policy_network = snt.Sequential(
        [networks.LayerNormMLP(policy_layer_sizes), tf.tanh])
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            critic_layer_sizes, activate_final=True)),
        networks.DiscreteValuedHead(vmin, vmax, num_atoms)
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
    }
Ejemplo n.º 13
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Dict[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(num_dimensions),
      networks.TanhToSpec(action_spec)
  ])
  critic_network = snt.Sequential([
      # The multiplexer concatenates the observations/actions.
      networks.CriticMultiplexer(),
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf2_utils.batch_concat,
  }
Ejemplo n.º 14
0
def make_networks(
    action_spec: specs.BoundedArray,
    num_critic_heads: int,
    policy_layer_sizes: Sequence[int] = (50, ),
    critic_layer_sizes: Sequence[int] = (50, ),
    num_layers_shared: int = 1,
    distributional_critic: bool = True,
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            tanh_mean=False,
                                            init_scale=0.69)
    ])

    if not distributional_critic:
        critic_layer_sizes = list(critic_layer_sizes) + [1]

    if not num_layers_shared:
        # No layers are shared
        critic_network_base = None
    else:
        critic_network_base = networks.LayerNormMLP(
            critic_layer_sizes[:num_layers_shared], activate_final=True)
    critic_network_heads = [
        snt.nets.MLP(critic_layer_sizes,
                     activation=tf.nn.elu,
                     activate_final=False) for _ in range(num_critic_heads)
    ]
    if distributional_critic:
        critic_network_heads = [
            snt.Sequential(
                [c, networks.DiscreteValuedHead(vmin, vmax, num_atoms)])
            for c in critic_network_heads
        ]
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            critic_network=critic_network_base,
            action_network=networks.ClipToSpec(action_spec)),
        networks.Multihead(network_heads=critic_network_heads),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Ejemplo n.º 15
0
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger,
                        hyperparams: Dict, checkpoint_path: str):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(
            layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.MultivariateNormalDiagHead(num_dimensions=action_size)
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            layer_sizes=[*params.pop('critic_layers'), 1]))
    ])

    observation_network = networks.DeepRNN([
        networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')),
        networks.LSTM(hidden_size=200)
    ])

    loss_param_keys = list(
        filter(lambda key: key.startswith('loss_'), params.keys()))
    loss_params = dict([(k.replace('loss_', ''), params.pop(k))
                        for k in loss_param_keys])
    policy_loss_module = losses.MPO(**loss_params)

    # Create a replay server to add data to.

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    actor = RecurrentActor(
        networks.DeepRNN([
            observation_network, policy_network,
            networks.StochasticModeHead()
        ]))

    # The learner updates the parameters (and initializes them).
    return RecurrentMPO(environment_spec=env_spec,
                        policy_network=policy_network,
                        critic_network=critic_network,
                        observation_network=observation_network,
                        policy_loss_module=policy_loss_module,
                        policy_optimizer=policy_optimizer,
                        critic_optimizer=critic_optimizer,
                        logger=logger,
                        checkpoint_path=checkpoint_path,
                        **params), actor
Ejemplo n.º 16
0
def make_network_with_prior(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (200, 100),
    critic_layer_sizes: Sequence[int] = (400, 300),
    prior_layer_sizes: Sequence[int] = (200, 100),
    policy_keys: Optional[Sequence[str]] = None,
    prior_keys: Optional[Sequence[str]] = None,
) -> Mapping[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)
  flatten_concat_policy = functools.partial(
      svg0_utils.batch_concat_selection, concat_keys=policy_keys)
  flatten_concat_prior = functools.partial(
      svg0_utils.batch_concat_selection, concat_keys=prior_keys)

  policy_network = snt.Sequential([
      flatten_concat_policy,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.1,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer(
      observation_network=flatten_concat_policy,
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])
  prior_network = snt.Sequential([
      flatten_concat_prior,
      networks.LayerNormMLP(prior_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.1,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  return {
      "policy": policy_network,
      "critic": critic_network,
      "prior": prior_network,
  }
Ejemplo n.º 17
0
def make_default_networks(
    environment_spec: specs.EnvironmentSpec,
    *,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    policy_init_scale: float = 0.7,
    critic_init_scale: float = 1e-3,
    critic_num_components: int = 5,
) -> Mapping[str, snt.Module]:
    """Creates networks used by the agent."""

    # Unpack the environment spec to get appropriate shapes, dtypes, etc.
    act_spec = environment_spec.actions
    obs_spec = environment_spec.observations
    num_dimensions = np.prod(act_spec.shape, dtype=int)

    # Create the observation network and make sure it's a Sonnet module.
    observation_network = tf2_utils.batch_concat
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            init_scale=policy_init_scale,
                                            use_tfd_independent=True)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            action_network=networks.ClipToSpec(act_spec)),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.GaussianMixtureHead(num_dimensions=1,
                                     num_components=critic_num_components,
                                     init_scale=critic_init_scale)
    ])

    # Create network variables.
    # Get embedding spec by creating observation network variables.
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
    tf2_utils.create_variables(policy_network, [emb_spec])
    tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Ejemplo n.º 18
0
def make_bc_network(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    policy_init_std=1e-9,
    binary_grip_action=False):
  """Residual BC network in Sonnet, equivalent to residual policy network."""
  num_dimensions = np.prod(action_spec.shape, dtype=int)
  if policy_layer_sizes:
    policy_network = snt.Sequential([
        tf_utils.batch_concat,
        networks.LayerNormMLP([int(l) for l in policy_layer_sizes]),
        networks.MultivariateNormalDiagHead(
            num_dimensions,
            init_scale=policy_init_std,
            min_scale=1e-10)
    ])
  else:
    policy_network = snt.Sequential([
        tf_utils.batch_concat,
        ArmPolicyNormalDiagHead(
            binary_grip_action=binary_grip_action,
            num_dimensions=num_dimensions,
            init_scale=policy_init_std,
            min_scale=1e-10)
    ])
  return {
      # 'observation': tf_utils.batch_concat,
      'policy': policy_network,
  }
Ejemplo n.º 19
0
 def __init__(self, layer_sizes: Sequence[int]):
     super(InstrumentalFeature, self).__init__()
     self._net = snt.Sequential([
         networks.CriticMultiplexer(),
         networks.LayerNormMLP(layer_sizes, activate_final=True)
     ])
     self._feature_dim = layer_sizes[-1] + 1
Ejemplo n.º 20
0
 def __init__(self, layer_sizes: Sequence[int]):
     super(TerminatePredictor, self).__init__()
     self._net = snt.Sequential([
         networks.CriticMultiplexer(),
         networks.LayerNormMLP(layer_sizes + [1], activate_final=False),
         tf.sigmoid
     ])
Ejemplo n.º 21
0
    def test_snapshot_distribution(self):
        """Test that snapshotter correctly calls saves/restores snapshots."""
        # Create a test network.
        net1 = snt.Sequential([
            networks.LayerNormMLP([10, 10]),
            networks.MultivariateNormalDiagHead(1)
        ])
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net1, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net1}
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory)
        snapshotter.save()

        # Reload the test network.
        net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
        inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

        with tf.GradientTape() as tape:
            dist1 = net1(inputs)
            loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
            grads1 = tape.gradient(loss1, net1.trainable_variables)

        with tf.GradientTape() as tape:
            dist2 = net2(inputs)
            loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
            grads2 = tape.gradient(loss2, net2.trainable_variables)

        assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
Ejemplo n.º 22
0
def make_d4pg_agent(env_spec: specs.EnvironmentSpec, logger: Logger, checkpoint_path: str, hyperparams: Dict):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.NearZeroInitializedLinear(output_size=action_size),
        networks.TanhToSpec(env_spec.actions),
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            critic_network=networks.LayerNormMLP(layer_sizes=[*params.pop('critic_layers'), 1])
        ),
        networks.DiscreteValuedHead(vmin=-100.0, vmax=100.0, num_atoms=params.pop('atoms'))
    ])

    observation_network = tf.identity

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    actor = FeedForwardActor(policy_network=snt.Sequential([
        observation_network,
        policy_network
    ]))


    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    # The learner updates the parameters (and initializes them).
    agent = D4PG(
        environment_spec=env_spec,
        policy_network=policy_network,
        critic_network=critic_network,
        observation_network=observation_network,
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        logger=logger,
        checkpoint_path=checkpoint_path,
        **params
    )
    agent.__setattr__('eval_actor', actor)
    return agent
Ejemplo n.º 23
0
def make_value_func_dm_control(layer_sizes: str = '512,512,256') -> snt.Module:
    layer_sizes = list(map(int, layer_sizes.split(',')))
    value_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])
    return value_function
Ejemplo n.º 24
0
def make_dmpo_networks(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    critic_layer_sizes=(512, 512, 256),
    vmin=-150.,
    vmax=150.,
    num_atoms=51,
    policy_init_std=1e-9,
    obs_network=None,
    binary_grip_action=False):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  if policy_layer_sizes:
    policy_network = snt.Sequential([
        networks.LayerNormMLP([int(l) for l in policy_layer_sizes]),
        networks.MultivariateNormalDiagHead(
            num_dimensions,
            init_scale=policy_init_std,
            min_scale=1e-10)
    ])
  else:
    # Useful when initializing from a trained BC network.
    policy_network = snt.Sequential([
        ArmPolicyNormalDiagHead(
            binary_grip_action=binary_grip_action,
            num_dimensions=num_dimensions,
            init_scale=policy_init_std,
            min_scale=1e-10)
    ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential(
      [critic_network,
       networks.DiscreteValuedHead(vmin, vmax, num_atoms)])
  if obs_network is None:
    obs_network = tf_utils.batch_concat

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': obs_network,
  }
Ejemplo n.º 25
0
def make_feed_forward_networks(
    action_spec: specs.BoundedArray,
    z_spec: specs.BoundedArray,
    policy_layer_sizes: Tuple[int, ...] = (256, 256),
    critic_layer_sizes: Tuple[int, ...] = (256, 256),
    discriminator_layer_sizes: Tuple[int, ...] = (256, 256),
    hierarchical_controller_layer_sizes: Tuple[int, ...] = (256, 256),
    vmin: float = -150.,  # Minimum value for the Critic distribution.
    vmax: float = 150.,  # Maximum value for the Critic distribution.
    num_atoms: int = 51,  # Number of atoms for the discrete value distribution.
) -> Dict[str, types.TensorTransformation]:
    num_dimensions = np.prod(action_spec.shape, dtype=int)
    z_dim = np.prod(z_spec.shape, dtype=int)

    observation_network = tf2_utils.batch_concat

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes),
        networks.MultivariateNormalDiagHead(num_dimensions)
    ])

    critic_multiplexer = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes),
        action_network=networks.ClipToSpec(action_spec))

    critic_network = snt.Sequential([
        critic_multiplexer,
        networks.DiscreteValuedHead(vmin, vmax, num_atoms),
    ])

    # The discriminator in DIAYN uses the same architecture as the critic.
    discriminator_network = networks.LayerNormMLP(discriminator_layer_sizes +
                                                  (z_dim, ))

    hierarchical_controller_network = networks.LayerNormMLP(
        hierarchical_controller_layer_sizes + (z_dim, ))

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
        'discriminator': discriminator_network,
        'hierarchical_controller': hierarchical_controller_network,
    }
Ejemplo n.º 26
0
def make_value_func_dm_control(
    value_layer_sizes: str = '512,512,256',
    adversarial_layer_sizes: str = '512,512,256',
) -> Tuple[snt.Module, snt.Module]:
    layer_sizes = list(map(int, value_layer_sizes.split(',')))
    value_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])

    layer_sizes = list(map(int, adversarial_layer_sizes.split(',')))
    advsarial_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])

    return value_function, advsarial_function
Ejemplo n.º 27
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (50, 1024, 1024),
    critic_layer_sizes: Sequence[int] = (50, 1024, 1024),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes,
                              w_init=snt.initializers.Orthogonal(),
                              activation=tf.nn.relu,
                              activate_final=True),
        networks.MultivariateNormalDiagHead(
            num_dimensions,
            tanh_mean=False,
            init_scale=1.0,
            fixed_scale=False,
            use_tfd_independent=True,
            w_init=snt.initializers.Orthogonal())
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        observation_network=snt.Sequential([
            snt.Linear(critic_layer_sizes[0],
                       w_init=snt.initializers.Orthogonal()),
            snt.LayerNorm(axis=slice(1, None),
                          create_scale=True,
                          create_offset=True), tf.nn.tanh
        ]),
        critic_network=snt.nets.MLP(critic_layer_sizes[1:],
                                    w_init=snt.initializers.Orthogonal(),
                                    activation=tf.nn.relu,
                                    activate_final=True),
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential([
        critic_network,
        networks.DiscreteValuedHead(vmin,
                                    vmax,
                                    num_atoms,
                                    w_init=snt.initializers.Orthogonal())
    ])
    observation_network = networks.DrQTorso()

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Ejemplo n.º 28
0
def make_networks(
        action_spec: types.NestedSpec,
        policy_layer_sizes: Sequence[int] = (10, 10),
        critic_layer_sizes: Sequence[int] = (10, 10),
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions]
    critic_layer_sizes = list(critic_layer_sizes) + [1]

    policy_network = snt.Sequential(
        [networks.LayerNormMLP(policy_layer_sizes), tf.tanh])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes))

    return {
        'policy': policy_network,
        'critic': critic_network,
    }
Ejemplo n.º 29
0
def make_networks(
    action_spec,
    policy_layer_sizes=(10, 10),
    critic_layer_sizes=(10, 10),
):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes) + [1]

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions)
  ])
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes))

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Ejemplo n.º 30
0
def make_networks(action_spec: specs.BoundedArray):
    """Simple networks for testing.."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP([50], activate_final=True),
        networks.NearZeroInitializedLinear(num_dimensions),
        networks.TanhToSpec(action_spec)
    ])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            critic_network=networks.LayerNormMLP([50], activate_final=True)),
        networks.DiscreteValuedHead(-1., 1., 10)
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }