def compute_next_values(self, next_states):
        """
            Compute Q(S, B) with a single forward pass.

            S: set of states
            B: set of budgets (discretised)
        :param next_states: batch of next state
        :return: Q values at next states
        """
        logger.debug("-Forward pass")
        # Compute the cartesian product sb of all next states s with all budgets b
        ss = next_states.squeeze().repeat((1, len(self.betas_for_discretisation))) \
            .view((len(next_states) * len(self.betas_for_discretisation), self._value_network.size_state))
        bb = torch.from_numpy(self.betas_for_discretisation).float().unsqueeze(1).to(device=self.device)
        bb = bb.repeat((len(next_states), 1))
        sb = torch.cat((ss, bb), dim=1).unsqueeze(1)

        # To avoid spikes in memory, we actually split the batch in several minibatches
        batch_sizes = near_split(x=len(sb), num_bins=self.config["split_batches"])
        q_values = []
        for minibatch in range(self.config["split_batches"]):
            mini_batch = sb[sum(batch_sizes[:minibatch]):sum(batch_sizes[:minibatch + 1])]
            q_values.append(self._value_network(mini_batch))
            torch.cuda.empty_cache()
        return torch.cat(q_values).detach().cpu().numpy()
Esempio n. 2
0
    def run_batched_episodes(self):
        """
            Alternatively,
            - run multiple sample-collection jobs in parallel
            - update model
        """
        episode = 0
        episode_duration = 14  # TODO: use a fixed number of samples instead
        batch_sizes = near_split(self.num_episodes * episode_duration,
                                 size_bins=self.agent.config["batch_size"])
        self.agent.reset()
        for batch, batch_size in enumerate(batch_sizes):
            logger.info(
                "[BATCH={}/{}]---------------------------------------".format(
                    batch + 1, len(batch_sizes)))
            logger.info(
                "[BATCH={}/{}][run_batched_episodes] #samples={}".format(
                    batch + 1, len(batch_sizes), len(self.agent.memory)))
            logger.info(
                "[BATCH={}/{}]---------------------------------------".format(
                    batch + 1, len(batch_sizes)))
            # Save current agent
            model_path = self.save_agent_model(identifier=batch)

            # Prepare workers
            env_config, agent_config = serialize(self.env), serialize(
                self.agent)
            cpu_processes = self.agent.config["processes"] or os.cpu_count()
            workers_sample_counts = near_split(batch_size, cpu_processes)
            workers_starts = list(
                np.cumsum(np.insert(workers_sample_counts[:-1], 0, 0)) +
                np.sum(batch_sizes[:batch]))
            base_seed = self.seed(batch * cpu_processes)[0]
            workers_seeds = [base_seed + i for i in range(cpu_processes)]
            workers_params = list(
                zip_with_singletons(env_config, agent_config,
                                    workers_sample_counts, workers_starts,
                                    workers_seeds, model_path, batch))

            # Collect trajectories
            logger.info("Collecting {} samples with {} workers...".format(
                batch_size, cpu_processes))
            if cpu_processes == 1:
                results = [Evaluation.collect_samples(*workers_params[0])]
            else:
                with Pool(processes=cpu_processes) as pool:
                    results = pool.starmap(Evaluation.collect_samples,
                                           workers_params)
            trajectories = [
                trajectory for worker in results for trajectory in worker
            ]

            # Fill memory
            for trajectory in trajectories:
                if trajectory[
                        -1].terminal:  # Check whether the episode was properly finished before logging
                    self.after_all_episodes(
                        episode,
                        [transition.reward for transition in trajectory])
                episode += 1
                [self.agent.record(*transition) for transition in trajectory]

            # Fit model
            self.agent.update()