Exemplo n.º 1
0
def make_networks(
    action_spec: types.NestedSpec,
    policy_layer_sizes: Sequence[int] = (10, 10),
    critic_layer_sizes: Sequence[int] = (10, 10),
) -> Dict[str, snt.Module]:
  """Creates networks used by the agent."""
  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      tf2_utils.batch_concat,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.3,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer()
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Exemplo n.º 2
0
def make_mpo_networks(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    critic_layer_sizes=(512, 512, 256),
    policy_init_std=1e-9,
    obs_network=None):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes) + [1]

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          init_scale=policy_init_std,
          min_scale=1e-10)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
      action_network=networks.ClipToSpec(action_spec))
  if obs_network is None:
    obs_network = tf_utils.batch_concat

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': obs_network,
  }
Exemplo n.º 3
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (50, ),
    critic_layer_sizes: Sequence[int] = (50, ),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            tanh_mean=True,
                                            init_scale=0.3,
                                            fixed_scale=True,
                                            use_tfd_independent=False)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        critic_network=networks.LayerNormMLP(critic_layer_sizes,
                                             activate_final=True),
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential(
        [critic_network,
         networks.DiscreteValuedHead(vmin, vmax, num_atoms)])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Exemplo n.º 4
0
def make_networks(action_spec: specs.BoundedArray):
    """Creates simple networks for testing.."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    # Create the observation network shared between the policy and critic.
    observation_network = tf2_utils.batch_concat

    # Create the policy network (head) and the evaluation network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP([50], activate_final=True),
        networks.NearZeroInitializedLinear(num_dimensions),
        networks.TanhToSpec(action_spec)
    ])
    evaluator_network = snt.Sequential([observation_network, policy_network])

    # Create the critic network.
    critic_network = snt.Sequential([
        # The multiplexer concatenates the observations/actions.
        networks.CriticMultiplexer(),
        networks.LayerNormMLP([50], activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
        'evaluator': evaluator_network,
    }
 def __init__(self, layer_sizes: Sequence[int]):
     super(TerminatePredictor, self).__init__()
     self._net = snt.Sequential([
         networks.CriticMultiplexer(),
         networks.LayerNormMLP(layer_sizes + [1], activate_final=False),
         tf.sigmoid
     ])
Exemplo n.º 6
0
def make_networks(
    environment_spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (256, 256),
    critic_layer_sizes: Sequence[int] = (256, 256),
) -> Mapping[str, types.TensorTransformation]:
    """Creates the networks used by the agent."""

    # Get total number of action dimensions from action spec.
    num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

    # Create the shared observation network; here simply a state-less operation.
    observation_network = tf.identity

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        sac.model.SquashedGaussianValueHead(num_dimensions),
    ])

    # Create the critic network.
    critic_network = snt.Sequential([
        # The multiplexer concatenates the observations/actions.
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Exemplo n.º 7
0
def make_networks(
    action_spec: specs.Array,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
) -> Dict[str, snt.Module]:
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  critic_layer_sizes = list(critic_layer_sizes)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions),
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = snt.Sequential([
      networks.CriticMultiplexer(
          critic_network=networks.LayerNormMLP(critic_layer_sizes)),
      networks.DiscreteValuedHead(0., 1., 10),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
  }
Exemplo n.º 8
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Dict[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(num_dimensions),
      networks.TanhToSpec(action_spec)
  ])
  critic_network = snt.Sequential([
      # The multiplexer concatenates the observations/actions.
      networks.CriticMultiplexer(),
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf2_utils.batch_concat,
  }
Exemplo n.º 9
0
    def __init__(self,
                 environment_spec: EnvironmentSpec,
                 layer_sizes: Sequence[int],
                 num_cat=10):
        super(MixtureDensity, self).__init__()

        if isinstance(environment_spec.observations, dict):
            obs_size = 0
            for obs_spec in environment_spec.observations.values():
                obs_size += int(np.prod(obs_spec.shape))
        else:
            obs_size = int(np.prod(environment_spec.observations.shape))
        self.obs_size = obs_size
        self.num_cat = num_cat

        action_network = functools.partial(
            tf.one_hot, depth=environment_spec.actions.num_values)
        self._net = snt.Sequential([
            networks.CriticMultiplexer(action_network=action_network),
            snt.nets.MLP(layer_sizes, activate_final=True)
        ])

        self._discount_logits = snt.Linear(1)
        self._mix_logits = snt.Linear(num_cat)
        self._locs = snt.Linear(num_cat * self.obs_size)
        self._scales = snt.Linear(num_cat * self.obs_size)
Exemplo n.º 10
0
def _make_networks(
    actions_dim: int,
    state_dim: int,
    policy_layers: Sequence[int] = [5, 5],
    critic_layers: Sequence[int] = [5, 5],
):

    # Create the policy network.
    policy_network = PolicyMLP(policy_layers, actions_dim)

    # Create the critic network.
    critic_layers = list(critic_layers) + [1]
    critic_network = snt.Sequential([
        # The multiplexer concatenates the observations/actions.
        networks.CriticMultiplexer(),
        snt.nets.MLP(critic_layers, activate_final=False),
    ])

    observation_network = InputStandardization(shape=state_dim)

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Exemplo n.º 11
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Dict[str, types.TensorTransformation]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            init_scale=0.7,
                                            use_tfd_independent=True)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    multiplexer = networks.CriticMultiplexer(
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential([
        multiplexer,
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Exemplo n.º 12
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, Union[snt.Module, Callable[[tf.Tensor], tf.Tensor]]]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(num_dimensions),
        networks.TanhToSpec(action_spec)
    ])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.DiscreteValuedHead(vmin, vmax, num_atoms),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Exemplo n.º 13
0
def make_d4pg_networks(
    action_spec,
    policy_layer_sizes=(256, 256, 256),
    critic_layer_sizes=(512, 512, 256),
    vmin=-150.,
    vmax=150.,
    num_atoms=201):
  """Creates networks used by the d4pg agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)
  policy_layer_sizes = list(policy_layer_sizes) + [int(num_dimensions)]

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.TanhToSpec(action_spec)
  ])

  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = snt.Sequential([
      networks.CriticMultiplexer(
          critic_network=networks.LayerNormMLP(
              critic_layer_sizes, activate_final=True)),
      networks.DiscreteValuedHead(vmin, vmax, num_atoms)
  ])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf2_utils.batch_concat,
  }
