예제 #1
0
파일: gym.py 프로젝트: zizai/rlpyt
 def __init__(self,
              env,
              act_null_value=0,
              obs_null_value=0,
              force_float32=True):
     super().__init__(env)
     o = self.env.reset()
     o, r, d, info = self.env.step(self.env.action_space.sample())
     env_ = self.env
     time_limit = isinstance(self.env, TimeLimit)
     while not time_limit and hasattr(env_, "env"):
         env_ = env_.env
         time_limit = isinstance(self.env, TimeLimit)
     if time_limit:
         info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
     self._time_limit = time_limit
     self.action_space = GymSpaceWrapper(
         space=self.env.action_space,
         name="act",
         null_value=act_null_value,
         force_float32=force_float32,
     )
     self.observation_space = GymSpaceWrapper(
         space=self.env.observation_space,
         name="obs",
         null_value=obs_null_value,
         force_float32=force_float32,
     )
     build_info_tuples(info)
예제 #2
0
class ProcgenWrapper(GymEnvWrapper):
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(self.env, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self._time_limit = time_limit
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        w = self.observation_space.space.shape[1]
        h = self.observation_space.space.shape[0]
        c = self.observation_space.space.shape[2]
        self.observation_space.space.shape = (c, h, w)
        build_info_tuples(info)

    def step(self, action):
        """Reverts the action from rlpyt format to gym format (i.e. if composite-to-
        dictionary spaces), steps the gym environment, converts the observation
        from gym to rlpyt format (i.e. if dict-to-composite), and converts the
        env_info from dictionary into namedtuple."""
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o.transpose((2, 0, 1)))
        if self._time_limit:
            if "TimeLimit.truncated" in info:
                info["timeout"] = info.pop("TimeLimit.truncated")
            else:
                info["timeout"] = False
        info = info_to_nt(info)
        if isinstance(r, float):
            r = np.dtype("float32").type(r)  # Scalar float32.
        return EnvStep(obs, r, d, info)

    def reset(self):
        """Returns converted observation from gym env reset."""
        return self.observation_space.convert(self.env.reset().transpose(
            (2, 0, 1)))

    def seed(self, seed):
        return
예제 #3
0
class GymEnvWrapper(Wrapper):
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)

        o = self.env.reset()
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)

        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(self.env, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self._time_limit = time_limit
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        build_info_tuples(info)

    def step(self, action):
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o)
        if self._time_limit:
            if "TimeLimit.truncated" in info:
                info["timeout"] = info.pop("TimeLimit.truncated")
            else:
                info["timeout"] = False
        info = info_to_nt(info)
        return EnvStep(obs, r, d, info)

    def reset(self):
        new_value = self.env.reset()
        sample = self.observation_space.convert(new_value)
        return sample

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )
예제 #4
0
    def test_seed(self):
        space = GymSpaceWrapper(
            gym.spaces.Box(low=np.zeros(1), high=np.ones(1)))
        space.seed(0)
        sample_1 = space.sample()
        space.seed(0)
        sample_2 = space.sample()
        self.assertEqual(sample_1, sample_2)

        sample_3 = space.sample()
        self.assertNotEqual(sample_1, sample_3)
예제 #5
0
class RLPytWrapper(gym.Wrapper):
    """
    Wrap the gym environment with namedtuple
    """
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)
        o = self.env.reset()
        o, r, d, info = self.env.dummy_action()
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        build_info_tuples(info)

    def step(self, action):
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o)
        info = info_to_nt(info)
        return EnvStep(obs, r, d, info)

    def reset(self):
        return self.observation_space.convert(self.env.reset())

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )
예제 #6
0
파일: gym.py 프로젝트: Xingyu-Lin/softagent
 def __init__(self,
              env,
              act_null_value=0,
              obs_null_value=0,
              force_float32=True):
     super().__init__(env)
     o = self.env.reset()
     o, r, d, info = self.env.step(self.env.action_space.sample())
     self.action_space = GymSpaceWrapper(
         space=self.env.action_space,
         name="act",
         null_value=act_null_value,
         force_float32=force_float32,
     )
     self.observation_space = GymSpaceWrapper(
         space=self.env.observation_space,
         name="obs",
         null_value=obs_null_value,
         force_float32=force_float32,
     )
     build_info_tuples(info)
