示例#1
0
def make_networks(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (256, 256)
) -> SACNetworks:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs):
        network = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes),
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu,
                        activate_final=True),
            networks_lib.NormalTanhDistribution(num_dimensions),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network1 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        network2 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return SACNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_obs), policy.apply),
        q_network=networks_lib.FeedForwardNetwork(
            lambda key: critic.init(key, dummy_obs, dummy_action),
            critic.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode())
示例#2
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation,
                   state: SimpleActorCoreRecurrentState[RecurrentState]):
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     rng = state.rng
     rng, policy_rng = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     recurrent_state = utils.add_batch_dim(state.recurrent_state)
     action, new_recurrent_state = utils.squeeze_batch_dim(
         recurrent_policy(params, policy_rng, observation, recurrent_state))
     return action, SimpleActorCoreRecurrentState(rng, new_recurrent_state)
示例#3
0
 def critic_mean(
     critic_params: networks_lib.Params,
     observation: types.NestedArray,
     action: types.NestedArray,
 ) -> jnp.ndarray:
   # We add batch dimension to make sure batch concat in critic_network
   # works correctly.
   observation = utils.add_batch_dim(observation)
   action = utils.add_batch_dim(action)
   # Computes the mean action-value estimate.
   logits, atoms = critic_network.apply(critic_params, observation, action)
   logits = utils.squeeze_batch_dim(logits)
   probabilities = jax.nn.softmax(logits)
   return jnp.sum(probabilities * atoms, axis=-1)
示例#4
0
def make_discriminator(
    environment_spec: specs.EnvironmentSpec,
    discriminator_transformed: hk.TransformedWithState,
    logpi_fn: Optional[Callable[
        [networks_lib.Params, networks_lib.Observation, networks_lib.Action],
        jnp.ndarray]] = None
) -> networks_lib.FeedForwardNetwork:
  """Creates the discriminator network.

  Args:
    environment_spec: Environment spec
    discriminator_transformed: Haiku transformed of the discriminator.
    logpi_fn: If the policy logpi function is provided, its output will be
      removed from the discriminator logit.

  Returns:
    The network.
  """

  def apply_fn(params: hk.Params,
               policy_params: networks_lib.Params,
               state: hk.State,
               transitions: types.Transition,
               is_training: bool,
               rng: networks_lib.PRNGKey) -> networks_lib.Logits:
    output, state = discriminator_transformed.apply(
        params, state, transitions.observation, transitions.action,
        transitions.next_observation, is_training, rng)
    if logpi_fn is not None:
      logpi = logpi_fn(policy_params, transitions.observation,
                       transitions.action)

      # Quick Maths:
      # D = exp(output)/(exp(output) + pi(a|s))
      # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s))
      # logit(D) = output - logpi
      return output - logpi, state
    return output, state

  dummy_obs = utils.zeros_like(environment_spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)
  dummy_actions = utils.zeros_like(environment_spec.actions)
  dummy_actions = utils.add_batch_dim(dummy_actions)

  return networks_lib.FeedForwardNetwork(
      # pylint: disable=g-long-lambda
      init=lambda rng: discriminator_transformed.init(
          rng, dummy_obs, dummy_actions, dummy_obs, False, rng),
      apply=apply_fn)
示例#5
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> D4PGNetworks:
    """Creates networks used by the agent."""

    action_spec = spec.actions

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    critic_atoms = jnp.linspace(vmin, vmax, num_atoms)

    def _actor_fn(obs):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True),
            networks_lib.NearZeroInitializedLinear(num_dimensions),
            networks_lib.TanhToSpec(action_spec),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(
                layer_sizes=[*critic_layer_sizes, num_atoms]),
        ])
        value = network([obs, action])
        return value, critic_atoms

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return D4PGNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda rng: policy.init(rng, dummy_obs), policy.apply),
        critic_network=networks_lib.FeedForwardNetwork(
            lambda rng: critic.init(rng, dummy_obs, dummy_action),
            critic.apply))
示例#6
0
    def test_step(self):
        simple_spec = specs.Array(shape=(), dtype=float)

        spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec,
                                     simple_spec)

        discriminator = _make_discriminator(spec)
        ail_network = ail_networks.AILNetworks(discriminator,
                                               imitation_reward_fn=lambda x: x,
                                               direct_rl_networks=None)

        loss = losses.gail_loss()

        optimizer = optax.adam(.01)

        step = jax.jit(
            functools.partial(ail_learning.ail_update_step,
                              optimizer=optimizer,
                              ail_network=ail_network,
                              loss_fn=loss))

        zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0.,
                                           0., np.array([0.]))
        zero_transition = utils.add_batch_dim(zero_transition)

        one_transition = types.Transition(np.array([1.]), np.array([0.]), 0.,
                                          0., np.array([0.]))
        one_transition = utils.add_batch_dim(one_transition)

        key = jax.random.PRNGKey(0)
        discriminator_params, discriminator_state = discriminator.init(key)

        state = ail_learning.DiscriminatorTrainingState(
            optimizer_state=optimizer.init(discriminator_params),
            discriminator_params=discriminator_params,
            discriminator_state=discriminator_state,
            policy_params=None,
            key=key,
            steps=0,
        )

        expected_loss = [1.062, 1.057, 1.052]

        for i in range(3):
            state, loss = step(state, (one_transition, zero_transition))
            self.assertAlmostEqual(loss['total_loss'],
                                   expected_loss[i],
                                   places=3)
