def compute_next_values(self, next_states): """ Compute Q(S, B) with a single forward pass. S: set of states B: set of budgets (discretised) :param next_states: batch of next state :return: Q values at next states """ logger.debug("-Forward pass") # Compute the cartesian product sb of all next states s with all budgets b ss = next_states.squeeze().repeat((1, len(self.betas_for_discretisation))) \ .view((len(next_states) * len(self.betas_for_discretisation), self._value_network.size_state)) bb = torch.from_numpy(self.betas_for_discretisation).float().unsqueeze(1).to(device=self.device) bb = bb.repeat((len(next_states), 1)) sb = torch.cat((ss, bb), dim=1).unsqueeze(1) # To avoid spikes in memory, we actually split the batch in several minibatches batch_sizes = near_split(x=len(sb), num_bins=self.config["split_batches"]) q_values = [] for minibatch in range(self.config["split_batches"]): mini_batch = sb[sum(batch_sizes[:minibatch]):sum(batch_sizes[:minibatch + 1])] q_values.append(self._value_network(mini_batch)) torch.cuda.empty_cache() return torch.cat(q_values).detach().cpu().numpy()
def run_batched_episodes(self): """ Alternatively, - run multiple sample-collection jobs in parallel - update model """ episode = 0 episode_duration = 14 # TODO: use a fixed number of samples instead batch_sizes = near_split(self.num_episodes * episode_duration, size_bins=self.agent.config["batch_size"]) self.agent.reset() for batch, batch_size in enumerate(batch_sizes): logger.info( "[BATCH={}/{}]---------------------------------------".format( batch + 1, len(batch_sizes))) logger.info( "[BATCH={}/{}][run_batched_episodes] #samples={}".format( batch + 1, len(batch_sizes), len(self.agent.memory))) logger.info( "[BATCH={}/{}]---------------------------------------".format( batch + 1, len(batch_sizes))) # Save current agent model_path = self.save_agent_model(identifier=batch) # Prepare workers env_config, agent_config = serialize(self.env), serialize( self.agent) cpu_processes = self.agent.config["processes"] or os.cpu_count() workers_sample_counts = near_split(batch_size, cpu_processes) workers_starts = list( np.cumsum(np.insert(workers_sample_counts[:-1], 0, 0)) + np.sum(batch_sizes[:batch])) base_seed = self.seed(batch * cpu_processes)[0] workers_seeds = [base_seed + i for i in range(cpu_processes)] workers_params = list( zip_with_singletons(env_config, agent_config, workers_sample_counts, workers_starts, workers_seeds, model_path, batch)) # Collect trajectories logger.info("Collecting {} samples with {} workers...".format( batch_size, cpu_processes)) if cpu_processes == 1: results = [Evaluation.collect_samples(*workers_params[0])] else: with Pool(processes=cpu_processes) as pool: results = pool.starmap(Evaluation.collect_samples, workers_params) trajectories = [ trajectory for worker in results for trajectory in worker ] # Fill memory for trajectory in trajectories: if trajectory[ -1].terminal: # Check whether the episode was properly finished before logging self.after_all_episodes( episode, [transition.reward for transition in trajectory]) episode += 1 [self.agent.record(*transition) for transition in trajectory] # Fit model self.agent.update()