示例#1
0
def get_dummy_batched_obs_and_actions(
    environment_spec):
  """Generates dummy batched (batch_size=1) obs and actions."""
  dummy_observation = utils.tile_nested(
      utils.zeros_like(environment_spec.observations), 1)
  dummy_action = utils.tile_nested(
      utils.zeros_like(environment_spec.actions), 1)
  return dummy_observation, dummy_action
示例#2
0
def make_networks(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (256, 256)
) -> SACNetworks:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs):
        network = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes),
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu,
                        activate_final=True),
            networks_lib.NormalTanhDistribution(num_dimensions),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network1 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        network2 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return SACNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_obs), policy.apply),
        q_network=networks_lib.FeedForwardNetwork(
            lambda key: critic.init(key, dummy_obs, dummy_action),
            critic.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode())
示例#3
0
def init_state(nest: types.Nest) -> RunningStatisticsState:
  """Initializes the running statistics for the given nested structure."""
  dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32

  return RunningStatisticsState(
      count=0.,
      mean=utils.zeros_like(nest, dtype=dtype),
      summed_variance=utils.zeros_like(nest, dtype=dtype),
      # Initialize with ones to make sure normalization works correctly
      # in the initial state.
      std=utils.ones_like(nest, dtype=dtype))
示例#4
0
def make_discriminator(
    environment_spec: specs.EnvironmentSpec,
    discriminator_transformed: hk.TransformedWithState,
    logpi_fn: Optional[Callable[
        [networks_lib.Params, networks_lib.Observation, networks_lib.Action],
        jnp.ndarray]] = None
) -> networks_lib.FeedForwardNetwork:
  """Creates the discriminator network.

  Args:
    environment_spec: Environment spec
    discriminator_transformed: Haiku transformed of the discriminator.
    logpi_fn: If the policy logpi function is provided, its output will be
      removed from the discriminator logit.

  Returns:
    The network.
  """

  def apply_fn(params: hk.Params,
               policy_params: networks_lib.Params,
               state: hk.State,
               transitions: types.Transition,
               is_training: bool,
               rng: networks_lib.PRNGKey) -> networks_lib.Logits:
    output, state = discriminator_transformed.apply(
        params, state, transitions.observation, transitions.action,
        transitions.next_observation, is_training, rng)
    if logpi_fn is not None:
      logpi = logpi_fn(policy_params, transitions.observation,
                       transitions.action)

      # Quick Maths:
      # D = exp(output)/(exp(output) + pi(a|s))
      # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s))
      # logit(D) = output - logpi
      return output - logpi, state
    return output, state

  dummy_obs = utils.zeros_like(environment_spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)
  dummy_actions = utils.zeros_like(environment_spec.actions)
  dummy_actions = utils.add_batch_dim(dummy_actions)

  return networks_lib.FeedForwardNetwork(
      # pylint: disable=g-long-lambda
      init=lambda rng: discriminator_transformed.init(
          rng, dummy_obs, dummy_actions, dummy_obs, False, rng),
      apply=apply_fn)
示例#5
0
def make_networks(
    spec,
    build_actor_fn=build_standard_actor_fn,
    img_encoder_fn=None,
):
    """Creates networks used by the agent."""
    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    if isinstance(spec.actions, specs.DiscreteArray):
        num_dimensions = spec.actions.num_values
        # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions)
    else:
        num_dimensions = np.prod(spec.actions.shape, dtype=int)

    _actor_fn = build_actor_fn(num_dimensions)

    if img_encoder_fn is not None:
        img_encoder = hk.without_apply_rng(
            hk.transform(img_encoder_fn, apply_rng=True))
        key = jax.random.PRNGKey(seed=42)
        temp_encoder_params = img_encoder.init(key, dummy_obs['state_image'])
        dummy_hidden = img_encoder.apply(temp_encoder_params,
                                         dummy_obs['state_image'])
        img_encoder_network = networks_lib.FeedForwardNetwork(
            lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply)
        dummy_policy_input = dict(
            state_image=dummy_hidden,
            state_dense=dummy_obs['state_dense'],
        )
    else:
        img_encoder_fn = None
        dummy_policy_input = dummy_obs
        img_encoder_network = None

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    return BCNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_policy_input), policy.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode(),
        img_encoder=img_encoder_network,
    )
