Exemple #1
0
    def __init__(self, name="ReverbUniformReplayBuffer", reverb_server=None):
        super().__init__(name=name)
        self.device = Params.DEVICE

        with tf.device(self.device), self.name_scope:

            self.buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64)
            self.batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64)
            self.batch_size_float = tf.cast(Params.MINIBATCH_SIZE, tf.float64)
            self.sequence_length = tf.cast(Params.N_STEP_RETURNS, tf.int64)

            # Initialize the reverb server
            if not reverb_server:
                self.reverb_server = reverb.Server(tables=[
                    reverb.Table(
                        name=Params.BUFFER_TYPE,
                        sampler=reverb.selectors.Uniform(),
                        remover=reverb.selectors.Fifo(),
                        max_size=self.buffer_size,
                        rate_limiter=reverb.rate_limiters.MinSize(
                            self.batch_size),
                    )
                ], )
            else:
                self.reverb_server = reverb_server

            dataset = reverb.ReplayDataset(
                server_address=f'localhost:{self.reverb_server.port}',
                table=Params.BUFFER_TYPE,
                max_in_flight_samples_per_worker=2 * self.batch_size,
                dtypes=Params.BUFFER_DATA_SPEC_DTYPES,
                shapes=Params.BUFFER_DATA_SPEC_SHAPES,
            )

            dataset = dataset.map(
                map_func=reduce_trajectory,
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
                deterministic=True,
            )
            dataset = dataset.batch(self.batch_size)
            dataset = dataset.prefetch(5)
            self.iterator = dataset.__iter__()
def _create_server(
    table: Text = reverb_variable_container.DEFAULT_TABLE,
    max_size: int = 1,
    signature: types.NestedTensorSpec = (tf.TensorSpec((), tf.int64), {
        'var1': (tf.TensorSpec((2), tf.float64), ),
        'var2':
        tf.TensorSpec((2, 1), tf.int32)
    })
) -> Tuple[reverb.Server, Text]:
    server = reverb.Server(tables=[
        reverb.Table(name=table,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     rate_limiter=reverb.rate_limiters.MinSize(1),
                     max_size=max_size,
                     max_times_sampled=0,
                     signature=signature)
    ],
                           port=portpicker.pick_unused_port())
    return server, 'localhost:{}'.format(server.port)
 def __init__(self,
              num_tables: int = 1,
              min_size: int = 64,
              max_size: int = 100000,
              checkpointer=None):
     self._min_size = min_size
     self._table_names = [f"uniform_table_{i}" for i in range(num_tables)]
     self._server = reverb.Server(
         tables=[
             reverb.Table(
                 name=self._table_names[i],
                 sampler=reverb.selectors.Uniform(),
                 remover=reverb.selectors.Fifo(),
                 max_size=int(max_size),
                 rate_limiter=reverb.rate_limiters.MinSize(min_size),
             ) for i in range(num_tables)
         ],
         # Sets the port to None to make the server pick one automatically.
         port=None,
         checkpointer=checkpointer)
