Exemple #1
0
Fichier : env.py Projet : cg31/cule
    def __init__(self, env_name, num_envs, color_mode='rgb', device='cpu', rescale=False,
                 frameskip=1, repeat_prob=0.25, episodic_life=False, max_noop_steps=30,
                 max_episode_length=10000):
        """Initialize the ALE class with a given environment

        Args:
            env_name (str): The name of the Atari rom
            num_envs (int): The number of environments to run
            color_mode (str): RGB ('rgb') or grayscale ('gray') observations
            use_cuda (bool) : Map ALEs to GPU
            rescale (bool) : Rescale grayscale observations to 84x84
            frameskip (int) : Number of frames to skip during training
            repeat_prob (float) : Probability of repeating previous action
            clip_rewards (bool) : Apply rewards clipping to {-1,1}
            episodic_life (bool) : Set 'done' on end of life
        """

        assert (color_mode == 'rgb') or (color_mode == 'gray')
        if color_mode == 'rgb' and rescale:
            raise ValueError('Rescaling is only valid in grayscale color mode')

        self.cart = Rom(env_name)
        super(Env, self).__init__(self.cart, num_envs, max_noop_steps)

        self.device = torch.device(device)
        self.num_envs = num_envs
        self.rescale = rescale
        self.frameskip = frameskip
        self.repeat_prob = repeat_prob
        self.is_cuda = self.device.type == 'cuda'
        self.is_training = False
        self.episodic_life = episodic_life
        self.height = 84 if self.rescale else self.cart.screen_height()
        self.width = 84 if self.rescale else self.cart.screen_width()
        self.num_channels = 3 if color_mode == 'rgb' else 1

        self.action_set = torch.Tensor([int(s) for s in self.cart.minimal_actions()]).to(self.device).byte()

        # check if FIRE is in the action set
        self.fire_reset = int(torchcule_atari.FIRE) in self.action_set

        self.action_space = spaces.Discrete(self.action_set.size(0))
        self.observation_space = spaces.Box(low=0, high=255, shape=(self.num_channels, self.height, self.width), dtype=np.uint8)

        self.observations1 = torch.zeros((num_envs, self.height, self.width, self.num_channels), device=self.device, dtype=torch.uint8)
        self.observations2 = torch.zeros((num_envs, self.height, self.width, self.num_channels), device=self.device, dtype=torch.uint8)
        self.done = torch.zeros(num_envs, device=self.device, dtype=torch.bool)
        self.actions = torch.zeros(num_envs, device=self.device, dtype=torch.uint8)
        self.last_actions = torch.zeros(num_envs, device=self.device, dtype=torch.uint8)
        self.lives = torch.zeros(num_envs, device=self.device, dtype=torch.int32)
        self.rewards = torch.zeros(num_envs, device=self.device, dtype=torch.float32)

        self.states = torch.zeros((num_envs, self.state_size()), device=self.device, dtype=torch.uint8)
        self.frame_states = torch.zeros((num_envs, self.frame_state_size()), device=self.device, dtype=torch.uint8)
        self.ram = torch.randint(0, 255, (num_envs, self.cart.ram_size()), device=self.device, dtype=torch.uint8)
        self.tia = torch.zeros((num_envs, self.tia_update_size()), device=self.device, dtype=torch.int32)
        self.frame_buffer = torch.zeros((num_envs, 300 * self.cart.screen_width()), device=self.device, dtype=torch.uint8)
        self.cart_offsets = torch.zeros(num_envs, device=self.device, dtype=torch.int32)
        self.rand_states = torch.randint(0, np.iinfo(np.int32).max, (num_envs,), device=self.device, dtype=torch.int32)
        self.cached_states = torch.zeros((max_noop_steps, self.state_size()), device=self.device, dtype=torch.uint8)
        self.cached_ram = torch.randint(0, 255, (max_noop_steps, self.cart.ram_size()), device=self.device, dtype=torch.uint8)
        self.cached_frame_states = torch.zeros((max_noop_steps, self.frame_state_size()), device=self.device, dtype=torch.uint8)
        self.cached_tia = torch.zeros((max_noop_steps, self.tia_update_size()), device=self.device, dtype=torch.int32)
        self.cache_index = torch.zeros((num_envs,), device=self.device, dtype=torch.int32)

        self.set_cuda(self.is_cuda, self.device.index if self.is_cuda else -1)
        self.initialize(self.states.data_ptr(),
                        self.frame_states.data_ptr(),
                        self.ram.data_ptr(),
                        self.tia.data_ptr(),
                        self.frame_buffer.data_ptr(),
                        self.cart_offsets.data_ptr(),
                        self.action_set.data_ptr(),
                        self.rand_states.data_ptr(),
                        self.cached_states.data_ptr(),
                        self.cached_ram.data_ptr(),
                        self.cached_frame_states.data_ptr(),
                        self.cached_tia.data_ptr(),
                        self.cache_index.data_ptr());