示例#6
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> D4PGNetworks:
    """Creates networks used by the agent."""

    action_spec = spec.actions

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    critic_atoms = jnp.linspace(vmin, vmax, num_atoms)

    def _actor_fn(obs):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True),
            networks_lib.NearZeroInitializedLinear(num_dimensions),
            networks_lib.TanhToSpec(action_spec),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(
                layer_sizes=[*critic_layer_sizes, num_atoms]),
        ])
        value = network([obs, action])
        return value, critic_atoms

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return D4PGNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda rng: policy.init(rng, dummy_obs), policy.apply),
        critic_network=networks_lib.FeedForwardNetwork(
            lambda rng: critic.init(rng, dummy_obs, dummy_action),
            critic.apply))
示例#7
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Tuple[int, ...] = (256, 256),
    critic_layer_sizes: Tuple[int, ...] = (256, 256),
    activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential([
        hk.nets.MLP(
            list(policy_layer_sizes),
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation,
            activate_final=True),
        networks_lib.NormalTanhDistribution(num_actions),
    ])
    return network(obs)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
      lambda key: policy.init(key, dummy_obs), policy.apply)

  def _critic_fn(obs, action):
    network = hk.Sequential([
        hk.nets.MLP(
            list(critic_layer_sizes) + [1],
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation),
    ])
    data = jnp.concatenate([obs, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply)

  return CRRNetworks(
      policy_network=policy_network,
      critic_network=critic_network,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
示例#8
0
def make_networks(
    env_spec: specs.EnvironmentSpec,
    forward_fn: Any,
    initial_state_fn: Any,
    unroll_fn: Any,
    batch_size) -> R2D2Networks:
  """Builds functional r2d2 network from recurrent model definitions."""

  # Make networks purely functional.
  forward_hk = hk.transform(forward_fn)
  initial_state_hk = hk.transform(initial_state_fn)
  unroll_hk = hk.transform(unroll_fn)

  # Define networks init functions.
  def initial_state_init_fn(rng, batch_size):
    return initial_state_hk.init(rng, batch_size)
  dummy_obs_batch = utils.tile_nested(
      utils.zeros_like(env_spec.observations), batch_size)
  dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch)
  def unroll_init_fn(rng, initial_state):
    return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

  # Make FeedForwardNetworks.
  forward = networks_lib.FeedForwardNetwork(
      init=forward_hk.init, apply=forward_hk.apply)
  unroll = networks_lib.FeedForwardNetwork(
      init=unroll_init_fn, apply=unroll_hk.apply)
  initial_state = networks_lib.FeedForwardNetwork(
      init=initial_state_init_fn, apply=initial_state_hk.apply)
  return R2D2Networks(
      forward=forward, unroll=unroll, initial_state=initial_state)
示例#9
0
def make_discrete_networks(
    environment_spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Sequence[int] = (512, ),
    use_conv: bool = True,
) -> PPONetworks:
    """Creates networks used by the agent for discrete action environments.

  Args:
    environment_spec: Environment spec used to define number of actions.
    hidden_layer_sizes: Network definition.
    use_conv: Whether to use a conv or MLP feature extractor.
  Returns:
    PPONetworks
  """

    num_actions = environment_spec.actions.num_values

    def forward_fn(inputs):
        layers = []
        if use_conv:
            layers.extend([networks_lib.AtariTorso()])
        layers.extend([
            hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu),
            networks_lib.CategoricalValueHead(num_values=num_actions)
        ])
        policy_value_network = hk.Sequential(layers)
        return policy_value_network(inputs)

    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
示例#10
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
示例#11
0
def make_haiku_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            networks_lib.CategoricalHead(num_actions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
示例#12
0
文件: models.py 项目: deepmind/acme
def make_ensemble_policy_prior(
    policy_prior_network: mbop_networks.PolicyPriorNetwork,
    spec: specs.EnvironmentSpec,
    use_round_robin: bool = True) -> PolicyPrior:
  """Creates an ensemble policy prior from its network.

  Args:
    policy_prior_network: The policy prior network.
    spec: Environment specification.
    use_round_robin: Whether to use round robin or mean to calculate the policy
      prior over the ensemble members.

  Returns:
    A policy prior.
  """

  def _policy_prior(params: networks.Params, key: networks.PRNGKey,
                    observation_t: networks.Observation,
                    action_tm1: networks.Action) -> networks.Action:
    # Regressor policies are deterministic.
    del key
    apply_fn = (
        ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean)
    return apply_fn(
        policy_prior_network.apply,
        params,
        observation_t=observation_t,
        action_tm1=action_tm1)

  dummy_action = utils.zeros_like(spec.actions)
  dummy_action = utils.add_batch_dim(dummy_action)

  return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action)