Exemplo n.º 14
0
def make_dmpo_networks(
    action_spec,
    policy_layer_sizes = (300, 200),
    critic_layer_sizes = (400, 300),
    vmin = -150.,
    vmax = 150.,
    num_atoms = 51,
):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      networks.LayerNormMLP(policy_layer_sizes),
      networks.MultivariateNormalDiagHead(num_dimensions)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  critic_network = networks.CriticMultiplexer(
      critic_network=networks.LayerNormMLP(critic_layer_sizes),
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential(
      [critic_network,
       networks.DiscreteValuedHead(vmin, vmax, num_atoms)])

  return {
      'policy': policy_network,
      'critic': critic_network,
      'observation': tf_utils.batch_concat,
  }
Exemplo n.º 15
0
 def __init__(self, layer_sizes: Sequence[int]):
     super(InstrumentalFeature, self).__init__()
     self._net = snt.Sequential([
         networks.CriticMultiplexer(),
         networks.LayerNormMLP(layer_sizes, activate_final=True)
     ])
     self._feature_dim = layer_sizes[-1] + 1
Exemplo n.º 16
0
def make_networks(
    action_spec: types.NestedSpec,
    policy_layer_sizes: Sequence[int] = (10, 10),
    critic_layer_sizes: Sequence[int] = (10, 10),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    policy_layer_sizes = list(policy_layer_sizes) + [num_dimensions]

    policy_network = snt.Sequential(
        [networks.LayerNormMLP(policy_layer_sizes), tf.tanh])
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            critic_layer_sizes, activate_final=True)),
        networks.DiscreteValuedHead(vmin, vmax, num_atoms)
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
    }
