Пример #1
0
def make_networks(
    env_spec: specs.EnvironmentSpec,
    forward_fn: Any,
    initial_state_fn: Any,
    unroll_fn: Any,
    batch_size) -> R2D2Networks:
  """Builds functional r2d2 network from recurrent model definitions."""

  # Make networks purely functional.
  forward_hk = hk.transform(forward_fn)
  initial_state_hk = hk.transform(initial_state_fn)
  unroll_hk = hk.transform(unroll_fn)

  # Define networks init functions.
  def initial_state_init_fn(rng, batch_size):
    return initial_state_hk.init(rng, batch_size)
  dummy_obs_batch = utils.tile_nested(
      utils.zeros_like(env_spec.observations), batch_size)
  dummy_obs_sequence = utils.add_batch_dim(dummy_obs_batch)
  def unroll_init_fn(rng, initial_state):
    return unroll_hk.init(rng, dummy_obs_sequence, initial_state)

  # Make FeedForwardNetworks.
  forward = networks_lib.FeedForwardNetwork(
      init=forward_hk.init, apply=forward_hk.apply)
  unroll = networks_lib.FeedForwardNetwork(
      init=unroll_init_fn, apply=unroll_hk.apply)
  initial_state = networks_lib.FeedForwardNetwork(
      init=initial_state_init_fn, apply=initial_state_hk.apply)
  return R2D2Networks(
      forward=forward, unroll=unroll, initial_state=initial_state)
Пример #2
0
def make_networks(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (256, 256)
) -> SACNetworks:
    """Creates networks used by the agent."""

    num_dimensions = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs):
        network = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes),
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu,
                        activate_final=True),
            networks_lib.NormalTanhDistribution(num_dimensions),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network1 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        network2 = hk.Sequential([
            hk.nets.MLP(list(hidden_layer_sizes) + [1],
                        w_init=hk.initializers.VarianceScaling(
                            1.0, 'fan_in', 'uniform'),
                        activation=jax.nn.relu),
        ])
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return SACNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_obs), policy.apply),
        q_network=networks_lib.FeedForwardNetwork(
            lambda key: critic.init(key, dummy_obs, dummy_action),
            critic.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode())
Пример #3
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Tuple[int, ...] = (256, 256),
    critic_layer_sizes: Tuple[int, ...] = (256, 256),
    activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
) -> CRRNetworks:
  """Creates networks used by the agent."""
  num_actions = np.prod(spec.actions.shape, dtype=int)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
  dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

  def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
    network = hk.Sequential([
        hk.nets.MLP(
            list(policy_layer_sizes),
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation,
            activate_final=True),
        networks_lib.NormalTanhDistribution(num_actions),
    ])
    return network(obs)

  policy = hk.without_apply_rng(hk.transform(_policy_fn))
  policy_network = networks_lib.FeedForwardNetwork(
      lambda key: policy.init(key, dummy_obs), policy.apply)

  def _critic_fn(obs, action):
    network = hk.Sequential([
        hk.nets.MLP(
            list(critic_layer_sizes) + [1],
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            activation=activation),
    ])
    data = jnp.concatenate([obs, action], axis=-1)
    return network(data)

  critic = hk.without_apply_rng(hk.transform(_critic_fn))
  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply)

  return CRRNetworks(
      policy_network=policy_network,
      critic_network=critic_network,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
Пример #4
0
def make_networks(
    spec,
    build_actor_fn=build_standard_actor_fn,
    img_encoder_fn=None,
):
    """Creates networks used by the agent."""
    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    if isinstance(spec.actions, specs.DiscreteArray):
        num_dimensions = spec.actions.num_values
        # _actor_fn = procgen_networks.build_procgen_actor_fn(num_dimensions)
    else:
        num_dimensions = np.prod(spec.actions.shape, dtype=int)

    _actor_fn = build_actor_fn(num_dimensions)

    if img_encoder_fn is not None:
        img_encoder = hk.without_apply_rng(
            hk.transform(img_encoder_fn, apply_rng=True))
        key = jax.random.PRNGKey(seed=42)
        temp_encoder_params = img_encoder.init(key, dummy_obs['state_image'])
        dummy_hidden = img_encoder.apply(temp_encoder_params,
                                         dummy_obs['state_image'])
        img_encoder_network = networks_lib.FeedForwardNetwork(
            lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply)
        dummy_policy_input = dict(
            state_image=dummy_hidden,
            state_dense=dummy_obs['state_dense'],
        )
    else:
        img_encoder_fn = None
        dummy_policy_input = dummy_obs
        img_encoder_network = None

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    return BCNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_policy_input), policy.apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        sample_eval=lambda params, key: params.mode(),
        img_encoder=img_encoder_network,
    )
