Exemplo n.º 1
0
    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 exploration, end_exploration, exploration_epoch_anneal,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.num_actions = player.get_action_space().num_actions()
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # TODO just use a semaphore?
        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
Exemplo n.º 2
0
    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 exploration, end_exploration, exploration_epoch_anneal,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.num_actions = player.get_action_space().num_actions()
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # TODO just use a semaphore?
        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
Exemplo n.º 3
0
 def __init__(self, exe_path, json_path, actions=ACTIONS, height=HEIGHT, width=WIDTH, gray=False, record=False):
   super(ThorPlayer, self).__init__()
   assert os.path.isfile(exe_path), 'wrong path of executable binary for Thor'
   assert os.path.isfile(json_path), 'wrong path of target json file'
   self.height = height
   self.width = width
   self.gray = gray
   self.record = record
   # set Thor controller
   self.env = robosims.controller.ChallengeController(
               unity_path=exe_path,
               height=self.height,
               width=self.width,
               record_actions=self.record)
   
   # read targets from the json file
   with open(json_path) as f:
     self.targets = json.loads(f.read())
   self.num_targets = len(self.targets)
   
   self.rng = get_rng(self)
   self.actions = actions
   self.current_episode_score = StatCounter()
   self.env.start()
   self.restart_episode()
Exemplo n.º 4
0
    def __init__(self, df, random_move=True):
        """Yield a board configuration and next move from a LMDB data point

        Args:
            df: dataflow of LMDB entries
            random_move (bool, optional): pick random_move move in match
        """
        rng = get_rng(self)

        def func(dp):
            raw = dp[0]
            max_moves = len(raw) / 2

            # game is too short -> skip
            if max_moves < 10:
                return None

            # all move up to the last one (we want to predict at least one move)
            move_id = 1 + max_moves - 2
            if random_move:
                move_id = rng.randint(1, max_moves - 1)

            features = np.zeros((FEATURE_LEN, 19, 19), dtype=np.int32)
            next_move = goplanes.planes_from_bytes(raw.tobytes(), features,
                                                   move_id)

            assert not np.isnan(features).any()

            return [features, int(next_move)]

        super(GameDecoder, self).__init__(df, func)
Exemplo n.º 5
0
    def __init__(self, pattern, window_size=5, nr_examples=10):
        super(VideoPatchesFlow, self).__init__()

        self.pattern = pattern
        self.window_size = window_size
        self.nr_examples = nr_examples
        from tensorpack.utils import get_rng
        self.rng = get_rng(self)
Exemplo n.º 6
0
 def benchmark():
     a = AtariPlayer(sys.argv[1], viz=False, height_range=(28, -8))
     num = a.get_action_space().num_actions()
     rng = get_rng(num)
     start = time.time()
     cnt = 0
     while True:
         act = rng.choice(range(num))
         r, o = a.action(act)
         a.current_state()
         cnt += 1
         if cnt == 5000:
             break
     print(time.time() - start)
Exemplo n.º 7
0
 def benchmark():
     a = AtariPlayer(sys.argv[1], viz=False, height_range=(28, -8))
     num = a.get_action_space().num_actions()
     rng = get_rng(num)
     start = time.time()
     cnt = 0
     while True:
         act = rng.choice(range(num))
         r, o = a.action(act)
         a.current_state()
         cnt += 1
         if cnt == 5000:
             break
     print(time.time() - start)
