Exemple #1
0
    def add(self, batch: SampleBatchType):
        """Splits a SampleBatch into episodes and adds episodes
        to the episode buffer

        Args:
            batch: SampleBatch to be added
        """

        self.timesteps += batch.count
        episodes = batch.split_by_episode()

        for i, e in enumerate(episodes):
            episodes[i] = self.preprocess_episode(e)
        self.episodes.extend(episodes)

        if len(self.episodes) > self.max_length:
            delta = len(self.episodes) - self.max_length
            # Drop oldest episodes
            self.episodes = self.episodes[delta:]
Exemple #2
0
    def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
        self.check_can_estimate_for(batch)
        estimates = []
        for sub_batch in batch.split_by_episode():
            rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"]
            new_prob = np.exp(self.action_log_likelihood(sub_batch))

            # calculate importance ratios
            p = []
            for t in range(sub_batch.count):
                if t == 0:
                    pt_prev = 1.0
                else:
                    pt_prev = p[t - 1]
                p.append(pt_prev * new_prob[t] / old_prob[t])
            for t, v in enumerate(p):
                if t >= len(self.filter_values):
                    self.filter_values.append(v)
                    self.filter_counts.append(1.0)
                else:
                    self.filter_values[t] += v
                    self.filter_counts[t] += 1.0

            # calculate stepwise weighted IS estimate
            v_old = 0.0
            v_new = 0.0
            for t in range(sub_batch.count):
                v_old += rewards[t] * self.gamma ** t
                w_t = self.filter_values[t] / self.filter_counts[t]
                v_new += p[t] / w_t * rewards[t] * self.gamma ** t

            estimates.append(
                OffPolicyEstimate(
                    self.name,
                    {
                        "v_old": v_old,
                        "v_new": v_new,
                        "v_gain": v_new / max(1e-8, v_old),
                    },
                )
            )
        return estimates
Exemple #3
0
    def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(batch)
        metrics = _get_shared_metrics()
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        lw = self.local_worker
        with learn_timer:
            # Subsample minibatches (size=`sgd_minibatch_size`) from the
            # train batch and loop through train batch `num_sgd_iter` times.
            if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
                learner_info = do_minibatch_sgd(
                    batch,
                    {
                        pid: lw.get_policy(pid)
                        for pid in self.policies or lw.get_policies_to_train(batch)
                    },
                    lw,
                    self.num_sgd_iter,
                    self.sgd_minibatch_size,
                    [],
                )
            # Single update step using train batch.
            else:
                learner_info = lw.learn_on_batch(batch)

            metrics.info[LEARNER_INFO] = learner_info
            learn_timer.push_units_processed(batch.count)
        metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
        if isinstance(batch, MultiAgentBatch):
            metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(
                    lw.get_weights(self.policies or lw.get_policies_to_train(batch))
                )
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        lw.set_global_vars(_get_global_vars())
        return batch, learner_info
Exemple #4
0
    def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)

        if self.count_steps_by == "env_steps":
            size = batch.count
        else:
            assert isinstance(batch, MultiAgentBatch), (
                "`count_steps_by=agent_steps` only allowed in multi-agent "
                "environments!"
            )
            size = batch.agent_steps()

        # Incoming batch is an empty dummy batch -> Ignore.
        # Possibly produced automatically by a PolicyServer to unblock
        # an external env waiting for inputs from unresponsive/disconnected
        # client(s).
        if size == 0:
            return []

        self.count += size
        self.buffer.append(batch)

        if self.count >= self.min_batch_size:
            if self.count > self.min_batch_size * 2:
                logger.info(
                    "Collected more training samples than expected "
                    "(actual={}, expected={}). ".format(self.count, self.min_batch_size)
                    + "This may be because you have many workers or "
                    "long episodes in 'complete_episodes' batch mode."
                )
            out = SampleBatch.concat_samples(self.buffer)

            perf_counter = time.perf_counter()
            timer = _get_shared_metrics().timers[SAMPLE_TIMER]
            timer.push(perf_counter - self.last_batch_time)
            timer.push_units_processed(self.count)

            self.last_batch_time = perf_counter
            self.buffer = []
            self.count = 0
            return [out]
        return []