Пример #5
0
def make_networks(
    spec: specs.EnvironmentSpec,
    policy_layer_sizes: Sequence[int] = (300, 200),
    critic_layer_sizes: Sequence[int] = (400, 300),
    vmin: float = -150.,
    vmax: float = 150.,
    num_atoms: int = 51,
) -> D4PGNetworks:
    """Creates networks used by the agent."""

    action_spec = spec.actions

    num_dimensions = np.prod(action_spec.shape, dtype=int)
    critic_atoms = jnp.linspace(vmin, vmax, num_atoms)

    def _actor_fn(obs):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True),
            networks_lib.NearZeroInitializedLinear(num_dimensions),
            networks_lib.TanhToSpec(action_spec),
        ])
        return network(obs)

    def _critic_fn(obs, action):
        network = hk.Sequential([
            utils.batch_concat,
            networks_lib.LayerNormMLP(
                layer_sizes=[*critic_layer_sizes, num_atoms]),
        ])
        value = network([obs, action])
        return value, critic_atoms

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return D4PGNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda rng: policy.init(rng, dummy_obs), policy.apply),
        critic_network=networks_lib.FeedForwardNetwork(
            lambda rng: critic.init(rng, dummy_obs, dummy_action),
            critic.apply))
Пример #6
0
def make_networks(
        spec: specs.EnvironmentSpec,
        discrete_actions: bool = False) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""

    if discrete_actions:
        final_layer_size = spec.actions.num_values
    else:
        final_layer_size = np.prod(spec.actions.shape, dtype=int)

    def _actor_fn(obs, is_training=False, key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        if discrete_actions:
            network = hk.nets.MLP([64, 64, final_layer_size])
        else:
            network = hk.Sequential([
                networks_lib.LayerNormMLP([64, 64], activate_final=True),
                networks_lib.NormalTanhDistribution(final_layer_size),
            ])
        return network(obs)

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)
    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)
    return network
Пример #7
0
    def test_dqn(self):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        def network(x):
            model = hk.Sequential(
                [hk.Flatten(),
                 hk.nets.MLP([50, 50, spec.actions.num_values])])
            return model(x)

        # Make network purely functional
        network_hk = hk.without_apply_rng(hk.transform(network,
                                                       apply_rng=True))
        dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))

        network = networks_lib.FeedForwardNetwork(
            init=lambda rng: network_hk.init(rng, dummy_obs),
            apply=network_hk.apply)

        # Construct the agent.
        agent = dqn.DQN(environment_spec=spec,
                        network=network,
                        batch_size=10,
                        samples_per_insert=2,
                        min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=20)