Exemple #4
0
def create_reverb_server_for_replay_buffer_and_variable_container(
        collect_policy, train_step, replay_buffer_capacity, port):
    """Sets up one reverb server for replay buffer and variable container."""
    # Create the signature for the variable container holding the policy weights.
    variables = {
        reverb_variable_container.POLICY_KEY: collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container_signature = tf.nest.map_structure(
        lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
        variables)

    # Create the signature for the replay buffer holding observed experience.
    replay_buffer_signature = tensor_spec.from_spec(
        collect_policy.collect_data_spec)

    # Crete and start the replay buffer and variable container server.
    server = reverb.Server(
        tables=[
            reverb.Table(  # Replay buffer storing experience.
                name=reverb_replay_buffer.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                # TODO(b/159073060): Set rate limiter for SAC properly.
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=replay_buffer_capacity,
                max_times_sampled=0,
                signature=replay_buffer_signature,
            ),
            reverb.Table(  # Variable container storing policy parameters.
                name=reverb_variable_container.DEFAULT_TABLE,
                sampler=reverb.selectors.Uniform(),
                remover=reverb.selectors.Fifo(),
                rate_limiter=reverb.rate_limiters.MinSize(1),
                max_size=1,
                max_times_sampled=0,
                signature=variable_container_signature,
            ),
        ],
        port=port)
    return server
    def test_uniform_table(self):
        table_name = 'test_uniform_table'
        queue_table = reverb.Table(
            table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=1000,
            rate_limiter=reverb.rate_limiters.MinSize(3))
        reverb_server = reverb.Server([queue_table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            local_server=reverb_server,
            sequence_length=1,
            dataset_buffer_size=1)

        with replay.py_client.trajectory_writer(
                num_keep_alive_refs=1) as writer:
            for i in range(3):
                writer.append(i)
                trajectory = writer.history[-1:]
                writer.create_item(table_name,
                                   trajectory=trajectory,
                                   priority=1)

        dataset = replay.as_dataset(sample_batch_size=1,
                                    num_steps=None,
                                    num_parallel_calls=1)

        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(1000):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x1.
            counts[int(item_0)] += 1

        # Comparing against 200 to avoid flakyness
        self.assertGreater(counts[0], 200)
        self.assertGreater(counts[1], 200)
        self.assertGreater(counts[2], 200)
Exemple #6
0
def test_reverb(data):
    import reverb
    import tensorflow as tf

    print("TEST REVERB")
    print("initializing...")
    reverb_server = reverb.Server(
        tables=[
            reverb.Table(
                name="req",
                sampler=reverb.selectors.Prioritized(0.6),
                remover=reverb.selectors.Fifo(),
                max_size=CAPACITY,
                rate_limiter=reverb.rate_limiters.MinSize(100),
            )
        ],
        port=15867,
    )
    client = reverb_server.in_process_client()
    for i in range(CAPACITY):
        client.insert([col[i] for col in data], {"req": np.random.rand()})
    dataset = reverb.ReplayDataset(
        server_address="localhost:15867",
        table="req",
        dtypes=(tf.float64, tf.float64, tf.float64, tf.float64),
        shapes=(
            tf.TensorShape([1, 84, 84]),
            tf.TensorShape([1, 84, 84]),
            tf.TensorShape([]),
            tf.TensorShape([]),
        ),
        max_in_flight_samples_per_worker=10,
    )
    dataset = dataset.batch(BATCH_SIZE)
    print("ready")
    t0 = time.perf_counter()
    for sample in dataset.take(TEST_CNT):
        pass
    t1 = time.perf_counter()
    print(TEST_CNT, t1 - t0)
    def test_prioritized_table_max_sample(self):
        table_name = 'test_prioritized_table'
        table = reverb.Table(table_name,
                             sampler=reverb.selectors.Prioritized(1.0),
                             remover=reverb.selectors.Fifo(),
                             max_times_sampled=10,
                             rate_limiter=reverb.rate_limiters.MinSize(1),
                             max_size=3)
        reverb_server = reverb.Server([table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            sequence_length=1,
            local_server=reverb_server,
            dataset_buffer_size=1)

        with replay.py_client.trajectory_writer(1) as writer:
            for i in range(3):
                writer.append(i)
                writer.create_item(table_name,
                                   trajectory=writer.history[-1:],
                                   priority=i)

        dataset = replay.as_dataset(sample_batch_size=3, num_parallel_calls=3)

        self.assertTrue(table.can_sample(3))
        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(10):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x3.
            for item in item_0:
                counts[int(item)] += 1
        self.assertFalse(table.can_sample(3))

        # Same number of counts due to limit on max_times_sampled
        self.assertEqual(counts[0], 10)  # priority 0
        self.assertEqual(counts[1], 10)  # priority 1
        self.assertEqual(counts[2], 10)  # priority 2
Exemple #8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    eval_table_size = _CONFIG.value.num_eval_points

    # TODO(joshgreaves): Choose an appropriate rate_limiter, max_size.
    server = reverb.Server(tables=[
        reverb.Table(name='successor_table',
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     max_size=1_000_000,
                     rate_limiter=reverb.rate_limiters.MinSize(20_000)),
        reverb.Table(
            name='eval_table',
            sampler=reverb.selectors.Fifo(),
            remover=reverb.selectors.Fifo(),
            max_size=eval_table_size,
            rate_limiter=reverb.rate_limiters.MinSize(eval_table_size)),
    ],
                           port=FLAGS.port)
    server.wait()
Exemple #9
0
def make_reverb_online_queue(
    environment_spec: specs.EnvironmentSpec,
    extra_spec: Dict[str, Any],
    max_queue_size: int,
    sequence_length: int,
    sequence_period: int,
    batch_size: int,
    replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
) -> ReverbReplay:
    """Creates a single process queue from an environment spec and extra_spec."""
    signature = adders.SequenceAdder.signature(environment_spec, extra_spec)
    queue = reverb.Table.queue(name=replay_table_name,
                               max_size=max_queue_size,
                               signature=signature)
    server = reverb.Server([queue], port=None)
    can_sample = lambda: queue.can_sample(batch_size)

    # Component to add things into replay.
    address = f'localhost:{server.port}'
    adder = adders.SequenceAdder(
        client=reverb.Client(address),
        period=sequence_period,
        sequence_length=sequence_length,
    )

    # The dataset object to learn from.
    # We don't use datasets.make_reverb_dataset() here to avoid interleaving
    # and prefetching, that doesn't work well with can_sample() check on update.
    dataset = reverb.ReplayDataset.from_table_signature(
        server_address=address,
        table=replay_table_name,
        max_in_flight_samples_per_worker=1,
        sequence_length=sequence_length,
        emit_timesteps=False)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    data_iterator = dataset.as_numpy_iterator()
    return ReverbReplay(server, adder, data_iterator, can_sample=can_sample)
    def test_prioritized_table(self):
        table_name = 'test_prioritized_table'
        queue_table = reverb.Table(
            table_name,
            sampler=reverb.selectors.Prioritized(1.0),
            remover=reverb.selectors.Fifo(),
            rate_limiter=reverb.rate_limiters.MinSize(1),
            max_size=3)
        reverb_server = reverb.Server([queue_table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            sequence_length=1,
            local_server=reverb_server,
            dataset_buffer_size=1)

        with replay.py_client.writer(max_sequence_length=1) as writer:
            for i in range(3):
                writer.append(i)
                writer.create_item(table=table_name,
                                   num_timesteps=1,
                                   priority=i)

        dataset = replay.as_dataset(sample_batch_size=1,
                                    num_steps=None,
                                    num_parallel_calls=None)

        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(1000):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x1.
            counts[int(item_0)] += 1

        self.assertEqual(counts[0], 0)  # priority 0
        self.assertGreater(counts[1], 250)  # priority 1
        self.assertGreater(counts[2], 600)  # priority 2
def main(_):
  environment = fakes.ContinuousEnvironment(action_dim=8,
                                            observation_dim=87,
                                            episode_length=10000000)
  spec = specs.make_environment_spec(environment)
  replay_tables = make_replay_tables(spec)
  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')
  adder = make_adder(replay_client)

  timestep = environment.reset()
  adder.add_first(timestep)
  # TODO(raveman): Consider also filling the table to say 1M (too slow).
  for steps in range(10000):
    if steps % 1000 == 0:
      logging.info('Processed %s steps', steps)
    action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32)
    next_timestep = environment.step(action)
    adder.add(action, next_timestep, extras=())

  for batch_size in [256, 256 * 8, 256 * 64]:
    for prefetch_size in [0, 1, 4]:
      print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
      ds = datasets.make_reverb_dataset(
          table='default',
          server_address=replay_client.server_address,
          batch_size=batch_size,
          prefetch_size=prefetch_size,
      )
      it = ds.as_numpy_iterator()

      for iteration in range(3):
        t = time.time()
        for _ in range(1000):
          _ = next(it)
        print(f'Iteration {iteration} finished in {time.time() - t}s')
Exemple #12
0
 def build(self):
     """Creates reverb server, client and dataset."""
     self._reverb_server = reverb.Server(
         tables=[
             reverb.Table(
                 name="replay_buffer",
                 sampler=reverb.selectors.Uniform(),
                 remover=reverb.selectors.Fifo(),
                 max_size=self._max_replay_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1),
                 signature=self._signature,
             ),
         ],
         port=None,
     )
     self._reverb_client = reverb.Client(f"localhost:{self._reverb_server.port}")
     self._reverb_dataset = reverb.TrajectoryDataset.from_table_signature(
         server_address=f"localhost:{self._reverb_server.port}",
         table="replay_buffer",
         max_in_flight_samples_per_worker=2 * self._batch_size,
     )
     self._batched_dataset = self._reverb_dataset.batch(
         self._batch_size, drop_remainder=True
     ).as_numpy_iterator()
Exemple #13
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 observation_network: types.TensorTransformation = tf.identity,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 prefetch_size: int = 4,
                 target_update_period: int = 100,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0,
                 n_step: int = 5,
                 sigma: float = 0.3,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])  # pytype: disable=wrong-arg-types

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.ClippedGaussian(sigma),
            networks.ClipToSpec(act_spec),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(behavior_network, adder=adder)

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=1e-4)
        critic_optimizer = snt.optimizers.Adam(learning_rate=1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.DDPGLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            target_update_period=target_update_period,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #14
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
Exemple #15
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        demonstration_dataset: tf.data.Dataset,
        demonstration_ratio: float,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      demonstration_dataset: tf.data.Dataset producing (timestep, action)
        tuples containing full episodes.
      demonstration_ratio: Ratio of transitions coming from demonstrations.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            transition_adder=True)

        # Combine with demonstration dataset.
        transition = functools.partial(_n_step_transition_from_episode,
                                       n_step=n_step,
                                       discount=discount)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(prefetch_size)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = dqn.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #16