Exemple #5
0
def standardize_fields(samples: SampleBatchType,
                       fields: List[str]) -> SampleBatchType:
    """Standardize fields of the given SampleBatch"""
    _check_sample_batch_type(samples)
    wrapped = False

    if isinstance(samples, SampleBatch):
        samples = samples.as_multi_agent()
        wrapped = True

    for policy_id in samples.policy_batches:
        batch = samples.policy_batches[policy_id]
        for field in fields:
            if field in batch:
                batch[field] = standardized(batch[field])

    if wrapped:
        samples = samples.policy_batches[DEFAULT_POLICY_ID]

    return samples
 def _add_to_policy_buffer(self, policy_id: PolicyID,
                           batch: SampleBatchType) -> None:
     if self.replay_sequence_length == 1:
         timeslices = batch.timeslices(1)
     else:
         timeslices = timeslice_along_seq_lens_with_overlap(
             sample_batch=batch,
             zero_pad_max_seq_len=self.replay_sequence_length,
             pre_overlap=self.replay_burn_in,
             zero_init_states=self.replay_zero_init_states,
         )
     for time_slice in timeslices:
         # If SampleBatch has prio-replay weights, average
         # over these to use as a weight for the entire
         # sequence.
         if "weights" in time_slice and len(time_slice["weights"]):
             weight = np.mean(time_slice["weights"])
         else:
             weight = None
         self.replay_buffers[policy_id].add(time_slice, weight=weight)
Exemple #7
0
    def estimate(self, batch: SampleBatchType) -> Dict[str, Any]:
        """Compute off-policy estimates.

        Args:
            batch: The SampleBatch to run off-policy estimation on

        Returns:
            A dict consists of the following metrics:
            - v_behavior: The discounted return averaged over episodes in the batch
            - v_behavior_std: The standard deviation corresponding to v_behavior
            - v_target: The estimated discounted return for `self.policy`,
            averaged over episodes in the batch
            - v_target_std: The standard deviation corresponding to v_target
            - v_gain: v_target / max(v_behavior, 1e-8), averaged over episodes
            - v_gain_std: The standard deviation corresponding to v_gain
        """
        batch = self.convert_ma_batch_to_sample_batch(batch)
        self.check_action_prob_in_batch(batch)
        estimates = {"v_behavior": [], "v_target": [], "v_gain": []}
        # Calculate Direct Method OPE estimates
        for episode in batch.split_by_episode():
            rewards = episode["rewards"]
            v_behavior = 0.0
            v_target = 0.0
            for t in range(episode.count):
                v_behavior += rewards[t] * self.gamma ** t

            init_step = episode[0:1]
            v_target = self.model.estimate_v(init_step)
            v_target = convert_to_numpy(v_target).item()

            estimates["v_behavior"].append(v_behavior)
            estimates["v_target"].append(v_target)
            estimates["v_gain"].append(v_target / max(v_behavior, 1e-8))
        estimates["v_behavior_std"] = np.std(estimates["v_behavior"])
        estimates["v_behavior"] = np.mean(estimates["v_behavior"])
        estimates["v_target_std"] = np.std(estimates["v_target"])
        estimates["v_target"] = np.mean(estimates["v_target"])
        estimates["v_gain_std"] = np.std(estimates["v_gain"])
        estimates["v_gain"] = np.mean(estimates["v_gain"])
        return estimates