Пример #8
0
def make_haiku_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            networks_lib.CategoricalHead(num_actions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP([64, 64]),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
Пример #9
0
    def test_mean_random(self):
        x = jnp.ones(10)
        bx = jnp.ones((9, 10))
        ffn = RandomFFN()
        wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial(
            ffn.init, x=x),
                                                  apply=ffn.apply)
        mean_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                               ensemble.apply_mean,
                                               num_networks=3)
        key = jax.random.PRNGKey(0)
        params = mean_ensemble.init(key)
        single_output = mean_ensemble.apply(params, x)
        self.assertEqual(single_output.shape, (15, ))
        batch_output = mean_ensemble.apply(params, bx)
        # Make sure all rows are equal:
        np.testing.assert_allclose(jnp.broadcast_to(batch_output[0],
                                                    batch_output.shape),
                                   batch_output,
                                   atol=1E-5,
                                   rtol=1E-5)

        # Check results explicitly:
        all_members = jnp.concatenate([
            jnp.expand_dims(ffn.apply(
                jax.tree_map(lambda p, i=i: p[i], params), bx),
                            axis=0) for i in range(3)
        ])
        batch_means = jnp.mean(all_members, axis=0)
        np.testing.assert_allclose(batch_output,
                                   batch_means,
                                   atol=1E-5,
                                   rtol=1E-5)
Пример #10
0
    def test_round_robin_random(self):
        x = jnp.ones(10)  # Base input
        bx = jnp.ones((9, 10))  # Batched input
        ffn = RandomFFN()
        wrapped_ffn = networks.FeedForwardNetwork(init=functools.partial(
            ffn.init, x=x),
                                                  apply=ffn.apply)
        rr_ensemble = ensemble.make_ensemble(wrapped_ffn,
                                             ensemble.apply_round_robin,
                                             num_networks=3)

        key = jax.random.PRNGKey(0)
        params = rr_ensemble.init(key)
        out = rr_ensemble.apply(params, bx)
        # The output should be the same every 3 rows:
        blocks = jnp.split(out, 3, axis=0)
        np.testing.assert_array_equal(blocks[0], blocks[1])
        np.testing.assert_array_equal(blocks[0], blocks[2])
        self.assertTrue((out[0] != out[1]).any())

        for i in range(9):
            np.testing.assert_allclose(
                out[i],
                ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params),
                          bx[i]),
                atol=1E-5,
                rtol=1E-5)
Пример #11
0
def make_discrete_networks(
    environment_spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Sequence[int] = (512, ),
    use_conv: bool = True,
) -> PPONetworks:
    """Creates networks used by the agent for discrete action environments.

  Args:
    environment_spec: Environment spec used to define number of actions.
    hidden_layer_sizes: Network definition.
    use_conv: Whether to use a conv or MLP feature extractor.
  Returns:
    PPONetworks
  """

    num_actions = environment_spec.actions.num_values

    def forward_fn(inputs):
        layers = []
        if use_conv:
            layers.extend([networks_lib.AtariTorso()])
        layers.extend([
            hk.nets.MLP(hidden_layer_sizes, activation=jax.nn.relu),
            networks_lib.CategoricalValueHead(num_values=num_actions)
        ])
        policy_value_network = hk.Sequential(layers)
        return policy_value_network(inputs)

    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))
    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
Пример #12
0
def get_fake_world_model() -> networks_lib.FeedForwardNetwork:
    def apply(params: Any, observation_t: jnp.ndarray, action_t: jnp.ndarray):
        del params
        return observation_t, jnp.ones((
            action_t.shape[0],
            1,
        ))

    return networks_lib.FeedForwardNetwork(init=lambda: None, apply=apply)
Пример #13
0
def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork:
    """Like params_adding_ffn, but with pytree inputs, preserves structure."""
    def init_fn(key, sx=sx):
        return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx)

    def apply_fn(params, x):
        return jax.tree_map(lambda p, v: p + v, params, x)

    return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn)
