def setUp(self):
    super(ReverbReplayBufferTest, self).setUp()

    # Prepare the environment (and the corresponding specs).
    self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3)
    tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec,
                                                  self._env.time_step_spec())
    tensor_action_spec = tensor_spec.from_spec(self._env.action_spec())
    self._data_spec = trajectory.Trajectory(
        step_type=tensor_time_step_spec.step_type,
        observation=tensor_time_step_spec.observation,
        action=tensor_action_spec,
        policy_info=(),
        next_step_type=tensor_time_step_spec.step_type,
        reward=tensor_time_step_spec.reward,
        discount=tensor_time_step_spec.discount,
    )
    table_spec = tf.nest.map_structure(
        lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape),
        self._data_spec)
    self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec)

    # Initialize and start a Reverb server (and set up a client to it).
    self._table_name = 'test_table'
    uniform_table = reverb.Table(
        self._table_name,
        max_size=100,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=table_spec,
    )
    self._server = reverb.Server([uniform_table])
    self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
Exemple #2
0
    def test_dataset_with_variable_sequence_length_truncates(self):
        spec = tf.TensorSpec((), tf.int64)
        table_spec = tf.TensorSpec((None, ), tf.int64)
        table = reverb.Table(
            name=self._table_name,
            sampler=reverb.selectors.Fifo(),
            remover=reverb.selectors.Fifo(),
            max_times_sampled=1,
            max_size=100,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=table_spec,
        )
        server = reverb.Server([table])
        py_client = reverb.Client('localhost:{}'.format(server.port))

        # Insert two episodes: one of length 3 and one of length 5
        with py_client.trajectory_writer(10) as writer:
            writer.append(1)
            writer.append(2)
            writer.append(3)
            writer.create_item(self._table_name,
                               trajectory=writer.history[-3:],
                               priority=5)

        with py_client.trajectory_writer(10) as writer:
            writer.append(10)
            writer.append(20)
            writer.append(30)
            writer.append(40)
            writer.append(50)
            writer.create_item(self._table_name,
                               trajectory=writer.history[-5:],
                               priority=5)

        replay = reverb_replay_buffer.ReverbReplayBuffer(
            spec,
            self._table_name,
            local_server=server,
            sequence_length=None,
            rate_limiter_timeout_ms=100)
        ds = replay.as_dataset(single_deterministic_pass=True, num_steps=2)
        it = iter(ds)

        # Expect [1, 2]
        data, _ = next(it)
        self.assertAllEqual(data, [1, 2])

        # Expect [10, 20]
        data, _ = next(it)
        self.assertAllEqual(data, [10, 20])

        # Expect [30, 40]
        data, _ = next(it)
        self.assertAllEqual(data, [30, 40])

        with self.assertRaises(StopIteration):
            next(it)
Exemple #3
0
def collect(summary_dir: Text,
            environment_name: Text,
            collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase,
            replay_buffer_server_address: Text,
            variable_container_server_address: Text,
            suite_load_fn: Callable[
                [Text], py_environment.PyEnvironment] = suite_mujoco.load,
            initial_collect_steps: int = 10000,
            max_train_steps: int = 2000000) -> None:
  """Collects experience using a policy updated after every episode."""
  # Create the environment. For now support only single environment collection.
  collect_env = suite_load_fn(environment_name)

  # Create the variable container.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the replay buffer observer.
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb.Client(replay_buffer_server_address),
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      sequence_length=2,
      stride_length=1)

  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  env_step_metric = py_metrics.EnvironmentSteps()
  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=1,
      metrics=actor.collect_metrics(10),
      summary_dir=summary_dir,
      observers=[rb_observer, env_step_metric])

  # Run the experience collection loop.
  while train_step.numpy() < max_train_steps:
    logging.info('Collecting with policy at step: %d', train_step.numpy())
    collect_actor.run()
    variable_container.update(variables)