Exemple #8
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             lw = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {
                     pid: lw.get_policy(pid)
                     for pid in self.policies
                     or self.local_worker.policies_to_train
                 }, lw, self.num_sgd_iter, self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
             #  callback (shouldn't).
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = extract_stats(
                 info, LEARNER_STATS_KEY)
             metrics.info["custom_metrics"] = extract_stats(
                 info, "custom_metrics")
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if isinstance(batch, MultiAgentBatch):
         metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps(
         )
     # Update weights - after learning on the local worker - on all remote
     # workers.
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies or self.local_worker.policies_to_train))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Exemple #9
0
    def add(self, batch: SampleBatchType, weight: float) -> None:
        """Add a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
            weight: The weight of the added sample used in subsequent sampling
                steps.
        """
        idx = self._next_idx

        assert batch.count > 0, batch
        warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)

        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._next_idx >= len(self._storage):
            self._storage.append(batch)
            self._est_size_bytes += batch.size_bytes()
        else:
            self._storage[self._next_idx] = batch

        # Wrap around storage as a circular buffer once we hit capacity.
        if self._num_timesteps_added_wrap >= self.capacity:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        # Eviction of older samples has already started (buffer is "full").
        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0

        if weight is None:
            weight = self._max_priority
        self._it_sum[idx] = weight**self._alpha
        self._it_min[idx] = weight**self._alpha
Exemple #10
0
    def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
        """Add a SampleBatch of experiences to self._storage.

        An item consists of either one or more timesteps, a sequence or an
        episode. Differs from add() in that it does not consider the storage
        unit or type of batch and simply stores it.

        Args:
            item: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        # Update add counts.
        self._num_add_calls += 1
        # Update our timesteps counts.

        if self._num_timesteps_added < self.capacity:
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            # Eviction of older samples has already started (buffer is "full")
            self._eviction_started = True
            idx = random.randint(0, self._num_add_calls - 1)
            if idx < len(self._storage):
                self._num_evicted += 1
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                # This is a bit of a hack: ReplayBuffer always inserts at
                # self._next_idx
                self._next_idx = idx
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                self._storage[idx] = item

                assert item.count > 0, item
                warn_replay_capacity(item=item,
                                     num_items=self.capacity / item.count)
Exemple #11
0
    def add(self, item: SampleBatchType, weight: float):
        warn_replay_buffer_size(
            item=item, num_items=self._maxsize / item.count)
        assert item.count > 0, item
        self._num_timesteps_added += item.count
        self._num_timesteps_added_wrap += item.count

        if self._next_idx >= len(self._storage):
            self._storage.append(item)
            self._est_size_bytes += item.size_bytes()
        else:
            self._storage[self._next_idx] = item

        # Wrap around storage as a circular buffer once we hit maxsize.
        if self._num_timesteps_added_wrap >= self._maxsize:
            self._eviction_started = True
            self._num_timesteps_added_wrap = 0
            self._next_idx = 0
        else:
            self._next_idx += 1

        if self._eviction_started:
            self._evicted_hit_stats.push(self._hit_count[self._next_idx])
            self._hit_count[self._next_idx] = 0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch to the appropriate policy's replay buffer.

        Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
        it is not a MultiAgentBatch. Subsequently, adds the individual policy
        batches to the storage.

        Args:
            batch: The batch to be added.
            **kwargs: Forward compatibility kwargs.
        """
        # Make a copy so the replay buffer doesn't pin plasma memory.
        batch = batch.copy()
        # Handle everything as if multi-agent.
        batch = batch.as_multi_agent()

        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # We need to split batches into timesteps, sequences or episodes
        # here already to properly keep track of self.last_added_batches
        # underlying buffers should not split up the batch any further
        with self.add_batch_timer:
            if self._storage_unit == StorageUnit.TIMESTEPS:
                for policy_id, sample_batch in batch.policy_batches.items():
                    if self.replay_sequence_length == 1:
                        timeslices = sample_batch.timeslices(1)
                    else:
                        timeslices = timeslice_along_seq_lens_with_overlap(
                            sample_batch=sample_batch,
                            zero_pad_max_seq_len=self.replay_sequence_length,
                            pre_overlap=self.replay_burn_in,
                            zero_init_states=self.replay_zero_init_states,
                        )
                    for time_slice in timeslices:
                        self.replay_buffers[policy_id].add(
                            time_slice, **kwargs)
                        self.last_added_batches[policy_id].append(time_slice)
            elif self._storage_unit == StorageUnit.SEQUENCES:
                timestep_count = 0
                for policy_id, sample_batch in batch.policy_batches.items():
                    for seq_len in sample_batch.get(SampleBatch.SEQ_LENS):
                        start_seq = timestep_count
                        end_seq = timestep_count + seq_len
                        self.replay_buffers[policy_id].add(
                            sample_batch[start_seq:end_seq], **kwargs)
                        self.last_added_batches[policy_id].append(
                            sample_batch[start_seq:end_seq])
                        timestep_count = end_seq
            elif self._storage_unit == StorageUnit.EPISODES:
                for policy_id, sample_batch in batch.policy_batches.items():
                    for eps in sample_batch.split_by_episode():
                        # Only add full episodes to the buffer
                        if (eps.get(SampleBatch.T)[0] == 0 and eps.get(
                                SampleBatch.DONES)[-1] == True  # noqa E712
                            ):
                            self.replay_buffers[policy_id].add(eps, **kwargs)
                            self.last_added_batches[policy_id].append(eps)
                        else:
                            if log_once("only_full_episodes"):
                                logger.info(
                                    "This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")

        self._num_added += batch.count
