Exemple #1
0
    def get_agent(self, test=False):

        if test:
            self.kwargs['test'] = True

        if self.use_ray:
            # get agent specifications from runner box
            runner_box = RunnerBox.remote(
                Agent,
                self.model,
                self.env_instance,
                runner_position=0,
                returns=self.returns,
                **self.kwargs,
            )
            agent_kwargs = ray.get(runner_box.get_agent_kwargs.remote())
        else:
            # get agent specifications from runner box
            runner_box = RunnerBox(
                Agent,
                self.model,
                self.env_instance,
                runner_position=0,
                returns=self.returns,
                **self.kwargs,
            )
            agent_kwargs = runner_box.get_agent_kwargs()

        agent = Agent(self.model, **agent_kwargs)
        if test:
            self.kwargs['test'] = False

        return agent
Exemple #2
0
    def get_data_no_ray(self, total_steps):
        self.reset_data()

        if total_steps is not None:
            old_steps = self.total_steps
            self.total_steps = total_steps

        not_done = True
        # create list of runnor boxes
        runner_box = RunnerBox(
            Agent,
            self.model,
            self.env_instance,
            runner_position=0,
            returns=self.returns,
            **self.kwargs,
        )

        # initial processes
        if self.run_episodes:
            run = lambda: runner_box.run_n_episodes(self.runner_steps)
        else:
            run = lambda: runner_box.run_n_steps(self.runner_steps)

        # run as long as not yet reached number of total steps
        while not_done:
            result, _ = run()
            not_done = self._store([result])

        if total_steps is not None:
            self.total_steps = old_steps

        return self.data
Exemple #3
0
    def get_agent(self, test=False):

        if test:
            old_e = self.kwargs["epsilon"]
            old_t = self.kwargs["temperature"]
            self.kwargs["epsilon"] = 0
            self.kwargs["temperature"] = 0.0001

        # get agent specifications from runner box
        runner_box = RunnerBox.remote(
            Agent,
            self.model,
            self.env_instance,
            runner_position=0,
            returns=self.returns,
            **self.kwargs,
        )
        agent_kwargs = ray.get(runner_box.get_agent_kwargs.remote())
        agent = Agent(self.model, **agent_kwargs)

        if test:
            self.kwargs["epsilon"] = old_e
            self.kwargs["temperature"] = old_t

        return agent
Exemple #4
0
    def get_data(self, do_print=False, total_steps=None):

        if total_steps is not None:
            old_steps = self.total_steps
            self.total_steps = total_steps

        not_done = True
        # create list of runnor boxes
        runner_boxes = [
            RunnerBox.remote(
                Agent,
                self.model,
                self.env_instance,
                runner_position=i,
                returns=self.returns,
                **self.kwargs,
            )
            for i in range(self.num_parallel)
        ]
        t = 0
        # run as long as not yet reached number of total steps
        while not_done:

            if self.run_episodes:
                ready, remaining = ray.wait(
                    [b.run_n_episodes.remote(self.runner_steps) for b in runner_boxes],
                    num_returns=self.remote_min_returns,
                    timeout=self.remote_time_out,
                )
            else:
                ready, remaining = ray.wait(
                    [b.run_n_steps.remote(self.runner_steps) for b in runner_boxes],
                    num_returns=self.remote_min_returns,
                    timeout=self.remote_time_out,
                )

            # boxes returns list of tuples (data_agg, index)
            returns = ray.get(ready)
            results = []
            indexes = []
            for r in returns:
                result, index = r
                results.append(result)
                indexes.append(index)

            # store data from dones
            if do_print:
                print(f"iteration: {t}, storing results of {len(results)} runners")
            not_done = self._store(results)
            # get boxes that are alreadey done
            accesed_mapping = map(runner_boxes.__getitem__, indexes)
            dones = list(accesed_mapping)
            # concatenate dones and not dones
            runner_boxes = dones + runner_boxes
            t += 1

        if total_steps is not None:
            self.total_steps = old_steps

        return self.data
Exemple #5
0
    def get_agent(self, test=False):

        if test:
            self.kwargs["test"] = True

        if self.kwargs["special_env"]:
            actual_env_instance = self.env_instance
            self.env_instance = self.environment

        # get agent specifications from runner box
        runner_box = RunnerBox.remote(
            Agent,
            self.model,
            self.env_instance,
            runner_position=0,
            returns=self.returns,
            **self.kwargs,
        )
        agent_kwargs = ray.get(runner_box.get_agent_kwargs.remote())
        agent = Agent(self.model, **agent_kwargs)

        if test:
            self.kwargs["test"] = False

        if self.kwargs["special_env"]:
            self.env_instance = actual_env_instance

        return agent
Exemple #6
0
    def get_data(self, do_print=False, total_steps=None):

        self.reset_data()

        if total_steps is not None:
            old_steps = self.total_steps
            self.total_steps = total_steps

        not_done = True

        if self.kwargs["special_env"]:
            actual_env_instance = self.env_instance
            self.env_instance = self.environment

        # create list of runner boxes
        runner_boxes = [
            RunnerBox.remote(
                Agent,
                self.model,
                self.env_instance,
                runner_position=i,
                returns=self.returns,
                **self.kwargs,
            )
            for i in range(self.num_parallel)
        ]
        t = 0

        # initial processes
        if self.run_episodes:
            runner_box_list = [b.run_n_episodes.remote(self.runner_steps) for b in runner_boxes]
        else:
            runner_box_list = [b.run_n_steps.remote(self.runner_steps) for b in runner_boxes]

        # run as long as not yet reached number of total steps
        while not_done:
            if do_print:
                print('Run: ',t)
            ready, runner_box_list = ray.wait(
                runner_box_list,
                num_returns=self.remote_min_returns,
                timeout=self.remote_time_out,
                )

            # boxes returns list of tuples (data_agg, index)
            returns = ray.get(ready)
            results = []
            indexes = []
            for r in returns:
                result, index = r
                results.append(result)
                indexes.append(index)
                if do_print:
                    print(f'Avg steps of runner {index}: {len(result["state"])/self.runner_steps}')

            # store data from dones
            not_done = self._store(results)
            # get boxes that are already done
            accessed_mapping = map(runner_boxes.__getitem__, indexes)
            dones = list(accessed_mapping)
            # concatenate newly created processes to remaining runner_boxes
            if self.run_episodes:
                runner_box_list.extend([b.run_n_episodes.remote(self.runner_steps) for b in dones])
            else:
                runner_box_list.extend([b.run_n_steps.remote(self.runner_steps) for b in dones])
            t += 1

        del runner_box_list

        if total_steps is not None:
            self.total_steps = old_steps

        if self.kwargs["special_env"]:
            self.env_instance = actual_env_instance

        return self.data