Exemple #4
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      builder: builders.ActorLearnerBuilder,
      networks: Any,
      policy_network: actors.FeedForwardPolicy,
      min_replay_size: int = 1000,
      samples_per_insert: float = 256.0,
      batch_size: int = 256,
      num_sgd_steps_per_step: int = 1,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
    ):
    """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      min_replay_size: minimum replay size before updating.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """
    # Create the replay server and grab its address.
    replay_tables = builder.make_replay_tables(environment_spec)
    replay_server = reverb.Server(replay_tables, port=None)
    replay_client = reverb.Client(f'localhost:{replay_server.port}')

    # Create actor, dataset, and learner for generating, storing, and consuming
    # data respectively.
    adder = builder.make_adder(replay_client)
    dataset = builder.make_dataset_iterator(replay_client)
    learner = builder.make_learner(networks=networks, dataset=dataset,
                                   replay_client=replay_client, counter=counter,
                                   logger=logger, checkpoint=checkpoint)
    actor = builder.make_actor(policy_network, adder, variable_source=learner)

    effective_batch_size = batch_size * num_sgd_steps_per_step
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(effective_batch_size, min_replay_size),
        observations_per_step=float(effective_batch_size) / samples_per_insert)

    # Save the replay so we don't garbage collect it.
    self._replay_server = replay_server
Exemple #5
0
def trainer_main_tf_dataset(perwez_url, config):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    # init reverb
    reverb_client = reverb.Client(f"localhost:{PORT}")
    # reverb dataset
    def _make_dataset(_):
        dataset = reverb.ReplayDataset(
            f"localhost:{PORT}",
            TABLE_NAME,
            max_in_flight_samples_per_worker=config["common"]["batch_size"],
            dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32),
            shapes=(
                tf.TensorShape((4, 84, 84)),
                tf.TensorShape([]),
                tf.TensorShape([]),
                tf.TensorShape((4, 84, 84)),
                tf.TensorShape([]),
            ),
        )
        dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True)
        return dataset

    num_parallel_calls = 16
    prefetch_size = 4
    dataset = tf.data.Dataset.range(num_parallel_calls)
    dataset = dataset.interleave(
        map_func=_make_dataset,
        cycle_length=num_parallel_calls,
        num_parallel_calls=num_parallel_calls,
        deterministic=False,
    )
    dataset = dataset.prefetch(prefetch_size)
    numpy_iter = dataset.as_numpy_iterator()

    trainer = get_trainer(config)
    sync_weights_interval = config["common"]["sync_weights_interval"]
    ts = 0
    while True:
        ts += 1

        info, data = next(numpy_iter)
        indices = info.key
        weights = info.probability
        weights = (weights / weights.min()) ** (-0.4)

        loss = trainer.step(data, weights=weights)
        reverb_client.mutate_priorities(
            TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss)))
        )

        if ts % sync_weights_interval == 0:
            weight_send.send(trainer.save_weights().getbuffer())
 def _push_nested_data(self, server_address: Optional[Text] = None) -> None:
     # Create Python client.
     address = server_address or self._server_address
     client = reverb.Client(address)
     with client.writer(1) as writer:
         writer.append([
             np.array(0, dtype=np.int64),
             np.array([1, 1], dtype=np.float64),
             np.array([[2], [3]], dtype=np.int32)
         ])
         writer.create_item(reverb_variable_container.DEFAULT_TABLE, 1, 1.0)
     self.assertEqual(
         client.server_info()[
             reverb_variable_container.DEFAULT_TABLE].current_size, 1)
Exemple #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    config = _CONFIG.value

    # Connect to reverb
    reverb_client = reverb.Client(FLAGS.reverb_address)

    puddles = arenas.get_arena(config.env.pw.arena)
    pw = puddle_world.PuddleWorld(puddles=puddles,
                                  goal_position=geometry.Point((1.0, 1.0)))
    dpw = pw_utils.DiscretizedPuddleWorld(pw, config.env.pw.num_bins)

    if FLAGS.eval:
        eval_worker(dpw, reverb_client, config)
    else:
        train_worker(dpw, reverb_client, config)