示例#13
0
def make_haiku_networks(
        env_spec: specs.EnvironmentSpec, forward_fn: Any,
        initial_state_fn: Any,
        unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]:
    """Builds functional impala network from recurrent model definitions."""
    # Make networks purely functional.
    forward_hk = hk.without_apply_rng(hk.transform(forward_fn))
    initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn))
    unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn))

    # Define networks init functions.
    def initial_state_init_fn(rng: networks_lib.PRNGKey) -> hk.Params:
        return initial_state_hk.init(rng)

    # Note: batch axis is not needed for the actors.
    dummy_obs = utils.zeros_like(env_spec.observations)
    dummy_obs_sequence = utils.add_batch_dim(dummy_obs)

    def unroll_init_fn(rng: networks_lib.PRNGKey,
                       initial_state: types.RecurrentState) -> hk.Params:
        return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

    return IMPALANetworks(forward_fn=forward_hk.apply,
                          unroll_init_fn=unroll_init_fn,
                          unroll_fn=unroll_hk.apply,
                          initial_state_init_fn=initial_state_init_fn,
                          initial_state_fn=initial_state_hk.apply)
示例#14
0
def make_networks(
        spec: specs.EnvironmentSpec,
        discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""

    if discrete_actions:
        final_layer_size = spec.actions.num_values
    else:
        final_layer_size = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs, is_training=False, key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        if discrete_actions:
            network = hk.nets.MLP([64, 64, final_layer_size])
        else:
            network = hk.Sequential([
                networks_lib.LayerNormMLP([64, 64], activate_final=True),
                networks_lib.NormalTanhDistribution(final_layer_size),
            ])
        return network(obs)

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)
    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)
    return network
示例#15
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
 def _update_spec(self, base_spec):
     dummy_obs = utils.zeros_like(base_spec)
     emb, _ = self._distance_fn(dummy_obs['state'], dummy_obs['goal'])
     full_spec = dict(base_spec)
     full_spec['embeddings'] = (dm_env_specs.Array(shape=emb.shape,
                                                   dtype=emb.dtype))
     return full_spec
示例#17
0
 def make_initial_state(key: jnp.ndarray) -> TrainingState:
   """Initialises the training state (parameters and optimiser state)."""
   dummy_obs = utils.zeros_like(obs_spec)
   dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
   initial_state = initial_state_fn.apply(None)
   initial_params = unroll_fn.init(key, dummy_obs, initial_state)
   initial_opt_state = optimizer.init(initial_params)
   return TrainingState(params=initial_params, opt_state=initial_opt_state)
示例#18
0
def make_network_from_module(
        module: hk.Transformed,
        spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork:
    """Creates a network with dummy init arguments using the specified module.

  Args:
    module: Module that expects one batch axis and one features axis for its
      inputs.
    spec: EnvironmentSpec shapes to derive dummy inputs.

  Returns:
    FeedForwardNetwork whose `init` method only takes a random key, and `apply`
    takes an observation and action and produces an output.
  """
    dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))
    dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
    return networks.FeedForwardNetwork(
        lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
示例#19
0
  def __init__(self,
               network: hk.Transformed,
               obs_spec: specs.Array,
               optimizer: optax.GradientTransformation,
               rng: hk.PRNGSequence,
               dataset: tf.data.Dataset,
               loss_fn: LossFn = _sparse_categorical_cross_entropy,
               counter: counting.Counter = None,
               logger: loggers.Logger = None):
    """Initializes the learner."""

    def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray:
      # Pull out the data needed for updates.
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data
      del r_t, d_t, o_t
      logits = network.apply(params, o_tm1)
      return jnp.mean(loss_fn(a_tm1, logits))

    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
      """Do a step of SGD."""
      grad_fn = jax.value_and_grad(loss)
      loss_value, gradients = grad_fn(state.params, sample)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      new_state = TrainingState(
          params=new_params, opt_state=new_opt_state, steps=steps)

      # Compute the global norm of the gradients for logging.
      global_gradient_norm = optax.global_norm(gradients)
      fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm}

      return new_state, fetches

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    # Initialise parameters and optimiser state.
    initial_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_opt_state = optimizer.init(initial_params)

    self._state = TrainingState(
        params=initial_params, opt_state=initial_opt_state, steps=0)

    self._sgd_step = jax.jit(sgd_step)
