예제 #1
0
class CWTO_EnvWrapper(Wrapper):
    def __init__(self,
                 work_env,
                 env_name,
                 obs_spaces,
                 action_spaces,
                 serial,
                 force_float32=True,
                 act_null_value=[0, 0],
                 obs_null_value=[0, 0],
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 fully_obs=False,
                 rand_obs=False,
                 inc_player_last_act=False,
                 max_episode_length=np.inf,
                 cont_act=False):
        env = work_env(env_name)
        super().__init__(env)
        o = self.env.reset()
        self.inc_player_last_act = inc_player_last_act
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.cont_act = cont_act
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.fully_obs = fully_obs
        self.rand_obs = rand_obs
        self.player_turn = False
        self.serial = serial
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        # self.obs_action_translator = obs_action_translator
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        dd = self.env.observation_space.shape
        obs_size = 1
        for d in dd:
            obs_size *= d
        self.obs_size = obs_size
        if serial:
            self.ser_cum_act = np.zeros(self.env.observation_space.shape)
            self.ser_counter = 0
        else:
            self.power_vec = 2**np.arange(self.obs_size)[::-1]
            self.obs_action_translator = obs_action_translator
        if len(obs_spaces) > 1:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[1]
        else:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[0]
        if len(action_spaces) > 1:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[1]
        else:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[0]

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=act_null_value[0],
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=act_null_value[1],
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=obs_null_value[0],
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=obs_null_value[1],
            force_float32=force_float32)

    def step(self, action):
        if self.player_turn:
            self.player_turn = False
            a = self.player_action_space.revert(action)
            if a.size <= 1:
                a = a.item()
            o, r, d, info = self.env.step(a)
            self.last_obs = o
            self.last_action = a
            if self.serial:
                obs = np.concatenate(
                    [np.zeros(self.last_obs_act.shape), self.last_masked_obs])
            else:
                obs = np.concatenate([self.last_obs_act, self.last_masked_obs])
            if self.inc_player_last_act:
                obs = np.append(obs, a)
            obs = self.observer_observation_space.convert(obs)
            if self.time_limit:
                if "TimeLimit.truncated" in info:
                    info["timeout"] = info.pop("TimeLimit.truncated")
                else:
                    info["timeout"] = False

            self.last_info = (info["timeout"])
            info = (False)
            if isinstance(r, float):
                r = np.dtype("float32").type(r)  # Scalar float32.
            self.last_reward = r
            # if (not d) and (self.observer_reward_shaping is not None):
            #     r = self.observer_reward_shaping(r,self.last_obs_act)
            self.curr_episode_length += 1
            if self.curr_episode_length >= self.max_episode_length:
                d = True
            self.last_done = d
            return EnvStep(obs, r, d, info)

        else:
            if not np.array_equal(action, action.astype(bool)):
                action = np.random.binomial(1, action)
            r_action = self.observer_action_space.revert(action)
            if self.serial:
                if self.fully_obs:
                    r_action = 1
                elif self.rand_obs:
                    r_action = random.randint(0, 1)
                self.ser_cum_act[self.ser_counter] = r_action
                self.ser_counter += 1
                if self.ser_counter == self.obs_size:
                    self.player_turn = True
                    self.ser_counter = 0
                    masked_obs = np.multiply(
                        np.reshape(self.ser_cum_act, self.last_obs.shape),
                        self.last_obs)
                    self.last_masked_obs = masked_obs
                    self.last_obs_act = self.ser_cum_act.copy()
                    self.ser_cum_act = np.zeros(
                        self.env.env.observation_space.shape)
                    r = self.last_reward
                    # if self.player_reward_shaping is not None:
                    #     r = self.player_reward_shaping(r, self.last_obs_act)
                    d = self.last_done
                    info = self.last_info
                    obs = np.concatenate([
                        np.reshape(self.last_obs_act, masked_obs.shape),
                        masked_obs
                    ])
                    obs = self.player_observation_space.convert(obs)
                else:
                    r = 0
                    info = (False)
                    obs = np.concatenate([
                        np.reshape(self.ser_cum_act,
                                   self.last_masked_obs.shape),
                        self.last_masked_obs
                    ])
                    if self.inc_player_last_act:
                        obs = np.append(obs, self.last_action)

                    obs = self.observer_observation_space.convert(obs)
                    d = False

            else:
                if not self.cont_act:
                    r_action = self.obs_action_translator(
                        r_action, self.power_vec, self.obs_size)
                if self.fully_obs:
                    r_action = np.ones(r_action.shape)
                elif self.rand_obs:
                    r_action = np.random.randint(0, 2, r_action.shape)
                self.player_turn = True
                self.last_obs_act = r_action
                masked_obs = np.multiply(
                    np.reshape(r_action, self.last_obs.shape), self.last_obs)
                self.last_masked_obs = masked_obs
                info = self.last_info
                r = self.last_reward
                # if self.player_reward_shaping is not None:
                #     r = self.player_reward_shaping(r, r_action)
                d = self.last_done
                obs = np.concatenate(
                    [np.reshape(r_action, masked_obs.shape), masked_obs])
                obs = self.player_observation_space.convert(obs)

            return EnvStep(obs, r, d, info)

    def reset(self):
        self.curr_episode_length = 0
        self.last_done = False
        self.last_reward = 0
        self.last_action = self.player_action_space.revert(
            self.player_action_space.null_value())
        if self.serial:
            self.ser_cum_act = np.zeros(self.env.observation_space.shape)
            self.ser_counter = 0
        self.player_turn = False
        o = self.env.reset()
        self.last_obs = o
        obs = np.concatenate([np.zeros(o.shape), np.zeros(o.shape)])
        if self.inc_player_last_act:
            obs = np.append(obs, self.last_action)
        self.last_obs_act = np.zeros(o.shape)
        self.last_masked_obs = np.zeros(o.shape)
        obs = self.observer_observation_space.convert(obs)
        return obs

    def spaces(self):
        comb_spaces = [
            EnvSpaces(observation=self.player_observation_space,
                      action=self.player_action_space),
            EnvSpaces(observation=self.observer_observation_space,
                      action=self.observer_action_space)
        ]
        return comb_spaces

    def action_space(self):
        if self.player_turn:
            return self.player_action_space
        else:
            return self.observer_action_space

    def observation_space(self):
        if self.player_turn:
            return self.player_observation_space
        else:
            return self.observer_observation_space

    def set_fully_observable(self):
        self.fully_obs = True

    def set_random_observation(self):
        self.rand_obs = True