示例#7
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Tuple[int, ...] = (256, 256),
    critic_layer_sizes: Tuple[int, ...] = (256, 256),
    activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential([
        hk.nets.MLP(
            list(policy_layer_sizes),
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation,
            activate_final=True),
        networks_lib.NormalTanhDistribution(num_actions),
    ])
    return network(obs)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
      lambda key: policy.init(key, dummy_obs), policy.apply)

  def _critic_fn(obs, action):
    network = hk.Sequential([
        hk.nets.MLP(
            list(critic_layer_sizes) + [1],
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation),
    ])
    data = jnp.concatenate([obs, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply)

  return CRRNetworks(
      policy_network=policy_network,
      critic_network=critic_network,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
示例#8
0
def make_networks(
    spec,
    build_actor_fn=build_standard_actor_fn,
    img_encoder_fn=None,
):
    """Creates networks used by the agent."""
    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    if isinstance(spec.actions, specs.DiscreteArray):
        num_dimensions = spec.actions.num_values
        # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions)
    else:
        num_dimensions = np.prod(spec.actions.shape, dtype=int)

    _actor_fn = build_actor_fn(num_dimensions)

    if img_encoder_fn is not None:
        img_encoder = hk.without_apply_rng(
            hk.transform(img_encoder_fn, apply_rng=True))
        key = jax.random.PRNGKey(seed=42)
        temp_encoder_params = img_encoder.init(key, dummy_obs['state_image'])
        dummy_hidden = img_encoder.apply(temp_encoder_params,
                                         dummy_obs['state_image'])
        img_encoder_network = networks_lib.FeedForwardNetwork(
            lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply)
        dummy_policy_input = dict(
            state_image=dummy_hidden,
            state_dense=dummy_obs['state_dense'],
        )
    else:
        img_encoder_fn = None
        dummy_policy_input = dummy_obs
        img_encoder_network = None

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    return BCNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_policy_input), policy.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode(),
        img_encoder=img_encoder_network,
    )
示例#9
0
def make_networks(
        spec: specs.EnvironmentSpec,
        discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""

    if discrete_actions:
        final_layer_size = spec.actions.num_values
    else:
        final_layer_size = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs, is_training=False, key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        if discrete_actions:
            network = hk.nets.MLP([64, 64, final_layer_size])
        else:
            network = hk.Sequential([
                networks_lib.LayerNormMLP([64, 64], activate_final=True),
                networks_lib.NormalTanhDistribution(final_layer_size),
            ])
        return network(obs)

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)
    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)
    return network
示例#10
0
def make_discrete_networks(
    environment_spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Sequence[int] = (512, ),
    use_conv: bool = True,
) -> PPONetworks:
    """Creates networks used by the agent for discrete action environments.

  Args:
    environment_spec: Environment spec used to define number of actions.
    hidden_layer_sizes: Network definition.
    use_conv: Whether to use a conv or MLP feature extractor.
  Returns:
    PPONetworks
  """

    num_actions = environment_spec.actions.num_values

    def forward_fn(inputs):
        layers = []
        if use_conv:
            layers.extend([networks_lib.AtariTorso()])
        layers.extend([
            hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu),
            networks_lib.CategoricalValueHead(num_values=num_actions)
        ])
        policy_value_network = hk.Sequential(layers)
        return policy_value_network(inputs)

    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