0
    def __init__(
        self,
        # ---
        env_name: str,
        port: int,
        # ---
        actor_units: list,
        clip_mean_min: float,
        clip_mean_max: float,
        init_noise: float,
        # ---
        min_replay_size: int,
        max_replay_size: int,
        samples_per_insert: int,
        # ---
        db_path: str,
    ):
        super(Server, self).__init__(env_name, False)
        self._port = port

        # Init actor's network
        self.actor = Actor(
            units=actor_units,
            n_outputs=np.prod(self._env.action_space.shape),
            clip_mean_min=clip_mean_min,
            clip_mean_max=clip_mean_max,
            init_noise=init_noise,
        )
        self.actor.build((None,) + self._env.observation_space.shape)

        # Show models details
        self.actor.summary()

        # Variables
        self._train_step = tf.Variable(
            0,
            trainable=False,
            dtype=tf.uint64,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
            shape=(),
        )
        self._stop_agents = tf.Variable(
            False,
            trainable=False,
            dtype=tf.bool,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
            shape=(),
        )

        # Table for storing variables
        self._variable_container = VariableContainer(
            db_server=f"localhost:{self._port}",
            table="variable",
            variables={
                "train_step": self._train_step,
                "stop_agents": self._stop_agents,
                "policy_variables": self.actor.variables,
            },
        )

        # Load DB from checkpoint or make a new one
        if db_path is None:
            checkpointer = None
        else:
            checkpointer = reverb.checkpointers.DefaultCheckpointer(path=db_path)

        if samples_per_insert:
            # 10% tolerance in rate
            samples_per_insert_tolerance = 0.1 * samples_per_insert
            error_buffer = min_replay_size * samples_per_insert_tolerance
            limiter = reverb.rate_limiters.SampleToInsertRatio(
                min_size_to_sample=min_replay_size,
                samples_per_insert=samples_per_insert,
                error_buffer=error_buffer,
            )
        else:
            limiter = reverb.rate_limiters.MinSize(min_replay_size)

        # Initialize the reverb server
        self.server = reverb.Server(
            tables=[
                reverb.Table(  # Replay buffer
                    name="experience",
                    sampler=reverb.selectors.Uniform(),
                    remover=reverb.selectors.Fifo(),
                    rate_limiter=limiter,
                    max_size=max_replay_size,
                    max_times_sampled=0,
                    signature={
                        "observation": tf.TensorSpec(
                            [*self._env.observation_space.shape],
                            self._env.observation_space.dtype,
                        ),
                        "action": tf.TensorSpec(
                            [*self._env.action_space.shape],
                            self._env.action_space.dtype,
                        ),
                        "reward": tf.TensorSpec([1], tf.float32),
                        "next_observation": tf.TensorSpec(
                            [*self._env.observation_space.shape],
                            self._env.observation_space.dtype,
                        ),
                        "terminal": tf.TensorSpec([1], tf.bool),
                    },
                ),
                reverb.Table(  # Variables container
                    name="variable",
                    sampler=reverb.selectors.Uniform(),
                    remover=reverb.selectors.Fifo(),
                    rate_limiter=reverb.rate_limiters.MinSize(1),
                    max_size=1,
                    max_times_sampled=0,
                    signature=self._variable_container.signature,
                ),
            ],
            port=self._port,
            checkpointer=checkpointer,
        )

        # Init variable container in DB
        self._variable_container.push_variables()