Exemple #13
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multi agent.
        samples = samples.as_multi_agent()

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_samples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if not self.local_worker.is_policy_to_train(
                        policy_id, samples):
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                # Load the entire train batch into the Policy's only buffer
                # (idx=0). Policies only have >1 buffers, if we are training
                # asynchronously.
                num_loaded_samples[policy_id] = self.local_worker.policy_map[
                    policy_id].load_batch_into_buffer(batch, buffer_index=0)

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            # Use LearnerInfoBuilder as a unified way to build the final
            # results dict from `learn_on_loaded_batch` call(s).
            # This makes sure results dicts always have the same structure
            # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
            # tf vs torch).
            learner_info_builder = LearnerInfoBuilder(
                num_devices=len(self.devices))

            for policy_id, samples_per_device in num_loaded_samples.items():
                policy = self.local_worker.policy_map[policy_id]
                num_batches = max(
                    1,
                    int(samples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        # Learn on the pre-loaded data in the buffer.
                        # Note: For minibatch SGD, the data is an offset into
                        # the pre-loaded entire train batch.
                        results = policy.learn_on_loaded_batch(
                            permutation[batch_index] *
                            self.per_device_batch_size,
                            buffer_index=0,
                        )

                        learner_info_builder.add_learn_on_batch_results(
                            results, policy_id)

            # Tower reduce and finalize results.
            learner_info = learner_info_builder.finalize()

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = learner_info

        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.local_worker.get_policies_to_train()))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, learner_info