Exemple #8
0
    def __init__(
        self,
        actor_id,
        environment_module,
        environment_fn_name,
        environment_kwargs,
        network_module,
        network_fn_name,
        network_kwargs,
        adder_module,
        adder_fn_name,
        adder_kwargs,
        replay_server_address,
        variable_server_name,
        variable_server_address,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
    ):
        # Counter and Logger
        self._actor_id = actor_id
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            f'actor_{actor_id}')

        # Create the environment
        self._environment = getattr(environment_module,
                                    environment_fn_name)(**environment_kwargs)
        env_spec = acme.make_environment_spec(self._environment)

        # Create actor's network
        self._network = getattr(network_module,
                                network_fn_name)(**network_kwargs)
        tf2_utils.create_variables(self._network, [env_spec.observations])

        self._variables = tree.flatten(self._network.variables)
        self._policy = tf.function(self._network)

        # The adder is used to insert observations into replay.
        self._adder = getattr(adder_module, adder_fn_name)(
            reverb.Client(replay_server_address), **adder_kwargs)

        variable_client = reverb.TFClient(variable_server_address)
        self._variable_dataset = variable_client.dataset(
            table=variable_server_name,
            dtypes=[tf.float32 for _ in self._variables],
            shapes=[v.shape for v in self._variables])
  def _assert_nested_variable_in_server(self,
                                        server_address: Optional[Text] = None
                                       ) -> None:
    # Create Python client.
    address = server_address or self._server_address
    client = reverb.Client(address)
    self.assertEqual(
        client.server_info()[
            reverb_variable_container.DEFAULT_TABLE].current_size, 1)

    # Draw one sample from the server using the Python client.
    content = next(
        iter(client.sample(reverb_variable_container.DEFAULT_TABLE, 1)))[0].data
    # Internally in Reverb the data is stored in the form of flat numpy lists.
    self.assertLen(content, 3)
    self.assertAllEqual(content[0], np.array(0, dtype=np.int64))
    self.assertAllClose(content[1], np.array([1, 1], dtype=np.float64))
    self.assertAllEqual(content[2], np.array([[2], [3]], dtype=np.int32))
Exemple #10
0
    def __init__(self, env_name,
                 buffer_table_name, buffer_server_port, buffer_min_size,
                 n_steps=2,
                 data=None, make_sparse=False):
        # environments; their hyperparameters
        self._train_env = gym.make(env_name)
        self._eval_env = gym.make(env_name)
        self._n_outputs = self._train_env.action_space.n  # number of actions
        self._input_shape = self._train_env.observation_space.shape

        # data contains weighs, masks, and a corresponding reward
        self._data = data
        self._is_sparse = make_sparse
        assert not (not data and make_sparse), "Making a sparse model needs data of weights and mask"

        # networks
        self._model = None
        self._target_model = None

        # fraction of random exp sampling
        self._epsilon = 0.1

        # hyperparameters for optimization
        self._optimizer = keras.optimizers.Adam(lr=1e-3)
        self._loss_fn = keras.losses.mean_squared_error

        # buffer; hyperparameters for a reward calculation
        self._table_name = buffer_table_name
        # an object with a client, which is used to store data on a server
        self._replay_memory_client = reverb.Client(f'localhost:{buffer_server_port}')
        # make a batch size equal of a minimal size of a buffer
        self._sample_batch_size = buffer_min_size
        self._n_steps = n_steps  # 1. amount of steps stored per item, it should be at least 2;
        # 2. for details see function _collect_trajectories_from_episode()
        # initialize a dataset to be used to sample data from a server
        self._dataset = storage.initialize_dataset(buffer_server_port, buffer_table_name,
                                                   self._input_shape, self._sample_batch_size, self._n_steps)
        self._iterator = iter(self._dataset)
        self._discount_rate = tf.constant(0.95, dtype=tf.float32)
        self._items_sampled = 0