Пример #14
0
def make_networks(
    spec: specs.EnvironmentSpec,
    direct_rl_networks: DirectRLNetworks,
    layer_sizes: Tuple[int, ...] = (256, 256),
    intrinsic_reward_coefficient: float = 1.0,
    extrinsic_reward_coefficient: float = 0.0,
) -> RNDNetworks[DirectRLNetworks]:
    """Creates networks used by the agent and returns RNDNetworks.

  Args:
    spec: Environment spec.
    direct_rl_networks: Networks used by a direct rl algorithm.
    layer_sizes: Layer sizes.
    intrinsic_reward_coefficient: Multiplier on intrinsic reward.
    extrinsic_reward_coefficient: Multiplier on extrinsic reward.

  Returns:
    The RND networks.
  """
    def _rnd_fn(obs, act):
        # RND does not use the action but other variants like RED do.
        del act
        network = networks_lib.LayerNormMLP(list(layer_sizes))
        return network(obs)

    target = hk.without_apply_rng(hk.transform(_rnd_fn))
    predictor = hk.without_apply_rng(hk.transform(_rnd_fn))

    # Create dummy observations and actions to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    return RNDNetworks(
        target=networks_lib.FeedForwardNetwork(
            lambda key: target.init(key, dummy_obs, ()), target.apply),
        predictor=networks_lib.FeedForwardNetwork(
            lambda key: predictor.init(key, dummy_obs, ()), predictor.apply),
        direct_rl_networks=direct_rl_networks,
        get_reward=functools.partial(
            rnd_reward_fn,
            intrinsic_reward_coefficient=intrinsic_reward_coefficient,
            extrinsic_reward_coefficient=extrinsic_reward_coefficient))
Пример #15
0
def make_discriminator(
    environment_spec: specs.EnvironmentSpec,
    discriminator_transformed: hk.TransformedWithState,
    logpi_fn: Optional[Callable[
        [networks_lib.Params, networks_lib.Observation, networks_lib.Action],
        jnp.ndarray]] = None
) -> networks_lib.FeedForwardNetwork:
  """Creates the discriminator network.

  Args:
    environment_spec: Environment spec
    discriminator_transformed: Haiku transformed of the discriminator.
    logpi_fn: If the policy logpi function is provided, its output will be
      removed from the discriminator logit.

  Returns:
    The network.
  """

  def apply_fn(params: hk.Params,
               policy_params: networks_lib.Params,
               state: hk.State,
               transitions: types.Transition,
               is_training: bool,
               rng: networks_lib.PRNGKey) -> networks_lib.Logits:
    output, state = discriminator_transformed.apply(
        params, state, transitions.observation, transitions.action,
        transitions.next_observation, is_training, rng)
    if logpi_fn is not None:
      logpi = logpi_fn(policy_params, transitions.observation,
                       transitions.action)

      # Quick Maths:
      # D = exp(output)/(exp(output) + pi(a|s))
      # logit(D) = log(D/(1-D)) = log(exp(output)/pi(a|s))
      # logit(D) = output - logpi
      return output - logpi, state
    return output, state

  dummy_obs = utils.zeros_like(environment_spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)
  dummy_actions = utils.zeros_like(environment_spec.actions)
  dummy_actions = utils.add_batch_dim(dummy_actions)

  return networks_lib.FeedForwardNetwork(
      # pylint: disable=g-long-lambda
      init=lambda rng: discriminator_transformed.init(
          rng, dummy_obs, dummy_actions, dummy_obs, False, rng),
      apply=apply_fn)
Пример #16
0
def make_dqn_atari_network(
        environment_spec: specs.EnvironmentSpec
) -> networks.FeedForwardNetwork:
    """Creates networks for training DQN on Atari."""
    def network(inputs):
        model = hk.Sequential([
            networks.AtariTorso(),
            hk.nets.MLP([512, environment_spec.actions.num_values]),
        ])
        return model(inputs)

    network_hk = hk.without_apply_rng(hk.transform(network))
    obs = utils.add_batch_dim(utils.zeros_like(environment_spec.observations))
    return networks.FeedForwardNetwork(
        init=lambda rng: network_hk.init(rng, obs), apply=network_hk.apply)