Exemple #14
0
 def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
     batch = batch.decompress_if_needed()
     self._mixin_buffer.add_batch(batch)
     processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
     return processed_batches
    def _add_to_underlying_buffer(
        self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
    ) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            **kwargs: Forward compatibility kwargs.
        """
        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self._storage_unit is StorageUnit.TIMESTEPS:
            if self.replay_sequence_length == 1:
                timeslices = batch.timeslices(1)
            else:
                timeslices = timeslice_along_seq_lens_with_overlap(
                    sample_batch=batch,
                    zero_pad_max_seq_len=self.replay_sequence_length,
                    pre_overlap=self.replay_burn_in,
                    zero_init_states=self.replay_zero_init_states,
                )
            for time_slice in timeslices:
                # If SampleBatch has prio-replay weights, average
                # over these to use as a weight for the entire
                # sequence.
                if self.replay_mode is ReplayMode.INDEPENDENT:
                    if "weights" in time_slice and len(time_slice["weights"]):
                        weight = np.mean(time_slice["weights"])
                    else:
                        weight = None

                    if "weight" in kwargs and weight is not None:
                        if log_once("overwrite_weight"):
                            logger.warning(
                                "Adding batches with column "
                                "`weights` to this buffer while "
                                "providing weights as a call argument "
                                "to the add method results in the "
                                "column being overwritten."
                            )

                    kwargs = {"weight": weight, **kwargs}
                else:
                    if "weight" in kwargs:
                        if log_once("lockstep_no_weight_allowed"):
                            logger.warning(
                                "Settings weights for batches in "
                                "lockstep mode is not allowed."
                                "Weights are being ignored."
                            )

                    kwargs = {**kwargs, "weight": None}
                self.replay_buffers[policy_id].add(time_slice, **kwargs)
        else:
            self.replay_buffers[policy_id].add(batch, **kwargs)
Exemple #16
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_tuples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if policy_id not in self.policies:
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                policy = self.workers.local_worker().get_policy(policy_id)
                policy._debug_vars()
                tuples = policy._get_loss_inputs_dict(
                    batch, shuffle=self.shuffle_sequences)
                data_keys = list(policy._loss_input_dict_no_rnn.values())
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            fetches = {}
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    batch_fetches_all_towers = []
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)

                        batch_fetches_all_towers.append(
                            tree.map_structure_with_path(
                                lambda p, *s: all_tower_reduce(p, *s),
                                *(batch_fetches["tower_{}".format(tower_num)]
                                  for tower_num in range(len(self.devices)))))

                # Reduce mean across all minibatch SGD steps (axis=0 to keep
                # all shapes as-is).
                fetches[policy_id] = tree.map_structure(
                    lambda *s: np.nanmean(s, axis=0),
                    *batch_fetches_all_towers)

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = fetches
        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.policies))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())
        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, fetches
Exemple #17
0
def timeslice_along_seq_lens_with_overlap(
    sample_batch: SampleBatchType,
    seq_lens: Optional[List[int]] = None,
    zero_pad_max_seq_len: int = 0,
    pre_overlap: int = 0,
    zero_init_states: bool = True,
) -> List["SampleBatch"]:
    """Slices batch along `seq_lens` (each seq-len item produces one batch).

    Args:
        sample_batch: The SampleBatch to timeslice.
        seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
            at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
        zero_pad_max_seq_len: If >0, already zero-pad the resulting
            slices up to this length. NOTE: This max-len will include the
            additional timesteps gained via setting pre_overlap (see Example).
        pre_overlap: If >0, will overlap each two consecutive slices by
            this many timesteps (toward the left side). This will cause
            zero-padding at the very beginning of the batch.
        zero_init_states: Whether initial states should always be
            zero'd. If False, will use the state_outs of the batch to
            populate state_in values.

    Returns:
        List[SampleBatch]: The list of (new) SampleBatches.

    Examples:
        assert seq_lens == [5, 5, 2]
        assert sample_batch.count == 12
        # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
        slices = timeslice_along_seq_lens_with_overlap(
            sample_batch=sample_batch.
            zero_pad_max_seq_len=10,
            pre_overlap=3)
        # Z = zero padding (at beginning or end).
        #             |pre (3)|     seq     | max-seq-len (up to 10)
        # slices[0] = | Z Z Z |  0  1 2 3 4 | Z Z
        # slices[1] = | 2 3 4 |  5  6 7 8 9 | Z Z
        # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
        # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
        #  count (makes sure each slice has exactly length 10).
    """
    if seq_lens is None:
        seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
    else:
        if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
            "overriding_sequencing_information"
        ):
            logger.warning(
                "Found sequencing information in a batch that will be "
                "ignored when slicing. Ignore this warning if you know "
                "what you are doing."
            )

    if seq_lens is None:
        max_seq_len = zero_pad_max_seq_len - pre_overlap
        if log_once("no_sequence_lengths_available_for_time_slicing"):
            logger.warning(
                "Trying to slice a batch along sequences without "
                "sequence lengths being provided in the batch. Batch will "
                "be sliced into slices of size "
                "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
                    max_seq_len, zero_pad_max_seq_len, pre_overlap
                )
            )
        num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
        seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
            [last_seq_len] if last_seq_len else []
        )

    assert (
        seq_lens is not None and len(seq_lens) > 0
    ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
    # Generate n slices based on seq_lens.
    start = 0
    slices = []
    for seq_len in seq_lens:
        pre_begin = start - pre_overlap
        slice_begin = start
        end = start + seq_len
        slices.append((pre_begin, slice_begin, end))
        start += seq_len

    timeslices = []
    for begin, slice_begin, end in slices:
        zero_length = None
        data_begin = 0
        zero_init_states_ = zero_init_states
        if begin < 0:
            zero_length = pre_overlap
            data_begin = slice_begin
            zero_init_states_ = True
        else:
            eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
            is_last_episode_ids = eps_ids == eps_ids[-1]
            if not is_last_episode_ids[0]:
                zero_length = int(sum(1.0 - is_last_episode_ids))
                data_begin = begin + zero_length
                zero_init_states_ = True

        if zero_length is not None:
            data = {
                k: np.concatenate(
                    [
                        np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
                        v[data_begin:end],
                    ]
                )
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }
        else:
            data = {
                k: v[begin:end]
                for k, v in sample_batch.items()
                if k != SampleBatch.SEQ_LENS
            }

        if zero_init_states_:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = np.zeros_like(sample_batch[key][0:1])
                # Del state_out_n from data if exists.
                data.pop("state_out_{}".format(i), None)
                i += 1
                key = "state_in_{}".format(i)
        # TODO: This will not work with attention nets as their state_outs are
        #  not compatible with state_ins.
        else:
            i = 0
            key = "state_in_{}".format(i)
            while key in data:
                data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
                del data["state_out_{}".format(i)]
                i += 1
                key = "state_in_{}".format(i)

        timeslices.append(SampleBatch(data, seq_lens=[end - begin]))

    # Zero-pad each slice if necessary.
    if zero_pad_max_seq_len > 0:
        for ts in timeslices:
            ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)

    return timeslices
Exemple #18
0
    def __call__(self,
                 samples: SampleBatchType) -> (SampleBatchType, List[dict]):
        _check_sample_batch_type(samples)

        # Handle everything as if multi agent.
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                      samples.count)

        metrics = _get_shared_metrics()
        load_timer = metrics.timers[LOAD_BATCH_TIMER]
        learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
        # Load data into GPUs.
        with load_timer:
            num_loaded_tuples = {}
            for policy_id, batch in samples.policy_batches.items():
                # Not a policy-to-train.
                if policy_id not in self.local_worker.policies_to_train:
                    continue

                # Decompress SampleBatch, in case some columns are compressed.
                batch.decompress_if_needed()

                # Load the entire train batch into the Policy's only buffer
                # (idx=0). Policies only have >1 buffers, if we are training
                # asynchronously.
                num_loaded_tuples[policy_id] = self.local_worker.policy_map[
                    policy_id].load_batch_into_buffer(batch, buffer_index=0)

        # Execute minibatch SGD on loaded data.
        with learn_timer:
            fetches = {}
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                policy = self.local_worker.policy_map[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                batch_fetches_all_towers = []
                for _ in range(self.num_sgd_iter):
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        # Learn on the pre-loaded data in the buffer.
                        # Note: For minibatch SGD, the data is an offset into
                        # the pre-loaded entire train batch.
                        batch_fetches = policy.learn_on_loaded_batch(
                            permutation[batch_index] *
                            self.per_device_batch_size,
                            buffer_index=0)

                        # No towers: Single CPU.
                        if "tower_0" not in batch_fetches:
                            batch_fetches_all_towers.append(batch_fetches)
                        else:
                            batch_fetches_all_towers.append(
                                tree.map_structure_with_path(
                                    lambda p, *s: all_tower_reduce(p, *s),
                                    *(batch_fetches["tower_{}".format(
                                        tower_num)]
                                      for tower_num in range(len(self.devices))
                                      )))

                # Reduce mean across all minibatch SGD steps (axis=0 to keep
                # all shapes as-is).
                fetches[policy_id] = tree.map_structure(
                    lambda *s: np.nanmean(s, axis=0),
                    *batch_fetches_all_towers)

        load_timer.push_units_processed(samples.count)
        learn_timer.push_units_processed(samples.count)

        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
        metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
        metrics.info[LEARNER_INFO] = fetches

        if self.workers.remote_workers():
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = ray.put(self.workers.local_worker().get_weights(
                    self.local_worker.policies_to_train))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights, _get_global_vars())

        # Also update global vars of the local worker.
        self.workers.local_worker().set_global_vars(_get_global_vars())
        return samples, fetches
Exemple #19
0
    def _add_to_underlying_buffer(self, policy_id: PolicyID,
                                  batch: SampleBatchType, **kwargs) -> None:
        """Add a batch of experiences to the underlying buffer of a policy.

        If the storage unit is `timesteps`, cut the batch into timeslices
        before adding them to the appropriate buffer. Otherwise, let the
        underlying buffer decide how slice batches.

        Args:
            policy_id: ID of the policy that corresponds to the underlying
            buffer
            batch: SampleBatch to add to the underlying buffer
            ``**kwargs``: Forward compatibility kwargs.
        """
        # Merge kwargs, overwriting standard call arguments
        kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args,
                                          kwargs)

        # For the storage unit `timesteps`, the underlying buffer will
        # simply store the samples how they arrive. For sequences and
        # episodes, the underlying buffer may split them itself.
        if self.storage_unit is StorageUnit.TIMESTEPS:
            timeslices = batch.timeslices(1)
        elif self.storage_unit is StorageUnit.SEQUENCES:
            timeslices = timeslice_along_seq_lens_with_overlap(
                sample_batch=batch,
                seq_lens=batch.get(SampleBatch.SEQ_LENS)
                if self.replay_sequence_override else None,
                zero_pad_max_seq_len=self.replay_sequence_length,
                pre_overlap=self.replay_burn_in,
                zero_init_states=self.replay_zero_init_states,
            )
        elif self.storage_unit == StorageUnit.EPISODES:
            timeslices = []
            for eps in batch.split_by_episode():
                if (eps.get(SampleBatch.T)[0] == 0
                        and eps.get(SampleBatch.DONES)[-1] == True  # noqa E712
                    ):
                    # Only add full episodes to the buffer
                    timeslices.append(eps)
                else:
                    if log_once("only_full_episodes"):
                        logger.info("This buffer uses episodes as a storage "
                                    "unit and thus allows only full episodes "
                                    "to be added to it. Some samples may be "
                                    "dropped.")
        elif self.storage_unit == StorageUnit.FRAGMENTS:
            timeslices = [batch]
        else:
            raise ValueError("Unknown `storage_unit={}`".format(
                self.storage_unit))

        for slice in timeslices:
            # If SampleBatch has prio-replay weights, average
            # over these to use as a weight for the entire
            # sequence.
            if self.replay_mode is ReplayMode.INDEPENDENT:
                if "weights" in slice and len(slice["weights"]):
                    weight = np.mean(slice["weights"])
                else:
                    weight = None

                if "weight" in kwargs and weight is not None:
                    if log_once("overwrite_weight"):
                        logger.warning("Adding batches with column "
                                       "`weights` to this buffer while "
                                       "providing weights as a call argument "
                                       "to the add method results in the "
                                       "column being overwritten.")

                kwargs = {"weight": weight, **kwargs}
            else:
                if "weight" in kwargs:
                    if log_once("lockstep_no_weight_allowed"):
                        logger.warning("Settings weights for batches in "
                                       "lockstep mode is not allowed."
                                       "Weights are being ignored.")

                kwargs = {**kwargs, "weight": None}
            self.replay_buffers[policy_id].add(slice, **kwargs)