Exemple #2
0
Fichier : env.py Projet : cg31/cule
class Env(torchcule_atari.AtariEnv):
    """
    ALE (Atari Learning Environment)

    This class provides access to ALE environments that may be executed on the CPU
    or GPU.

    Example:
		import argparse
		from torchcule.atari import Env

		parser = argparse.ArgumentParser(description="CuLE")
		parser.add_argument("game", type=str,
							help="Atari game name (breakout)")
		parser.add_argument("--n", type=int, default=20,
							help="Number of atari environments")
		parser.add_argument("--s", type=int, default=200,
							help="Number steps/frames to generate per environment")
		parser.add_argument("--c", type=str, default='rgb',
							help="Color mode (rgb or gray)")
		parser.add_argument("--rescale", action='store_true',
							help="Resize output frames to 84x84 using bilinear interpolation")
		args = parser.parse_args()

		color_mode = args.c
		num_envs = args.n
		num_steps = args.s

		env = Env(args.game, num_envs, color_mode, args.rescale)
		observations = env.reset()

		for _ in np.arange(num_steps):
			actions = env.sample_random_actions()
			observations, reward, done, info = env.step(actions)
    """

    def __init__(self, env_name, num_envs, color_mode='rgb', device='cpu', rescale=False,
                 frameskip=1, repeat_prob=0.25, episodic_life=False, max_noop_steps=30,
                 max_episode_length=10000):
        """Initialize the ALE class with a given environment

        Args:
            env_name (str): The name of the Atari rom
            num_envs (int): The number of environments to run
            color_mode (str): RGB ('rgb') or grayscale ('gray') observations
            use_cuda (bool) : Map ALEs to GPU
            rescale (bool) : Rescale grayscale observations to 84x84
            frameskip (int) : Number of frames to skip during training
            repeat_prob (float) : Probability of repeating previous action
            clip_rewards (bool) : Apply rewards clipping to {-1,1}
            episodic_life (bool) : Set 'done' on end of life
        """

        assert (color_mode == 'rgb') or (color_mode == 'gray')
        if color_mode == 'rgb' and rescale:
            raise ValueError('Rescaling is only valid in grayscale color mode')

        self.cart = Rom(env_name)
        super(Env, self).__init__(self.cart, num_envs, max_noop_steps)

        self.device = torch.device(device)
        self.num_envs = num_envs
        self.rescale = rescale
        self.frameskip = frameskip
        self.repeat_prob = repeat_prob
        self.is_cuda = self.device.type == 'cuda'
        self.is_training = False
        self.episodic_life = episodic_life
        self.height = 84 if self.rescale else self.cart.screen_height()
        self.width = 84 if self.rescale else self.cart.screen_width()
        self.num_channels = 3 if color_mode == 'rgb' else 1

        self.action_set = torch.Tensor([int(s) for s in self.cart.minimal_actions()]).to(self.device).byte()

        # check if FIRE is in the action set
        self.fire_reset = int(torchcule_atari.FIRE) in self.action_set

        self.action_space = spaces.Discrete(self.action_set.size(0))
        self.observation_space = spaces.Box(low=0, high=255, shape=(self.num_channels, self.height, self.width), dtype=np.uint8)

        self.observations1 = torch.zeros((num_envs, self.height, self.width, self.num_channels), device=self.device, dtype=torch.uint8)
        self.observations2 = torch.zeros((num_envs, self.height, self.width, self.num_channels), device=self.device, dtype=torch.uint8)
        self.done = torch.zeros(num_envs, device=self.device, dtype=torch.bool)
        self.actions = torch.zeros(num_envs, device=self.device, dtype=torch.uint8)
        self.last_actions = torch.zeros(num_envs, device=self.device, dtype=torch.uint8)
        self.lives = torch.zeros(num_envs, device=self.device, dtype=torch.int32)
        self.rewards = torch.zeros(num_envs, device=self.device, dtype=torch.float32)

        self.states = torch.zeros((num_envs, self.state_size()), device=self.device, dtype=torch.uint8)
        self.frame_states = torch.zeros((num_envs, self.frame_state_size()), device=self.device, dtype=torch.uint8)
        self.ram = torch.randint(0, 255, (num_envs, self.cart.ram_size()), device=self.device, dtype=torch.uint8)
        self.tia = torch.zeros((num_envs, self.tia_update_size()), device=self.device, dtype=torch.int32)
        self.frame_buffer = torch.zeros((num_envs, 300 * self.cart.screen_width()), device=self.device, dtype=torch.uint8)
        self.cart_offsets = torch.zeros(num_envs, device=self.device, dtype=torch.int32)
        self.rand_states = torch.randint(0, np.iinfo(np.int32).max, (num_envs,), device=self.device, dtype=torch.int32)
        self.cached_states = torch.zeros((max_noop_steps, self.state_size()), device=self.device, dtype=torch.uint8)
        self.cached_ram = torch.randint(0, 255, (max_noop_steps, self.cart.ram_size()), device=self.device, dtype=torch.uint8)
        self.cached_frame_states = torch.zeros((max_noop_steps, self.frame_state_size()), device=self.device, dtype=torch.uint8)
        self.cached_tia = torch.zeros((max_noop_steps, self.tia_update_size()), device=self.device, dtype=torch.int32)
        self.cache_index = torch.zeros((num_envs,), device=self.device, dtype=torch.int32)

        self.set_cuda(self.is_cuda, self.device.index if self.is_cuda else -1)
        self.initialize(self.states.data_ptr(),
                        self.frame_states.data_ptr(),
                        self.ram.data_ptr(),
                        self.tia.data_ptr(),
                        self.frame_buffer.data_ptr(),
                        self.cart_offsets.data_ptr(),
                        self.action_set.data_ptr(),
                        self.rand_states.data_ptr(),
                        self.cached_states.data_ptr(),
                        self.cached_ram.data_ptr(),
                        self.cached_frame_states.data_ptr(),
                        self.cached_tia.data_ptr(),
                        self.cache_index.data_ptr());

    def to(self, device):
        if self.is_cuda:
            torch.cuda.current_stream().synchronize()
            self.sync_this_stream()
            self.sync_other_stream()

        self.device = torch.device(device)
        self.is_cuda = self.device.type == 'cuda'
        self.set_cuda(self.is_cuda, self.device.index if self.is_cuda else -1)

        self.observations1 = self.observations1.to(self.device)
        self.observations2 = self.observations2.to(self.device)
        self.done = self.done.to(self.device)
        self.actions = self.actions.to(self.device)
        self.last_actions = self.last_actions.to(self.device)
        self.lives = self.lives.to(self.device)
        self.rewards = self.rewards.to(self.device)
        self.action_set = self.action_set.to(self.device)

        self.states = self.states.to(self.device)
        self.frame_states = self.frame_states.to(self.device)
        self.ram = self.ram.to(self.device)
        self.tia = self.tia.to(self.device)
        self.frame_buffer = self.frame_buffer.to(self.device)
        self.cart_offsets = self.cart_offsets.to(self.device)
        self.rand_states = self.rand_states.to(self.device)
        self.cached_states = self.cached_states.to(self.device)
        self.cached_ram = self.cached_ram.to(self.device)
        self.cached_frame_states = self.cached_frame_states.to(self.device)
        self.cached_tia = self.cached_tia.to(self.device)
        self.cache_index = self.cache_index.to(self.device)

        self.initialize(self.states.data_ptr(),
                        self.frame_states.data_ptr(),
                        self.ram.data_ptr(),
                        self.tia.data_ptr(),
                        self.frame_buffer.data_ptr(),
                        self.cart_offsets.data_ptr(),
                        self.action_set.data_ptr(),
                        self.rand_states.data_ptr(),
                        self.cached_states.data_ptr(),
                        self.cached_ram.data_ptr(),
                        self.cached_frame_states.data_ptr(),
                        self.cached_tia.data_ptr(),
                        self.cache_index.data_ptr());

        if self.is_cuda:
            torch.cuda.current_stream().synchronize()
            self.sync_this_stream()
            self.sync_other_stream()

    def train(self, frameskip=4):
        """Set ALE to training mode"""
        self.frameskip = frameskip
        self.is_training = True

    def eval(self):
        """Set ALE to evaluation mode"""
        self.is_training = False

    def minimal_actions(self):
        """Minimal number of actions for the environment

        Returns:
            list[Action]: minimal set of actions for the environment
        """
        return self.action_set

    def sample_random_actions(self, asyn=False):
        """Generate a random set of actions

        Returns:
            list[Action]: random set of actions generated for the environment
        """
        return torch.randint(self.minimal_actions().size(0), (self.num_envs,), device=self.device, dtype=torch.uint8)

    def screen_shape(self):
        """Get the shape of the observations

        Returns:
            tuple(int,int): Tuple containing height and width of observations
        """
        return (self.height, self.width)

    def reset(self, seeds=None, initial_steps=50, verbose=False, asyn=False):
        """Reset the environments

        Args:
            seeds (list[int]): seeds to use for initialization
            initial_steps (int): number of initial NOOP steps to execute during initialization

        Returns:
            tuple(int,int): Tuple containing height and width of observations
        """
        if seeds is None:
            seeds = torch.randint(np.iinfo(np.int32).max, (self.num_envs,), dtype=torch.int32, device=self.device)

        if self.is_cuda:
            self.sync_other_stream()
            stream = torch.cuda.current_stream()

        super(Env, self).reset(seeds.data_ptr())

        if self.is_training:
            iterator = range(math.ceil(initial_steps / self.frameskip))

            if verbose:
                from tqdm import tqdm
                iterator = tqdm(iterator)

            for _ in iterator:
                actions = self.sample_random_actions()
                self.step(actions, asyn=True)

        if self.is_cuda:
            self.sync_this_stream()
            if not asyn:
                stream.synchronize()

        return self.observations1

    def step(self, player_a_actions, player_b_actions=None, asyn=False):
        """Take a step in the environment by apply a set of actions

        Args:
            actions (list[Action]): list of actions to apply to each environment

        Returns:
            ByteTensor: observations for each environment
            IntTensor: sum of rewards for frameskip steps in each environment
            ByteTensor: 'done' state for each environment
            list[str]: miscellaneous information (currently unused)
        """

	    # sanity checks
        assert player_a_actions.size(0) == self.num_envs

        self.rewards.zero_()
        self.observations1.zero_()
        self.observations2.zero_()
        self.done.zero_()

        self.player_a_actions = self.action_set[player_a_actions.long()]
        player_a_actions_ptr = self.player_a_actions.data_ptr()

        if player_b_actions is not None:
            self.player_b_actions = self.action_set[player_b_actions.long()]
            player_b_actions_ptr = self.player_b_actions.data_ptr()
        else:
            player_b_actions_ptr = 0

        if self.is_cuda:
            self.sync_other_stream()

        for frame in range(self.frameskip):
            super(Env, self).step(self.fire_reset and self.is_training, player_a_actions_ptr, player_b_actions_ptr, self.done.data_ptr())
            self.get_data(self.episodic_life, self.done.data_ptr(), self.rewards.data_ptr(), self.lives.data_ptr())
            if frame == (self.frameskip - 2):
                self.generate_frames(self.rescale, False, self.num_channels, self.observations2.data_ptr())

        self.reset_states()
        self.generate_frames(self.rescale, True, self.num_channels, self.observations1.data_ptr())

        if self.is_cuda:
            self.sync_this_stream()
            if not asyn:
                torch.cuda.current_stream().synchronize()

        self.observations1 = torch.max(self.observations1, self.observations2)

        info = {'ale.lives': self.lives}

        return self.observations1, self.rewards, self.done, info

    def get_states(self, indices):
        from torchcule.atari.state import State
        return [State(s) for s in super(Env, self).get_states([i for i in indices.cpu()])]

    def set_states(self, indices, states):
        super(Env, self).set_states([i for i in indices.cpu()], [s.state for s in states])