Пример #17
0
def make_ppo_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates Haiku networks to be used by the agent."""

    num_actions = spec.actions.num_values

    forward_fn = functools.partial(ppo_forward_fn, num_actions=num_actions)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
Пример #18
0
def add_bc_pretraining(sac_networks: sac.SACNetworks) -> sac.SACNetworks:
    """Augments `sac_networks` to run BC pretraining in policy_network.init."""

    make_demonstrations = functools.partial(
        helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name)
    bc_network = bc.pretraining.convert_to_bc_network(
        sac_networks.policy_network)
    loss = bc.logp(sac_networks.log_prob)

    def bc_init(*unused_args):
        return bc.pretraining.train_with_bc(make_demonstrations, bc_network,
                                            loss)

    return dataclasses.replace(sac_networks,
                               policy_network=networks_lib.FeedForwardNetwork(
                                   bc_init, sac_networks.policy_network.apply))
Пример #19
0
def convert_policy_value_to_bc_network(
    policy_value_network: networks_lib.FeedForwardNetwork
) -> networks_lib.FeedForwardNetwork:
    """Converts a network from e.g. PPO into a BC policy network.

  Args:
    policy_value_network: FeedForwardNetwork taking the observation as input.

  Returns:
    The BC policy network taking observation, is_training, key as input.
  """
    def apply(params, obs, is_training=False, key=None):
        del is_training, key
        actions, _ = policy_value_network.apply(params, obs)
        return actions

    return networks_lib.FeedForwardNetwork(policy_value_network.init, apply)
Пример #20
0
def convert_to_bc_network(
    policy_network: networks_lib.FeedForwardNetwork
) -> networks_lib.FeedForwardNetwork:
    """Converts a policy_network from SAC/TD3/D4PG/.. into a BC policy network.

  Args:
    policy_network: FeedForwardNetwork taking the observation as input and
      returning action representation compatible with one of the BC losses.

  Returns:
    The BC policy network taking observation, is_training, key as input.
  """
    def apply(params, obs, is_training=False, key=None):
        del is_training, key
        return policy_network.apply(params, obs)

    return networks_lib.FeedForwardNetwork(policy_network.init, apply)
Пример #21
0
def add_bc_pretraining(td3_networks: td3.TD3Networks) -> td3.TD3Networks:
    """Augments `td3_networks` to run BC pretraining in policy_network.init."""

    make_demonstrations = functools.partial(
        helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name)
    bc_network = bc.pretraining.convert_to_bc_network(
        td3_networks.policy_network)
    # TODO(lukstafi): consider passing noised policy.
    loss = bc.mse(lambda x, key: x)

    def bc_init(*unused_args):
        return bc.pretraining.train_with_bc(make_demonstrations, bc_network,
                                            loss)

    return dataclasses.replace(td3_networks,
                               policy_network=networks_lib.FeedForwardNetwork(
                                   bc_init, td3_networks.policy_network.apply))
Пример #22
0
def add_bc_pretraining(ppo_networks: ppo.PPONetworks) -> ppo.PPONetworks:
    """Augments `ppo_networks` to run BC pretraining in policy_network.init."""

    make_demonstrations = functools.partial(
        helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name)
    bc_network = bc.pretraining.convert_policy_value_to_bc_network(
        ppo_networks.network)
    loss = bc.logp(ppo_networks.log_prob)

    # Note: despite only training the policy network, this will also include the
    # initial value network params.
    def bc_init(*unused_args):
        return bc.pretraining.train_with_bc(make_demonstrations, bc_network,
                                            loss)

    return dataclasses.replace(ppo_networks,
                               network=networks_lib.FeedForwardNetwork(
                                   bc_init, ppo_networks.network.apply))
Пример #23
0
def make_network_from_module(
        module: hk.Transformed,
        spec: specs.EnvironmentSpec) -> networks.FeedForwardNetwork:
    """Creates a network with dummy init arguments using the specified module.

  Args:
    module: Module that expects one batch axis and one features axis for its
      inputs.
    spec: EnvironmentSpec shapes to derive dummy inputs.

  Returns:
    FeedForwardNetwork whose `init` method only takes a random key, and `apply`
    takes an observation and action and produces an output.
  """
    dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations))
    dummy_action = utils.add_batch_dim(utils.zeros_like(spec.actions))
    return networks.FeedForwardNetwork(
        lambda key: module.init(key, dummy_obs, dummy_action), module.apply)
Пример #24
0
def make_networks(
    spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
  """Creates networks used by the agent.

  The model used by the ARS paper is a simple clipped linear model.

  Args:
    spec: an environment spec

  Returns:
    A FeedForwardNetwork network.
  """

  obs_size = spec.observations.shape[0]
  act_size = spec.actions.shape[0]
  return networks_lib.FeedForwardNetwork(
      init=lambda _: jnp.zeros((obs_size, act_size)),
      apply=lambda matrix, obs: jnp.clip(jnp.matmul(obs, matrix), -1, 1))
Пример #25
0
def make_flax_networks(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates FLAX networks to be used by the agent."""

    num_actions = spec.actions.num_values

    class MLP(flax.deprecated.nn.Module):
        """MLP module."""
        def apply(self,
                  data: jnp.ndarray,
                  layer_sizes: Tuple[int],
                  activation: Callable[[jnp.ndarray],
                                       jnp.ndarray] = flax.deprecated.nn.relu,
                  kernel_init: object = jax.nn.initializers.lecun_uniform(),
                  activate_final: bool = False,
                  bias: bool = True):
            hidden = data
            for i, hidden_size in enumerate(layer_sizes):
                hidden = flax.deprecated.nn.Dense(hidden,
                                                  hidden_size,
                                                  name=f'hidden_{i}',
                                                  kernel_init=kernel_init,
                                                  bias=bias)
                if i != len(layer_sizes) - 1 or activate_final:
                    hidden = activation(hidden)
            return hidden

    class PolicyValueModule(flax.deprecated.nn.Module):
        """MLP module."""
        def apply(self, inputs: jnp.ndarray):
            inputs = utils.batch_concat(inputs)
            logits = MLP(inputs, [64, 64, num_actions])
            value = MLP(inputs, [64, 64, 1])
            value = jnp.squeeze(value, axis=-1)
            return tfd.Categorical(logits=logits), value

    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    return networks_lib.FeedForwardNetwork(
        lambda rng: PolicyValueModule.init(rng, dummy_obs)[1],
        PolicyValueModule.call)
