Ejemplo n.º 1
0
 def aggregate_into_larger_batch():
     if (sum(b.count for b in self.batch_being_built) >=
             self.config["train_batch_size"]):
         batch_to_add = SampleBatch.concat_samples(
             self.batch_being_built)
         self.batches_to_place_on_learner.append(batch_to_add)
         self.batch_being_built = []
Ejemplo n.º 2
0
    def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        batch_count = batch.policy_batches[self.policy_id_to_count_for].count
        if self.drop_samples_for_other_agents:
            batch = MultiAgentBatch(policy_batches={
                self.policy_id_to_count_for:
                batch.policy_batches[self.policy_id_to_count_for]
            },
                                    env_steps=batch.policy_batches[
                                        self.policy_id_to_count_for].count)

        self.buffer.append(batch)
        self.count += batch_count

        if self.count >= self.min_batch_size:
            if self.count > self.min_batch_size * 2:
                logger.info("Collected more training samples than expected "
                            "(actual={}, expected={}). ".format(
                                self.count, self.min_batch_size) +
                            "This may be because you have many workers or "
                            "long episodes in 'complete_episodes' batch mode.")
            out = SampleBatch.concat_samples(self.buffer)
            timer = _get_shared_metrics().timers[SAMPLE_TIMER]
            timer.push(time.perf_counter() - self.batch_start_time)
            timer.push_units_processed(self.count)
            self.batch_start_time = None
            self.buffer = []
            self.count = 0
            return [out]
        return []
Ejemplo n.º 3
0
    def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
        _check_sample_batch_type(batch)
        if self.done:
            # Warmup phase done, simply return batch
            return [batch]

        metrics = _get_shared_metrics()
        timesteps_total = metrics.counters[STEPS_SAMPLED_COUNTER]
        self.buffer.append(batch)
        self.count += batch.count
        assert self.count == timesteps_total

        if timesteps_total < self.learning_starts:
            # Return emtpy if still in warmup
            return []

        # Warmup just done
        if self.count > self.learning_starts * 2:
            logger.info(  # pylint:disable=logging-fstring-interpolation
                "Collected more training samples than expected "
                f"(actual={self.count}, expected={self.learning_starts}). "
                "This may be because you have many workers or "
                "long episodes in 'complete_episodes' batch mode.")
        out = SampleBatch.concat_samples(self.buffer)
        self.buffer = []
        self.count = 0
        self.done = True
        return [out]
Ejemplo n.º 4
0
    def improve_policy(self, num_improvements: int) -> Dict[str, float]:
        """Call the policy to perform policy improvement using the augmented replay.

        Args:
            num_improvements: Number of times to call `policy.learn_on_batch`

        Returns:
            A dictionary of training and exploration statistics
        """
        policy = self.get_policy()
        batch_size = self.config["train_batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        stats = {}
        for _ in range(num_improvements):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            stats = get_learner_stats(policy.learn_on_batch(batch))
            self.tracker.num_steps_trained += batch.count

        stats.update(policy.get_exploration_info())
        return stats
Ejemplo n.º 5
0
    def transition_dataset(trajs: list[SampleBatch]) -> TensorDataset:
        """Convert a list of trajectories into a transition tensor dataset."""
        transitions = SampleBatch.concat_samples(trajs)

        dataset = TensorDataset(
            torch.from_numpy(transitions[SampleBatch.CUR_OBS]),
            torch.from_numpy(transitions[SampleBatch.ACTIONS]),
            torch.from_numpy(transitions[SampleBatch.NEXT_OBS]),
        )
        assert len(dataset) == transitions.count
        return dataset
Ejemplo n.º 6
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += 1
     if self.count >= self.num_episodes:
         out = SampleBatch.concat_samples(self.buffer)
         timer = _get_shared_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Ejemplo n.º 7
0
    def update_policy(self, times: int) -> StatDict:
        batch_size = self.config["batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        for _ in range(times):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            batch = self.lazy_tensor_dict(batch)
            info = self.improve_policy(batch)

        return info
Ejemplo n.º 8
0
    def generate_virtual_sample_batch(self,
                                      samples: SampleBatch) -> SampleBatch:
        """Rollout model with latest policy.

        Produces samples for populating the virtual buffer, hence no gradient
        information is retained.

        If a transition is terminal, the next transition, if any, is generated from
        the initial state passed through `samples`.

        Args:
            samples: the transitions to extract initial states from

        Returns:
            A batch of transitions sampled from the model
        """
        virtual_samples = []
        obs = init_obs = self.convert_to_tensor(samples[SampleBatch.CUR_OBS])

        rollout_length = round(self.rollout_schedule(self.global_timestep))
        for _ in range(rollout_length):
            model = self.rng.choice(self.elite_models)

            action, _ = self.module.actor.sample(obs)
            next_obs, _ = model.sample(model(obs, action))
            reward = self.reward_fn(obs, action, next_obs)
            done = self.termination_fn(obs, action, next_obs)

            transition = {
                SampleBatch.CUR_OBS: obs,
                SampleBatch.ACTIONS: action,
                SampleBatch.NEXT_OBS: next_obs,
                SampleBatch.REWARDS: reward,
                SampleBatch.DONES: done,
            }
            virtual_samples += [
                SampleBatch(
                    {k: v.cpu().numpy()
                     for k, v in transition.items()})
            ]
            obs = torch.where(done.unsqueeze(-1), init_obs, next_obs)

        return SampleBatch.concat_samples(virtual_samples)