Esempio n. 1
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Esempio n. 2
0
    def actor_evaluator(
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor = actors.FeedForwardActor(
            policy=evaluator_network,
            random_key=random_key,
            # Inference happens on CPU, so it's better to move variables there too.
            variable_client=variable_utils.VariableClient(variable_source,
                                                          'policy',
                                                          device='cpu'))

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Esempio n. 3
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Esempio n. 4
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    evaluator = actors.FeedForwardActor(
        policy=evaluator_network,
        random_key=key,
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client=variable_utils.VariableClient(learner,
                                                      'policy',
                                                      device='cpu'))

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Esempio n. 5
0
 def make_actor(
     self,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   assert variable_source is not None
   key, self._random_key = jax.random.split(self._random_key)
   return actors.FeedForwardActor(
       policy=policy_network,
       random_key=key,
       # Inference happens on CPU, so it's better to move variables there too.
       variable_client=variable_utils.VariableClient(variable_source, '',
                                                     device='cpu'),
       adder=adder,
   )
Esempio n. 6
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 7
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = tfds.as_numpy(dataset)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)
    policy_network = hk.without_apply_rng(hk.transform(policy_network))

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        action_values = policy_network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 optimizer=optax.adam(FLAGS.learning_rate),
                                 obs_spec=environment.observation_spec(),
                                 dataset=dataset,
                                 counter=learner_counter,
                                 rng=hk.PRNGSequence(FLAGS.seed))

    # Create the actor which defines how we take actions.
    variable_client = variable_utils.VariableClient(learner, '')
    evaluator = actors.FeedForwardActor(evaluator_network,
                                        variable_client=variable_client,
                                        rng=hk.PRNGSequence(FLAGS.seed))

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Esempio n. 8
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 0.5,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.05,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=n_step,
            batch_size=batch_size,
            max_replay_size=max_replay_size,
            min_replay_size=min_replay_size,
            priority_exponent=priority_exponent,
            discount=discount,
        )
        self._server = reverb_replay.server

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(seed),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)