示例#20
0
def make_networks(
    spec,
    actor_fn_build_fn = build_mlp_actor_fn,
    actor_hidden_layer_sizes = (256, 256),
    critic_fn_build_fn = build_hk_batch_ensemble_mlp_critic_fn,
    # critic_fn_build_fn: Callable = build_mlp_critic_fn,
    critic_hidden_layer_sizes = (256, 256),
    use_double_q = False,
    ):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(spec.actions.shape, dtype=int)

  _actor_fn = actor_fn_build_fn(num_dimensions, actor_hidden_layer_sizes)
  policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

  _critic_fn = critic_fn_build_fn(critic_hidden_layer_sizes, use_double_q)
  critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True))
  critic_ensemble_init = ensemble_utils.transform_init_for_ensemble(critic.init, init_same=False)
  critic_ensemble_member_apply = ensemble_utils.transform_apply_for_ensemble_member(critic.apply)
  critic_same_batch_ensemble_apply = ensemble_utils.build_same_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2)
  critic_diff_batch_ensemble_apply = ensemble_utils.build_different_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.zeros_like(spec.actions)
  dummy_obs = utils.zeros_like(spec.observations)
  dummy_action = utils.add_batch_dim(dummy_action)
  dummy_obs = utils.add_batch_dim(dummy_obs)

  return BatchEnsembleMSGNetworks(
      policy_network=networks_lib.FeedForwardNetwork(
          lambda key: policy.init(key, dummy_obs), policy.apply),
      q_ensemble_init=lambda ensemble_size, key: critic_ensemble_init(ensemble_size, key, dummy_obs, dummy_action),
      q_ensemble_member_apply=critic_ensemble_member_apply,
      q_ensemble_same_batch_apply=critic_same_batch_ensemble_apply,
      q_ensemble_different_batch_apply=critic_diff_batch_ensemble_apply,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
示例#21
0
def default_models_to_snapshot(networks: SACNetworks,
                               spec: specs.EnvironmentSpec):
    """Defines default models to be snapshotted."""
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.zeros_like(spec.actions)
    dummy_key = jax.random.PRNGKey(0)

    def q_network(source: core.VariableSource) -> types.ModelToSnapshot:
        params = source.get_variables(['critic'])[0]
        return types.ModelToSnapshot(networks.q_network.apply, params, {
            'obs': dummy_obs,
            'action': dummy_action
        })

    def default_training_actor(
            source: core.VariableSource) -> types.ModelToSnapshot:
        params = source.get_variables(['policy'])[0]
        return types.ModelToSnapshot(apply_policy_and_sample(networks, False),
                                     params, {
                                         'key': dummy_key,
                                         'obs': dummy_obs
                                     })

    def default_eval_actor(
            source: core.VariableSource) -> types.ModelToSnapshot:
        params = source.get_variables(['policy'])[0]
        return types.ModelToSnapshot(apply_policy_and_sample(networks, True),
                                     params, {
                                         'key': dummy_key,
                                         'obs': dummy_obs
                                     })

    return {
        'q_network': q_network,
        'default_training_actor': default_training_actor,
        'default_eval_actor': default_eval_actor,
    }
示例#22
0
def make_ppo_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    forward_fn = functools.partial(ppo_forward_fn, num_actions=num_actions)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
示例#23
0
def make_dqn_atari_network(
        environment_spec: specs.EnvironmentSpec
) -> networks.FeedForwardNetwork:
    """Creates networks for training DQN on Atari."""
    def network(inputs):
        model = hk.Sequential([
            networks.AtariTorso(),
            hk.nets.MLP([512, environment_spec.actions.num_values]),
        ])
        return model(inputs)

    network_hk = hk.without_apply_rng(hk.transform(network))
    obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations))
    return networks.FeedForwardNetwork(
        init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply)
示例#24
0
        def make_initial_state(key: jnp.ndarray) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            dummy_obs = utils.zeros_like(obs_spec)
            dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.

            key, key_initial_state = jax.random.split(key)
            params = initial_state_init_fn(key_initial_state)
            # TODO(jferret): as it stands, we do not yet support
            # training the initial state params.
            initial_state = initial_state_fn(params)

            initial_params = unroll_init_fn(key, dummy_obs, initial_state)
            initial_opt_state = optimizer.init(initial_params)
            return TrainingState(params=initial_params,
                                 opt_state=initial_opt_state)
