Пример #1
0
    def __init__(self,
                 worker_id,
                 neptune_client,
                 pipe_c2s,
                 pipe_s2c,
                 model,
                 dummy,
                 predictor_threads,
                 predict_batch_size=16,
                 do_train=True):
        # predictor_threads is previous PREDICTOR_THREAD
        super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c,
                                                args.simulator_procs,
                                                os.getpid())
        self.M = model
        self.do_train = do_train

        # the second queue is here!
        self.queue = queue.Queue(maxsize=args.my_sim_master_queue)
        self.dummy = dummy
        self.predictor_threads = predictor_threads

        self.last_queue_put = start_timer()
        self.queue_put_times = []
        self.predict_batch_size = predict_batch_size
        self.counter = 0

        self.worker_id = worker_id
        self.neptune_client = neptune_client
        self.stats = defaultdict(StatCounter)
        self.games = StatCounter()
Пример #2
0
    def __init__(self,
                 rom_file,
                 viz=0,
                 height_range=(None, None),
                 frame_skip=4,
                 image_shape=(84, 84),
                 nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
                "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn(
                    "https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
                )

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()
Пример #3
0
class MySimulatorMaster(SimulatorMaster, Callback):
    def __init__(self,
                 worker_id,
                 neptune_client,
                 pipe_c2s,
                 pipe_s2c,
                 model,
                 dummy,
                 predictor_threads,
                 predict_batch_size=16,
                 do_train=True):
        # predictor_threads is previous PREDICTOR_THREAD
        super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c,
                                                args.simulator_procs,
                                                os.getpid())
        self.M = model
        self.do_train = do_train

        # the second queue is here!
        self.queue = queue.Queue(maxsize=args.my_sim_master_queue)
        self.dummy = dummy
        self.predictor_threads = predictor_threads

        self.last_queue_put = start_timer()
        self.queue_put_times = []
        self.predict_batch_size = predict_batch_size
        self.counter = 0

        self.worker_id = worker_id
        self.neptune_client = neptune_client
        self.stats = defaultdict(StatCounter)
        self.games = StatCounter()

    def _setup_graph(self):
        with tf.device('/cpu:0'):
            with tf.variable_scope(tf.get_variable_scope(), reuse=None):
                self.sess = self.trainer.sess
                self.async_predictor = MultiThreadAsyncPredictor(
                    self.trainer.get_predict_funcs(['state'],
                                                   ['logitsT', 'pred_value'],
                                                   self.predictor_threads),
                    batch_size=self.predict_batch_size)
                self.async_predictor.run()

    def _on_state(self, state, ident):
        ident, ts = ident
        client = self.clients[ident]

        if self.dummy:
            action = 0
            value = 0.0
            client.memory.append(
                TransitionExperience(state, action, None, value=value))
            self.send_queue.put([ident, dumps(action)])
        else:

            def cb(outputs):
                # distrib, value, global_step, isAlive  = outputs.result()
                o = outputs.result()
                if o[-1]:
                    distrib = o[0]
                    value = o[1]
                    global_step = o[2]
                    assert np.all(np.isfinite(distrib)), distrib
                    action = np.random.choice(len(distrib), p=distrib)
                    client = self.clients[ident]
                    client.memory.append(
                        TransitionExperience(state,
                                             action,
                                             None,
                                             value=value,
                                             ts=ts))
                else:
                    self.send_queue.put([ident, dumps((0, 0, False))])
                    return

                #print"Q-debug: MySimulatorMaster send_queue before put, size: ", self.send_queue.qsize(), '/', self.send_queue.maxsize
                self.send_queue.put(
                    [ident, dumps((action, global_step, True))])

            self.async_predictor.put_task([state], cb)

    def _on_episode_over(self, ident):
        ident, ts = ident

        client = self.clients[ident]
        # send game score to neptune
        self.games.feed(self.stats[ident].sum)
        self.stats[ident].reset()

        if self.games.count == 10:
            self.neptune_client.send(
                (self.worker_id, ('online', self.games.average)))
            self.games.reset()

        self._parse_memory(0, ident, True, ts)

    def _on_datapoint(self, ident):
        ident, ts = ident
        client = self.clients[ident]

        self.stats[ident].feed(client.memory[-1].reward)

        if len(client.memory) == LOCAL_TIME_MAX + 1:
            R = client.memory[-1].value
            self._parse_memory(R, ident, False, ts)

    def _parse_memory(self, init_r, ident, isOver, ts):
        client = self.clients[ident]
        mem = client.memory
        if not isOver:
            last = mem[-1]
            mem = mem[:-1]

        mem.reverse()
        R = float(init_r)
        for idx, k in enumerate(mem):
            R = np.clip(k.reward, -1, 1) + GAMMA * R
            point_ts = k.ts
            self.log_queue_put()
            if self.do_train:
                self.queue.put(
                    [k.state, k.action, R, point_ts, init_r, isOver])

        if not isOver:
            client.memory = [last]
        else:
            client.memory = []

    def log_queue_put(self):
        self.counter += 1
        elapsed_last_put = elapsed_time_ms(self.last_queue_put)
        self.queue_put_times.append(elapsed_last_put)
        k = 1000
        if self.counter % 1 == 0:
            logger.debug("queue_put_times elapsed {elapsed}".format(
                elapsed=elapsed_last_put))
            logger.debug("queue_put_times {puts_s} puts/s".format(
                puts_s=1000.0 / np.mean(self.queue_put_times[-k:])))
        self.last_queue_put = start_timer()
Пример #4
0
class AtariPlayer(RLEnvironment):
    """
    A wrapper for atari emulator.
    Will automatically restart when a real episode ends (isOver might be just
    lost of lives but not game over).
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 height_range=(None, None),
                 frame_skip=4,
                 image_shape=(84, 84),
                 nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
                "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn(
                    "https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
                )

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def current_state(self):
        """
        :returns: a gray-scale (h, w, 1) float32 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                #m = cv2.resize(ret, (1920,1200))
                cv2.imshow(self.windowname, ret)
                time.sleep(self.viz)
        ret = ret[self.height_range[0]:self.height_range[1], :].astype(
            'float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        ret = cv2.resize(ret, self.image_shape)
        ret = np.expand_dims(ret, axis=2)
        return ret

    def get_action_space(self):
        return DiscreteActionSpace(len(self.actions))

    def finish_episode(self):
        self.stats['score'].append(self.current_episode_score.sum)

    def restart_episode(self):
        self.current_episode_score.reset()
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def action(self, act):
        """
        :param act: an index of the action
        :returns: (reward, isOver)
        """
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        self.current_episode_score.feed(r)
        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives
        if isOver:
            self.finish_episode()
        if self.ale.game_over():
            self.restart_episode()
        return (r, isOver)