Beispiel #1
0
 def write_metadata(self):
     metadata = dict(env=serialize(self.env), agent=serialize(self.agent))
     file_infix = '{}.{}'.format(self.monitor.monitor_id, os.getpid())
     file = Path(
         self.monitor.directory) / self.METADATA_FILE.format(file_infix)
     with open(file, 'w') as f:
         json.dump(metadata, f, sort_keys=True, indent=4)
Beispiel #2
0
    def run_batched_episodes(self):
        """
            Alternatively,
            - run multiple sample-collection jobs in parallel
            - update model
        """
        episode = 0
        episode_duration = 14  # TODO: use a fixed number of samples instead
        batch_sizes = near_split(self.num_episodes * episode_duration,
                                 size_bins=self.agent.config["batch_size"])
        self.agent.reset()
        for batch, batch_size in enumerate(batch_sizes):
            logger.info(
                "[BATCH={}/{}]---------------------------------------".format(
                    batch + 1, len(batch_sizes)))
            logger.info(
                "[BATCH={}/{}][run_batched_episodes] #samples={}".format(
                    batch + 1, len(batch_sizes), len(self.agent.memory)))
            logger.info(
                "[BATCH={}/{}]---------------------------------------".format(
                    batch + 1, len(batch_sizes)))
            # Save current agent
            model_path = self.save_agent_model(identifier=batch)

            # Prepare workers
            env_config, agent_config = serialize(self.env), serialize(
                self.agent)
            cpu_processes = self.agent.config["processes"] or os.cpu_count()
            workers_sample_counts = near_split(batch_size, cpu_processes)
            workers_starts = list(
                np.cumsum(np.insert(workers_sample_counts[:-1], 0, 0)) +
                np.sum(batch_sizes[:batch]))
            base_seed = self.seed(batch * cpu_processes)[0]
            workers_seeds = [base_seed + i for i in range(cpu_processes)]
            workers_params = list(
                zip_with_singletons(env_config, agent_config,
                                    workers_sample_counts, workers_starts,
                                    workers_seeds, model_path, batch))

            # Collect trajectories
            logger.info("Collecting {} samples with {} workers...".format(
                batch_size, cpu_processes))
            if cpu_processes == 1:
                results = [Evaluation.collect_samples(*workers_params[0])]
            else:
                with Pool(processes=cpu_processes) as pool:
                    results = pool.starmap(Evaluation.collect_samples,
                                           workers_params)
            trajectories = [
                trajectory for worker in results for trajectory in worker
            ]

            # Fill memory
            for trajectory in trajectories:
                if trajectory[
                        -1].terminal:  # Check whether the episode was properly finished before logging
                    self.after_all_episodes(
                        episode,
                        [transition.reward for transition in trajectory])
                episode += 1
                [self.agent.record(*transition) for transition in trajectory]

            # Fit model
            self.agent.update()