Exemple #11
0
def worker_main(perwez_url, config, idx):
    reverb_client = reverb.Client(f"localhost:{PORT}")
    reverb_writer = reverb_client.writer(1)
    weight_recv = perwez.RecvSocket(perwez_url, "weight", broadcast=True)

    batch_size = config["common"]["batch_size"]
    num_workers = config["common"]["num_workers"]
    eps = 0.4 ** (1 + (idx / (num_workers - 1)) * 7)
    solver = get_solver(config, device="cpu")
    log_flag = idx >= num_workers + (-num_workers // 3)  # aligned with ray
    worker = get_worker(
        config,
        exploration=eps,
        solver=solver,
        logger=getLogger(f"worker{idx}") if log_flag else None,
    )

    while True:
        # load weights
        if not weight_recv.empty():
            worker.load_weights(io.BytesIO(weight_recv.recv()))

        # step
        data = worker.step_batch(batch_size)
        loss = worker.solver.calc_loss(data)
        # format
        s0, a, r, s1, done = data
        s0 = np.asarray(s0, dtype="f4")
        a = np.asarray(a, dtype="i8")
        r = np.asarray(r, dtype="f4")
        s1 = np.asarray(s1, dtype="f4")
        done = np.asarray(done, dtype="f4")
        loss = np.asarray(loss, dtype="f4")
        # upload
        for i, _ in enumerate(s0):
            reverb_writer.append([s0[i], a[i], r[i], s1[i], done[i]])
            reverb_writer.create_item(
                table=TABLE_NAME, num_timesteps=1, priority=loss[i]
            )
Exemple #12
0
    def learner(self, replay: reverb.Client, counter: counting.Counter):
        """The Learning part of the agent."""

        # Create the networks.
        network = self._network_factory(self._env_spec.actions)
        target_network = copy.deepcopy(network)

        tf2_utils.create_variables(network, [self._env_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [self._env_spec.observations])

        # The dataset object to learn from.
        replay_client = reverb.Client(replay.server_address)
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        logger = loggers.make_default_logger('learner',
                                             steps_key='learner_steps')

        # Return the learning agent.
        counter = counting.Counter(counter, 'learner')

        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=self._discount,
            importance_sampling_exponent=self._importance_sampling_exponent,
            learning_rate=self._learning_rate,
            target_update_period=self._target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            counter=counter,
            logger=logger)
        return tf2_savers.CheckpointingRunner(learner,
                                              subdirectory='dqn_learner',
                                              time_delta_minutes=60)
Exemple #13
0
def trainer_main_np_client(perwez_url, config):
    weight_send = perwez.SendSocket(perwez_url, "weight", broadcast=True)
    # init reverb
    reverb_client = reverb.Client(f"localhost:{PORT}")

    trainer = get_trainer(config)
    sync_weights_interval = config["common"]["sync_weights_interval"]
    ts = 0
    while True:
        ts += 1

        samples = reverb_client.sample(TABLE_NAME, config["common"]["batch_size"])
        samples = list(samples)
        data, indices, weights = _reverb_samples_to_ndarray(samples)
        weights = (weights / weights.min()) ** (-0.4)

        loss = trainer.step(data, weights=weights)
        reverb_client.mutate_priorities(
            TABLE_NAME, updates=dict(zip(np.asarray(indices), np.asarray(loss)))
        )

        if ts % sync_weights_interval == 0:
            weight_send.send(trainer.save_weights().getbuffer())
Exemple #14
0
def make_reverb_online_queue(
    environment_spec: specs.EnvironmentSpec,
    extra_spec: Dict[str, Any],
    max_queue_size: int,
    sequence_length: int,
    sequence_period: int,
    batch_size: int,
    replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
) -> ReverbReplay:
    """Creates a single process queue from an environment spec and extra_spec."""
    signature = adders.SequenceAdder.signature(environment_spec, extra_spec)
    queue = reverb.Table.queue(name=replay_table_name,
                               max_size=max_queue_size,
                               signature=signature)
    server = reverb.Server([queue], port=None)
    can_sample = lambda: queue.can_sample(batch_size)

    # Component to add things into replay.
    address = f'localhost:{server.port}'
    adder = adders.SequenceAdder(
        client=reverb.Client(address),
        period=sequence_period,
        sequence_length=sequence_length,
    )

    # The dataset object to learn from.
    # We don't use datasets.make_reverb_dataset() here to avoid interleaving
    # and prefetching, that doesn't work well with can_sample() check on update.
    dataset = reverb.ReplayDataset.from_table_signature(
        server_address=address,
        table=replay_table_name,
        max_in_flight_samples_per_worker=1,
        sequence_length=sequence_length,
        emit_timesteps=False)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    data_iterator = dataset.as_numpy_iterator()
    return ReverbReplay(server, adder, data_iterator, can_sample=can_sample)
Exemple #15
0
 def build(self):
     """Creates reverb server, client and dataset."""
     self._reverb_server = reverb.Server(
         tables=[
             reverb.Table(
                 name="replay_buffer",
                 sampler=reverb.selectors.Uniform(),
                 remover=reverb.selectors.Fifo(),
                 max_size=self._max_replay_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1),
                 signature=self._signature,
             ),
         ],
         port=None,
     )
     self._reverb_client = reverb.Client(f"localhost:{self._reverb_server.port}")
     self._reverb_dataset = reverb.TrajectoryDataset.from_table_signature(
         server_address=f"localhost:{self._reverb_server.port}",
         table="replay_buffer",
         max_in_flight_samples_per_worker=2 * self._batch_size,
     )
     self._batched_dataset = self._reverb_dataset.batch(
         self._batch_size, drop_remainder=True
     ).as_numpy_iterator()