示例#25
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      jax.random.PRNGKey(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
示例#26
0
def make_networks(
    spec: specs.EnvironmentSpec,
    direct_rl_networks: DirectRLNetworks,
    layer_sizes: Tuple[int, ...] = (256, 256),
    intrinsic_reward_coefficient: float = 1.0,
    extrinsic_reward_coefficient: float = 0.0,
) -> RNDNetworks[DirectRLNetworks]:
    """Creates networks used by the agent and returns RNDNetworks.

  Args:
    spec: Environment spec.
    direct_rl_networks: Networks used by a direct rl algorithm.
    layer_sizes: Layer sizes.
    intrinsic_reward_coefficient: Multiplier on intrinsic reward.
    extrinsic_reward_coefficient: Multiplier on extrinsic reward.

  Returns:
    The RND networks.
  """
    def _rnd_fn(obs, act):
        # RND does not use the action but other variants like RED do.
        del act
        network = networks_lib.LayerNormMLP(list(layer_sizes))
        return network(obs)

    target = hk.without_apply_rng(hk.transform(_rnd_fn))
    predictor = hk.without_apply_rng(hk.transform(_rnd_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return RNDNetworks(
        target=networks_lib.FeedForwardNetwork(
            lambda key: target.init(key, dummy_obs, ()), target.apply),
        predictor=networks_lib.FeedForwardNetwork(
            lambda key: predictor.init(key, dummy_obs, ()), predictor.apply),
        direct_rl_networks=direct_rl_networks,
        get_reward=functools.partial(
            rnd_reward_fn,
            intrinsic_reward_coefficient=intrinsic_reward_coefficient,
            extrinsic_reward_coefficient=extrinsic_reward_coefficient))
示例#27
0
def make_flax_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates FLAX networks to be used by the agent."""

    num_actions = spec.actions.num_values

    class MLP(flax.deprecated.nn.Module):
        """MLP module."""
        def apply(self,
                  data: jnp.ndarray,
                  layer_sizes: Tuple[int],
                  activation: Callable[[jnp.ndarray],
                                       jnp.ndarray] = flax.deprecated.nn.relu,
                  kernel_init: object = jax.nn.initializers.lecun_uniform(),
                  activate_final: bool = False,
                  bias: bool = True):
            hidden = data
            for i, hidden_size in enumerate(layer_sizes):
                hidden = flax.deprecated.nn.Dense(hidden,
                                                  hidden_size,
                                                  name=f'hidden_{i}',
                                                  kernel_init=kernel_init,
                                                  bias=bias)
                if i != len(layer_sizes) - 1 or activate_final:
                    hidden = activation(hidden)
            return hidden

    class PolicyValueModule(flax.deprecated.nn.Module):
        """MLP module."""
        def apply(self, inputs: jnp.ndarray):
            inputs = utils.batch_concat(inputs)
            logits = MLP(inputs, [64, 64, num_actions])
            value = MLP(inputs, [64, 64, 1])
            value = jnp.squeeze(value, axis=-1)
            return tfd.Categorical(logits=logits), value

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: PolicyValueModule.init(rng, dummy_obs)[1],
        PolicyValueModule.call)
示例#28
0
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
示例#29
0
def make_network(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""
    num_actions = spec.actions.num_values

    def actor_fn(obs, is_training=True, key=None):
        # is_training and key allows to utilize train/test dependant modules
        # like dropout.
        del is_training
        del key
        mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])])
        return mlp(obs)

    policy = hk.without_apply_rng(hk.transform(actor_fn))

    # Create dummy observations to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)

    return network
示例#30
0
  def test_recurrent(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)
    output_size = env_spec.actions.num_values
    obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    rng = hk.PRNGSequence(1)

    @hk.transform
    def network(inputs: jnp.ndarray, state: hk.LSTMState):
      return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)

    @hk.transform
    def initial_state(batch_size: int):
      network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
      return network.initial_state(batch_size)

    initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
    params = network.init(next(rng), obs, initial_state)

    def policy(
        params: jnp.ndarray,
        key: jnp.ndarray,
        observation: jnp.ndarray,
        core_state: hk.LSTMState
    ) -> Tuple[jnp.ndarray, hk.LSTMState]:
      del key  # Unused for test-case deterministic policy.
      action_values, core_state = network.apply(params, observation, core_state)
      return jnp.argmax(action_values, axis=-1), core_state

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.RecurrentActor(
        policy, hk.PRNGSequence(1), initial_state, variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)