Exemple #17
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    # Training params
    num_iterations=1600,
    actor_fc_layers=(64, 64),
    value_fc_layers=(64, 64),
    learning_rate=3e-4,
    collect_sequence_length=2048,
    minibatch_size=64,
    num_epochs=10,
    # Agent params
    importance_ratio_clipping=0.2,
    lambda_value=0.95,
    discount_factor=0.99,
    entropy_regularization=0.,
    value_pred_loss_coef=0.5,
    use_gae=True,
    use_td_lambda_return=True,
    gradient_clipping=0.5,
    value_clipping=None,
    # Replay params
    reverb_port=None,
    replay_capacity=10000,
    # Others
    policy_save_interval=5000,
    summary_interval=1000,
    eval_interval=10000,
    eval_episodes=100,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """Trains and evaluates PPO (Importance Ratio Clipping).

  Args:
    root_dir: Main directory path where checkpoints, saved_models, and summaries
      will be written to.
    env_name: Name for the Mujoco environment to load.
    num_iterations: The number of iterations to perform collection and training.
    actor_fc_layers: List of fully_connected parameters for the actor network,
      where each item is the number of units in the layer.
    value_fc_layers: : List of fully_connected parameters for the value network,
      where each item is the number of units in the layer.
    learning_rate: Learning rate used on the Adam optimizer.
    collect_sequence_length: Number of steps to take in each collect run.
    minibatch_size: Number of elements in each mini batch. If `None`, the entire
      collected sequence will be treated as one batch.
    num_epochs: Number of iterations to repeat over all collected data per data
      collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for
      Roboschool and 3 for Atari.
    importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For
      more detail, see explanation at the top of the doc.
    lambda_value: Lambda parameter for TD-lambda computation.
    discount_factor: Discount factor for return computation. Default to `0.99`
      which is the value used for all environments from (Schulman, 2017).
    entropy_regularization: Coefficient for entropy regularization loss term.
      Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
    value_pred_loss_coef: Multiplier for value prediction loss to balance with
      policy gradient loss. Default to `0.5`, which was used for all
      environments in the OpenAI baseline implementation. This parameters is
      irrelevant unless you are sharing part of actor_net and value_net. In that
      case, you would want to tune this coeeficient, whose value depends on the
      network architecture of your choice.
    use_gae: If True (default False), uses generalized advantage estimation for
      computing per-timestep advantage. Else, just subtracts value predictions
      from empirical return.
    use_td_lambda_return: If True (default False), uses td_lambda_return for
      training value function; here: `td_lambda_return = gae_advantage +
        value_predictions`. `use_gae` must be set to `True` as well to enable TD
        -lambda returns. If `use_td_lambda_return` is set to True while
        `use_gae` is False, the empirical return will be used and a warning will
        be logged.
    gradient_clipping: Norm length to clip gradients.
    value_clipping: Difference between new and old value predictions are clipped
      to this threshold. Value clipping could be helpful when training
      very deep networks. Default: no clipping.
    reverb_port: Port for reverb server, if None, use a randomly chosen unused
      port.
    replay_capacity: The maximum number of elements for the replay buffer. Items
      will be wasted if this is smalled than collect_sequence_length.
    policy_save_interval: How often, in train_steps, the policy will be saved.
    summary_interval: How often to write data into Tensorboard.
    eval_interval: How often to run evaluation, in train_steps.
    eval_episodes: Number of episodes to evaluate over.
    debug_summaries: Boolean for whether to gather debug summaries.
    summarize_grads_and_vars: If true, gradient summaries will be written.
  """
  collect_env = suite_mujoco.load(env_name)
  eval_env = suite_mujoco.load(env_name)
  num_environments = 1

  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))
  # TODO(b/172267869): Remove this conversion once TensorNormalizer stops
  # converting float64 inputs to float32.
  observation_tensor_spec = tf.TensorSpec(
      dtype=tf.float32, shape=observation_tensor_spec.shape)

  train_step = train_utils.create_train_step()
  actor_net_builder = ppo_actor_network.PPOActorNetwork()
  actor_net = actor_net_builder.create_sequential_actor_net(
      actor_fc_layers, action_tensor_spec)
  value_net = value_network.ValueNetwork(
      observation_tensor_spec,
      fc_layer_params=value_fc_layers,
      kernel_initializer=tf.keras.initializers.Orthogonal())

  current_iteration = tf.Variable(0, dtype=tf.int64)
  def learning_rate_fn():
    # Linearly decay the learning rate.
    return learning_rate * (1 - current_iteration / num_iterations)

  agent = ppo_clip_agent.PPOClipAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      optimizer=tf.keras.optimizers.Adam(
          learning_rate=learning_rate_fn, epsilon=1e-5),
      actor_net=actor_net,
      value_net=value_net,
      importance_ratio_clipping=importance_ratio_clipping,
      lambda_value=lambda_value,
      discount_factor=discount_factor,
      entropy_regularization=entropy_regularization,
      value_pred_loss_coef=value_pred_loss_coef,
      # This is a legacy argument for the number of times we repeat the data
      # inside of the train function, incompatible with mini batch learning.
      # We set the epoch number from the replay buffer and tf.Data instead.
      num_epochs=1,
      use_gae=use_gae,
      use_td_lambda_return=use_td_lambda_return,
      gradient_clipping=gradient_clipping,
      value_clipping=value_clipping,
      # TODO(b/150244758): Default compute_value_and_advantage_in_train to False
      # after Reverb open source.
      compute_value_and_advantage_in_train=False,
      # Skips updating normalizers in the agent, as it's handled in the learner.
      update_normalizers_in_train=False,
      debug_summaries=debug_summaries,
      summarize_grads_and_vars=summarize_grads_and_vars,
      train_step_counter=train_step)
  agent.initialize()

  reverb_server = reverb.Server(
      [
          reverb.Table(  # Replay buffer storing experience for training.
              name='training_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          ),
          reverb.Table(  # Replay buffer storing experience for normalization.
              name='normalization_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          )
      ],
      port=reverb_port)

  # Create the replay buffer.
  reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='training_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)
  reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='normalization_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)

  rb_observer = reverb_utils.ReverbTrajectorySequenceObserver(
      reverb_replay_train.py_client, ['training_table', 'normalization_table'],
      sequence_length=collect_sequence_length,
      stride_length=collect_sequence_length)

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  collect_env_step_metric = py_metrics.EnvironmentSteps()
  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={
              triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
          }),
      triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
  ]

  def training_dataset_fn():
    return reverb_replay_train.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  def normalization_dataset_fn():
    return reverb_replay_normalization.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  agent_learner = ppo_learner.PPOLearner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn=training_dataset_fn,
      normalization_dataset_fn=normalization_dataset_fn,
      num_samples=1,
      num_epochs=num_epochs,
      minibatch_size=minibatch_size,
      shuffle_buffer_size=collect_sequence_length,
      triggers=learning_triggers)

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
      tf_collect_policy, use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=collect_sequence_length,
      observers=[rb_observer],
      metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
      reference_metrics=[collect_env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
      summary_interval=summary_interval)

  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
      agent.policy, use_tf_function=True)

  if eval_interval:
    logging.info('Intial evaluation.')
    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        metrics=actor.eval_metrics(eval_episodes),
        reference_metrics=[collect_env_step_metric],
        summary_dir=os.path.join(root_dir, 'eval'),
        episodes_per_run=eval_episodes)

    eval_actor.run_and_log()

  logging.info('Training on %s', env_name)
  last_eval_step = 0
  for i in range(num_iterations):
    collect_actor.run()
    rb_observer.flush()
    agent_learner.run()
    reverb_replay_train.clear()
    reverb_replay_normalization.clear()
    current_iteration.assign_add(1)

    # Eval only if `eval_interval` has been set. Then, eval if the current train
    # step is equal or greater than the `last_eval_step` + `eval_interval` or if
    # this is the last iteration. This logic exists because agent_learner.run()
    # does not return after every train step.
    if (eval_interval and
        (agent_learner.train_step_numpy >= eval_interval + last_eval_step
         or i == num_iterations - 1)):
      logging.info('Evaluating.')
      eval_actor.run_and_log()
      last_eval_step = agent_learner.train_step_numpy

  rb_observer.close()
  reverb_server.stop()
