Exemplo n.º 1
0
    def add(self, samples: SampleBatch):
        """Add a SampleBatch to storage.

        Optimized to avoid several queries for large sample batches.

        Args:
            samples: The sample batch
        """
        if samples.count >= self._maxsize:
            samples = samples.slice(samples.count - self._maxsize, None)
            end_idx = 0
            assign = [(slice(0, self._maxsize), samples)]
        else:
            start_idx = self._next_idx
            end_idx = (self._next_idx + samples.count) % self._maxsize
            if end_idx < start_idx:
                tailcount = self._maxsize - start_idx
                assign = [
                    (slice(start_idx, None), samples.slice(0, tailcount)),
                    (slice(end_idx), samples.slice(tailcount, None)),
                ]
            else:
                assign = [(slice(start_idx, end_idx), samples)]

        for field in self.fields:
            for slc, smp in assign:
                self._storage[field.name][slc] = smp[field.name]

        self._next_idx = end_idx
        self._curr_size = min(self._curr_size + samples.count, self._maxsize)
Exemplo n.º 2
0
    def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        batch_count = batch.policy_batches[self.policy_id_to_count_for].count
        if self.drop_samples_for_other_agents:
            batch = MultiAgentBatch(policy_batches={
                self.policy_id_to_count_for:
                batch.policy_batches[self.policy_id_to_count_for]
            },
                                    env_steps=batch.policy_batches[
                                        self.policy_id_to_count_for].count)

        self.buffer.append(batch)
        self.count += batch_count

        if self.count >= self.min_batch_size:
            if self.count > self.min_batch_size * 2:
                logger.info("Collected more training samples than expected "
                            "(actual={}, expected={}). ".format(
                                self.count, self.min_batch_size) +
                            "This may be because you have many workers or "
                            "long episodes in 'complete_episodes' batch mode.")
            out = SampleBatch.concat_samples(self.buffer)
            timer = _get_shared_metrics().timers[SAMPLE_TIMER]
            timer.push(time.perf_counter() - self.batch_start_time)
            timer.push_units_processed(self.count)
            self.batch_start_time = None
            self.buffer = []
            self.count = 0
            return [out]
        return []
Exemplo n.º 3
0
    def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        if self.done:
            # Warmup phase done, simply return batch
            return [batch]

        metrics = _get_shared_metrics()
        timesteps_total = metrics.counters[STEPS_SAMPLED_COUNTER]
        self.buffer.append(batch)
        self.count += batch.count
        assert self.count == timesteps_total

        if timesteps_total < self.learning_starts:
            # Return emtpy if still in warmup
            return []

        # Warmup just done
        if self.count > self.learning_starts * 2:
            logger.info(  # pylint:disable=logging-fstring-interpolation
                "Collected more training samples than expected "
                f"(actual={self.count}, expected={self.learning_starts}). "
                "This may be because you have many workers or "
                "long episodes in 'complete_episodes' batch mode.")
        out = SampleBatch.concat_samples(self.buffer)
        self.buffer = []
        self.count = 0
        self.done = True
        return [out]
    def replay(self) -> SampleBatchType:
        """If this buffer was given a fake batch, return it, otherwise return
        a MultiAgentBatch with samples.
        """
        if self._fake_batch:
            fake_batch = SampleBatch(self._fake_batch)
            return MultiAgentBatch({
                DEFAULT_POLICY_ID: fake_batch
            }, fake_batch.count)

        if self.num_added < self.replay_starts:
            return None
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == "lockstep":
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)
Exemplo n.º 5
0
 def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch:
     """Transition batch corresponding with the given indexes."""
     batch = {
         k: self._storage[k][idxes]
         for k in (f.name for f in self.fields)
     }
     return SampleBatch(batch)
Exemplo n.º 6
0
    def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
        """If this buffer was given a fake batch, return it, otherwise return
        a MultiAgentBatch with samples.
        """
        if self._fake_batch:
            if not isinstance(self._fake_batch, MultiAgentBatch):
                self._fake_batch = SampleBatch(
                    self._fake_batch).as_multi_agent()
            return self._fake_batch

        if self.num_added < self.replay_starts:
            return None
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == "lockstep":
                assert (
                    policy_id is None
                ), "`policy_id` specifier not allowed in `locksetp` mode!"
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            elif policy_id is not None:
                return self.replay_buffers[policy_id].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)
