示例#1
0
    def step(self, policy_output):
        assert policy_output.n_elems() == self.envs_running.size()[0]
        outputs = policy_output.unfold()
        alls = []
        env_run = {}
        for b in range(len(outputs)):
            idx_env = self.envs_running[b]
            action = policy_output["action"][b]
            last_action = action
            if (isinstance(self.gym_envs[0].action_space,
                           gym.spaces.Discrete)):
                action = action.item()
                last_action = last_action.unsqueeze(0)
            else:
                action = action.tolist()
                last_action = last_action.unsqueeze(0)

            initial_state = torch.tensor([False])
            act = action

            frame, reward, done, unused_info = self.gym_envs[idx_env].step(act)
            reward = torch.tensor([reward])
            frame = format_frame(frame)
            if isinstance(frame, torch.Tensor):
                frame = {"frame": frame}
            if not done:
                env_run[b] = idx_env

            done = torch.tensor([done])
            r = DictTensor({
                "reward": reward,
                "done": done,
                "initial_state": initial_state,
                "last_action": last_action,
                **frame,
            })
            alls.append(r)

        d = DictTensor.cat(alls)

        keys = []
        values = []
        for key, value in env_run.items():
            keys.append(key)
            values.append(value)
        dd = d.index(torch.tensor(keys).long())
        old_envs_running = self.envs_running
        self.envs_running = torch.tensor(values)
        return (d, old_envs_running), (dd, self.envs_running)
示例#2
0
    def reset(self, env_info=DictTensor({})):
        N = self.n_envs()
        self.envs_running = torch.arange(N)
        reward = torch.zeros(N)

        last_action = None
        if (isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete)):
            last_action = torch.zeros(N, dtype=torch.int64)
        else:
            a = self.gym_envs[0].action_space.sample()
            a = torch.tensor(a).unsqueeze(0).repeat(N, 1)
            last_action = a

        done = torch.zeros(N).bool()
        initial_state = torch.ones(N).bool()
        frames = None
        if (env_info.empty()):
            frames = [format_frame(e.reset()) for e in self.gym_envs]
        else:
            frames = []
            for n in range(len(self.gym_envs)):
                v = {k: env_info[k][n].tolist() for k in env_info.keys()}
                frames.append(format_frame(self.gym_envs[n].reset(env_info=v)))
        _frames = []
        for f in frames:
            if isinstance(f, torch.Tensor):
                _frames.append({"frame": f})
            else:
                _frames.append(f)
        frames = [DictTensor(_f) for _f in _frames]
        frames = DictTensor.cat(frames)
        frames.set("reward", reward)
        frames.set("done", done)
        frames.set("initial_state", initial_state)
        frames.set("last_action", last_action)
        return frames, self.envs_running
示例#3
0
    def step(self, policy_output):
        assert policy_output.n_elems() == self.n_envs()
        outputs = policy_output.unfold()
        alls = []
        alls_after = []
        env_run = {}
        for b in range(len(outputs)):
            action = policy_output["action"][b]
            last_action = action
            if (isinstance(self.gym_envs[0].action_space,
                           gym.spaces.Discrete)):
                action = action.item()
                last_action = last_action.unsqueeze(0)
            else:
                action = action.tolist()
                last_action = last_action.unsqueeze(0)

            initial_state = torch.tensor([False])
            act = action

            frame, reward, done, unused_info = self.gym_envs[b].step(act)
            reward = torch.tensor([reward])
            frame = format_frame(frame)
            if isinstance(frame, torch.Tensor):
                frame = {"frame": frame}

            done = torch.tensor([done])
            r = DictTensor({
                "reward": reward,
                "done": done,
                "initial_state": initial_state,
                "last_action": last_action,
                **frame,
            })
            alls.append(r)

            if done:
                if "set" in dir(self.gym_envs[b]):
                    self.gym_envs[b].set(self.env_info[b])

                if self.env_info.empty():
                    frame = self.gym_envs[b].reset()
                else:
                    v = {k: env_info[k][b].tolist() for k in env_info.keys()}
                    frame = self.gym_envs[b].reset(env_info=v)

                frame = format_frame(frame)
                if isinstance(frame, torch.Tensor):
                    frame = {"frame": frame}

                last_action = None
                if (isinstance(self.gym_envs[0].action_space,
                               gym.spaces.Discrete)):
                    last_action = torch.zeros(1, dtype=torch.int64)
                else:
                    a = self.gym_envs[0].action_space.sample()
                    a = torch.tensor([a])
                    last_action = a

                initial_state = torch.tensor([True])
                reward = torch.tensor([0.0])
                r = DictTensor({
                    "reward": reward,
                    "done": done,
                    "initial_state": initial_state,
                    "last_action": last_action,
                    **frame,
                })
                alls_after.append(r)
            else:
                alls_after.append(r)

        next_observation = DictTensor.cat(alls)
        next_observation_next_slot = DictTensor.cat(alls_after)
        return (
            (next_observation, torch.arange(self.n_envs())),
            (next_observation_next_slot, torch.arange(self.n_envs())),
        )