def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        if batch is None:
            if log_once("empty_batch_added_to_buffer"):
                logger.info(
                    "A batch that is `None` was added to {}. This can be "
                    "normal at the beginning of execution but might "
                    "indicate an issue.".format(type(self).__name__))
            return
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_underlying_buffer(policy_id, sample_batch,
                                                   **kwargs)
        self._num_added += batch.count
    def add_batch(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self.replay_buffers[_ALL_POLICIES].add_batch(batch)
                self.last_added_batches[_ALL_POLICIES].append(batch)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self.replay_buffers[policy_id].add_batch(sample_batch)
                    self.last_added_batches[policy_id].append(sample_batch)

        self.num_added += batch.count
Exemple #3
0
    def add(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            # Lockstep mode: Store under _ALL_POLICIES key (we will always
            # only sample from all policies at the same time).
            if self.replay_mode == "lockstep":
                # Note that prioritization is not supported in this mode.
                for s in batch.timeslices(self.replay_sequence_length):
                    self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
                    self.last_added_batches[_ALL_POLICIES].append(s)
            else:
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_policy_buffer(policy_id, sample_batch)
                    self.last_added_batches[policy_id].append(sample_batch)
        self._num_added += batch.count
Exemple #4
0
    def __call__(self, samples: SampleBatchType) -> SampleBatchType:
        _check_sample_batch_type(samples)
        wrapped = False

        if isinstance(samples, SampleBatch):
            samples = samples.as_multi_agent()
            wrapped = True

        for policy_id in samples.policy_batches:
            batch = samples.policy_batches[policy_id]
            for field in self.fields:
                if field not in batch:
                    raise KeyError(
                        f"`{field}` not found in SampleBatch for policy "
                        f"`{policy_id}`! Maybe this policy fails to add "
                        f"{field} in its `postprocess_trajectory` method? Or "
                        "this policy is not meant to learn at all and you "
                        "forgot to add it to the list under `config."
                        "multiagent.policies_to_train`.")
                batch[field] = standardized(batch[field])

        if wrapped:
            samples = samples.policy_batches[DEFAULT_POLICY_ID]

        return samples
Exemple #5
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            ``**kwargs``: Forward compatibility kwargs.
        """
        if batch is None:
            if log_once("empty_batch_added_to_buffer"):
                logger.info(
                    "A batch that is `None` was added to {}. This can be "
                    "normal at the beginning of execution but might "
                    "indicate an issue.".format(type(self).__name__))
            return
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            pids_and_batches = self._maybe_split_into_policy_batches(batch)
            for policy_id, sample_batch in pids_and_batches.items():
                self._add_to_underlying_buffer(policy_id, sample_batch,
                                               **kwargs)

        self._num_added += batch.count
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch : The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            if self.replay_mode == ReplayMode.LOCKSTEP:
                # Lockstep mode: Store under _ALL_POLICIES key (we will always
                # only sample from all policies at the same time).
                # This means storing a MultiAgentBatch to the underlying buffer
                self._add_to_underlying_buffer(_ALL_POLICIES, batch, **kwargs)
            else:
                # Store independent SampleBatches
                for policy_id, sample_batch in batch.policy_batches.items():
                    self._add_to_underlying_buffer(policy_id, sample_batch,
                                                   **kwargs)
        self._num_added += batch.count
    def add_batch(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch.

        Args:
            batch (SampleBatchType): The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            # Lockstep mode: Store under _ALL_POLICIES key (we will always
            # only sample from all policies at the same time).
            if self.replay_mode == "lockstep":
                # Note that prioritization is not supported in this mode.
                for s in batch.timeslices(self.replay_sequence_length):
                    self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
            else:
                for policy_id, sample_batch in batch.policy_batches.items():
                    if self.replay_sequence_length == 1:
                        timeslices = sample_batch.timeslices(1)
                    else:
                        timeslices = timeslice_along_seq_lens_with_overlap(
                            sample_batch=sample_batch,
                            zero_pad_max_seq_len=self.replay_sequence_length,
                            pre_overlap=self.replay_burn_in,
                            zero_init_states=self.replay_zero_init_states,
                        )
                    for time_slice in timeslices:
                        # If SampleBatch has prio-replay weights, average
                        # over these to use as a weight for the entire
                        # sequence.
                        if "weights" in time_slice and len(
                                time_slice["weights"]):
                            weight = np.mean(time_slice["weights"])
                        else:
                            weight = None
                        self.replay_buffers[policy_id].add(time_slice,
                                                           weight=weight)
        self.num_added += batch.count
Exemple #8
0
    def add_batch(self, batch: SampleBatchType) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        batch = batch.as_multi_agent()

        with self.add_batch_timer:
            for policy_id, sample_batch in batch.policy_batches.items():
                self.replay_buffers[policy_id].add_batch(sample_batch)
                self.last_added_batches[policy_id].append(sample_batch)
        self.num_added += batch.count
Exemple #9
0
def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
    """Standardize fields of the given SampleBatch"""
    _check_sample_batch_type(samples)
    wrapped = False

    if isinstance(samples, SampleBatch):
        samples = samples.as_multi_agent()
        wrapped = True

    for policy_id in samples.policy_batches:
        batch = samples.policy_batches[policy_id]
        for field in fields:
            if field in batch:
                batch[field] = standardized(batch[field])

    if wrapped:
        samples = samples.policy_batches[DEFAULT_POLICY_ID]

    return samples
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # We need to split batches into timesteps, sequences or episodes
        # here already to properly keep track of self.last_added_batches
        # underlying buffers should not split up the batch any further
        with self.add_batch_timer:
            if self._storage_unit == StorageUnit.TIMESTEPS:
                for policy_id, sample_batch in batch.policy_batches.items():
                    if self.replay_sequence_length == 1:
                        timeslices = sample_batch.timeslices(1)
                    else:
                        timeslices = timeslice_along_seq_lens_with_overlap(
                            sample_batch=sample_batch,
                            zero_pad_max_seq_len=self.replay_sequence_length,
                            pre_overlap=self.replay_burn_in,
                            zero_init_states=self.replay_zero_init_states,
                        )
                    for time_slice in timeslices:
                        self.replay_buffers[policy_id].add(
                            time_slice, **kwargs)
                        self.last_added_batches[policy_id].append(time_slice)
            elif self._storage_unit == StorageUnit.SEQUENCES:
                timestep_count = 0
                for policy_id, sample_batch in batch.policy_batches.items():
                    for seq_len in sample_batch.get(SampleBatch.SEQ_LENS):
                        start_seq = timestep_count
                        end_seq = timestep_count + seq_len
                        self.replay_buffers[policy_id].add(
                            sample_batch[start_seq:end_seq], **kwargs)
                        self.last_added_batches[policy_id].append(
                            sample_batch[start_seq:end_seq])
                        timestep_count = end_seq
            elif self._storage_unit == StorageUnit.EPISODES:
                for policy_id, sample_batch in batch.policy_batches.items():
                    for eps in sample_batch.split_by_episode():
                        # Only add full episodes to the buffer
                        if (eps.get(SampleBatch.T)[0] == 0 and eps.get(
                                SampleBatch.DONES)[-1] == True  # noqa E712
                            ):
                            self.replay_buffers[policy_id].add(eps, **kwargs)
                            self.last_added_batches[policy_id].append(eps)
                        else:
                            if log_once("only_full_episodes"):
                                logger.info(
                                    "This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        self._num_added += batch.count
Exemple #11
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multi agent.
        samples = samples.as_multi_agent()

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_samples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if not self.local_worker.is_policy_to_train(
                        policy_id, samples):
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                # Load the entire train batch into the Policy's only buffer
                # (idx=0). Policies only have >1 buffers, if we are training
                # asynchronously.
                num_loaded_samples[policy_id] = self.local_worker.policy_map[
                    policy_id].load_batch_into_buffer(batch, buffer_index=0)

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            # Use LearnerInfoBuilder as a unified way to build the final
            # results dict from `learn_on_loaded_batch` call(s).
            # This makes sure results dicts always have the same structure
            # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
            # tf vs torch).
            learner_info_builder = LearnerInfoBuilder(
                num_devices=len(self.devices))

            for policy_id, samples_per_device in num_loaded_samples.items():
                policy = self.local_worker.policy_map[policy_id]
                num_batches = max(
                    1,
                    int(samples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        # Learn on the pre-loaded data in the buffer.
                        # Note: For minibatch SGD, the data is an offset into
                        # the pre-loaded entire train batch.
                        results = policy.learn_on_loaded_batch(
                            permutation[batch_index] *
                            self.per_device_batch_size,
                            buffer_index=0,
                        )

                        learner_info_builder.add_learn_on_batch_results(
                            results, policy_id)

            # Tower reduce and finalize results.
            learner_info = learner_info_builder.finalize()

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = learner_info

        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.local_worker.get_policies_to_train()))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, learner_info