Exemplo n.º 7
0
 def aggregate_into_larger_batch():
     if (sum(b.count for b in self.batch_being_built) >=
             self.config["train_batch_size"]):
         batch_to_add = SampleBatch.concat_samples(
             self.batch_being_built)
         self.batches_to_place_on_learner.append(batch_to_add)
         self.batch_being_built = []
Exemplo n.º 8
0
    def improve_policy(self, num_improvements: int) -> Dict[str, float]:
        """Call the policy to perform policy improvement using the augmented replay.

        Args:
            num_improvements: Number of times to call `policy.learn_on_batch`

        Returns:
            A dictionary of training and exploration statistics
        """
        policy = self.get_policy()
        batch_size = self.config["train_batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        stats = {}
        for _ in range(num_improvements):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            stats = get_learner_stats(policy.learn_on_batch(batch))
            self.tracker.num_steps_trained += batch.count

        stats.update(policy.get_exploration_info())
        return stats
Exemplo n.º 9
0
    def generate_virtual_sample_batch(self,
                                      samples: SampleBatch) -> SampleBatch:
        """Rollout model with latest policy.

        Produces samples for populating the virtual buffer, hence no gradient
        information is retained.

        If a transition is terminal, the next transition, if any, is generated from
        the initial state passed through `samples`.

        Args:
            samples: the transitions to extract initial states from

        Returns:
            A batch of transitions sampled from the model
        """
        virtual_samples = []
        obs = init_obs = self.convert_to_tensor(samples[SampleBatch.CUR_OBS])

        rollout_length = round(self.rollout_schedule(self.global_timestep))
        for _ in range(rollout_length):
            model = self.rng.choice(self.elite_models)

            action, _ = self.module.actor.sample(obs)
            next_obs, _ = model.sample(model(obs, action))
            reward = self.reward_fn(obs, action, next_obs)
            done = self.termination_fn(obs, action, next_obs)

            transition = {
                SampleBatch.CUR_OBS: obs,
                SampleBatch.ACTIONS: action,
                SampleBatch.NEXT_OBS: next_obs,
                SampleBatch.REWARDS: reward,
                SampleBatch.DONES: done,
            }
            virtual_samples += [
                SampleBatch(
                    {k: v.cpu().numpy()
                     for k, v in transition.items()})
            ]
            obs = torch.where(done.unsqueeze(-1), init_obs, next_obs)

        return SampleBatch.concat_samples(virtual_samples)
Exemplo n.º 10
0
def fake_batch(obs_space, action_space, batch_size=1):
    """Create a fake SampleBatch compatible with Policy.learn_on_batch."""
    samples = {
        SampleBatch.CUR_OBS: fake_space_samples(obs_space, batch_size),
        SampleBatch.ACTIONS: fake_space_samples(action_space, batch_size),
        SampleBatch.REWARDS: np.random.randn(batch_size),
        SampleBatch.NEXT_OBS: fake_space_samples(obs_space, batch_size),
        SampleBatch.DONES: np.random.randn(batch_size) > 0,
    }
    return SampleBatch(samples)
Exemplo n.º 11
0
    def transition_dataset(trajs: list[SampleBatch]) -> TensorDataset:
        """Convert a list of trajectories into a transition tensor dataset."""
        transitions = SampleBatch.concat_samples(trajs)

        dataset = TensorDataset(
            torch.from_numpy(transitions[SampleBatch.CUR_OBS]),
            torch.from_numpy(transitions[SampleBatch.ACTIONS]),
            torch.from_numpy(transitions[SampleBatch.NEXT_OBS]),
        )
        assert len(dataset) == transitions.count
        return dataset
Exemplo n.º 12
0
 def _train_dual_policies(self, samples: SampleBatch):
     learner_stats = {"learner_stats": {}}
     for policy_n, policy in enumerate(self.algorithms):
         if policy_n in self.DUAL_POLICIES:
             logger.debug(f"train policy {policy}")
             samples_copy = samples.copy()
             samples_copy = self._modify_batch_for_policy(
                 policy_n, samples_copy)
             learner_stats_one_policy = policy.learn_on_batch(samples_copy)
             learner_stats["learner_stats"][
                 f"algo{policy_n}"] = learner_stats_one_policy
     return learner_stats
Exemplo n.º 13
0
def group_batch_episodes(samples: SampleBatch) -> SampleBatch:
    """Return the sample batch with rows grouped by episode id.

    Moreover, rows are sorted by timestep.

    Warning:
        Modifies the sample batch in-place
    """
    # Assume "t" is the timestep key in the sample batch
    sorted_timestep_idxs = np.argsort(samples["t"])
    for key, val in samples.items():
        samples[key] = val[sorted_timestep_idxs]

    # Stable sort is important so that we don't alter the order
    # of timesteps
    sorted_episode_idxs = np.argsort(samples[SampleBatch.EPS_ID],
                                     kind="stable")
    for key, val in samples.items():
        samples[key] = val[sorted_episode_idxs]

    return samples
Exemplo n.º 14
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += 1
     if self.count >= self.num_episodes:
         out = SampleBatch.concat_samples(self.buffer)
         timer = _get_shared_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Exemplo n.º 15
0
    def update_policy(self, times: int) -> StatDict:
        batch_size = self.config["batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        for _ in range(times):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            batch = self.lazy_tensor_dict(batch)
            info = self.improve_policy(batch)

        return info
Exemplo n.º 16
0
    def _learn_on_policy(self, samples: SampleBatch) -> dict:
        """Update on-policy components."""
        batch = self.lazy_tensor_dict(samples)
        episodes = [
            self.lazy_tensor_dict(s) for s in samples.split_by_episode()
        ]

        with self.optimizers.optimize("on_policy"):
            loss, info = self.loss_actor(episodes)
            kl_div = self._avg_kl_divergence(batch)
            loss = loss + kl_div * self.curr_kl_coeff
            loss.backward()

        info.update(self.extra_grad_info(batch, on_policy=True))
        info.update(self.update_kl_coeff(samples))
        return info
Exemplo n.º 17
0
    def sample(
            self,
            num_items: int,
            policy_id: Optional[PolicyID] = None) -> Optional[SampleBatchType]:
        """Samples a batch of size `num_items` from a policy's buffer

        If this buffer was given a fake batch, return it, otherwise
        return a MultiAgentBatch with samples. If less than `num_items`
        records are in the policy's buffer, some samples in
        the results may be repeated to fulfil the batch size (`num_items`)
        request.

        Args:
            num_items: Number of items to sample from a policy's buffer.
            policy_id: ID of the policy that created the experiences we sample

        Returns:
            Concatenated batch of items. None if buffer is empty.
        """
        if self._fake_batch:
            if not isinstance(self._fake_batch, MultiAgentBatch):
                self._fake_batch = SampleBatch(
                    self._fake_batch).as_multi_agent()
            return self._fake_batch

        if self._num_added < self.replay_starts:
            return None
        with self.replay_timer:
            # Lockstep mode: Sample from all policies at the same time an
            # equal amount of steps.
            if self.replay_mode == "lockstep":
                assert (
                    policy_id is None
                ), "`policy_id` specifier not allowed in `locksetp` mode!"
                return self.replay_buffers[_ALL_POLICIES].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            elif policy_id is not None:
                return self.replay_buffers[policy_id].sample(
                    self.replay_batch_size, beta=self.prioritized_replay_beta)
            else:
                samples = {}
                for policy_id, replay_buffer in self.replay_buffers.items():
                    samples[policy_id] = replay_buffer.sample(
                        self.replay_batch_size,
                        beta=self.prioritized_replay_beta)
                return MultiAgentBatch(samples, self.replay_batch_size)
Exemplo n.º 18
0
def test_getitem(numpy_replay: NumpyReplayBuffer, sample_batch: SampleBatch,
                 idx):
    replay = numpy_replay

    batch = replay[idx]
    assert isinstance(batch, dict)
    assert all([
        np.allclose(batch[k], sample_batch[k][idx])
        for k in sample_batch.keys()
    ])

    mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0)
    std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0)
    replay.update_obs_stats()
    batch = replay[idx]
    for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS:
        expected = (sample_batch[key][idx] - mean) / (std + 1e-7)
        assert np.allclose(batch[key], expected)