Exemplo n.º 17
0
def make_default_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
) -> Mapping[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)

  policy_network = snt.Sequential([
      tf2_utils.batch_concat,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.3,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer(
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])

  return {
      "policy": policy_network,
      "critic": critic_network,
  }
Exemplo n.º 18
0
 def __init__(self, environment_spec, layer_sizes: Sequence[int]):
     super(ValueFeature, self).__init__()
     action_network = functools.partial(
         tf.one_hot, depth=environment_spec.actions.num_values)
     self._net = snt.Sequential([
         networks.CriticMultiplexer(action_network=action_network),
         snt.nets.MLP(layer_sizes, activate_final=True)
     ])
Exemplo n.º 19
0
def make_value_func_dm_control(layer_sizes: str = '512,512,256') -> snt.Module:
    layer_sizes = list(map(int, layer_sizes.split(',')))
    value_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])
    return value_function
Exemplo n.º 20
0
 def __init__(self, environment_spec, layer_sizes: Sequence[int]):
     super(TerminatePredictor, self).__init__()
     action_network = functools.partial(
         tf.one_hot, depth=environment_spec.actions.num_values)
     self._net = snt.Sequential([
         networks.CriticMultiplexer(action_network=action_network),
         snt.nets.MLP(layer_sizes + [1], activate_final=False), tf.sigmoid
     ])
Exemplo n.º 21
0
 def __init__(self, environment_spec, n_component, gamma):
     super(ValueFeature, self).__init__()
     action_network = functools.partial(
         tf.one_hot, depth=environment_spec.actions.num_values)
     self._net = snt.Sequential([
         networks.CriticMultiplexer(action_network=action_network),
         RandomFourierFeature(n_component=n_component, gamma=gamma)])
     self._feature_dim = n_component
Exemplo n.º 22
0
def make_value_func_dm_control(
    value_layer_sizes: str = '512,512,256',
    adversarial_layer_sizes: str = '512,512,256',
) -> Tuple[snt.Module, snt.Module]:
    layer_sizes = list(map(int, value_layer_sizes.split(',')))
    value_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])

    layer_sizes = list(map(int, adversarial_layer_sizes.split(',')))
    advsarial_function = snt.Sequential([
        networks.CriticMultiplexer(),
        networks.LayerNormMLP(layer_sizes, activate_final=True),
        snt.Linear(1)
    ])

    return value_function, advsarial_function
Exemplo n.º 23
0
def get_dm_control_median(dataset):
    data = next(iter(dataset)).data
    obs, action = data[:2]

    net = networks.CriticMultiplexer()
    inputs = net(obs, action)

    arr = inputs.numpy()
    dists = cdist(arr, arr, "sqeuclidean")
    return 1.0 / np.median(dists)
Exemplo n.º 24
0
def make_networks(
    action_spec: specs.BoundedArray,
    num_critic_heads: int,
    policy_layer_sizes: Sequence[int] = (50, ),
    critic_layer_sizes: Sequence[int] = (50, ),
    num_layers_shared: int = 1,
    distributional_critic: bool = True,
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            tanh_mean=False,
                                            init_scale=0.69)
    ])

    if not distributional_critic:
        critic_layer_sizes = list(critic_layer_sizes) + [1]

    if not num_layers_shared:
        # No layers are shared
        critic_network_base = None
    else:
        critic_network_base = networks.LayerNormMLP(
            critic_layer_sizes[:num_layers_shared], activate_final=True)
    critic_network_heads = [
        snt.nets.MLP(critic_layer_sizes,
                     activation=tf.nn.elu,
                     activate_final=False) for _ in range(num_critic_heads)
    ]
    if distributional_critic:
        critic_network_heads = [
            snt.Sequential(
                [c, networks.DiscreteValuedHead(vmin, vmax, num_atoms)])
            for c in critic_network_heads
        ]
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            critic_network=critic_network_base,
            action_network=networks.ClipToSpec(action_spec)),
        networks.Multihead(network_heads=critic_network_heads),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': tf2_utils.batch_concat,
    }