Exemple #18
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemple #19
0
def train_eval(
        root_dir,
        env_name,
        # Training params
        train_sequence_length,
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_iterations=100000,
        # RNN params.
        q_network_fn=q_lstm_network,  # defaults to q_lstm_network.
        # Agent params
    epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""

    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    q_net = q_network_fn(num_actions=num_actions)

    sequence_length = train_sequence_length + 1
    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        # n-step updates aren't supported with RNNs yet.
        n_step_update=1,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=sequence_length,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=sequence_length,
        stride_length=1,
        pad_end_of_episodes=True)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=sequence_length)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=collect_steps_per_iteration,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Exemple #20
0
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    # TODO: add code to instantiate the training and evaluation environments

    # TODO: add code to create a reinforcement learning agent that is going to be trained

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server,
    )

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME,
        REPLAY_BUFFER_CAPACITY)

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy,
                                              NUM_EVAL_EPISODES)

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # TODO: add code to collect game episodes and train the agent

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES)
            with summary_writer.as_default():
                tf.summary.scalar("Average return", avg_return, step=i)
                tf.summary.scalar("Average episode length",
                                  avg_episode_length,
                                  step=i)
                summary_writer.flush()
            logger.info(
                "iteration = {0}: Average Return = {1}, Average Episode Length = {2}"
                .format(i, avg_return, avg_episode_length))

    summary_writer.close()

    tf_policy_saver.save(policydir)