Exemplo n.º 8
0
    def __init__(self,
                 predictor_io_names,
                 get_player,
                 num_parallel_players,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 update_frequency, history_len,
                 state_dtype='uint8'):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            get_player (-> gym.Env): a callable which returns a player.
            num_parallel_players (int): number of players to run in parallel.
                Standard DQN uses 1.
                Parallelism increases speed, but will affect the distribution of
                experiences in the replay buffer.
            state_shape (tuple):
            batch_size (int):
            memory_size (int):
            init_memory_size (int):
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            state_dtype (str):
        """
        assert len(state_shape) in [1, 2, 3], state_shape
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = 1.0  # default initial exploration

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        self.mem = ReplayMemory(memory_size, state_shape, self.history_len, dtype=state_dtype)
Exemplo n.º 9
0
    def __init__(self, player, predictor, memory, history_len):
        """
        Args:
            player (gym.Env)
            predictor (callable): the model forward function which takes a
                state and returns the prediction.
            memory (ReplayMemory): the replay memory to store experience to.
            history_len (int):
        """
        self.player = player
        self.num_actions = player.action_space.n
        self.predictor = predictor
        self.memory = memory
        self.state_shape = memory.state_shape
        self.dtype = memory.dtype
        self.history_len = history_len

        self._current_episode = []
        self._current_ob = player.reset()
        self._current_game_score = StatCounter()  # store per-step reward
        self.total_scores = []  # store per-game total score

        self.rng = get_rng(self)
Exemplo n.º 10
0
 def reset_state(self):
     """ Reset the RNG """
     self.rng = get_rng(self)
Exemplo n.º 11
0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info


if __name__ == '__main__':
    import sys

    a = AtariPlayer(sys.argv[1], viz=0.03)
    num = a.action_space.n
    rng = get_rng(num)
    while True:
        act = rng.choice(range(num))
        state, reward, isOver, info = a.step(act)
        if isOver:
            print(info)
            a.reset()
        print("Reward:", reward)
Exemplo n.º 12
0
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "ROM {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                #cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width),
                                            dtype=np.uint8)
        self._restart_episode()
Exemplo n.º 13
0
 def reset_state(self):
     self.rng = get_rng(self)
     self.gnr = self.gnr_constructor()
Exemplo n.º 14
0
 def reset_state(self):
     self.rng = get_rng(self)
Exemplo n.º 15
0
    def __init__(self,
                 rom_file,
                 viz=0,
                 height_range=(None, None),
                 frame_skip=4,
                 image_shape=(84, 84),
                 nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
                "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn(
                    "https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
                )

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()
Exemplo n.º 16
0
            a.current_state()
            cnt += 1
            if cnt == 5000:
                break
        print(time.time() - start)

    if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
        import threading
        import multiprocessing
        for k in range(3):
            # th = multiprocessing.Process(target=benchmark)
            th = threading.Thread(target=benchmark)
            th.start()
            time.sleep(0.02)
        benchmark()
    else:
        a = AtariPlayer(sys.argv[1],
                        viz=0.03, height_range=(28, -8))
        num = a.get_action_space().num_actions()
        rng = get_rng(num)
        import time
        while True:
            # im = a.grab_image()
            # cv2.imshow(a.romname, im)
            act = rng.choice(range(num))
            print(act)
            r, o = a.action(act)
            a.current_state()
            # time.sleep(0.1)
            print(r, o)
Exemplo n.º 17
0
 def reset_state(self):
     self.rng = get_rng(self)
Exemplo n.º 18
0
def get_raw_env(experiment):
    if experiment == 'STANDARD':
        return soccer_environment.SoccerEnvironment
    elif experiment == 'PASSING':
        return soccer_environment.SoccerPassingBallEnvironment
    elif experiment == 'SAVING':
        return soccer_environment.SoccerSavingBallEnvironment
    assert 0


if __name__ == '__main__':
    pl = SoccerPlayer(image_shape=(84, 84),
                      viz=1,
                      frame_skip=1,
                      field='large',
                      ai_frame_skip=1,
                      team_size=2,
                      raw_env=SoccerPassingBallEnvironment)
    rng = get_rng(5)
    import time
    while True:
        # im = a.grab_image()
        # cv2.imshow(a.romname, im)
        act = rng.choice(range(5))
        act = 4
        r, o = pl.action(act)
        pl.current_state()
        print(pl.get_changing_counter())
        # time.sleep(0.1)
Exemplo n.º 19
0
    def __init__(self, rom_file, viz=0, height_range=(None, None),
                 frame_skip=4, image_shape=(84, 84), nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()