Exemple #1
0
    def learner(
        self,
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        iterator = self._builder.make_dataset_iterator(replay)

        dummy_seed = 1
        environment_spec = (self._environment_spec
                            or specs.make_environment_spec(
                                self._environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = self._network_factory(environment_spec)

        if self._prefetch_size > 1:
            # When working with single GPU we should prefetch to device for
            # efficiency. If running on TPU this isn't necessary as the computation
            # and input placement can be done automatically. For multi-gpu currently
            # the best solution is to pre-fetch to host although this may change in
            # the future.
            device = jax.devices()[0] if self._device_prefetch else None
            iterator = utils.prefetch(iterator,
                                      buffer_size=self._prefetch_size,
                                      device=device)
        else:
            logging.info('Not prefetching the iterator.')

        counter = counting.Counter(counter, 'learner')
        learner = self._builder.make_learner(random_key, networks, iterator,
                                             replay, counter)

        return savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=self._checkpointing_config.directory,
            add_uid=self._checkpointing_config.add_uid,
            max_to_keep=self._checkpointing_config.max_to_keep)
    def build_learner(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: Optional[counting.Counter] = None,
        primary_learner: Optional[core.Learner] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        iterator = experiment.builder.make_dataset_iterator(replay)
        # make_dataset_iterator is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(random_key, networks,
                                                  iterator,
                                                  experiment.logger_factory,
                                                  spec, replay, counter)

        if primary_learner is None:
            learner = savers.CheckpointingRunner(
                learner,
                key='learner',
                subdirectory='learner',
                time_delta_minutes=5,
                directory=checkpointing_config.directory,
                add_uid=checkpointing_config.add_uid,
                max_to_keep=checkpointing_config.max_to_keep)
        else:
            learner.restore(primary_learner.save())
            # NOTE: This initially synchronizes secondary learner states with the
            # primary one. Further synchronization should be handled by the learner
            # properly doing a pmap/pmean on the loss/gradients, respectively.

        return learner
    def build_learner(
        random_key: networks_lib.PRNGKey,
        counter: Optional[counting.Counter] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        dataset_key, random_key = jax.random.split(random_key)
        iterator = experiment.demonstration_dataset_factory(dataset_key)
        # make_demonstrations is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(
            random_key=random_key,
            networks=networks,
            dataset=iterator,
            logger_fn=experiment.logger_factory,
            environment_spec=spec,
            counter=counter)

        learner = savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

        return learner
    def _initialize_train(self):
        """Initialize train.

    This includes initializing the input pipeline and Byol's state.
    """
        self._train_input = acme_utils.prefetch(self._build_train_input())

        # Check we haven't already restored params
        if self._byol_state is None:
            logging.info(
                'Initializing parameters rather than restoring from checkpoint.'
            )

            # initialize Byol and setup optimizer state
            inputs = next(self._train_input)
            init_byol = jax.pmap(self._make_initial_state, axis_name='i')

            # Init uses the same RNG key on all hosts+devices to ensure everyone
            # computes the same initial state and parameters.
            init_rng = jax.random.PRNGKey(self._random_seed)
            init_rng = helpers.bcast_local_devices(init_rng)

            self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
Exemple #5
0
  def learner(
      self,
      random_key,
      replay,
      counter,
  ):
    """The Learning part of the agent."""

    if self._builder._config.env_name.startswith('offline_ant'):  # pytype: disable=attribute-error, pylint: disable=protected-access
      adder = self._builder.make_adder(replay)
      env = self._environment_factory(0)
      dataset = env.get_dataset()  # pytype: disable=attribute-error
      for t in tqdm.trange(dataset['observations'].shape[0]):
        discount = 1.0
        if t == 0 or dataset['timeouts'][t - 1]:
          step_type = dm_env.StepType.FIRST
        elif dataset['timeouts'][t]:
          step_type = dm_env.StepType.LAST
          discount = 0.0
        else:
          step_type = dm_env.StepType.MID

        ts = dm_env.TimeStep(
            step_type=step_type,
            reward=dataset['rewards'][t],
            discount=discount,
            observation=np.concatenate([dataset['observations'][t],
                                        dataset['infos/goal'][t]]),
        )
        if t == 0 or dataset['timeouts'][t - 1]:
          adder.add_first(ts)  # pytype: disable=attribute-error
        else:
          adder.add(action=dataset['actions'][t-1], next_timestep=ts)  # pytype: disable=attribute-error

        if self._builder._config.local and t > 10_000:  # pytype: disable=attribute-error, pylint: disable=protected-access
          break

    iterator = self._builder.make_dataset_iterator(replay)

    dummy_seed = 1
    environment_spec = (
        self._environment_spec or
        specs.make_environment_spec(self._environment_factory(dummy_seed)))

    # Creates the networks to optimize (online) and target networks.
    networks = self._network_factory(environment_spec)

    if self._prefetch_size > 1:
      # When working with single GPU we should prefetch to device for
      # efficiency. If running on TPU this isn't necessary as the computation
      # and input placement can be done automatically. For multi-gpu currently
      # the best solution is to pre-fetch to host although this may change in
      # the future.
      device = jax.devices()[0] if self._device_prefetch else None
      iterator = utils.prefetch(
          iterator, buffer_size=self._prefetch_size, device=device)
    else:
      logging.info('Not prefetching the iterator.')

    counter = counting.Counter(counter, 'learner')

    learner = self._builder.make_learner(random_key, networks, iterator, replay,
                                         counter)
    kwargs = {}
    if self._checkpointing_config:
      kwargs = vars(self._checkpointing_config)
    # Return the learning agent.
    return savers.CheckpointingRunner(
        learner,
        key='learner',
        subdirectory='learner',
        time_delta_minutes=5,
        **kwargs)
Exemple #6
0
    def __init__(self,
                 network: networks.QNetwork,
                 obs_spec: specs.Array,
                 discount: float,
                 importance_sampling_exponent: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optix.InitUpdate,
                 rng: hk.PRNGSequence,
                 max_abs_reward: float = 1.,
                 huber_loss_parameter: float = 1.,
                 replay_client: reverb.Client = None,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner."""

        # Transform network into a pure function.
        network = hk.transform(network)

        def loss(params: hk.Params, target_params: hk.Params,
                 sample: reverb.ReplaySample):
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            keys, probs = sample.info[:2]

            # Forward pass.
            q_tm1 = network.apply(params, o_tm1)
            q_t_value = network.apply(target_params, o_t)
            q_t_selector = network.apply(params, o_t)

            # Cast and clip rewards.
            d_t = (d_t * discount).astype(jnp.float32)
            r_t = jnp.clip(r_t, -max_abs_reward,
                           max_abs_reward).astype(jnp.float32)

            # Compute double Q-learning n-step TD-error.
            batch_error = jax.vmap(rlax.double_q_learning)
            td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                   q_t_selector)
            batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

            # Importance weighting.
            importance_weights = (1. / probs).astype(jnp.float32)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)

            # Reweight.
            mean_loss = jnp.mean(importance_weights * batch_loss)  # []

            priorities = jnp.abs(td_error).astype(jnp.float64)

            return mean_loss, (keys, priorities)

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            new_state = TrainingState(params=new_params,
                                      target_params=state.target_params,
                                      opt_state=new_opt_state,
                                      step=state.step + 1)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs

        def update_priorities(outputs: LearnerOutputs):
            for key, priority in zip(outputs.keys, outputs.priorities):
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates={key: priority})

        # Internalise agent components (replay buffer, networks, optimizer).
        self._replay_client = replay_client
        self._iterator = utils.prefetch(iterator)

        # Internalise the hyperparameters.
        self._target_update_period = target_update_period

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialise parameters and optimiser state.
        initial_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_target_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_opt_state = optimizer.init(initial_params)

        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)

        self._forward = jax.jit(network.apply)
        self._sgd_step = jax.jit(sgd_step)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #7
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        def postprocess_aux(extra: LossExtra) -> LossExtra:
            reverb_update = jax.tree_map(
                lambda a: jnp.reshape(a, (-1, *a.shape[2:])),
                extra.reverb_update)
            return extra._replace(metrics=jax.tree_map(jnp.mean,
                                                       extra.metrics),
                                  reverb_update=reverb_update)

        self._num_sgd_steps_per_step = num_sgd_steps_per_step
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step,
                                                  postprocess_aux)
        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        # Initialize the network parameters
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params)
        initial_target_params = self.network.init(key_target)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: ReverbUpdate) -> None:
            if replay_client is None:
                return
            keys, priorities = tree.map_structure(
                utils.fetch_devicearray,
                (reverb_update.keys, reverb_update.priorities))
            replay_client.mutate_priorities(table=replay_table_name,
                                            updates=dict(zip(keys,
                                                             priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #8
0
def run_experiment(experiment: config.ExperimentConfig,
                   eval_every: int = 100,
                   num_eval_episodes: int = 1):
  """Runs a simple, single-threaded training loop using the default evaluators.

  It targets simplicity of the code and so only the basic features of the
  ExperimentConfig are supported.

  Arguments:
    experiment: Definition and configuration of the agent to run.
    eval_every: After how many actor steps to perform evaluation.
    num_eval_episodes: How many evaluation episodes to execute at each
      evaluation step.
  """

  key = jax.random.PRNGKey(experiment.seed)

  # Create the environment and get its spec.
  environment = experiment.environment_factory(experiment.seed)
  environment_spec = experiment.environment_spec or specs.make_environment_spec(
      environment)

  # Create the networks and policy.
  networks = experiment.network_factory(environment_spec)
  policy = config.make_policy(
      experiment=experiment,
      networks=networks,
      environment_spec=environment_spec,
      evaluation=False)

  # Create the replay server and grab its address.
  replay_tables = experiment.builder.make_replay_tables(environment_spec,
                                                        policy)

  # Disable blocking of inserts by tables' rate limiters, as this function
  # executes learning (sampling from the table) and data generation
  # (inserting into the table) sequentially from the same thread
  # which could result in blocked insert making the algorithm hang.
  replay_tables, rate_limiters_max_diff = _disable_insert_blocking(
      replay_tables)

  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')

  # Parent counter allows to share step counts between train and eval loops and
  # the learner, so that it is possible to plot for example evaluator's return
  # value as a function of the number of training episodes.
  parent_counter = counting.Counter(time_delta=0.)

  # Create actor, and learner for generating, storing, and consuming
  # data respectively.
  dataset = experiment.builder.make_dataset_iterator(replay_client)
  # We always use prefetch, as it provides an iterator with additional
  # 'ready' method.
  dataset = utils.prefetch(dataset, buffer_size=1)
  learner_key, key = jax.random.split(key)
  learner = experiment.builder.make_learner(
      random_key=learner_key,
      networks=networks,
      dataset=dataset,
      logger_fn=experiment.logger_factory,
      environment_spec=environment_spec,
      replay_client=replay_client,
      counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.))

  adder = experiment.builder.make_adder(replay_client, environment_spec, policy)
  actor_key, key = jax.random.split(key)
  actor = experiment.builder.make_actor(
      actor_key, policy, environment_spec, variable_source=learner, adder=adder)

  # Create the environment loop used for training.
  train_counter = counting.Counter(
      parent_counter, prefix='train', time_delta=0.)
  train_logger = experiment.logger_factory('train',
                                           train_counter.get_steps_key(), 0)

  # Replace the actor with a LearningActor. This makes sure that every time
  # that `update` is called on the actor it checks to see whether there is
  # any new data to learn from and if so it runs a learner step. The rate
  # at which new data is released is controlled by the replay table's
  # rate_limiter which is created by the builder.make_replay_tables call above.
  actor = _LearningActor(actor, learner, dataset, replay_tables,
                         rate_limiters_max_diff)

  train_loop = acme.EnvironmentLoop(
      environment,
      actor,
      counter=train_counter,
      logger=train_logger,
      observers=experiment.observers)

  if num_eval_episodes == 0:
    # No evaluation. Just run the training loop.
    train_loop.run(num_steps=experiment.max_num_actor_steps)
    return

  # Create the evaluation actor and loop.
  eval_counter = counting.Counter(parent_counter, prefix='eval', time_delta=0.)
  eval_logger = experiment.logger_factory('eval', eval_counter.get_steps_key(),
                                          0)
  eval_policy = config.make_policy(
      experiment=experiment,
      networks=networks,
      environment_spec=environment_spec,
      evaluation=True)
  eval_actor = experiment.builder.make_actor(
      random_key=jax.random.PRNGKey(experiment.seed),
      policy=eval_policy,
      environment_spec=environment_spec,
      variable_source=learner)
  eval_loop = acme.EnvironmentLoop(
      environment,
      eval_actor,
      counter=eval_counter,
      logger=eval_logger,
      observers=experiment.observers)

  steps = 0
  while steps < experiment.max_num_actor_steps:
    eval_loop.run(num_episodes=num_eval_episodes)
    steps += train_loop.run(num_steps=eval_every)
  eval_loop.run(num_episodes=num_eval_episodes)
Exemple #9
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 obs_spec: specs.Array,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialize the network parameters
        dummy_obs = utils.add_batch_dim(utils.zeros_like(obs_spec))
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params, dummy_obs)
        initial_target_params = self.network.init(key_target, dummy_obs)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: Optional[ReverbUpdate]) -> None:
            if reverb_update is None or replay_client is None:
                return
            else:
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates=dict(
                        zip(reverb_update.keys, reverb_update.priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Exemple #10
0
  def __init__(
      self,
      seed: int,
      environment_spec: specs.EnvironmentSpec,
      builder: builders.GenericActorLearnerBuilder,
      networks: Any,
      policy_network: Any,
      workdir: Optional[str] = '~/acme',
      min_replay_size: int = 1000,
      samples_per_insert: float = 256.0,
      batch_size: int = 256,
      num_sgd_steps_per_step: int = 1,
      prefetch_size: int = 1,
      device_prefetch: bool = True,
      counter: Optional[counting.Counter] = None,
      checkpoint: bool = True,
  ):
    """Initialize the agent.

    Args:
      seed: A random seed to use for this layout instance.
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      workdir: if provided saves the state of the learner and the counter
        (if the counter is not None) into workdir.
      min_replay_size: minimum replay size before updating.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      prefetch_size: whether to prefetch iterator.
      device_prefetch: whether prefetching should happen to a device.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner
        and the counter (if the counter is not None).
    """
    if prefetch_size < 0:
      raise ValueError(f'Prefetch size={prefetch_size} should be non negative')

    key = jax.random.PRNGKey(seed)

    # Create the replay server and grab its address.
    replay_tables = builder.make_replay_tables(environment_spec)
    replay_server = reverb.Server(replay_tables, port=None)
    replay_client = reverb.Client(f'localhost:{replay_server.port}')

    # Create actor, dataset, and learner for generating, storing, and consuming
    # data respectively.
    adder = builder.make_adder(replay_client)

    def _is_reverb_queue(reverb_table: reverb.Table,
                         reverb_client: reverb.Client) -> bool:
      """Returns True iff the Reverb Table is actually a queue."""
      # TODO(sinopalnikov): make it more generic and check for a table that
      # needs special handling on update.
      info = reverb_client.server_info()
      table_info = info[reverb_table.name]
      is_queue = (
          table_info.max_times_sampled == 1 and
          table_info.sampler_options.fifo and
          table_info.remover_options.fifo)
      return is_queue

    is_reverb_queue = any(_is_reverb_queue(table, replay_client)
                          for table in replay_tables)

    dataset = builder.make_dataset_iterator(replay_client)
    if prefetch_size > 1:
      device = jax.devices()[0] if device_prefetch else None
      dataset = utils.prefetch(dataset, buffer_size=prefetch_size,
                               device=device)
    learner_key, key = jax.random.split(key)
    learner = builder.make_learner(
        random_key=learner_key,
        networks=networks,
        dataset=dataset,
        replay_client=replay_client,
        counter=counter)
    if not checkpoint or workdir is None:
      self._checkpointer = None
    else:
      objects_to_save = {'learner': learner}
      if counter is not None:
        objects_to_save.update({'counter': counter})
      self._checkpointer = savers.Checkpointer(
          objects_to_save,
          time_delta_minutes=30,
          subdirectory='learner',
          directory=workdir,
          add_uid=(workdir == '~/acme'))

    actor_key, key = jax.random.split(key)
    actor = builder.make_actor(
        actor_key, policy_network, adder, variable_source=learner)
    self._custom_update_fn = None
    if is_reverb_queue:
      # Reverb queue requires special handling on update: custom logic to
      # decide when it is safe to make a learner step. This is only needed for
      # the local agent, where the actor and the learner are running
      # synchronously and the learner will deadlock if it makes a step with
      # no data available.
      def custom_update():
        should_update_actor = False
        # Run a number of learner steps (usually gradient steps).
        # TODO(raveman): This is wrong. When running multi-level learners,
        # different levels might have different batch sizes. Find a solution.
        while all(table.can_sample(batch_size) for table in replay_tables):
          learner.step()
          should_update_actor = True

        if should_update_actor:
          # "wait=True" to make it more onpolicy
          actor.update(wait=True)

      self._custom_update_fn = custom_update

    effective_batch_size = batch_size * num_sgd_steps_per_step
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(effective_batch_size, min_replay_size),
        observations_per_step=float(effective_batch_size) / samples_per_insert)

    # Save the replay so we don't garbage collect it.
    self._replay_server = replay_server
Exemple #11
0
    def _initialize_train(self, rng):
        """BYOL's _ExperimentState initialization.

    Args:
      rng: random number generator used to initialize parameters. If working in
        a multi device setup, this need to be a ShardedArray.
      dummy_input: a dummy image, used to compute intermediate outputs shapes.

    Returns:
      Initial EvalExperiment state.

    Raises:
      RuntimeError: invalid or empty checkpoint.
    """
        self._train_input = acme_utils.prefetch(self._build_train_input())

        # Check we haven't already restored params
        if self._experiment_state is None:

            inputs = next(self._train_input)

            if self._checkpoint_to_evaluate is not None:
                # Load params from checkpoint
                checkpoint_data = checkpointing.load_checkpoint(
                    self._checkpoint_to_evaluate)
                if checkpoint_data is None:
                    raise RuntimeError('Invalid checkpoint.')
                backbone_params = checkpoint_data[
                    'experiment_state'].online_params
                backbone_state = checkpoint_data[
                    'experiment_state'].online_state
                backbone_params = helpers.bcast_local_devices(backbone_params)
                backbone_state = helpers.bcast_local_devices(backbone_state)
            else:
                if not self._allow_train_from_scratch:
                    raise ValueError(
                        'No checkpoint specified, but `allow_train_from_scratch` '
                        'set to False')
                # Initialize with random parameters
                logging.info(
                    'No checkpoint specified, initializing the networks from scratch '
                    '(dry run mode)')
                backbone_params, backbone_state = jax.pmap(
                    functools.partial(self.forward_backbone.init,
                                      is_training=True),
                    axis_name='i')(rng=rng, inputs=inputs)

            init_experiment = jax.pmap(self._make_initial_state, axis_name='i')

            # Init uses the same RNG key on all hosts+devices to ensure everyone
            # computes the same initial state and parameters.
            init_rng = jax.random.PRNGKey(self._random_seed)
            init_rng = helpers.bcast_local_devices(init_rng)
            self._experiment_state = init_experiment(
                rng=init_rng,
                dummy_input=inputs,
                backbone_params=backbone_params,
                backbone_state=backbone_state)

            # Clear the backbone optimizer's state when the backbone is frozen.
            if self._freeze_backbone:
                self._experiment_state = _EvalExperimentState(
                    backbone_params=self._experiment_state.backbone_params,
                    classif_params=self._experiment_state.classif_params,
                    backbone_state=self._experiment_state.backbone_state,
                    backbone_opt_state=None,
                    classif_opt_state=self._experiment_state.classif_opt_state,
                )
Exemple #12
0
    def __init__(
        self,
        seed: int,
        environment_spec: specs.EnvironmentSpec,
        builder: builders.ActorLearnerBuilder,
        networks: Any,
        policy_network: Any,
        learner_logger: Optional[loggers.Logger] = None,
        workdir: Optional[str] = '~/acme',
        batch_size: int = 256,
        num_sgd_steps_per_step: int = 1,
        prefetch_size: int = 1,
        counter: Optional[counting.Counter] = None,
        checkpoint: bool = True,
    ):
        """Initialize the agent.

    Args:
      seed: A random seed to use for this layout instance.
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      learner_logger: logger used by the learner.
      workdir: if provided saves the state of the learner and the counter
        (if the counter is not None) into workdir.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      prefetch_size: whether to prefetch iterator.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner
        and the counter (if the counter is not None).
    """
        if prefetch_size < 0:
            raise ValueError(
                f'Prefetch size={prefetch_size} should be non negative')

        key = jax.random.PRNGKey(seed)

        # Create the replay server and grab its address.
        replay_tables = builder.make_replay_tables(environment_spec,
                                                   policy_network)

        # Disable blocking of inserts by tables' rate limiters, as LocalLayout
        # agents run inserts and sampling from the same thread and blocked insert
        # would result in a hang.
        new_tables = []
        for table in replay_tables:
            rl_info = table.info.rate_limiter_info
            rate_limiter = reverb.rate_limiters.RateLimiter(
                samples_per_insert=rl_info.samples_per_insert,
                min_size_to_sample=rl_info.min_size_to_sample,
                min_diff=rl_info.min_diff,
                max_diff=sys.float_info.max)
            new_tables.append(table.replace(rate_limiter=rate_limiter))
        replay_tables = new_tables

        replay_server = reverb.Server(replay_tables, port=None)
        replay_client = reverb.Client(f'localhost:{replay_server.port}')

        # Create actor, dataset, and learner for generating, storing, and consuming
        # data respectively.
        adder = builder.make_adder(replay_client, environment_spec,
                                   policy_network)

        dataset = builder.make_dataset_iterator(replay_client)
        # We always use prefetch, as it provides an iterator with additional
        # 'ready' method.
        dataset = utils.prefetch(dataset, buffer_size=prefetch_size)
        learner_key, key = jax.random.split(key)
        learner = builder.make_learner(
            random_key=learner_key,
            networks=networks,
            dataset=dataset,
            logger_fn=(lambda label, steps_key=None, task_instance=None:
                       learner_logger),
            environment_spec=environment_spec,
            replay_client=replay_client,
            counter=counter)
        if not checkpoint or workdir is None:
            self._checkpointer = None
        else:
            objects_to_save = {'learner': learner}
            if counter is not None:
                objects_to_save.update({'counter': counter})
            self._checkpointer = savers.Checkpointer(
                objects_to_save,
                time_delta_minutes=30,
                subdirectory='learner',
                directory=workdir,
                add_uid=(workdir == '~/acme'))

        actor_key, key = jax.random.split(key)
        actor = builder.make_actor(actor_key,
                                   policy_network,
                                   environment_spec,
                                   variable_source=learner,
                                   adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         iterator=dataset,
                         replay_tables=replay_tables)

        # Save the replay so we don't garbage collect it.
        self._replay_server = replay_server