示例#11
0
def make_haiku_networks(
        env_spec: specs.EnvironmentSpec, forward_fn: Any,
        initial_state_fn: Any,
        unroll_fn: Any) -> IMPALANetworks[types.RecurrentState]:
    """Builds functional impala network from recurrent model definitions."""
    # Make networks purely functional.
    forward_hk = hk.without_apply_rng(hk.transform(forward_fn))
    initial_state_hk = hk.without_apply_rng(hk.transform(initial_state_fn))
    unroll_hk = hk.without_apply_rng(hk.transform(unroll_fn))

    # Define networks init functions.
    def initial_state_init_fn(rng: networks_lib.PRNGKey) -> hk.Params:
        return initial_state_hk.init(rng)

    # Note: batch axis is not needed for the actors.
    dummy_obs = utils.zeros_like(env_spec.observations)
    dummy_obs_sequence = utils.add_batch_dim(dummy_obs)

    def unroll_init_fn(rng: networks_lib.PRNGKey,
                       initial_state: types.RecurrentState) -> hk.Params:
        return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

    return IMPALANetworks(forward_fn=forward_hk.apply,
                          unroll_init_fn=unroll_init_fn,
                          unroll_fn=unroll_hk.apply,
                          initial_state_init_fn=initial_state_init_fn,
                          initial_state_fn=initial_state_hk.apply)
示例#12
0
def make_networks(
    env_spec: specs.EnvironmentSpec,
    forward_fn: Any,
    initial_state_fn: Any,
    unroll_fn: Any,
    batch_size) -> R2D2Networks:
  """Builds functional r2d2 network from recurrent model definitions."""

  # Make networks purely functional.
  forward_hk = hk.transform(forward_fn)
  initial_state_hk = hk.transform(initial_state_fn)
  unroll_hk = hk.transform(unroll_fn)

  # Define networks init functions.
  def initial_state_init_fn(rng, batch_size):
    return initial_state_hk.init(rng, batch_size)
  dummy_obs_batch = utils.tile_nested(
      utils.zeros_like(env_spec.observations), batch_size)
  dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch)
  def unroll_init_fn(rng, initial_state):
    return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

  # Make FeedForwardNetworks.
  forward = networks_lib.FeedForwardNetwork(
      init=forward_hk.init, apply=forward_hk.apply)
  unroll = networks_lib.FeedForwardNetwork(
      init=unroll_init_fn, apply=unroll_hk.apply)
  initial_state = networks_lib.FeedForwardNetwork(
      init=initial_state_init_fn, apply=initial_state_hk.apply)
  return R2D2Networks(
      forward=forward, unroll=unroll, initial_state=initial_state)
示例#13
0
文件: models.py 项目: deepmind/acme
def make_ensemble_policy_prior(
    policy_prior_network: mbop_networks.PolicyPriorNetwork,
    spec: specs.EnvironmentSpec,
    use_round_robin: bool = True) -> PolicyPrior:
  """Creates an ensemble policy prior from its network.

  Args:
    policy_prior_network: The policy prior network.
    spec: Environment specification.
    use_round_robin: Whether to use round robin or mean to calculate the policy
      prior over the ensemble members.

  Returns:
    A policy prior.
  """

  def _policy_prior(params: networks.Params, key: networks.PRNGKey,
                    observation_t: networks.Observation,
                    action_tm1: networks.Action) -> networks.Action:
    # Regressor policies are deterministic.
    del key
    apply_fn = (
        ensemble.apply_round_robin if use_round_robin else ensemble.apply_mean)
    return apply_fn(
        policy_prior_network.apply,
        params,
        observation_t=observation_t,
        action_tm1=action_tm1)

  dummy_action = utils.zeros_like(spec.actions)
  dummy_action = utils.add_batch_dim(dummy_action)

  return feed_forward_policy_prior_to_actor_core(_policy_prior, dummy_action)
示例#14
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
示例#15
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     key = next(self._rng)
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     action = self._policy(self._client.params, key, observation)
     return utils.to_numpy_squeeze(action)
示例#16
0
 def batched_policy(
     params: network_types.Params, key: RNGKey, observation: Observation
 ) -> Union[Action, Tuple[Action, types.NestedArray]]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     output = policy(params, key, observation)
     return utils.squeeze_batch_dim(output)
示例#17
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation, state: PRNGKey):
     rng = state
     rng1, rng2 = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     action = utils.squeeze_batch_dim(policy(params, rng1, observation))
     return action, rng2