예제 #7
0
파일: gym.py 프로젝트: Xingyu-Lin/softagent
class GymEnvWrapper(Wrapper):
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)
        o = self.env.reset()
        o, r, d, info = self.env.step(self.env.action_space.sample())
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        build_info_tuples(info)

    def step(self, action):
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o)
        info = info_to_nt(info)
        return EnvStep(obs, r, d, info)

    def reset(self):
        return self.observation_space.convert(self.env.reset())

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )
예제 #8
0
 def __init__(self, wrapped_env, act_null_value=0, force_float32=True):
     self._wrapped_env = wrapped_env
     action_dim = 3
     self.action_space = GymSpaceWrapper(
         space=gym.spaces.Box(low=-1,
                              high=1,
                              shape=(action_dim, ),
                              dtype=np.float32),
         name="act",
         null_value=act_null_value,
         force_float32=force_float32,
     )
     self.observation_space = Composite([
         Box(low=-np.inf, high=np.inf, shape=(64, 64, 3), dtype=np.float32),
         Box(np.array([-1] * 100), np.array([1] * 100), dtype=np.float32)
     ], OBS)
     self.spaces = EnvSpaces(observation=self.observation_space,
                             action=self.action_space)
     self._dtype = None
     self.current_location = None
예제 #9
0
파일: gym.py 프로젝트: zizai/rlpyt
class GymEnvWrapper(Wrapper):
    """Gym-style wrapper for converting the Openai Gym interface to the
    rlpyt interface.  Action and observation spaces are wrapped by rlpyt's
    ``GymSpaceWrapper``.

    Output `env_info` is automatically converted from a dictionary to a
    corresponding namedtuple, which the rlpyt sampler expects.  For this to
    work, every key that might appear in the gym environments `env_info` at
    any step must appear at the first step after a reset, as the `env_info`
    entries will have sampler memory pre-allocated for them (so they also
    cannot change dtype or shape).  (see `EnvInfoWrapper`, `build_info_tuples`,
    and `info_to_nt` in file or more help/details)

    Warning:
        Unrecognized keys in `env_info` appearing later during use will be
        silently ignored.

    This wrapper looks for gym's ``TimeLimit`` env wrapper to
    see whether to add the field ``timeout`` to env info.   
    """
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)
        o = self.env.reset()
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(self.env, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self._time_limit = time_limit
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        build_info_tuples(info)

    def step(self, action):
        """Reverts the action from rlpyt format to gym format (i.e. if composite-to-
        dictionary spaces), steps the gym environment, converts the observation
        from gym to rlpyt format (i.e. if dict-to-composite), and converts the
        env_info from dictionary into namedtuple."""
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o)
        if self._time_limit:
            if "TimeLimit.truncated" in info:
                info["timeout"] = info.pop("TimeLimit.truncated")
            else:
                info["timeout"] = False
        info = info_to_nt(info)
        if isinstance(r, float):
            r = np.dtype("float32").type(r)  # Scalar float32.
        return EnvStep(obs, r, d, info)

    def reset(self):
        """Returns converted observation from gym env reset."""
        return self.observation_space.convert(self.env.reset())

    @property
    def spaces(self):
        """Returns the rlpyt spaces for the wrapped env."""
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )
예제 #10
0
    def __init__(self,
                 work_env,
                 env_name,
                 obs_spaces,
                 action_spaces,
                 serial,
                 force_float32=True,
                 act_null_value=[0, 0],
                 obs_null_value=[0, 0],
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 fully_obs=False,
                 rand_obs=False,
                 inc_player_last_act=False,
                 max_episode_length=np.inf,
                 cont_act=False):
        env = work_env(env_name)
        super().__init__(env)
        o = self.env.reset()
        self.inc_player_last_act = inc_player_last_act
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.cont_act = cont_act
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.fully_obs = fully_obs
        self.rand_obs = rand_obs
        self.player_turn = False
        self.serial = serial
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        # self.obs_action_translator = obs_action_translator
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        dd = self.env.observation_space.shape
        obs_size = 1
        for d in dd:
            obs_size *= d
        self.obs_size = obs_size
        if serial:
            self.ser_cum_act = np.zeros(self.env.observation_space.shape)
            self.ser_counter = 0
        else:
            self.power_vec = 2**np.arange(self.obs_size)[::-1]
            self.obs_action_translator = obs_action_translator
        if len(obs_spaces) > 1:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[1]
        else:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[0]
        if len(action_spaces) > 1:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[1]
        else:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[0]

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=act_null_value[0],
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=act_null_value[1],
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=obs_null_value[0],
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=obs_null_value[1],
            force_float32=force_float32)