Exemplo n.º 19
0
def test_getitem(filled_replay: NumpyReplayBuffer, sample_batch: SampleBatch,
                 idx):
    replay = filled_replay

    batch = replay[idx]
    assert isinstance(batch, dict)
    assert all([
        np.allclose(batch[k], sample_batch[k][idx])
        for k in sample_batch.keys()
    ])

    mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0)
    std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0)
    std[std < 1e-12] = 1.0

    replay.compute_stats = True
    batch = replay[idx]
    for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS:
        expected = (sample_batch[key][idx] - mean) / std
        assert np.allclose(batch[key], expected)
Exemplo n.º 20
0
 def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch:
     """Sample a batch of experiences corresponding to the given indexes."""
     self._num_sampled += len(idxes)
     data = self._encode_sample(idxes)
     return SampleBatch(dict(zip([f.name for f in self.fields], data)))
Exemplo n.º 21
0
 def sample(self, batch_size: int) -> SampleBatch:
     """Transition batch uniformly sampled with replacement."""
     return SampleBatch(self[self.sample_idxes(batch_size)])
Exemplo n.º 22
0
 def all_samples(self) -> SampleBatch:
     """All stored transitions."""
     return SampleBatch(self[:len(self)])
Exemplo n.º 23
0
 def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch:
     self._num_sampled += len(idxes)
     data = self._encode_sample(idxes)
     return SampleBatch(dict(zip([f.name for f in self.fields], data)))