示例#18
0
def make_haiku_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            networks_lib.CategoricalHead(num_actions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
示例#19
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
示例#20
0
 def make_initial_state(key: jnp.ndarray) -> TrainingState:
   """Initialises the training state (parameters and optimiser state)."""
   dummy_obs = utils.zeros_like(obs_spec)
   dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
   initial_state = initial_state_fn.apply(None)
   initial_params = unroll_fn.init(key, dummy_obs, initial_state)
   initial_opt_state = optimizer.init(initial_params)
   return TrainingState(params=initial_params, opt_state=initial_opt_state)
示例#21
0
 def unvectorized_select_action(
     params: networks_lib.Params,
     observations: networks_lib.Observation,
     state: State,
 ) -> Tuple[networks_lib.Action, State]:
     observations, state = utils.add_batch_dim((observations, state))
     actions, state = actor_core.select_action(params, observations, state)
     return utils.squeeze_batch_dim((actions, state))
示例#22
0
文件: actor.py 项目: deepmind/acme
 def apply_and_sample(params: networks_lib.Params, key: networks_lib.PRNGKey,
                      observation: networks_lib.Observation, epsilon: Epsilon
                      ) -> networks_lib.Action:
   # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
   observation = utils.add_batch_dim(observation)
   action_values = network.apply(params, observation)
   action_values = utils.squeeze_batch_dim(action_values)
   return rlax.epsilon_greedy(epsilon).sample(key, action_values)
示例#23
0
 def select_action(params: networks_lib.Params,
                   observation: networks_lib.Observation,
                   state: SimpleActorCoreStateWithExtras):
     rng = state.rng
     rng1, rng2 = jax.random.split(rng)
     observation = utils.add_batch_dim(observation)
     action, extras = utils.squeeze_batch_dim(
         policy(params, rng1, observation))
     return action, SimpleActorCoreStateWithExtras(rng2, extras)
示例#24
0
def make_network_from_module(
        module: hk.Transformed,
        spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork:
    """Creates a network with dummy init arguments using the specified module.

  Args:
    module: Module that expects one batch axis and one features axis for its
      inputs.
    spec: EnvironmentSpec shapes to derive dummy inputs.

  Returns:
    FeedForwardNetwork whose `init` method only takes a random key, and `apply`
    takes an observation and action and produces an output.
  """
    dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))
    dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
    return networks.FeedForwardNetwork(
        lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
示例#25
0
 def batched_recurrent_policy(
     params: network_types.Params, key: RNGKey,
     observation: Observation, core_state: RecurrentState
 ) -> Tuple[Union[Action, Tuple[Action, types.NestedArray]],
            RecurrentState]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     output, new_state = recurrent_policy(params, key, observation,
                                          core_state)
     return output, new_state
示例#26
0
 def batched_policy(
     params: network_lib.Params, key: network_lib.PRNGKey,
     observation: network_lib.Observation
 ) -> Tuple[Union[network_lib.Action, Tuple[
         network_lib.Action, types.NestedArray]], network_lib.PRNGKey]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     key, key2 = jax.random.split(key)
     observation = utils.add_batch_dim(observation)
     output = policy(params, key2, observation)
     return utils.squeeze_batch_dim(output), key
示例#27
0
文件: actors.py 项目: shadowkun/acme
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     action, new_state = self._recurrent_policy(
         self._client.params,
         key=next(self._rng),
         observation=utils.add_batch_dim(observation),
         core_state=self._state)
     self._prev_state = self._state  # Keep previous state to save in replay.
     self._state = new_state  # Keep new state for next policy call.
     return utils.to_numpy_squeeze(action)
示例#28
0
        def batched_policy(
            params,
            observation,
            discrete_action,
        ):
            observation = utils.add_batch_dim(observation)
            action = utils.squeeze_batch_dim(
                policy(params, observation, discrete_action))

            return action
示例#29
0
  def __init__(self,
               network: hk.Transformed,
               obs_spec: specs.Array,
               optimizer: optax.GradientTransformation,
               rng: hk.PRNGSequence,
               dataset: tf.data.Dataset,
               loss_fn: LossFn = _sparse_categorical_cross_entropy,
               counter: counting.Counter = None,
               logger: loggers.Logger = None):
    """Initializes the learner."""

    def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray:
      # Pull out the data needed for updates.
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data
      del r_t, d_t, o_t
      logits = network.apply(params, o_tm1)
      return jnp.mean(loss_fn(a_tm1, logits))

    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
      """Do a step of SGD."""
      grad_fn = jax.value_and_grad(loss)
      loss_value, gradients = grad_fn(state.params, sample)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      new_state = TrainingState(
          params=new_params, opt_state=new_opt_state, steps=steps)

      # Compute the global norm of the gradients for logging.
      global_gradient_norm = optax.global_norm(gradients)
      fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm}

      return new_state, fetches

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    # Initialise parameters and optimiser state.
    initial_params = network.init(
        next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
    initial_opt_state = optimizer.init(initial_params)

    self._state = TrainingState(
        params=initial_params, opt_state=initial_opt_state, steps=0)

    self._sgd_step = jax.jit(sgd_step)
示例#30
0
 def batched_recurrent_policy(
     params: network_lib.Params, key: network_lib.PRNGKey,
     observation: network_lib.Observation, core_state: RecurrentState
 ) -> Tuple[Union[network_lib.Action, Tuple[
         network_lib.Action, types.NestedArray]], RecurrentState,
            network_lib.PRNGKey]:
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     key, key2 = jax.random.split(key)
     output, new_state = recurrent_policy(params, key2, observation,
                                          core_state)
     return output, new_state, key