Пример #26
0
def make_networks(
    spec,
    actor_fn_build_fn = build_mlp_actor_fn,
    actor_hidden_layer_sizes = (256, 256),
    critic_fn_build_fn = build_hk_batch_ensemble_mlp_critic_fn,
    # critic_fn_build_fn: Callable = build_mlp_critic_fn,
    critic_hidden_layer_sizes = (256, 256),
    use_double_q = False,
    ):
  """Creates networks used by the agent."""

  num_dimensions = np.prod(spec.actions.shape, dtype=int)

  _actor_fn = actor_fn_build_fn(num_dimensions, actor_hidden_layer_sizes)
  policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))

  _critic_fn = critic_fn_build_fn(critic_hidden_layer_sizes, use_double_q)
  critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True))
  critic_ensemble_init = ensemble_utils.transform_init_for_ensemble(critic.init, init_same=False)
  critic_ensemble_member_apply = ensemble_utils.transform_apply_for_ensemble_member(critic.apply)
  critic_same_batch_ensemble_apply = ensemble_utils.build_same_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2)
  critic_diff_batch_ensemble_apply = ensemble_utils.build_different_batch_ensemble_apply_fn(critic_ensemble_member_apply, 2)

  # Create dummy observations and actions to create network parameters.
  dummy_action = utils.zeros_like(spec.actions)
  dummy_obs = utils.zeros_like(spec.observations)
  dummy_action = utils.add_batch_dim(dummy_action)
  dummy_obs = utils.add_batch_dim(dummy_obs)

  return BatchEnsembleMSGNetworks(
      policy_network=networks_lib.FeedForwardNetwork(
          lambda key: policy.init(key, dummy_obs), policy.apply),
      q_ensemble_init=lambda ensemble_size, key: critic_ensemble_init(ensemble_size, key, dummy_obs, dummy_action),
      q_ensemble_member_apply=critic_ensemble_member_apply,
      q_ensemble_same_batch_apply=critic_same_batch_ensemble_apply,
      q_ensemble_different_batch_apply=critic_diff_batch_ensemble_apply,
      log_prob=lambda params, actions: params.log_prob(actions),
      sample=lambda params, key: params.sample(seed=key),
      sample_eval=lambda params, key: params.mode())