예제 #11
0
class CWTO_EnvWrapper(Wrapper):
    def __init__(self,
                 work_env,
                 env_name,
                 obs_spaces,
                 action_spaces,
                 serial,
                 force_float32=True,
                 act_null_value=[0, 0],
                 obs_null_value=[0, 0],
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 fully_obs=False,
                 rand_obs=False,
                 inc_player_last_act=False,
                 max_episode_length=np.inf,
                 cont_act=False):
        env = work_env(env_name)
        super().__init__(env)
        o = self.env.reset()
        self.inc_player_last_act = inc_player_last_act
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.cont_act = cont_act
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.fully_obs = fully_obs
        self.rand_obs = rand_obs
        self.player_turn = False
        self.serial = serial
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        # self.obs_action_translator = obs_action_translator
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        dd = self.env.observation_space.shape
        obs_size = 1
        for d in dd:
            obs_size *= d
        self.obs_size = obs_size
        if serial:
            self.ser_cum_act = np.zeros(self.env.observation_space.shape)
            self.ser_counter = 0
        else:
            self.power_vec = 2**np.arange(self.obs_size)[::-1]
            self.obs_action_translator = obs_action_translator
        if len(obs_spaces) > 1:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[1]
        else:
            player_obs_space = obs_spaces[0]
            observer_obs_space = obs_spaces[0]
        if len(action_spaces) > 1:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[1]
        else:
            player_act_space = action_spaces[0]
            observer_act_space = action_spaces[0]

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=act_null_value[0],
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=act_null_value[1],
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=obs_null_value[0],
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=obs_null_value[1],
            force_float32=force_float32)

    def step(self, action):
        if self.player_turn:
            self.player_turn = False
            a = self.player_action_space.revert(action)
            if a.size <= 1:
                a = a.item()
            o, r, d, info = self.env.step(a)
            self.last_obs = o
            self.last_action = a
            if self.serial:
                obs = np.concatenate(
                    [np.zeros(self.last_obs_act.shape), self.last_masked_obs])
            else:
                obs = np.concatenate([self.last_obs_act, self.last_masked_obs])
            if self.inc_player_last_act:
                obs = np.append(obs, a)
            obs = self.observer_observation_space.convert(obs)
            if self.time_limit:
                if "TimeLimit.truncated" in info:
                    info["timeout"] = info.pop("TimeLimit.truncated")
                else:
                    info["timeout"] = False

            self.last_info = (info["timeout"])
            info = (False)
            if isinstance(r, float):
                r = np.dtype("float32").type(r)  # Scalar float32.
            self.last_reward = r
            # if (not d) and (self.observer_reward_shaping is not None):
            #     r = self.observer_reward_shaping(r,self.last_obs_act)
            self.curr_episode_length += 1
            if self.curr_episode_length >= self.max_episode_length:
                d = True
            self.last_done = d
            return EnvStep(obs, r, d, info)

        else:
            if not np.array_equal(action, action.astype(bool)):
                action = np.random.binomial(1, action)
            r_action = self.observer_action_space.revert(action)
            if self.serial:
                if self.fully_obs:
                    r_action = 1
                elif self.rand_obs:
                    r_action = random.randint(0, 1)
                self.ser_cum_act[self.ser_counter] = r_action
                self.ser_counter += 1
                if self.ser_counter == self.obs_size:
                    self.player_turn = True
                    self.ser_counter = 0
                    masked_obs = np.multiply(
                        np.reshape(self.ser_cum_act, self.last_obs.shape),
                        self.last_obs)
                    self.last_masked_obs = masked_obs
                    self.last_obs_act = self.ser_cum_act.copy()
                    self.ser_cum_act = np.zeros(
                        self.env.env.observation_space.shape)
                    r = self.last_reward
                    # if self.player_reward_shaping is not None:
                    #     r = self.player_reward_shaping(r, self.last_obs_act)
                    d = self.last_done
                    info = self.last_info
                    obs = np.concatenate([
                        np.reshape(self.last_obs_act, masked_obs.shape),
                        masked_obs
                    ])
                    obs = self.player_observation_space.convert(obs)
                else:
                    r = 0
                    info = (False)
                    obs = np.concatenate([
                        np.reshape(self.ser_cum_act,
                                   self.last_masked_obs.shape),
                        self.last_masked_obs
                    ])
                    if self.inc_player_last_act:
                        obs = np.append(obs, self.last_action)

                    obs = self.observer_observation_space.convert(obs)
                    d = False

            else:
                if not self.cont_act:
                    r_action = self.obs_action_translator(
                        r_action, self.power_vec, self.obs_size)
                if self.fully_obs:
                    r_action = np.ones(r_action.shape)
                elif self.rand_obs:
                    r_action = np.random.randint(0, 2, r_action.shape)
                self.player_turn = True
                self.last_obs_act = r_action
                masked_obs = np.multiply(
                    np.reshape(r_action, self.last_obs.shape), self.last_obs)
                self.last_masked_obs = masked_obs
                info = self.last_info
                r = self.last_reward
                # if self.player_reward_shaping is not None:
                #     r = self.player_reward_shaping(r, r_action)
                d = self.last_done
                obs = np.concatenate(
                    [np.reshape(r_action, masked_obs.shape), masked_obs])
                obs = self.player_observation_space.convert(obs)

            return EnvStep(obs, r, d, info)

    def reset(self):
        self.curr_episode_length = 0
        self.last_done = False
        self.last_reward = 0
        self.last_action = self.player_action_space.revert(
            self.player_action_space.null_value())
        if self.serial:
            self.ser_cum_act = np.zeros(self.env.observation_space.shape)
            self.ser_counter = 0
        self.player_turn = False
        o = self.env.reset()
        self.last_obs = o
        obs = np.concatenate([np.zeros(o.shape), np.zeros(o.shape)])
        if self.inc_player_last_act:
            obs = np.append(obs, self.last_action)
        self.last_obs_act = np.zeros(o.shape)
        self.last_masked_obs = np.zeros(o.shape)
        obs = self.observer_observation_space.convert(obs)
        return obs

    def spaces(self):
        comb_spaces = [
            EnvSpaces(observation=self.player_observation_space,
                      action=self.player_action_space),
            EnvSpaces(observation=self.observer_observation_space,
                      action=self.observer_action_space)
        ]
        return comb_spaces

    def action_space(self):
        if self.player_turn:
            return self.player_action_space
        else:
            return self.observer_action_space

    def observation_space(self):
        if self.player_turn:
            return self.player_observation_space
        else:
            return self.observer_observation_space

    def set_fully_observable(self):
        self.fully_obs = True

    def set_random_observation(self):
        self.rand_obs = True