Exemple #21
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 encoder_network: types.TensorTransformation = tf.identity,
                 entropy_coeff: float = 0.01,
                 target_update_period: int = 0,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 policy_learn_rate: float = 3e-4,
                 critic_learn_rate: float = 5e-4,
                 prefetch_size: int = 4,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 250000,
                 samples_per_insert: float = 64.0,
                 n_step: int = 5,
                 sigma: float = 0.5,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.

        dim_actions = np.prod(environment_spec.actions.shape, dtype=int)
        extra_spec = {
            'logP': tf.ones(shape=(1), dtype=tf.float32),
            'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec, extras_spec=extra_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(table=replay_table_name,
                                               server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Make sure observation network is a Sonnet Module.
        observation_network = model.MDPNormalization(environment_spec,
                                                     encoder_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations

        # Create the behavior policy.
        sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma)
        self._behavior_network = model.PolicyValueBehaviorNet(
            snt.Sequential([observation_network, policy_network]),
            sampling_head)

        # Create variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

        # Create the actor which defines how we take actions.
        actor = model.SACFeedForwardActor(self._behavior_network, adder)

        if target_update_period > 0:
            target_policy_network = copy.deepcopy(policy_network)
            target_critic_network = copy.deepcopy(critic_network)
            target_observation_network = copy.deepcopy(observation_network)

            tf2_utils.create_variables(target_policy_network, [emb_spec])
            tf2_utils.create_variables(target_critic_network,
                                       [emb_spec, act_spec])
            tf2_utils.create_variables(target_observation_network, [obs_spec])
        else:
            target_policy_network = policy_network
            target_critic_network = critic_network
            target_observation_network = observation_network

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate)
        critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate)

        # The learner updates the parameters (and initializes them).
        learner = learning.SACLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            sampling_head=sampling_head,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            target_update_period=target_update_period,
            learning_rate=policy_learn_rate,
            clipping=clipping,
            entropy_coeff=entropy_coeff,
            discount=discount,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #22
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        cql_alpha: float = 1.,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint_subpath: str = '~/acme/',
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
      checkpoint_subpath: directory for the checkpoint.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = CQLLearner(
            network=network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            cql_alpha=cql_alpha,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            counter=counter,
            checkpoint_subpath=checkpoint_subpath)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #23
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        # Training params
        initial_collect_steps=1000,
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Agent params
        epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""
    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())
    action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())

    train_step = train_utils.create_train_step()
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_in', distribution='truncated_normal'))

    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # it's output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03,
                                                               maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential(dense_layers + [q_values_layer])

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=1,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2)
    eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2)

    train_env = tf_py_environment.TFPyEnvironment(train_py_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # Alternatively you could use ActorDistributionNetwork as actor_net
    actor_net = tfa.networks.Sequential(
        [
            tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE],
                                          [BOARD_SIZE**2]),
            tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'),
            tf.keras.layers.Dense(BOARD_SIZE**2),
            tf.keras.layers.Lambda(
                lambda t: tfp.distributions.Categorical(logits=t)),
        ],
        input_spec=train_py_env.observation_spec())

    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

    train_step_counter = tf.Variable(0)

    tf_agent = reinforce_agent.ReinforceAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        optimizer=optimizer,
        normalize_returns=True,
        train_step_counter=train_step_counter)

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server)

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME,
        REPLAY_BUFFER_CAPACITY)

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy,
                                              NUM_EVAL_EPISODES)

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(train_py_env, collect_policy,
                        COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer)

        # Use data from the buffer and update the agent's network.
        iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
        trajectories, _ = next(iterator)
        tf_agent.train(experience=trajectories)
        replay_buffer.clear()

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES)
            with summary_writer.as_default():
                tf.summary.scalar('Average return', avg_return, step=i)
                tf.summary.scalar('Average episode length',
                                  avg_episode_length,
                                  step=i)
                summary_writer.flush()
            logger.info(
                'iteration = {0}: Average Return = {1}, Average Episode Length = {2}'
                .format(i, avg_return, avg_episode_length))

    summary_writer.close()

    tf_policy_saver.save(policydir)
    # Convert to tflite model
    converter = tf.lite.TFLiteConverter.from_saved_model(
        policydir, signature_keys=['action'])
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
    ]
    tflite_policy = converter.convert()
    with open(os.path.join(modeldir, 'planestrike_tf_agents.tflite'),
              'wb') as f:
        f.write(tflite_policy)