Exemplo n.º 25
0
def make_value_func_bsuite(environment_spec: EnvironmentSpec,
                           layer_sizes: str = '50,50',
                           ) -> snt.Module:
    layer_sizes = list(map(int, layer_sizes.split(',')))
    action_network = functools.partial(
        tf.one_hot, depth=environment_spec.actions.num_values)
    value_function = snt.Sequential([
        networks.CriticMultiplexer(action_network=action_network),
        snt.nets.MLP(layer_sizes, activate_final=True),
        snt.Linear(1)])
    return value_function
Exemplo n.º 26
0
def make_networks(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (50, 1024, 1024),
    critic_layer_sizes: Sequence[int] = (50, 1024, 1024),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> Dict[str, snt.Module]:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes,
                              w_init=snt.initializers.Orthogonal(),
                              activation=tf.nn.relu,
                              activate_final=True),
        networks.MultivariateNormalDiagHead(
            num_dimensions,
            tanh_mean=False,
            init_scale=1.0,
            fixed_scale=False,
            use_tfd_independent=True,
            w_init=snt.initializers.Orthogonal())
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = networks.CriticMultiplexer(
        observation_network=snt.Sequential([
            snt.Linear(critic_layer_sizes[0],
                       w_init=snt.initializers.Orthogonal()),
            snt.LayerNorm(axis=slice(1, None),
                          create_scale=True,
                          create_offset=True), tf.nn.tanh
        ]),
        critic_network=snt.nets.MLP(critic_layer_sizes[1:],
                                    w_init=snt.initializers.Orthogonal(),
                                    activation=tf.nn.relu,
                                    activate_final=True),
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential([
        critic_network,
        networks.DiscreteValuedHead(vmin,
                                    vmax,
                                    num_atoms,
                                    w_init=snt.initializers.Orthogonal())
    ])
    observation_network = networks.DrQTorso()

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }
Exemplo n.º 27
0
def make_lstm_mpo_agent(env_spec: specs.EnvironmentSpec, logger: Logger,
                        hyperparams: Dict, checkpoint_path: str):
    params = DEFAULT_PARAMS.copy()
    params.update(hyperparams)
    action_size = np.prod(env_spec.actions.shape, dtype=int).item()
    policy_network = snt.Sequential([
        networks.LayerNormMLP(
            layer_sizes=[*params.pop('policy_layers'), action_size]),
        networks.MultivariateNormalDiagHead(num_dimensions=action_size)
    ])

    critic_network = snt.Sequential([
        networks.CriticMultiplexer(critic_network=networks.LayerNormMLP(
            layer_sizes=[*params.pop('critic_layers'), 1]))
    ])

    observation_network = networks.DeepRNN([
        networks.LayerNormMLP(layer_sizes=params.pop('observation_layers')),
        networks.LSTM(hidden_size=200)
    ])

    loss_param_keys = list(
        filter(lambda key: key.startswith('loss_'), params.keys()))
    loss_params = dict([(k.replace('loss_', ''), params.pop(k))
                        for k in loss_param_keys])
    policy_loss_module = losses.MPO(**loss_params)

    # Create a replay server to add data to.

    # Make sure observation network is a Sonnet Module.
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create optimizers.
    policy_optimizer = Adam(params.pop('policy_lr'))
    critic_optimizer = Adam(params.pop('critic_lr'))

    actor = RecurrentActor(
        networks.DeepRNN([
            observation_network, policy_network,
            networks.StochasticModeHead()
        ]))

    # The learner updates the parameters (and initializes them).
    return RecurrentMPO(environment_spec=env_spec,
                        policy_network=policy_network,
                        critic_network=critic_network,
                        observation_network=observation_network,
                        policy_loss_module=policy_loss_module,
                        policy_optimizer=policy_optimizer,
                        critic_optimizer=critic_optimizer,
                        logger=logger,
                        checkpoint_path=checkpoint_path,
                        **params), actor