예제 #12
0
    def __init__(self,
                 env_name,
                 window_size,
                 force_float32=True,
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 max_episode_length=np.inf,
                 add_channel=False):
        self.serial = False
        env = AtariEnv(game=env_name)
        env.metadata = None
        env.reward_range = None
        super().__init__(env)
        o = self.env.reset()
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        self.add_channel = add_channel
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=self.env.action_space.null_value(),
            force_float32=force_float32,
        )
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=self.env.observation_space.null_value(),
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.player_turn = False
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        self.obs_size = self.env.observation_space.shape
        self.window_size = window_size
        self.obs_action_translator = obs_action_translator

        player_obs_space = self.env.observation_space
        if add_channel:
            player_obs_space = IntBox(low=player_obs_space.low,
                                      high=player_obs_space.high,
                                      shape=player_obs_space.shape,
                                      dtype=player_obs_space.dtype,
                                      null_value=player_obs_space.null_value())
        player_act_space = self.env.action_space
        observer_obs_space = self.env.observation_space
        observer_act_space = Box(low=np.asarray([0.0, 0.0]),
                                 high=np.asarray([
                                     self.env.observation_space.shape[0],
                                     self.env.observation_space.shape[1]
                                 ]))

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=player_act_space.null_value(),
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=np.zeros(2),
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=player_obs_space.null_value(),
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=observer_obs_space.null_value(),
            force_float32=force_float32)