Exemple #25
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemple #26
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        builder: builders.ActorLearnerBuilder,
        networks: Any,
        policy_network: Any,
        min_replay_size: int = 1000,
        samples_per_insert: float = 256.0,
        batch_size: int = 256,
        num_sgd_steps_per_step: int = 1,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      min_replay_size: minimum replay size before updating.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """
        # Create the replay server and grab its address.
        replay_tables = builder.make_replay_tables(environment_spec)
        replay_server = reverb.Server(replay_tables, port=None)
        replay_client = reverb.Client(f'localhost:{replay_server.port}')

        # Create actor, dataset, and learner for generating, storing, and consuming
        # data respectively.
        adder = builder.make_adder(replay_client)
        dataset = builder.make_dataset_iterator(replay_client)
        learner = builder.make_learner(networks=networks,
                                       dataset=dataset,
                                       replay_client=replay_client,
                                       counter=counter,
                                       logger=logger,
                                       checkpoint=checkpoint)
        actor = builder.make_actor(policy_network,
                                   adder,
                                   variable_source=learner)

        effective_batch_size = batch_size * num_sgd_steps_per_step
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(effective_batch_size,
                                              min_replay_size),
                         observations_per_step=float(effective_batch_size) /
                         samples_per_insert)

        # Save the replay so we don't garbage collect it.
        self._replay_server = replay_server
Exemple #27
0
    return avg_return.numpy()[0]


# Standard implementations for evaluation metrics in the metrics module.

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
table = reverb.Table(table_name,
                     max_size=replay_buffer_capacity,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     rate_limiter=reverb.rate_limiters.MinSize(1),
                     signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddEpisodeObserver(replay_buffer.py_client,
                                                    table_name,
                                                    replay_buffer_capacity)


def collect_episode(environment, policy, num_episodes):

    driver = py_driver.PyDriver(environment,
Exemple #28
0
    discount: float = 1.,
    replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    prefetch_size: int = 4,
) -> ReverbReplay:
    """Creates a single-process replay infrastructure from an environment spec."""
    # Create a replay server to add data to. This uses no limiter behavior in
    # order to allow the Agent interface to handle it.
    replay_table = reverb.Table(
        name=replay_table_name,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_replay_size),
        signature=adders.NStepTransitionAdder.signature(
            environment_spec=environment_spec))
    server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{server.port}'
    client = reverb.Client(address)
    adder = adders.NStepTransitionAdder(client=client,
                                        n_step=n_step,
                                        discount=discount)

    # The dataset provides an interface to sample from replay.
    data_iterator = datasets.make_reverb_dataset(
        table=replay_table_name,
        server_address=address,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        environment_spec=environment_spec,
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Exemple #30
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        prior_network: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        prior_optimizer: Optional[snt.Optimizer] = None,
        distillation_cost: Optional[float] = 1e-3,
        entropy_regularizer_cost: Optional[float] = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        sequence_length: int = 10,
        sigma: float = 0.3,
        replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      prior_network: an optional `behavior prior` to regularize against.
      policy_optimizer: optimizer for the policy network updates.
      critic_optimizer: optimizer for the critic network updates.
      prior_optimizer: optimizer for the prior network updates.
      distillation_cost: a multiplier to be used when adding distillation
        against the prior to the losses.
      entropy_regularizer_cost: a multiplier used for per state sample based
        entropy added to the actor loss.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      sequence_length: number of timesteps to store for each trajectory.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      replay_table_name: string indicating what name to give the replay table.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """
        # Create the Builder object which will internally create agent components.
        builder = SVG0Builder(
            # TODO(mwhoffman): pass the config dataclass in directly.
            # TODO(mwhoffman): use the limiter rather than the workaround below.
            # Right now this modifies min_replay_size and samples_per_insert so that
            # they are not controlled by a limiter and are instead handled by the
            # Agent base class (the above TODO directly references this behavior).
            SVG0Config(
                discount=discount,
                batch_size=batch_size,
                prefetch_size=prefetch_size,
                target_update_period=target_update_period,
                policy_optimizer=policy_optimizer,
                critic_optimizer=critic_optimizer,
                prior_optimizer=prior_optimizer,
                distillation_cost=distillation_cost,
                entropy_regularizer_cost=entropy_regularizer_cost,
                min_replay_size=1,  # Let the Agent class handle this.
                max_replay_size=max_replay_size,
                samples_per_insert=None,  # Let the Agent class handle this.
                sequence_length=sequence_length,
                sigma=sigma,
                replay_table_name=replay_table_name,
            ))

        # TODO(mwhoffman): pass the network dataclass in directly.
        online_networks = SVG0Networks(
            policy_network=policy_network,
            critic_network=critic_network,
            prior_network=prior_network,
        )

        # Target networks are just a copy of the online networks.
        target_networks = copy.deepcopy(online_networks)

        # Initialize the networks.
        online_networks.init(environment_spec)
        target_networks.init(environment_spec)

        # TODO(mwhoffman): either make this Dataclass or pass only one struct.
        # The network struct passed to make_learner is just a tuple for the
        # time-being (for backwards compatibility).
        networks = (online_networks, target_networks)

        # Create the behavior policy.
        policy_network = online_networks.make_policy()

        # Create the replay server and grab its address.
        replay_tables = builder.make_replay_tables(environment_spec,
                                                   sequence_length)
        replay_server = reverb.Server(replay_tables, port=None)
        replay_client = reverb.Client(f'localhost:{replay_server.port}')

        # Create actor, dataset, and learner for generating, storing, and consuming
        # data respectively.
        adder = builder.make_adder(replay_client)
        actor = builder.make_actor(policy_network, adder)
        dataset = builder.make_dataset_iterator(replay_client)
        learner = builder.make_learner(networks, dataset, counter, logger,
                                       checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)

        # Save the replay so we don't garbage collect it.
        self._replay_server = replay_server