Exemplo n.º 24
0
 def all_samples(self) -> SampleBatch:
     """All stored transitions."""
     return SampleBatch({
         k: self._storage[k][:len(self)]
         for k in (f.name for f in self.fields)
     })
    def _initialize_loss(self):
        def fake_array(tensor, none_shape):
            shape = tensor.shape.as_list()
            non_none_shape = [s for s in shape if s is not None]
            none_shape = none_shape if isinstance(none_shape, list) else [none_shape]
            shape = none_shape + non_none_shape
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        T = self.config["model"]["max_seq_len"]
        B = self.config["train_batch_size"] // T
        dummy_batch = {
            SampleBatch.CUR_OBS: fake_array(self._obs_input, B * T),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input, B * T),
            SampleBatch.DONES: np.array([False] * B * T, dtype=np.bool),
            SampleBatch.ACTIONS: fake_array(
                ModelCatalog.get_action_placeholder(self.action_space), B * T
            ),
            SampleBatch.REWARDS: np.array([0] * B * T, dtype=np.float32),
            SampleBatch.INFOS: np.array([self.sample_info] * B * T),
        }
        if self._obs_include_prev_action_reward:
            dummy_batch.update(
                {
                    SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input, B * T),
                    SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input, B * T),
                }
            )

        state_init = self.get_initial_state()
        state_batches = []
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.repeat(
                np.expand_dims(h, 0), B * T, 0
            )
            dummy_batch["state_out_{}".format(i)] = np.repeat(
                np.expand_dims(h, 0), B * T, 0
            )
            state_batches.append(np.repeat(np.expand_dims(h, 0), B * T, 0))
        if state_init:
            dummy_batch["seq_lens"] = np.array([T] * B * T, dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v, B * T)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())

        postprocessed_batch = self.postprocess_trajectory(SampleBatch(dummy_batch))

        # model forward pass for the loss (needed after postprocess to
        # overwrite any tensor state from that call)
        self.model(self._input_dict, self._state_in, self._seq_lens)

        if self._obs_include_prev_action_reward:
            train_batch = UsageTrackingDict(
                {
                    SampleBatch.PREV_ACTIONS: self._prev_action_input,
                    SampleBatch.PREV_REWARDS: self._prev_reward_input,
                    SampleBatch.CUR_OBS: self._obs_input,
                }
            )
            loss_inputs = [
                (SampleBatch.PREV_ACTIONS, self._prev_action_input),
                (SampleBatch.PREV_REWARDS, self._prev_reward_input),
                (SampleBatch.CUR_OBS, self._obs_input),
            ]
        else:
            train_batch = UsageTrackingDict({SampleBatch.CUR_OBS: self._obs_input})
            loss_inputs = [
                (SampleBatch.CUR_OBS, self._obs_input),
            ]

        for k, v in postprocessed_batch.items():
            if k in train_batch:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            elif k == "seq_lens" or k.startswith("state_in_"):
                continue
            shape = (None,) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            train_batch[k] = placeholder

        for i, si in enumerate(self._state_in):
            train_batch["state_in_{}".format(i)] = si
        train_batch["seq_lens"] = self._seq_lens

        if log_once("loss_init"):
            logger.debug(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(train_batch)
                )
            )

        self._loss_input_dict = train_batch
        loss = self._do_loss_init(train_batch)
        for k in sorted(train_batch.accessed_keys):
            if k != "seq_lens" and not k.startswith("state_in_"):
                loss_inputs.append((k, train_batch[k]))

        TFPolicy._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(
                self._grad_stats_fn(self, train_batch, self._grads)
            )
        self._sess.run(tf.global_variables_initializer())