def main(_):
  environment = fakes.ContinuousEnvironment(action_dim=8,
                                            observation_dim=87,
                                            episode_length=10000000)
  spec = specs.make_environment_spec(environment)
  replay_tables = make_replay_tables(spec)
  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')
  adder = make_adder(replay_client)

  timestep = environment.reset()
  adder.add_first(timestep)
  # TODO(raveman): Consider also filling the table to say 1M (too slow).
  for steps in range(10000):
    if steps % 1000 == 0:
      logging.info('Processed %s steps', steps)
    action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32)
    next_timestep = environment.step(action)
    adder.add(action, next_timestep, extras=())

  for batch_size in [256, 256 * 8, 256 * 64]:
    for prefetch_size in [0, 1, 4]:
      print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
      ds = datasets.make_reverb_dataset(
          table='default',
          server_address=replay_client.server_address,
          batch_size=batch_size,
          prefetch_size=prefetch_size,
      )
      it = ds.as_numpy_iterator()

      for iteration in range(3):
        t = time.time()
        for _ in range(1000):
          _ = next(it)
        print(f'Iteration {iteration} finished in {time.time() - t}s')
Exemple #17
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemple #18
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        prior_network: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        prior_optimizer: Optional[snt.Optimizer] = None,
        distillation_cost: Optional[float] = 1e-3,
        entropy_regularizer_cost: Optional[float] = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        sequence_length: int = 10,
        sigma: float = 0.3,
        replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      prior_network: an optional `behavior prior` to regularize against.
      policy_optimizer: optimizer for the policy network updates.
      critic_optimizer: optimizer for the critic network updates.
      prior_optimizer: optimizer for the prior network updates.
      distillation_cost: a multiplier to be used when adding distillation
        against the prior to the losses.
      entropy_regularizer_cost: a multiplier used for per state sample based
        entropy added to the actor loss.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      sequence_length: number of timesteps to store for each trajectory.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      replay_table_name: string indicating what name to give the replay table.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """
        # Create the Builder object which will internally create agent components.
        builder = SVG0Builder(
            # TODO(mwhoffman): pass the config dataclass in directly.
            # TODO(mwhoffman): use the limiter rather than the workaround below.
            # Right now this modifies min_replay_size and samples_per_insert so that
            # they are not controlled by a limiter and are instead handled by the
            # Agent base class (the above TODO directly references this behavior).
            SVG0Config(
                discount=discount,
                batch_size=batch_size,
                prefetch_size=prefetch_size,
                target_update_period=target_update_period,
                policy_optimizer=policy_optimizer,
                critic_optimizer=critic_optimizer,
                prior_optimizer=prior_optimizer,
                distillation_cost=distillation_cost,
                entropy_regularizer_cost=entropy_regularizer_cost,
                min_replay_size=1,  # Let the Agent class handle this.
                max_replay_size=max_replay_size,
                samples_per_insert=None,  # Let the Agent class handle this.
                sequence_length=sequence_length,
                sigma=sigma,
                replay_table_name=replay_table_name,
            ))

        # TODO(mwhoffman): pass the network dataclass in directly.
        online_networks = SVG0Networks(
            policy_network=policy_network,
            critic_network=critic_network,
            prior_network=prior_network,
        )

        # Target networks are just a copy of the online networks.
        target_networks = copy.deepcopy(online_networks)

        # Initialize the networks.
        online_networks.init(environment_spec)
        target_networks.init(environment_spec)

        # TODO(mwhoffman): either make this Dataclass or pass only one struct.
        # The network struct passed to make_learner is just a tuple for the
        # time-being (for backwards compatibility).
        networks = (online_networks, target_networks)

        # Create the behavior policy.
        policy_network = online_networks.make_policy()

        # Create the replay server and grab its address.
        replay_tables = builder.make_replay_tables(environment_spec,
                                                   sequence_length)
        replay_server = reverb.Server(replay_tables, port=None)
        replay_client = reverb.Client(f'localhost:{replay_server.port}')

        # Create actor, dataset, and learner for generating, storing, and consuming
        # data respectively.
        adder = builder.make_adder(replay_client)
        actor = builder.make_actor(policy_network, adder)
        dataset = builder.make_dataset_iterator(replay_client)
        learner = builder.make_learner(networks, dataset, counter, logger,
                                       checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)

        # Save the replay so we don't garbage collect it.
        self._replay_server = replay_server
Exemple #19
0
    """Creates a single-process replay infrastructure from an environment spec."""
    # Create a replay server to add data to. This uses no limiter behavior in
    # order to allow the Agent interface to handle it.
    replay_table = reverb.Table(
        name=replay_table_name,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_replay_size),
        signature=adders.NStepTransitionAdder.signature(
            environment_spec=environment_spec))
    server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{server.port}'
    client = reverb.Client(address)
    adder = adders.NStepTransitionAdder(client=client,
                                        n_step=n_step,
                                        discount=discount)

    # The dataset provides an interface to sample from replay.
    data_iterator = datasets.make_reverb_dataset(
        table=replay_table_name,
        server_address=address,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        environment_spec=environment_spec,
        transition_adder=True,
    ).as_numpy_iterator()
    return ReverbReplay(server, adder, data_iterator, client)
Exemple #20
0
def _reverb_client(port):
    return reverb.Client('localhost:{}'.format(port))
Exemple #21
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 encoder_network: types.TensorTransformation = tf.identity,
                 entropy_coeff: float = 0.01,
                 target_update_period: int = 0,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 policy_learn_rate: float = 3e-4,
                 critic_learn_rate: float = 5e-4,
                 prefetch_size: int = 4,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 250000,
                 samples_per_insert: float = 64.0,
                 n_step: int = 5,
                 sigma: float = 0.5,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.

        dim_actions = np.prod(environment_spec.actions.shape, dtype=int)
        extra_spec = {
            'logP': tf.ones(shape=(1), dtype=tf.float32),
            'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec, extras_spec=extra_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(table=replay_table_name,
                                               server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Make sure observation network is a Sonnet Module.
        observation_network = model.MDPNormalization(environment_spec,
                                                     encoder_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations

        # Create the behavior policy.
        sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma)
        self._behavior_network = model.PolicyValueBehaviorNet(
            snt.Sequential([observation_network, policy_network]),
            sampling_head)

        # Create variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

        # Create the actor which defines how we take actions.
        actor = model.SACFeedForwardActor(self._behavior_network, adder)

        if target_update_period > 0:
            target_policy_network = copy.deepcopy(policy_network)
            target_critic_network = copy.deepcopy(critic_network)
            target_observation_network = copy.deepcopy(observation_network)

            tf2_utils.create_variables(target_policy_network, [emb_spec])
            tf2_utils.create_variables(target_critic_network,
                                       [emb_spec, act_spec])
            tf2_utils.create_variables(target_observation_network, [obs_spec])
        else:
            target_policy_network = policy_network
            target_critic_network = critic_network
            target_observation_network = observation_network

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate)
        critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate)

        # The learner updates the parameters (and initializes them).
        learner = learning.SACLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            sampling_head=sampling_head,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            target_update_period=target_update_period,
            learning_rate=policy_learn_rate,
            clipping=clipping,
            entropy_coeff=entropy_coeff,
            discount=discount,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #22
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
Exemple #23
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 observation_network: types.TensorTransformation = tf.identity,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 prefetch_size: int = 4,
                 target_update_period: int = 100,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0,
                 n_step: int = 5,
                 sigma: float = 0.3,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])  # pytype: disable=wrong-arg-types

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.ClippedGaussian(sigma),
            networks.ClipToSpec(act_spec),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(behavior_network, adder=adder)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.DDPGLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            target_update_period=target_update_period,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #24
0
 def __init__(self, variable_server_name, variable_server_address):
   self._variable_server_name = variable_server_name
   self._variable_client = reverb.Client(variable_server_address),
Exemple #25
0
# Initalize Reverb Server
server = reverb.Server(tables=[
    reverb.Table(name='ReplayBuffer',
                 sampler=reverb.selectors.Uniform(),
                 remover=reverb.selectors.Fifo(),
                 max_size=buffer_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1)),
    reverb.Table(name='PrioritizedReplayBuffer',
                 sampler=reverb.selectors.Prioritized(alpha),
                 remover=reverb.selectors.Fifo(),
                 max_size=buffer_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1))
])

client = reverb.Client(f"localhost:{server.port}")
tf_client = reverb.TFClient(f"localhost:{server.port}")


# Helper Function
def env(n):
    e = {
        "obs": np.ones((n, obs_shape)),
        "act": np.zeros((n, act_shape)),
        "next_obs": np.ones((n, obs_shape)),
        "rew": np.zeros(n),
        "done": np.zeros(n)
    }
    return e

Exemple #26
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        demonstration_dataset: tf.data.Dataset,
        demonstration_ratio: float,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      demonstration_dataset: tf.data.Dataset producing (timestep, action)
        tuples containing full episodes.
      demonstration_ratio: Ratio of transitions coming from demonstrations.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            transition_adder=True)

        # Combine with demonstration dataset.
        transition = functools.partial(_n_step_transition_from_episode,
                                       n_step=n_step,
                                       discount=discount)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(prefetch_size)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = dqn.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #27
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        observation_network: types.TensorTransformation = tf.identity,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_policy_update_period: int = 100,
        target_critic_update_period: int = 100,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        policy_loss_module: snt.Module = None,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        n_step: int = 5,
        num_samples: int = 20,
        clipping: bool = True,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint: bool = True,
        replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.MPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #28
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        cql_alpha: float = 1.,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint_subpath: str = '~/acme/',
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
      checkpoint_subpath: directory for the checkpoint.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = CQLLearner(
            network=network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            cql_alpha=cql_alpha,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            counter=counter,
            checkpoint_subpath=checkpoint_subpath)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #29
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #30
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.01,
        epsilon_schedule_timesteps: float = 20000,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        if store_lstm_state:
            extra_spec = {
                'core_state':
                tf2_utils.squeeze_batch_dim(network.initial_state(1)),
            }
        else:
            extra_spec = ()

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        self._adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = make_reverb_dataset(server_address=address,
                                      batch_size=batch_size,
                                      prefetch_size=prefetch_size,
                                      sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._saver = tf2_savers.Saver(learner.state)

        policy_network = snt.DeepRNN([
            network,
            EpsilonGreedyExploration(
                epsilon_init=epsilon_init,
                epsilon_final=epsilon_final,
                epsilon_schedule_timesteps=epsilon_schedule_timesteps)
        ])
        actor = actors.RecurrentActor(policy_network,
                                      self._adder,
                                      store_recurrent_state=store_lstm_state)

        max_Q_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(),
        ])
        self._deterministic_actor = actors.RecurrentActor(
            max_Q_network, self._adder, store_recurrent_state=store_lstm_state)

        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)