예제 #2
0
class CWTO_EnvWrapperAtari(Wrapper):
    def __init__(self,
                 env_name,
                 window_size,
                 force_float32=True,
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 max_episode_length=np.inf,
                 add_channel=False):
        self.serial = False
        env = AtariEnv(game=env_name)
        env.metadata = None
        env.reward_range = None
        super().__init__(env)
        o = self.env.reset()
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        self.add_channel = add_channel
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=self.env.action_space.null_value(),
            force_float32=force_float32,
        )
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=self.env.observation_space.null_value(),
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.player_turn = False
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        self.obs_size = self.env.observation_space.shape
        self.window_size = window_size
        self.obs_action_translator = obs_action_translator

        player_obs_space = self.env.observation_space
        if add_channel:
            player_obs_space = IntBox(low=player_obs_space.low,
                                      high=player_obs_space.high,
                                      shape=player_obs_space.shape,
                                      dtype=player_obs_space.dtype,
                                      null_value=player_obs_space.null_value())
        player_act_space = self.env.action_space
        observer_obs_space = self.env.observation_space
        observer_act_space = Box(low=np.asarray([0.0, 0.0]),
                                 high=np.asarray([
                                     self.env.observation_space.shape[0],
                                     self.env.observation_space.shape[1]
                                 ]))

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=player_act_space.null_value(),
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=np.zeros(2),
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=player_obs_space.null_value(),
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=observer_obs_space.null_value(),
            force_float32=force_float32)

    def step(self, action):
        if self.player_turn:
            self.player_turn = False
            a = self.player_action_space.revert(action)
            if a.size <= 1:
                a = a.item()
            o, r, d, info = self.env.step(a)
            self.last_obs = o
            self.last_action = a
            obs = self.observer_observation_space.convert(o)
            if self.time_limit:
                if "TimeLimit.truncated" in info:
                    info["timeout"] = info.pop("TimeLimit.truncated")
                else:
                    info["timeout"] = False

            self.last_info = info  #(info["timeout"])
            #             info = (False)
            if isinstance(r, float):
                r = np.dtype("float32").type(r)  # Scalar float32.
            self.last_reward = r
            self.curr_episode_length += 1
            if self.curr_episode_length >= self.max_episode_length:
                d = True
            self.last_done = d
            return EnvStep(obs, r, d, info)

        else:
            r_action = self.observer_action_space.revert(action)
            r_action = self.obs_action_translator(r_action, self.window_size,
                                                  self.obs_size)
            self.player_turn = True
            self.last_obs_act = r_action
            masked_obs = np.multiply(r_action, self.last_obs)
            info = self.last_info
            r = self.last_reward
            d = self.last_done
            if self.add_channel:
                masked_obs = np.concatenate([r_action, masked_obs], axis=0)
            else:
                masked_obs[r_action == 0] = -1
            obs = self.player_observation_space.convert(masked_obs)

            return EnvStep(obs, r, d, info)

    def reset(self):
        self.curr_episode_length = 0
        self.last_done = False
        self.last_reward = 0
        self.last_action = self.player_action_space.revert(
            self.player_action_space.null_value())
        self.player_turn = False
        o = self.env.reset()
        self.last_obs = o
        self.last_obs_act = np.zeros(o.shape)
        obs = self.observer_observation_space.convert(o)
        return obs

    def spaces(self):
        comb_spaces = [
            EnvSpaces(observation=self.player_observation_space,
                      action=self.player_action_space),
            EnvSpaces(observation=self.observer_observation_space,
                      action=self.observer_action_space)
        ]
        return comb_spaces

    def action_space(self):
        if self.player_turn:
            return self.player_action_space
        else:
            return self.observer_action_space

    def observation_space(self):
        if self.player_turn:
            return self.player_observation_space
        else:
            return self.observer_observation_space

    def seed(seed1=0, seed2=0):
        return