Пример #27
0
def make_network(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""
    num_actions = spec.actions.num_values

    def actor_fn(obs, is_training=True, key=None):
        # is_training and key allows to utilize train/test dependant modules
        # like dropout.
        del is_training
        del key
        mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])])
        return mlp(obs)

    policy = hk.without_apply_rng(hk.transform(actor_fn))

    # Create dummy observations to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)

    return network
Пример #28
0
def make_action_candidates_network(
    spec,
    num_actions,
    discrete_rl_networks,
    torso_layer_sizes = (256,),
    head_layer_sizes = (256,),
    input_dropout_rate = 0.1,
    hidden_dropout_rate = 0.1):
  """Creates networks used by the agent and wraps it into Flax Model.

  Args:
    spec: Environment spec.
    num_actions: the number of actions proposed by the multi-modal model.
    discrete_rl_networks: Direct RL algorithm networks.
    torso_layer_sizes: Layer sizes of the torso.
    head_layer_sizes: Layer sizes of the heads.
    input_dropout_rate: Dropout rate input.
    hidden_dropout_rate: Dropout rate hidden.
  Returns:
    The Flax model.
  """
  dummy_obs, _ = get_dummy_batched_obs_and_actions(spec)
  encoder_module = Encoder(
      action_dim=np.prod(spec.actions.shape, dtype=int),
      num_actions=num_actions,
      torso_layer_sizes=torso_layer_sizes,
      head_layer_sizes=head_layer_sizes,
      input_dropout_rate=input_dropout_rate,
      hidden_dropout_rate=hidden_dropout_rate,)

  encoder = networks_lib.FeedForwardNetwork(
      lambda key: encoder_module.init(key, dummy_obs, is_training=False),
      encoder_module.apply)

  return AquademNetworks(
      encoder=encoder,
      discrete_rl_networks=discrete_rl_networks)
Пример #29
0
def make_continuous_networks(
        environment_spec: specs.EnvironmentSpec,
        policy_layer_sizes: Sequence[int] = (64, 64),
        value_layer_sizes: Sequence[int] = (64, 64),
) -> PPONetworks:
    """Creates PPONetworks to be used for continuous action environments."""

    # Get total number of action dimensions from action spec.
    num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP(policy_layer_sizes, activation=jnp.tanh),
            # Note: we don't respect bounded action specs here and instead
            # rely on CanonicalSpecWrapper to clip actions accordingly.
            networks_lib.MultivariateNormalDiagHead(num_dimensions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP(value_layer_sizes, activation=jnp.tanh),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
Пример #30
0
def make_q_network(spec,
                   hidden_layer_sizes=(512, 512, 256),
                   architecture='LayerNorm'):
  """DQN network for Aquadem algo."""

  def _q_fn(obs):
    if architecture == 'MLP':  # AQuaOff architecture
      network_fn = hk.nets.MLP
    elif architecture == 'LayerNorm':  # Original AQuaDem architecture
      network_fn = networks_lib.LayerNormMLP
    else:
      return ValueError('Architecture not recognized')

    network = network_fn(list(hidden_layer_sizes) + [spec.actions.num_values])
    value = network(obs)
    return value

  critic = hk.without_apply_rng(hk.transform(_q_fn))
  dummy_obs = utils.zeros_like(spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)

  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs), critic.apply)
  return critic_network