Exemplo n.º 28
0
def make_network_with_prior(
    action_spec: specs.BoundedArray,
    policy_layer_sizes: Sequence[int] = (200, 100),
    critic_layer_sizes: Sequence[int] = (400, 300),
    prior_layer_sizes: Sequence[int] = (200, 100),
    policy_keys: Optional[Sequence[str]] = None,
    prior_keys: Optional[Sequence[str]] = None,
) -> Mapping[str, types.TensorTransformation]:
  """Creates networks used by the agent."""

  # Get total number of action dimensions from action spec.
  num_dimensions = np.prod(action_spec.shape, dtype=int)
  flatten_concat_policy = functools.partial(
      svg0_utils.batch_concat_selection, concat_keys=policy_keys)
  flatten_concat_prior = functools.partial(
      svg0_utils.batch_concat_selection, concat_keys=prior_keys)

  policy_network = snt.Sequential([
      flatten_concat_policy,
      networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.1,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  # The multiplexer concatenates the (maybe transformed) observations/actions.
  multiplexer = networks.CriticMultiplexer(
      observation_network=flatten_concat_policy,
      action_network=networks.ClipToSpec(action_spec))
  critic_network = snt.Sequential([
      multiplexer,
      networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
      networks.NearZeroInitializedLinear(1),
  ])
  prior_network = snt.Sequential([
      flatten_concat_prior,
      networks.LayerNormMLP(prior_layer_sizes, activate_final=True),
      networks.MultivariateNormalDiagHead(
          num_dimensions,
          tanh_mean=True,
          min_scale=0.1,
          init_scale=0.7,
          fixed_scale=False,
          use_tfd_independent=False)
  ])
  return {
      "policy": policy_network,
      "critic": critic_network,
      "prior": prior_network,
  }
Exemplo n.º 29
0
def get_bsuite_median(environment_spec, dataset):
    data = next(iter(dataset)).data
    obs, action = data[:2]

    action_network = functools.partial(
        tf.one_hot, depth=environment_spec.actions.num_values)
    net = networks.CriticMultiplexer(action_network=action_network)
    inputs = net(obs, action)

    arr = inputs.numpy()
    dists = cdist(arr, arr, "sqeuclidean")
    return 1.0 / np.median(dists)
Exemplo n.º 30
0
def make_default_networks(
    environment_spec: specs.EnvironmentSpec,
    *,
    policy_layer_sizes: Sequence[int] = (256, 256, 256),
    critic_layer_sizes: Sequence[int] = (512, 512, 256),
    policy_init_scale: float = 0.7,
    critic_init_scale: float = 1e-3,
    critic_num_components: int = 5,
) -> Mapping[str, snt.Module]:
    """Creates networks used by the agent."""

    # Unpack the environment spec to get appropriate shapes, dtypes, etc.
    act_spec = environment_spec.actions
    obs_spec = environment_spec.observations
    num_dimensions = np.prod(act_spec.shape, dtype=int)

    # Create the observation network and make sure it's a Sonnet module.
    observation_network = tf2_utils.batch_concat
    observation_network = tf2_utils.to_sonnet_module(observation_network)

    # Create the policy network.
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            init_scale=policy_init_scale,
                                            use_tfd_independent=True)
    ])

    # The multiplexer concatenates the (maybe transformed) observations/actions.
    critic_network = snt.Sequential([
        networks.CriticMultiplexer(
            action_network=networks.ClipToSpec(act_spec)),
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.GaussianMixtureHead(num_dimensions=1,
                                     num_components=critic_num_components,
                                     init_scale=critic_init_scale)
    ])

    # Create network variables.
    # Get embedding spec by creating observation network variables.
    emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
    tf2_utils.create_variables(policy_network, [emb_spec])
    tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
    }