예제 #13
0
class CWTO_EnvWrapperAtari(Wrapper):
    def __init__(self,
                 env_name,
                 window_size,
                 force_float32=True,
                 player_reward_shaping=None,
                 observer_reward_shaping=None,
                 max_episode_length=np.inf,
                 add_channel=False):
        self.serial = False
        env = AtariEnv(game=env_name)
        env.metadata = None
        env.reward_range = None
        super().__init__(env)
        o = self.env.reset()
        self.max_episode_length = max_episode_length
        self.curr_episode_length = 0
        self.add_channel = add_channel
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(env_, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self.time_limit = time_limit
        self._action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=self.env.action_space.null_value(),
            force_float32=force_float32,
        )
        self._observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=self.env.observation_space.null_value(),
            force_float32=force_float32,
        )
        del self.action_space
        del self.observation_space
        self.player_turn = False
        self.last_done = False
        self.last_reward = 0
        self.last_info = {}
        if player_reward_shaping is None:
            self.player_reward_shaping = reward_shaping_ph
        else:
            self.player_reward_shaping = player_reward_shaping
        if observer_reward_shaping is None:
            self.observer_reward_shaping = reward_shaping_ph
        else:
            self.observer_reward_shaping = observer_reward_shaping
        self.obs_size = self.env.observation_space.shape
        self.window_size = window_size
        self.obs_action_translator = obs_action_translator

        player_obs_space = self.env.observation_space
        if add_channel:
            player_obs_space = IntBox(low=player_obs_space.low,
                                      high=player_obs_space.high,
                                      shape=player_obs_space.shape,
                                      dtype=player_obs_space.dtype,
                                      null_value=player_obs_space.null_value())
        player_act_space = self.env.action_space
        observer_obs_space = self.env.observation_space
        observer_act_space = Box(low=np.asarray([0.0, 0.0]),
                                 high=np.asarray([
                                     self.env.observation_space.shape[0],
                                     self.env.observation_space.shape[1]
                                 ]))

        self.player_action_space = GymSpaceWrapper(
            space=player_act_space,
            name="act",
            null_value=player_act_space.null_value(),
            force_float32=force_float32)
        self.observer_action_space = GymSpaceWrapper(
            space=observer_act_space,
            name="act",
            null_value=np.zeros(2),
            force_float32=force_float32)
        self.player_observation_space = GymSpaceWrapper(
            space=player_obs_space,
            name="obs",
            null_value=player_obs_space.null_value(),
            force_float32=force_float32)
        self.observer_observation_space = GymSpaceWrapper(
            space=observer_obs_space,
            name="obs",
            null_value=observer_obs_space.null_value(),
            force_float32=force_float32)

    def step(self, action):
        if self.player_turn:
            self.player_turn = False
            a = self.player_action_space.revert(action)
            if a.size <= 1:
                a = a.item()
            o, r, d, info = self.env.step(a)
            self.last_obs = o
            self.last_action = a
            obs = self.observer_observation_space.convert(o)
            if self.time_limit:
                if "TimeLimit.truncated" in info:
                    info["timeout"] = info.pop("TimeLimit.truncated")
                else:
                    info["timeout"] = False

            self.last_info = info  #(info["timeout"])
            #             info = (False)
            if isinstance(r, float):
                r = np.dtype("float32").type(r)  # Scalar float32.
            self.last_reward = r
            self.curr_episode_length += 1
            if self.curr_episode_length >= self.max_episode_length:
                d = True
            self.last_done = d
            return EnvStep(obs, r, d, info)

        else:
            r_action = self.observer_action_space.revert(action)
            r_action = self.obs_action_translator(r_action, self.window_size,
                                                  self.obs_size)
            self.player_turn = True
            self.last_obs_act = r_action
            masked_obs = np.multiply(r_action, self.last_obs)
            info = self.last_info
            r = self.last_reward
            d = self.last_done
            if self.add_channel:
                masked_obs = np.concatenate([r_action, masked_obs], axis=0)
            else:
                masked_obs[r_action == 0] = -1
            obs = self.player_observation_space.convert(masked_obs)

            return EnvStep(obs, r, d, info)

    def reset(self):
        self.curr_episode_length = 0
        self.last_done = False
        self.last_reward = 0
        self.last_action = self.player_action_space.revert(
            self.player_action_space.null_value())
        self.player_turn = False
        o = self.env.reset()
        self.last_obs = o
        self.last_obs_act = np.zeros(o.shape)
        obs = self.observer_observation_space.convert(o)
        return obs

    def spaces(self):
        comb_spaces = [
            EnvSpaces(observation=self.player_observation_space,
                      action=self.player_action_space),
            EnvSpaces(observation=self.observer_observation_space,
                      action=self.observer_action_space)
        ]
        return comb_spaces

    def action_space(self):
        if self.player_turn:
            return self.player_action_space
        else:
            return self.observer_action_space

    def observation_space(self):
        if self.player_turn:
            return self.player_observation_space
        else:
            return self.observer_observation_space

    def seed(seed1=0, seed2=0):
        return