Example #1
0
def eval_with_funcs(predictors, nr_eval, get_player_fn):
    """
    Args:
        predictors ([PredictorBase])
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue):
            super(Worker, self).__init__()
            self._func = func
            self.q = queue

        def func(self, *args, **kwargs):
            if self.stopped():
                raise RuntimeError("stopped!")
            return self._func(*args, **kwargs)

        def run(self):
            with self.default_sess():
                player = get_player_fn(train=False)
                while not self.stopped():
                    try:
                        score = play_one_episode(player, self.func)
                        # print("Score, ", score)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, score)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        r = q.get()
        stat.feed(r)
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        r = q.get()
        stat.feed(r)

    if stat.count > 0:
        return (stat.average, stat.max)
    return (0, 0)
Example #2
0
    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            state_shape (tuple): h, w, c
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        assert len(state_shape) == 3, state_shape
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()
Example #3
0
    def _clear_history(self):
        """ clear history buffer with current state
        """
        # stat counter to store current score or accumlated reward
        self.rewards = StatCounter()

        self._agent_nodes = np.zeros((self.max_num_frames, self.dims),
                                     dtype=int)  # [(0,) * self.dims] * self._history_length
        # self._IOU_history = np.zeros((self._history_length,))
        self._distances = np.zeros((self.max_num_frames,))
        # list of value lists
        # self._qvalues_history = np.zeros(
        #     (self.max_num_frames, self.actions))  # [(0,) * self.actions] * self._history_length
        self.reward_history = np.zeros((self.max_num_frames,))
Example #4
0
    def __init__(self,
                 # model,
                 agent_name,
                 player,
                 state_shape,
                 num_actions,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency,
                 encoding_file='../AutoEncoder/encoding.npy'):
        init_memory_size = int(init_memory_size)
        # self.model = model

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.agent_name = agent_name
        self.exploration = init_exploration
        self.num_actions = num_actions
        self.encoding = np.load(encoding_file)
        logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        self.player.reset()
        self.player.prepare()
        self._comb_mask = True
        self._fine_mask = None
        self._current_ob, self._action_space = self.get_state_and_action_spaces()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()
Example #5
0
def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
    """
    Args:
        predictors ([PredictorBase])
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue):
            super(Worker, self).__init__()
            self.func = func
            self.q = queue

        def run(self):
            with self.default_sess():
                player = get_player_fn()
                while not self.stopped():
                    try:
                        val = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, val)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()

    def fetch():
        val = q.get()
        stat.feed(val)
        if verbose:
            if val > 0:
                logger.info("farmer wins")
            else:
                logger.info("lord wins")

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        fetch()
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        fetch()
    farmer_win_rate = stat.average
    return farmer_win_rate
Example #6
0
 def _init(self):
     logger.info("[{}]: agent init, isTrain={}".format(self._agentIdent, self._isTrain))
     self._episodeCount = -1
     from tensorpack.utils.utils import get_rng
     self._rng = get_rng(self)
     from tensorpack.utils.stats import StatCounter
     self.reset_stat()
     self.rwd_counter = StatCounter()
     self._memorySaver = None
     save_dir = self._kwargs.pop('save_dir', None)
     if save_dir is not None:
         self._memorySaver = MemorySaver(save_dir,
                                         self._kwargs.pop('max_save_item', 3),
                                         self._kwargs.pop('min_save_score', None),
                                         )
     self.restart_episode()
     pass
Example #7
0
def eval_with_funcs(predictors, nr_eval, get_player_fn):
    """
    Args:
        predictors ([PredictorBase])
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue):
            super(Worker, self).__init__()
            self._func = func
            self.q = queue

        def func(self, *args, **kwargs):
            if self.stopped():
                raise RuntimeError("stopped!")
            return self._func(*args, **kwargs)

        def run(self):
            with self.default_sess():
                player = get_player_fn(train=False)
                while not self.stopped():
                    try:
                        score = play_one_episode(player, self.func)
                        # print("Score, ", score)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, score)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()
    try:
        for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
            r = q.get()
            stat.feed(r)
        logger.info("Waiting for all the workers to finish the last run...")
        for k in threads:
            k.stop()
        for k in threads:
            k.join()
        while q.qsize():
            r = q.get()
            stat.feed(r)
    except:
        logger.exception("Eval")
    finally:
        if stat.count > 0:
            return (stat.average, stat.max)
        return (0, 0)
Example #8
0
class MedicalPlayer(gym.Env):
    """
        Class that provides 3D medical image environment.
        This is just an implementation of the classic "agent-environment loop".
        Each time-step, the agent chooses an action, and the environment returns
        an observation and a reward.
    """

    def __init__(self, directory=None, viz=False, task=False, files_list=None,
                 screen_dims=(27, 27), history_length=20, multiscale=True,
                 max_num_frames=0, saveGif=False, saveVideo=False):
        """
        :param train_directory: environment or game name

        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames

        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d, w) - defaults (27, 27)

        :param nullop_start: start with random number of null ops

        :param location_history_length: consider lost of lives as end of
            episode (useful for training)

        :max_num_frames: maximum numbe0r of frames per episode.
        """

        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filenames", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5, 0.25, 0.75]
        # y = [0.5, 0.25, 0.75]
        # z = [0.5, 0.25, 0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0

        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames

        # stores information: terminal, score, distError
        self.info = None

        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo

        # training flag
        self.task = task

        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)

        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0

            if isinstance(viz, int):  # check if viz is an int number
                viz = float(viz)
            self.viz = viz

            if self.viz and isinstance(self.viz, float): # check if viz is an float number
                self.viewer = None
                self.gif_buffer = []

        # stat counter to store current score or accumulated reward
        self.current_episode_score = StatCounter()

        # get action space and minimal action set
        self.action_space = spaces.Discrete(4)  # change number actions here
        self.actions = self.action_space.n

        # 都是需要配备observation space的
        self.observation_space = spaces.Box(low=0, high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)

        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [(0,) * self.dims] * self._history_length
        self._qvalues_history = [(0,) * self.actions] * self._history_length

        # initialize rectangle limits from input image coordinates
        self.rectangle = Rectangle(0, 0, 0, 0)

        # add your data loader here
        # play is returnLandmarks = False and returnLandmarks = True
        if self.task == 'play':
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=False)
            print("self.files:", self.files)
        else:
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=True)


        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular()

        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
            restart current episoide
        """
        self.terminal = False
        self.reward = 0
        self.cnt = 0 # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset the stat counter
        self._loc_history = [(0,) * self.dims] * self._history_length

        # list of q-value lists
        self._qvalues_history = [(0,) * self.actions] * self._history_length
        self.new_random_game()

    def new_random_game(self):
        """
            1.load image,
            2.set dimensions,
            3.randomize start point,
            4.init _screen, qvals,
            5.calc distance to goal
        """
        self.terminal = False
        self.viewer = None

        # ######################################################################
        # ## generate evaluation results from 19 different points
        # if self.count_points ==0:
        #     print('\n============== new game ===============\n')
        #     # save results
        #     if self.total_loc:
        #         with open(self.csvfile, 'a') as outcsv:
        #             fields= [self.filenames, self.cur_dist]
        #             writer = csv.writer(outcsv)
        #             writer.writerow(map(lambda x: x, fields))
        #         self.total_loc = []
        #     # sample a new image
        #     self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        #     scale = next(self.start_points)
        #     self.count_points +=1
        # else:
        #     self.count_points += 1
        #     logger.info('count_points {}'.format(self.count_points))
        #     scale = next(self.start_points)
        #
        # x = int(scale[0] * self._image.dims[0])
        # y = int(scale[1] * self._image.dims[1])
        # z = int(scale[2] * self._image.dims[2])
        # logger.info('starting point {}-{}-{}'.format(x,y,z))
        # ######################################################################

        # # sample a new image
        self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        self.filename = os.path.basename(self.filepath)

        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            ## brain
            self.action_step = 9
            self.xscale = 3
            self.yscale = 3

            ## cardiac
            # self.action_step = 6
            # self.xscale = 2
            # self.yscale = 2
        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1

        # image volume size
        self._image_dims = self._image.dims
        self._image.data =  np.array(self._image)

        print("image_dim:", self._image_dims)

        #######################################################################
        # select random starting point
        # add padding to avoid start right on the border of the image
        if (self.task == 'train'):
            skip_thickness = ((int)(self._image_dims[0] / 5),
                              (int)(self._image_dims[1] / 5))
        else:
            skip_thickness = (int(self._image_dims[0] / 4),
                              int(self._image_dims[1] / 4))

        x = self.rng.randint(0 + skip_thickness[0],
                             self._image_dims[0] - skip_thickness[0])

        y = self.rng.randint(0 + skip_thickness[1],
                             self._image_dims[1] - skip_thickness[1])
        #######################################################################

        # 都是用的(x, y)
        self._location = (x, y)
        self._start_location = (x, y)

        self._qvalues = [0, ] * self.actions
        self._screen = self._current_state()

        if self.task == 'play':
            self.cur_dist = 0
        else:
            self.cur_dist = self.calcDistance(self._location,
                                              self._target_loc,
                                              self.spacing)

    def calcDistance(self, points1, points2, spacing=(1, 1)):
        """
            calculate the distance between two points in mm
        """
        spacing = np.array(spacing)
        points1 = spacing * np.array(points1)
        points2 = spacing * np.array(points2)

        return np.linalg.norm(points1 - points2)

    # 这个还是基于gym的 --> 要不要将 gym 中搞清楚呢?
    def step(self, act, qvalues):
        """
        The environment's step function returns exactly what we need.

        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.

          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.

          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)

          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """

        self._qvalues = qvalues
        current_loc = self._location
        self.terminal = False
        go_out = False

        # 这种还不是我们真正所说的 连续的 action
        # FORWARD Y+ ---------------------------------------------------------
        if (act == 1):
            next_location = (current_loc[0],
                             round(current_loc[1] + self.action_step),
                             current_loc[2])
            if (next_location[1] >= self._image_dims[1]):
                # print(' trying to go out the image Y+ ',)
                next_location = current_loc
                go_out = True

        # RIGHT X+ -----------------------------------------------------------
        if (act == 2):
            next_location = (round(current_loc[0] + self.action_step),
                             current_loc[1],
                             current_loc[2])
            if next_location[0] >= self._image_dims[0]:
                # print(' trying to go out the image X+ ',)
                next_location = current_loc
                go_out = True

        # LEFT X- -----------------------------------------------------------
        if act == 3:
            next_location = (round(current_loc[0] - self.action_step),
                             current_loc[1],
                             current_loc[2])
            if next_location[0] <= 0:
                # print(' trying to go out the image X- ',)
                next_location = current_loc
                go_out = True

        # BACKWARD Y- ---------------------------------------------------------
        if act == 4:
            next_location = (current_loc[0],
                             round(current_loc[1] - self.action_step),
                             current_loc[2])
            if next_location[1] <= 0:
                # print(' trying to go out the image Y- ',)
                next_location = current_loc
                go_out = True

        # ---------------------------------------------------------------------
        # ---------------------------------------------------------------------
        # punish -1 reward if the agent tries to go out
        if (self.task!='play'):
            if go_out:
                self.reward = -1
            else:
                self.reward = self._calc_reward(current_loc, next_location)

        # update screen, reward ,location, terminal
        self._location = next_location
        self._screen = self._current_state()

        # terminate if the distance is less than 1 during trainig
        if (self.task == 'train'):
            if self.cur_dist <= 1:
                self.terminal = True
                self.num_success.feed(1)

        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames: self.terminal = True

        # update history buffer with new location and qvalues
        if (self.task != 'play'):
            self.cur_dist = self.calcDistance(self._location,
                                              self._target_loc,
                                              self.spacing)
        self._update_history()

        # check if agent oscillates -- 摆动
        # 检测agent是否摆动
        if self._oscillate:
            self._location = self.getBestLocation()
            self._screen = self._current_state()

            if (self.task != 'play'):
                self.cur_dist = self.calcDistance(self._location,
                                                  self._target_loc,
                                                  self.spacing)
            # multi-scale steps
            if self.multiscale:
                if self.xscale > 1:
                    self.xscale -= 1
                    self.yscale -= 1
                    self.action_step = int(self.action_step / 3)
                    self._clear_history()

                # terminate if scale is less than 1
                else:
                    self.terminal = True
                    if self.cur_dist <= 1: self.num_success.feed(1)
            else:
                self.terminal = True
                if self.cur_dist <= 1: self.num_success.feed(1)

        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        distance_error = self.cur_dist
        self.current_episode_score.feed(self.reward)

        info = {'score': self.current_episode_score.sum, 'gameOver': self.terminal,
                'distError': distance_error, 'filenames': self.filename}

        # #######################################################################
        # ## generate evaluation results from 19 different points
        # if self.terminal:
        #     logger.info(info)
        #     self.total_loc.append(self._location)
        #     if not(self.count_points == 19):
        #         self._restart_episode()
        #     else:
        #         mean_location = np.mean(self.total_loc,axis=0)
        #         logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location))
        #         self.cur_dist = self.calcDistance(mean_location,
        #                                           self._target_loc,
        #                                           self.spacing)
        #         logger.info('final distance error {} \n'.format(self.cur_dist))
        #         self.count_points = 0
        # #######################################################################

        return self._current_state(), self.reward, self.terminal, info

    def getBestLocation(self):
        '''
            get best location with best qvalue from last for locations
            stored in history
        '''
        last_qvalues_history = self._qvalues_history[-4:]
        last_loc_history = self._loc_history[-4:]
        best_qvalues = np.max(last_qvalues_history, axis=1)

        # best_idx = best_qvalues.argmax()
        best_idx = best_qvalues.argmin()
        best_location = last_loc_history[best_idx]

        return best_location

    def _clear_history(self):
        '''
            clear history buffer with current state
        '''
        self._loc_history = [(0,) * self.dims] * self._history_length
        self._qvalues_history = [(0,) * self.actions] * self._history_length

    def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        self._loc_history[-1] = self._location

        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues

    def _current_state(self):
        """
        crop image data around current location to update what network sees.
        update rectangle

        :return: new state
        """
        # initialize screen with zeros - all background
        screen = np.zeros((self.screen_dims)).astype('float32')

        # screen uses coordinate system relative to origin (0, 0, 0)
        screen_xmin, screen_ymin = 0, 0
        screen_xmax, screen_ymax = self.screen_dims

        print("image_data:", self._image)

        # extract boundary locations using coordinate system relative to "global" image
        # width, height, depth in terms of screen coord system
        if self.xscale % 2:
            xmin = self._location[0] - int(self.width * self.xscale / 2) - 1
            xmax = self._location[0] + int(self.width * self.xscale / 2)
            ymin = self._location[1] - int(self.height * self.yscale / 2) - 1
            ymax = self._location[1] + int(self.height * self.yscale / 2)
        else:
            xmin = self._location[0] - round(self.width * self.xscale / 2)
            xmax = self._location[0] + round(self.width * self.xscale / 2)
            ymin = self._location[1] - round(self.height * self.yscale / 2)
            ymax = self._location[1] + round(self.height * self.yscale / 2)

        # check if they violate image boundary and fix it
        if xmin < 0:
            xmin = 0
            screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale))
        if ymin < 0:
            ymin = 0
            screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale))

        if xmax > self._image_dims[0]:
            xmax = self._image_dims[0]
            screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale))
        if ymax>self._image_dims[1]:
            ymax = self._image_dims[1]
            screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale))

        # crop image data to update what network sees
        # image coordinate system becomes screen coordinates
        # scale can be thought of as a stride
        screen[screen_xmin:screen_xmax, screen_ymin:screen_ymax] = self._image.data[
                                                                            xmin:xmax:self.xscale,
                                                                            ymin:ymax:self.yscale]

        # update rectangle limits from input image coordinates
        # this is what the network sees
        self.rectangle = Rectangle(xmin, xmax,
                                   ymin, ymax)

        return screen

    def get_plane(self):
        return self._image.data[:, :]

    def _calc_reward(self, current_loc, next_loc):
        """ Calculate the new reward based on the decrease in euclidean distance to the target location
        """
        curr_dist = self.calcDistance(current_loc, self._target_loc,
                                      self.spacing)
        next_dist = self.calcDistance(next_loc, self._target_loc,
                                      self.spacing)
        return curr_dist - next_dist

    @property
    def _oscillate(self):
        """
            Return True if the agent is stuck and oscillating
        """
        counter = Counter(self._loc_history)
        freq = counter.most_common()

        if freq[0][0] == (0, 0):
            if (freq[1][1] > 3):
                return True
            else:
                return False
        elif (freq[0][1] > 3):
            return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "FORWARD",  # MOVE Y+
            2: "RIGHT",  # MOVE X+
            3: "LEFT",  # MOVE X-
            4: "BACKWARD",  # MOVE Y-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
            return screen dimensions
        """
        return (self.width, self.height)

    def lives(self):
        return None

    def reset_stat(self):
        """
            Reset all statistics counter
        """
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self, return_rgb_array=False):
        # pass
        # get dimensions
        current_point = self._location
        target_point = self._target_loc

        # get image and convert it to pyglet
        plane = self.get_plane(current_point[2])  # z-plane

        # plane = np.squeeze(self._current_state()[:,:,13])
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 1
        scale_y = 1
        img = cv2.resize(plane,
                         (int(scale_x*plane.shape[1]), int(scale_y*plane.shape[0])),
                         interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)

        # draw current point
        self.viewer.draw_circle(radius=scale_x * 1,
                                pos_x=scale_x * current_point[0],
                                pos_y=scale_y * current_point[1],
                                color=(0.0, 0.0, 1.0, 1.0))

        # draw a box around the agent - what the network sees ROI
        self.viewer.draw_rect(scale_x*self.rectangle.xmin, scale_y*self.rectangle.ymin,
                              scale_x*self.rectangle.xmax, scale_y*self.rectangle.ymax)

        self.viewer.display_text('Agent ', color=(204, 204, 0, 255),
                                 x=self.rectangle.xmin - 15,
                                 y=self.rectangle.ymin)
        # display info
        text = 'Spacing ' + str(self.xscale)
        self.viewer.display_text(text, color=(204, 204, 0, 255),
                                 x=10, y=self._image_dims[1]-80)

        # ---------------------------------------------------------------------
        if (self.task != 'play'):
            # draw a transparent circle around target point with variable radius
            # based on the difference z-direction
            diff_z = scale_x * abs(current_point[2]-target_point[2])

            self.viewer.draw_circle(radius=diff_z,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 0.2))
            # draw target point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 1.0))
            # display info
            color = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255)
            text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
            self.viewer.display_text(text, color=color, x=10, y=20)

        # ---------------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()

        # time.sleep(self.viz)
        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)

            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)

            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.saveGif(gifname, arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
                    act = input("select action: d (delete) / q (quit): ").lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                resolution = str(3 * self.viewer.img_width) + 'x' + str(3 * self.viewer.img_height)
                save_cmd = ['ffmpeg', '-f', 'image2', '-framerate', '30',
                            '-pattern_type', 'sequence', '-start_number', '0', '-r',
                            '6', '-i', dirname + '/%04d.png', '-s', resolution,
                            '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4']
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Example #9
0
class AgentBase(GymEnv):
    def __init__(self, agentIdent, is_train=False, auto_restart = True, **kwargs):
        # super(AgentBase, self).__init__(name='torcs')
        self.auto_restart = auto_restart
        self._isTrain = is_train
        self._agentIdent = agentIdent
        self._kwargs = kwargs
        self._init()

    def _init(self):
        logger.info("[{}]: agent init, isTrain={}".format(self._agentIdent, self._isTrain))
        self._episodeCount = -1
        from tensorpack.utils.utils import get_rng
        self._rng = get_rng(self)
        from tensorpack.utils.stats import StatCounter
        self.reset_stat()
        self.rwd_counter = StatCounter()
        self._memorySaver = None
        save_dir = self._kwargs.pop('save_dir', None)
        if save_dir is not None:
            self._memorySaver = MemorySaver(save_dir,
                                            self._kwargs.pop('max_save_item', 3),
                                            self._kwargs.pop('min_save_score', None),
                                            )
        self.restart_episode()
        pass

    def restart_episode(self):
        self.rwd_counter.reset()
        self.__ob = self.reset()

    def finish_episode(self):
        score = self.rwd_counter.sum
        self.stats['score'].append(score)
        logger.info("episode finished, rewards = {:.3f}, episode = {}, steps = {}"
                    .format(score, self._episodeCount, self._episodeSteps))

    def current_state(self):
        return self.__ob

    def reset(self):
        self._episodeCount += 1
        ret = self._reset()
        self._episodeRewards = 0.
        self._episodeSteps = 0
        if self._memorySaver:
            self._memorySaver.createMemory(self._episodeCount)
        logger.info("restart, episode={}".format(self._episodeCount))
        return ret

    @abc.abstractmethod
    def _reset(self):
        pass

    def action(self, pred):
        ob, act, r, isOver, info = self._step(pred)
        self.rwd_counter.feed(r)
        if self._memorySaver:
            self._memorySaver.addCurrent(ob, act, r, isOver)
        self.__ob = ob
        self._episodeSteps += 1
        self._episodeRewards += r
        if isOver:
            self.finish_episode()
            if self.auto_restart:
                self.restart_episode()
        return act, r, isOver

    @abc.abstractmethod
    def _step(self, action):
        raise NotImplementedError

    def get_action_space(self):
        raise NotImplementedError
Example #10
0
def eval_with_funcs(predictors,
                    nr_eval,
                    get_player_fn,
                    directory=None,
                    files_list=None):
    """
    Args:
        predictors ([PredictorBase])

    Runs episodes in parallel, returning statistics about the model performance.
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue, distErrorQueue):
            super(Worker, self).__init__()
            self._func = func
            self.q = queue
            self.q_dist = distErrorQueue

        def func(self, *args, **kwargs):
            if self.stopped():
                raise RuntimeError("stopped!")
            return self._func(*args, **kwargs)

        def run(self):
            with self.default_sess():
                player = get_player_fn(directory=directory,
                                       task=False,
                                       files_list=files_list)
                while not self.stopped():
                    try:
                        score, filename, ditance_error, q_values = play_one_episode(
                            player, self.func)
                        # print("Score, ", score)
                    except RuntimeError:
                        return
                    self.queue_put_stoppable(self.q, score)
                    self.queue_put_stoppable(self.q_dist, ditance_error)

    q = queue.Queue()
    q_dist = queue.Queue()

    threads = [Worker(f, q, q_dist) for f in predictors]

    # start all workers
    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stat = StatCounter()
    dist_stat = StatCounter()

    # show progress bar w/ tqdm
    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        r = q.get()
        stat.feed(r)
        dist = q_dist.get()
        dist_stat.feed(dist)

    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        r = q.get()
        stat.feed(r)

    while q_dist.qsize():
        dist = q_dist.get()
        dist_stat.feed(dist)

    if stat.count > 0:
        return (stat.average, stat.max, dist_stat.average, dist_stat.max)
    return (0, 0, 0, 0)
Example #11
0
    def __init__(
            self,
            directory=None,
            viz=False,
            task=False,
            files_list=None,
            observation_dims=(27, 27, 27),
            multiscale=False,  # FIXME automatic dimensions
            max_num_frames=20,
            saveGif=False,
            saveVideo=False):  # FIXME hardcoded max num frames!
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param observation_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum number of frames per episode.
        """
        super(Brain_Env, self).__init__()

        print(
            "warning! max num frames hard coded to {}!".format(max_num_frames),
            flush=True)

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.observation_dims = observation_dims
        self.dims = len(self.observation_dims)
        # multi-scale agent
        self.multiscale = multiscale
        # FIXME force multiscale false for now
        self.multiscale = False

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = observation_dims
        elif self.dims == 3:
            self.width, self.height, self.depth = observation_dims
        else:
            raise ValueError

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # TODO: understand this viz setup
            # visualization setup
            #     if isinstance(viz, six.string_types):  # check if viz is a string
            #         assert os.path.isdir(viz), viz
            #         viz = 0
            #     if isinstance(viz, int):
            #         viz = float(viz)
            self.viz = viz
        #     if self.viz and isinstance(self.viz, float):
        #         self.viewer = None
        #         self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.observation_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        self._history_length = max_num_frames
        # TODO initialize _observation_bounds limits from input image coordinates
        self._observation_bounds = ObservationBounds(0, 0, 0, 0, 0, 0)
        # add your data loader here
        # TODO: look into returnLandmarks
        # if self.task == 'play':
        #     self.files = filesListBrainMRLandmark(directory, files_list,
        #                                           returnLandmarks=False)
        # else:
        #     self.files = filesListBrainMRLandmark(directory, files_list,
        #                                           returnLandmarks=True)
        self.files = FilesListCubeNPY(directory, files_list)

        # self.files = filesListFetalUSLandmark(directory,files_list)
        # self.files = filesListCardioMRLandmark(directory,files_list)
        # prepare file sampler
        self.filepath = None
        self.file_sampler = self.files.sample_circular()  # returns generator
        # reset buffer, terminal, counters, and init new_random_game
        # we put this here so that init_player in DQN.py doesn't try to update_history
        self._clear_history()  # init arrays
        self._restart_episode()
        # self.viz = True  # FIXME viz should default False
        assert (np.shape(self._state) == self.observation_dims)
        assert np.isclose(jaccard(self.original_state, self.original_state), 1)
Example #12
0
class EnvRunner(object):
    """
    A class which is responsible for
    stepping the environment with epsilon-greedy,
    and fill the results to experience replay buffer.
    """
    def __init__(self, player, predictor, memory, history_len):
        """
        Args:
            player (gym.Env)
            predictor (callable): the model forward function which takes a
                state and returns the prediction.
            memory (ReplayMemory): the replay memory to store experience to.
            history_len (int):
        """
        self.player = player
        self.num_actions = player.action_space.n
        self.predictor = predictor
        self.memory = memory
        self.state_shape = memory.state_shape
        self.dtype = memory.dtype
        self.history_len = history_len

        self._current_episode = []
        self._current_ob = player.reset()
        self._current_game_score = StatCounter()  # store per-step reward
        self.total_scores = []  # store per-game total score

        self.rng = get_rng(self)

    def step(self, exploration):
        """
        Run the environment for one step.
        If the episode ends, store the entire episode to the replay memory.
        """
        old_s = self._current_ob
        if self.rng.rand() <= exploration:
            act = self.rng.choice(range(self.num_actions))
        else:
            history = self.recent_state()
            history.append(old_s)
            history = np.stack(history, axis=-1)  # state_shape + (Hist,)

            # assume batched network
            history = np.expand_dims(history, axis=0)
            q_values = self.predictor(history)[0][0]  # this is the bottleneck
            act = np.argmax(q_values)

        self._current_ob, reward, isOver, info = self.player.step(act)
        self._current_game_score.feed(reward)
        self._current_episode.append(Experience(old_s, act, reward, isOver))

        if isOver:
            flush_experience = True
            if 'ale.lives' in info:  # if running Atari, do something special
                if info['ale.lives'] != 0:
                    # only record score and flush experience
                    # when a whole game is over (not when an episode is over)
                    flush_experience = False
            self.player.reset()

            if flush_experience:
                self.total_scores.append(self._current_game_score.sum)
                self._current_game_score.reset()

                # Ensure that the whole episode of experience is continuous in the replay buffer
                with self.memory.writer_lock:
                    for exp in self._current_episode:
                        self.memory.append(exp)
                self._current_episode.clear()

    def recent_state(self):
        """
        Get the recent state (with stacked history) of the environment.

        Returns:
            a list of ``hist_len-1`` elements, each of shape ``self.state_shape``
        """
        expected_len = self.history_len - 1
        if len(self._current_episode) >= expected_len:
            return [k.state for k in self._current_episode[-expected_len:]]
        else:
            states = [np.zeros(self.state_shape, dtype=self.dtype)] * (expected_len - len(self._current_episode))
            states.extend([k.state for k in self._current_episode])
            return states
Example #13
0
def play_one_episode(env, func):
    env.reset()
    env.prepare()
    r = 0
    stats = [StatCounter() for _ in range(7)]
    while r == 0:
        last_cards_value = env.get_last_outcards()
        last_cards_char = to_char(last_cards_value)
        last_out_cards = Card.val2onehot60(last_cards_value)
        last_category_idx = env.get_last_outcategory_idx()
        curr_cards_char = to_char(env.get_curr_handcards())
        is_active = True if last_cards_value.size == 0 else False

        s = env.get_state_prob()
        intention, r, category_idx = env.step_auto()

        if category_idx == 14:
            continue
        minor_cards_targets = pick_minor_targets(category_idx,
                                                 to_char(intention))

        if not is_active:
            if category_idx == Category.QUADRIC.value and category_idx != last_category_idx:
                passive_decision_input = 1
                passive_bomb_input = intention[0] - 3
                passive_decision_prob, passive_bomb_prob, _, _, _, _, _ = func(
                    [
                        s.reshape(1, -1),
                        last_out_cards.reshape(1, -1),
                        np.zeros([s.shape[0]])
                    ])
                stats[0].feed(
                    int(passive_decision_input == np.argmax(
                        passive_decision_prob)))
                stats[1].feed(
                    int(passive_bomb_input == np.argmax(passive_bomb_prob)))

            else:
                if category_idx == Category.BIGBANG.value:
                    passive_decision_input = 2
                    passive_decision_prob, _, _, _, _, _, _ = func([
                        s.reshape(1, -1),
                        last_out_cards.reshape(1, -1),
                        np.zeros([s.shape[0]])
                    ])
                    stats[0].feed(
                        int(passive_decision_input == np.argmax(
                            passive_decision_prob)))
                else:
                    if category_idx != Category.EMPTY.value:
                        passive_decision_input = 3
                        # OFFSET_ONE
                        # 1st, Feb - remove relative card output since shift is hard for the network to learn
                        passive_response_input = intention[0] - 3
                        if passive_response_input < 0:
                            print("something bad happens")
                            passive_response_input = 0
                        passive_decision_prob, _, passive_response_prob, _, _, _, _ = func(
                            [
                                s.reshape(1, -1),
                                last_out_cards.reshape(1, -1),
                                np.zeros([s.shape[0]])
                            ])
                        stats[0].feed(
                            int(passive_decision_input == np.argmax(
                                passive_decision_prob)))
                        stats[2].feed(
                            int(passive_response_input == np.argmax(
                                passive_response_prob)))
                    else:
                        passive_decision_input = 0
                        passive_decision_prob, _, _, _, _, _, _ = func([
                            s.reshape(1, -1),
                            last_out_cards.reshape(1, -1),
                            np.zeros([s.shape[0]])
                        ])
                        stats[0].feed(
                            int(passive_decision_input == np.argmax(
                                passive_decision_prob)))

        else:
            seq_length = get_seq_length(category_idx, intention)

            # ACTIVE OFFSET ONE!
            active_decision_input = category_idx - 1
            active_response_input = intention[0] - 3
            _, _, _, active_decision_prob, active_response_prob, active_seq_prob, _ = func(
                [
                    s.reshape(1, -1),
                    last_out_cards.reshape(1, -1),
                    np.zeros([s.shape[0]])
                ])

            stats[3].feed(
                int(active_decision_input == np.argmax(active_decision_prob)))
            stats[4].feed(
                int(active_response_input == np.argmax(active_response_prob)))

            if seq_length is not None:
                # length offset one
                seq_length_input = seq_length - 1
                stats[5].feed(
                    int(seq_length_input == np.argmax(active_seq_prob)))

        if minor_cards_targets is not None:
            main_cards = pick_main_cards(category_idx, to_char(intention))
            handcards = curr_cards_char.copy()
            state = s.copy()
            for main_card in main_cards:
                handcards.remove(main_card)
            cards_onehot = Card.char2onehot60(main_cards)

            # we must make the order in each 4 batch correct...
            discard_onehot_from_s_60(state, cards_onehot)

            is_pair = False
            minor_type = 0
            if category_idx == Category.THREE_TWO.value or category_idx == Category.THREE_TWO_LINE.value:
                is_pair = True
                minor_type = 1
            for target in minor_cards_targets:
                target_val = Card.char2value_3_17(target) - 3
                _, _, _, _, _, _, minor_response_prob = func([
                    state.copy().reshape(1, -1),
                    last_out_cards.reshape(1, -1),
                    np.array([minor_type])
                ])
                stats[6].feed(
                    int(target_val == np.argmax(minor_response_prob)))
                cards = [target]
                handcards.remove(target)
                if is_pair:
                    if target not in handcards:
                        logger.warn('something wrong...')
                        logger.warn('minor', target)
                        logger.warn('main_cards', main_cards)
                        logger.warn('handcards', handcards)
                    else:
                        handcards.remove(target)
                        cards.append(target)

                # correct for one-hot state
                cards_onehot = Card.char2onehot60(cards)

                # print(s.shape)
                # print(cards_onehot.shape)
                discard_onehot_from_s_60(state, cards_onehot)
    return stats
Example #14
0
class AtariPlayer(RLEnvironment):
    """
    A wrapper for atari emulator.
    Will automatically restart when a real episode ends (isOver might be just
    lost of lives but not game over).
    """

    def __init__(self, rom_file, viz=0, height_range=(None, None),
                 frame_skip=4, image_shape=(84, 84), nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def current_state(self):
        """
        :returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                cv2.imshow(self.windowname, ret)
                time.sleep(self.viz)
        ret = ret[self.height_range[0]:self.height_range[1], :].astype('float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        ret = cv2.resize(ret, self.image_shape)
        return ret.astype('uint8')  # to save some memory

    def get_action_space(self):
        return DiscreteActionSpace(len(self.actions))

    def finish_episode(self):
        self.stats['score'].append(self.current_episode_score.sum)

    def restart_episode(self):
        self.current_episode_score.reset()
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def action(self, act):
        """
        :param act: an index of the action
        :returns: (reward, isOver)
        """
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        self.current_episode_score.feed(r)
        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives
        if isOver:
            self.finish_episode()
        if self.ale.game_over():
            self.restart_episode()
        return (r, isOver)
Example #15
0
    def __init__(self,
                 rom_file,
                 viz=0,
                 height_range=(None, None),
                 frame_skip=4,
                 image_shape=(84, 84),
                 nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()
Example #16
0
        if t > 10 * 60:  # eval takes too long
            self.eval_episode = int(self.eval_episode * 0.94)
        self.trainer.monitors.put_scalar('farmer_win_rate', farmer_win_rate)
        self.trainer.monitors.put_scalar('lord_win_rate', 1 - farmer_win_rate)


if __name__ == '__main__':
    # encoding = np.load('encoding.npy')
    # print(encoding.shape)
    # env = Env()
    # stat = StatCounter()
    # init_cards = np.arange(21)
    # # init_cards = np.append(init_cards[::4], init_cards[1::4])
    # for _ in range(10):
    #     fw = play_one_episode(env, lambda b: np.random.rand(1, 1, 100) if b[1][0] else np.random.rand(1, 1, 21), [100, 21])
    #     stat.feed(int(fw))
    # print('lord win rate: {}'.format(1. - stat.average))
    env = Env()
    stat = StatCounter()
    for i in range(100):
        env.reset()
        print('begin')
        env.prepare()
        r = 0
        while r == 0:
            role = env.get_role_ID()
            intention, r, _ = env.step_auto()
            # print('lord gives' if role == 2 else 'farmer gives', to_char(intention))
        stat.feed(int(r < 0))
    print(stat.average)
Example #17
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """

    def __init__(self,
                 # model,
                 agent_name,
                 state_shape,
                 num_actions,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency,
                 pipe_exp2sim, pipe_sim2exp):
        logger.info('starting expreplay {}'.format(agent_name))
        self.init_memory_size = int(init_memory_size)

        self.context = zmq.Context()
        # no reply for now
        # self.exp2sim_socket = self.context.socket(zmq.ROUTER)
        # self.exp2sim_socket.set_hwm(20)
        # self.exp2sim_socket.bind(pipe_exp2sim)

        self.sim2exp_socket = self.context.socket(zmq.PULL)
        self.sim2exp_socket.set_hwm(2)
        self.sim2exp_socket.bind(pipe_sim2exp)

        self.queue = queue.Queue(maxsize=1000)

        # self.model = model

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.agent_name = agent_name

        self.exploration = init_exploration
        self.num_actions = num_actions
        logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        # self._current_ob, self._action_space = self.get_state_and_action_spaces()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()

    def get_recv_thread(self):
        def f():
            msg = self.sim2exp_socket.recv(copy=False).bytes
            msg = loads(msg)
            print('{}: received msg'.format(self.agent_name))
            try:
                self.queue.put_nowait(msg)
            except Exception:
                logger.info('put queue failed!')
            # send response or not?

        recv_thread = LoopThread(f, pausable=False)
        # recv_thread.daemon = True
        recv_thread.name = "recv thread"
        return recv_thread

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            i = 0
            # synchronous training
            while i < self.update_frequency:
                if self._populate_exp():
                    i += 1
                    time.sleep(0.1)

            # for _ in range(self.update_frequency):
            #     self._populate_exp()
        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("{} populating replay memory with epsilon={} ...".format(self.agent_name, self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                if self._populate_exp():
                    pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""
        try:
            # do not wait for an update, this may cause some agents have old replay buffer trained more times before new buffer comes in
            state, action, reward, isOver, comb_mask, fine_mask = self.queue.get_nowait()
            self._current_game_score.feed(reward)
            # print(reward)

            if isOver:
                self._player_scores.feed(self._current_game_score.sum)
                self._current_game_score.reset()
            self.mem.append(Experience(np.stack(state), action, reward, isOver, comb_mask, np.stack(fine_mask)))
            return True
        except queue.Empty:
            return False

    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        while True:
            idx = self.rng.randint(
                self._populate_job_queue.maxsize * self.update_frequency,
                len(self.mem) - 1,
                size=self.batch_size)
            batch_exp = [self.mem.sample(i) for i in idx]

            yield self._process_batch(batch_exp)
            self._populate_job_queue.put(1)

    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='float32')
        action = np.asarray([e[1] for e in batch_exp], dtype='int32')
        reward = np.asarray([e[2] for e in batch_exp], dtype='float32')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        comb_mask = np.asarray([e[4] for e in batch_exp], dtype='bool')
        fine_mask = np.asarray([e[5] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver, comb_mask, fine_mask]

    def _setup_graph(self):
        self._recv_th = self.get_recv_thread()
        self._recv_th.start()
        # self.curr_predictor = self.trainer.get_predictor([self.agent_name + '/state:0', self.agent_name + '_comb_mask:0', self.agent_name + '/fine_mask:0'], [self.agent_name + '/Qvalue:0'])

    def _before_train(self):
        logger.info('{}-receive thread started'.format(self.agent_name))

        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

        self._init_memory()

    def _trigger(self):
        from simulator.tools import mean_score_logger
        v = self._player_scores
        try:
            mean, max = v.average, v.max
            logger.info('{} mean_score: {}'.format(self.agent_name, mean))
            mean_score_logger('{} mean_score: {}\n'.format(self.agent_name, mean))
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
        except Exception:
            logger.exception(self.agent_name + " Cannot log training scores.")
        v.reset()
Example #18
0
class AtariPlayer(RLEnvironment):
    """
    A wrapper for atari emulator.
    Will automatically restart when a real episode ends (isOver might be just
    lost of lives but not game over).
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 height_range=(None, None),
                 frame_skip=4,
                 image_shape=(84, 84),
                 nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def current_state(self):
        """
        :returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                cv2.imshow(self.windowname, ret)
                time.sleep(self.viz)
        ret = ret[self.height_range[0]:self.height_range[1], :].astype(
            'float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        ret = cv2.resize(ret, self.image_shape)
        return ret.astype('uint8')  # to save some memory

    def get_action_space(self):
        return DiscreteActionSpace(len(self.actions))

    def finish_episode(self):
        self.stats['score'].append(self.current_episode_score.sum)

    def restart_episode(self):
        self.current_episode_score.reset()
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def action(self, act):
        """
        :param act: an index of the action
        :returns: (reward, isOver)
        """
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        self.current_episode_score.feed(r)
        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives
        if isOver:
            self.finish_episode()
        if self.ale.game_over():
            self.restart_episode()
        return (r, isOver)
Example #19
0
class AtariPlayer(gym.Env):
    """
    A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.

    Info:
        score: the accumulated reward in the current game
        gameOver: True when the current game is Over
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.current_episode_score = StatCounter()

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width))
        self._restart_episode()

    def get_action_meanings(self):
        return [ACTION_MEANING[i] for i in self.actions]

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def _current_state(self):
        """
        :returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                cv2.imshow(self.windowname, ret)
                time.sleep(self.viz)
        ret = ret.astype('float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        return ret.astype('uint8')  # to save some memory

    def _restart_episode(self):
        self.current_episode_score.reset()
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def _reset(self):
        if self.ale.game_over():
            self._restart_episode()
        return self._current_state()

    def _step(self, act):
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        self.current_episode_score.feed(r)
        trueIsOver = isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {
            'score': self.current_episode_score.sum,
            'gameOver': trueIsOver
        }
        return self._current_state(), r, isOver, info
Example #20
0
def eval_child(model_cls, args, log_dir, model_dir, collect_hallu_stats=True):
    """
    Args:
        model_cls (PetridishModel) :
        args :
        log_dir (str): where to log
        model_dir (str) : where to load from
        collect_hallu_stats (bool) : whether to collect hallu stats if there are any.
    Return:
        eval_vals (list) : a list of evaluation related value.
        The first is the vaildation error on the specified validation set;
        it is followed by hallucination stats.
    """
    ckpt = tf.train.latest_checkpoint(model_dir)
    if not ckpt:
        logger.info("No model exists. Do not sort")
        return []
    args.compute_hallu_stats = True
    (model, args, ds_val, insrc_val, output_names,
     output_funcs) = get_training_params(model_cls, args, is_training=False)
    n_outputs = len(output_names)
    logger.info("{} num vals present. Will use the final perf {} as eval score".format(\
        n_outputs, output_names[-1]))
    stats_handlers = [StatCounter() for _ in range(n_outputs)]

    # additional handlers for hallucinations
    if collect_hallu_stats:
        hallu_stats_names = get_net_info_hallu_stats_output_names(
            model.net_info)
        stats_handlers.extend([StatCounter() for _ in hallu_stats_names])
        output_names.extend(hallu_stats_names)
    # Note at this point stats_handlers[n_outputs-1:] contains all
    # the value needed for evaluation.

    # batch size counter
    sample_counter = StatCounter()
    # ignore loading certain variables during inference
    ignore_names = getattr(model, 'load_ignore_var_names', [])
    pred_config = PredictConfig(model=model,
                                input_names=model._input_names,
                                output_names=output_names,
                                session_init=SaverRestore(ckpt,
                                                          ignore=ignore_names))
    predictor = OfflinePredictor(pred_config)

    # two types of input, dataflow or input_source
    if ds_val:
        gen = ds_val.get_data()
        ds_val.reset_state()
        input_sess = None
    else:
        if not insrc_val.setup_done():
            insrc_val.setup(model.get_inputs_desc())
        sess_config = get_default_sess_config()
        sess_config.device_count['GPU'] = 0
        input_tensors = insrc_val.get_input_tensors()
        sess_creater = tf.train.ChiefSessionCreator(config=sess_config)
        input_sess = tf.train.MonitoredSession(sess_creater)

        def _gen_func():
            insrc_val.reset_state()
            for _ in range(insrc_val.size()):
                yield input_sess.run(input_tensors)

        gen = _gen_func()

    for dp_idx, dp in enumerate(gen):
        output = predictor(*dp)
        batch_size = output[n_outputs - 1].shape[0]
        sample_counter.feed(batch_size)
        for o, handler in zip(output, stats_handlers):
            handler.feed(np.sum(o))
        if (args.debug_steps_per_epoch
                and dp_idx + 1 >= args.debug_steps_per_epoch):
            # stop early during debgging
            break
    eval_vals = []
    N = float(sample_counter.sum)
    for hi, handler in enumerate(stats_handlers):
        stat = handler.sum / float(N)
        logger.info('Stat {} has an avg of {}'.format(hi, stat))
        if hi < n_outputs:
            o_func = output_funcs[hi]
            if o_func is not None:
                stat = o_func(stat)
        if hi >= n_outputs - 1:
            # Note that again n_outputs - 1 is the eval val
            # followed by hallu stats.
            eval_vals.append(stat)
    if input_sess:
        input_sess.close()
    logger.info("evaluation_value={}".format(eval_vals))
    return eval_vals
Example #21
0
    def __init__(self, rom_file, viz=0, height_range=(None, None),
                 frame_skip=4, image_shape=(84, 84), nullop_start=30,
                 live_lost_as_eoe=True):
        """
        :param rom_file: path to the rom
        :param frame_skip: skip every k frames and repeat the action
        :param image_shape: (w, h)
        :param height_range: (h1, h2) to cut
        :param viz: visualization to be done.
            Set to 0 to disable.
            Set to a positive number to be the delay between frames to show.
            Set to a string to be a directory to store frames.
        :param nullop_start: start with random number of null ops
        :param live_losts_as_eoe: consider lost of lives as end of episode.  useful for training.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start
        self.height_range = height_range
        self.image_shape = image_shape

        self.current_episode_score = StatCounter()
        self.restart_episode()
Example #22
0
    def __init__(self, directory=None, viz=False, task=False, files_list=None,
                 screen_dims=(27, 27), history_length=20, multiscale=True,
                 max_num_frames=0, saveGif=False, saveVideo=False):
        """
        :param train_directory: environment or game name

        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames

        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d, w) - defaults (27, 27)

        :param nullop_start: start with random number of null ops

        :param location_history_length: consider lost of lives as end of
            episode (useful for training)

        :max_num_frames: maximum numbe0r of frames per episode.
        """

        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filenames", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5, 0.25, 0.75]
        # y = [0.5, 0.25, 0.75]
        # z = [0.5, 0.25, 0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0

        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames

        # stores information: terminal, score, distError
        self.info = None

        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo

        # training flag
        self.task = task

        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)

        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0

            if isinstance(viz, int):  # check if viz is an int number
                viz = float(viz)
            self.viz = viz

            if self.viz and isinstance(self.viz, float): # check if viz is an float number
                self.viewer = None
                self.gif_buffer = []

        # stat counter to store current score or accumulated reward
        self.current_episode_score = StatCounter()

        # get action space and minimal action set
        self.action_space = spaces.Discrete(4)  # change number actions here
        self.actions = self.action_space.n

        # 都是需要配备observation space的
        self.observation_space = spaces.Box(low=0, high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)

        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [(0,) * self.dims] * self._history_length
        self._qvalues_history = [(0,) * self.actions] * self._history_length

        # initialize rectangle limits from input image coordinates
        self.rectangle = Rectangle(0, 0, 0, 0)

        # add your data loader here
        # play is returnLandmarks = False and returnLandmarks = True
        if self.task == 'play':
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=False)
            print("self.files:", self.files)
        else:
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=True)


        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular()

        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()
class SoccerPlayer(RLEnvironment):
    """
    A wrapper for pygame_soccer emulator.
    Will automatically restart when a real episode ends (isOver might be just
    lost of lives but not game over).
    """
    SOCCER_WIDTH = 288
    SOCCER_HEIGHT = 192

    def __init__(self,
                 viz=0,
                 field=None,
                 partial=False,
                 radius=2,
                 frame_skip=4,
                 image_shape=(84, 84),
                 mode=None,
                 team_size=1,
                 ai_frame_skip=1,
                 raw_env=soccer_environment.SoccerEnvironment):
        super(SoccerPlayer, self).__init__()

        if team_size > 1 and mode != None:
            self.mode = mode.split(',')
        else:
            self.mode = [mode]
        self.field = field
        self.partial = partial
        self.viz = viz
        if self.viz:
            self.renderer_options = soccer_renderer.RendererOptions(
                show_display=True, max_fps=10, enable_key_events=True)
        else:
            self.renderer_options = None

        if self.field == 'large':
            map_path = file_util.resolve_path(__file__,
                                              '../data/map/soccer_large.tmx')
        else:
            map_path = None

        self.team_size = team_size
        self.env_options = soccer_environment.SoccerEnvironmentOptions(
            team_size=self.team_size,
            map_path=map_path,
            ai_frame_skip=ai_frame_skip)
        self.env = raw_env(env_options=self.env_options,
                           renderer_options=self.renderer_options)

        self.computer_team_name = self.env.team_names[1]
        self.player_team_name = self.env.team_names[0]

        # Partial
        if self.partial:
            self.radius = radius
            self.player_agent_index = self.env.get_agent_index(
                self.player_team_name, 0)

        self.actions = self.env.actions
        self.frame_skip = frame_skip
        self.image_shape = image_shape

        self.last_info = {}
        self.agent_actions = ['STAND'] * (self.team_size * 2)
        self.changing_counter = 0
        self.timestep = 0
        self.current_episode_score = StatCounter()
        self.restart_episode()

    def _grab_raw_image(self):
        self.env.render()
        if self.partial:
            screenshot = self.env.renderer.get_po_screenshot(
                self.player_agent_index, self.radius)
        else:
            screenshot = self.env.renderer.get_screenshot()
        return screenshot

    def _get_computer_actions(self):
        # Collaborator
        for i in range(self.team_size):
            index = self.env.get_agent_index(self.player_team_name, i)
            action = self.env.state.get_agent_action(index)
            self.agent_actions[self.team_size * 0 + i] = action
        # Opponent
        for i in range(self.team_size):
            index = self.env.get_agent_index(self.computer_team_name, i)
            action = self.env.state.get_agent_action(index)
            self.agent_actions[self.team_size * 1 + i] = action
        return np.asarray([
            self.env.actions.index(act if act else 'STAND')
            for act in self.agent_actions
        ])

    def _set_opponent_mode(self, mode):
        for i in range(self.team_size):
            index = self.env.get_agent_index(self.computer_team_name, i)
            m = mode[i]
            self.env.state.set_agent_mode(index, m)

    def _set_collaborator_mode(self, mode):
        for i in range(1, self.team_size):
            index = self.env.get_agent_index(self.player_team_name, i)
            m = mode[i - 1]
            self.env.state.set_agent_mode(index, m)

    def _set_computer_mode(self, mode):
        if mode[0] == None or len(mode) < self.team_size * 2 - 1:
            return
        if mode[0] in ['OFFENVIE', 'DFFENSIVE']:
            # Collaborator
            if self.team_size >= 2:
                self._set_collaborator_mode(mode[:(self.team_size - 1)])
            # Opponent
            self._set_opponent_mode(mode[(self.team_size - 1):])

    def current_state(self):
        ret = self._grab_raw_image()
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        ret = cv2.resize(ret, self.image_shape)
        return ret.astype('uint8')  # to save some memory

    def get_action_space(self):
        return DiscreteActionSpace(len(self.actions))

    def finish_episode(self):
        self.stats['score'].append(self.current_episode_score.sum)

    def restart_episode(self):
        self.current_episode_score.reset()
        self.env.reset()
        self._set_computer_mode(self.mode)
        self.last_raw_screen = self._grab_raw_image()
        self.changing_counter = 0
        self.timestep = 0

    def action(self, act):
        ball_pos_agent_old = self.env.state.get_ball_possession()
        r = 0
        ball_poss_old = self.env.state.get_ball_possession()['team_name']
        for k in range(self.frame_skip):
            self.timestep += 1
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()

            if self.mode[0] == 'WEAKCOOP':
                actions = {}
                for team_name in self.env.team_names:
                    for team_agent_index in range(self.env.options.team_size):
                        agent_index = self.env.get_agent_index(
                            team_name, team_agent_index)
                        agent_action = self.env._get_ai_action(
                            team_name, team_agent_index)
                        print(team_name +
                              self.env.state.get_agent_mode(agent_index))
                        actions[agent_index] = agent_action
                player_index = self.env.get_agent_index(
                    self.player_team_name, 0)
                coop_index = self.env.get_agent_index(self.player_team_name, 1)

                actions[player_index] = self.env.actions[act]
                if random.random() < 0.5:
                    actions[coop_index] = random.choice(self.env.actions)
                ret = self.env.take_action(actions)

            elif self.mode[0] == 'ALL_RANDOM':
                if self.team_size == 1:
                    player_index = self.env.get_agent_index(
                        self.player_team_name, 0)
                    opponent_index = self.env.get_agent_index(
                        self.computer_team_name, 0)
                    actions = {
                        player_index: self.env.actions[act],
                        opponent_index: random.choice(self.env.actions)
                    }
                else:
                    actions = {}
                    for team_name in [
                            self.player_team_name, self.computer_team_name
                    ]:
                        for team_index in range(self.team_size):
                            agent_index = self.env.get_agent_index(
                                team_name, team_index)
                            actions[agent_index] = random.choice(
                                self.env.actions)
                    player_index = self.env.get_agent_index(
                        self.player_team_name, 0)
                    actions[player_index] = self.env.actions[act]
                ret = self.env.take_action(actions)
            # else:
            # print(self.env.actions[act])
            # ret = self.env.take_action(self.env.actions[act])

            else:
                if self.mode[0] == 'OPPONENT_DYNAMIC':
                    choices = ['OFFENSIVE', 'DEFENSIVE']
                    if self.timestep % random.randint(4, 10) == 0:
                        new_modes = [
                            random.choice(choices)
                            for i in range(self.team_size)
                        ]
                        self._set_opponent_mode(new_modes)

                if self.mode[0] == 'COOP_DYNAMIC':
                    choices = ['OFFENSIVE', 'DEFENSIVE']
                    if self.timestep % random.randint(4, 10) == 0:
                        new_modes = [
                            random.choice(choices)
                            for i in range(self.team_size - 1)
                        ]
                        self._set_collaborator_mode(new_modes)

                actions = {}
                for team_name in self.env.team_names:
                    for team_agent_index in range(self.env.options.team_size):
                        agent_index = self.env.get_agent_index(
                            team_name, team_agent_index)
                        agent_action = self.env._get_ai_action(
                            team_name, team_agent_index)
                        # print(team_name + self.env.state.get_agent_mode(agent_index))
                        actions[agent_index] = agent_action
                player_index = self.env.get_agent_index(
                    self.player_team_name, 0)
                actions[player_index] = self.env.actions[act]
                ret = self.env.take_action(actions)

            if k == 0:
                self.last_info['agent_actions'] = self._get_computer_actions()
            r += ret.reward

            if self.env.state.is_terminal():
                break

        self.current_episode_score.feed(r)
        isOver = self.env.state.is_terminal()
        ball_pos_agent_new = self.env.state.get_ball_possession()

        if ball_pos_agent_old['team_name'] == ball_pos_agent_new[
                'team_name'] and ball_pos_agent_new['team_name'] == 'PLAYER':
            if ball_pos_agent_old['team_agent_index'] != ball_pos_agent_new[
                    'team_agent_index']:
                self.changing_counter += 1

        if isOver:
            self.finish_episode()
            self.restart_episode()
        return (r, isOver)

    def get_internal_state(self):
        return self.last_info

    def get_changing_counter(self):
        return self.changing_counter
Example #24
0
def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
    """
    Args:
        predictors ([PredictorBase])
    """
    class Worker(StoppableThread, ShareSessionThread):
        def __init__(self, func, queue):
            super(Worker, self).__init__()
            self._func = func
            self.q = queue

        def func(self, *args, **kwargs):
            if self.stopped():
                raise RuntimeError("stopped!")
            return self._func(*args, **kwargs)

        def run(self):
            with self.default_sess():
                player = get_player_fn()
                while not self.stopped():
                    try:
                        stats = play_one_episode(player, self.func)
                    except RuntimeError:
                        return
                    scores = [
                        stat.average if stat.count > 0 else -1
                        for stat in stats
                    ]
                    self.queue_put_stoppable(self.q, scores)

    q = queue.Queue()
    threads = [Worker(f, q) for f in predictors]

    for k in threads:
        k.start()
        time.sleep(0.1)  # avoid simulator bugs
    stats = [StatCounter() for _ in range(7)]

    def fetch():
        scores = q.get()
        for i, score in enumerate(scores):
            if scores[i] >= 0:
                stats[i].feed(scores[i])
        accs = [stat.average if stat.count > 0 else 0 for stat in stats]
        if verbose:
            logger.info("passive decision accuracy: {}\n"
                        "passive bomb accuracy: {}\n"
                        "passive response accuracy: {}\n"
                        "active decision accuracy: {}\n"
                        "active response accuracy: {}\n"
                        "active sequence accuracy: {}\n"
                        "minor response accuracy: {}\n".format(
                            accs[0], accs[1], accs[2], accs[3], accs[4],
                            accs[5], accs[6]))

    for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
        fetch()
    logger.info("Waiting for all the workers to finish the last run...")
    for k in threads:
        k.stop()
    for k in threads:
        k.join()
    while q.qsize():
        fetch()
    accs = [stat.average if stat.count > 0 else 0 for stat in stats]
    return accs
    def __init__(self,
                 directory=None,
                 files_list=None,
                 viz=False,
                 train=False,
                 screen_dims=(27, 27, 27),
                 spacing=(1, 1, 1),
                 nullop_start=30,
                 history_length=30,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param history_length: consider lost of lives as end of
            episode (useful for training)
        """
        super(MedicalPlayer, self).__init__()

        self.reset_stat()
        self._supervised = False
        self._init_action_angle_step = 8
        self._init_action_dist_step = 4

        #######################################################################
        ## save results in csv file
        self.csvfile = 'dummy.csv'
        if not train:
            with open(self.csvfile, 'w') as outcsv:
                fields = ["filename", "dist_error", "angle_error"]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))
        #######################################################################

        # read files from directory - add your data loader here
        self.files = filesListCardioMRPlane(directory, files_list)

        # prepare file sampler
        self.sampled_files = self.files.sample_circular()
        self.filepath = None
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.train = train
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self._plane_size = screen_dims
        self.dims = len(self.screen_dims)
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims
        # plane sampling spacings
        self.init_spacing = np.array(spacing)
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(8)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims)
        # history buffer for storing last locations to check oscillations
        self._history_length = history_length
        # circular buffer to store plane parameters history [4,history_length]
        self._plane_history = deque(maxlen=self._history_length)
        self._bestq_history = deque(maxlen=self._history_length)
        self._dist_history = deque(maxlen=self._history_length)
        self._dist_history_params = deque(maxlen=self._history_length)
        self._dist_supervised_history = deque(maxlen=self._history_length)
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        self._qvalues = [
            0,
        ] * self.actions

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []

        self._restart_episode()
class Unity3DPlayer(RLEnvironment):
    '''
    ACTION_TABLE = [(0.5, 0.0), # Forward
                    (-0.5, 0.0), # Backward
                    (0.5, 1.0), # Forward-Right
                    (-0.5, 1.0), # Backward-Right
                    (0.5, -1.0), # Forward-Left
                    (-0.5, -1.0) ] # Backward-Left 
    '''
    ACTION_TABLE = [(2.0 * ACTION_SCALE, 0.0 * ACTION_SCALE),
                    (2.0 * ACTION_SCALE, 0.5 * ACTION_SCALE),
                    (2.0 * ACTION_SCALE, -0.5 * ACTION_SCALE)]

    def __init__(self,
                 connection,
                 skip=1,
                 dumpdir=None,
                 viz=False,
                 auto_restart=True):
        if connection != None:
            with _ENV_LOCK:
                self.gymenv = Unity3DEnvironment(server_address=connection)
            self.use_dir = dumpdir
            self.skip = skip
            self.reset_stat()
            self.rwd_counter = StatCounter()
            self.restart_episode()
            self.auto_restart = auto_restart
            self.viz = viz
        self.connection = connection

    def restart_episode(self):
        self.rwd_counter.reset()
        self.rwd_counter.feed(0)
        self._ob = self.gymenv.reset()

    def finish_episode(self):
        self.stats['score'].append(self.rwd_counter.sum)

    def current_state(self):
        if self.viz:
            self.gymenv.render()
            time.sleep(self.viz)
        cv2.imwrite('state_%04d.png' % self.connection[1], self._ob)
        return self._ob

    def action(self, act):
        env_act = self.ACTION_TABLE[act]
        for i in range(self.skip):
            self._ob, r, isOver, info = self.gymenv.step(env_act)
            if r > 0:
                r = 0.0
            if r < 0.0:
                isOver = True
            if isOver:
                break
        self.rwd_counter.feed(r)
        if isOver:
            self.finish_episode()
            if self.auto_restart:
                self.restart_episode()
        return r, isOver

    def get_action_space(self):
        return DiscreteActionSpace(len(self.ACTION_TABLE))

    def close(self):
        self.gymenv.close()
Example #27
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.

    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).

    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """

    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            for _ in range(self.update_frequency):
                self._populate_exp()
        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()

    # quickly fill the memory for debug
    def _fake_init_memory(self):
        from copy import deepcopy
        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < 5:
                self._populate_exp()
                pbar.update()
            while len(self.mem) < self.init_memory_size:
                self.mem.append(deepcopy(self.mem._hist[0]))
                pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""
        old_s = self._current_ob
        if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
            act = self.rng.choice(range(self.num_actions))
        else:
            # build a history state
            history = self.mem.recent_state()
            history.append(old_s)
            history = np.stack(history, axis=2)

            # assume batched network
            q_values = self.predictor(history[None, :, :, :])[0][0]  # this is the bottleneck
            act = np.argmax(q_values)
        self._current_ob, reward, isOver, info = self.player.step(act)
        self._current_game_score.feed(reward)
        if isOver:
            if info['ale.lives'] == 0:  # only record score when a whole game is over (not when an episode is over)
                self._player_scores.feed(self._current_game_score.sum)
                self._current_game_score.reset()
            self.player.reset()
        self.mem.append(Experience(old_s, act, reward, isOver))

    def _debug_sample(self, sample):
        import cv2

        def view_state(comb_state):
            state = comb_state[:, :, :-1]
            next_state = comb_state[:, :, 1:]
            r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1)
            r2 = np.concatenate([next_state[:, :, k] for k in range(self.history_len)], axis=1)
            r = np.concatenate([r, r2], axis=0)
            cv2.imshow("state", r)
            cv2.waitKey()
        print("Act: ", sample[2], " reward:", sample[1], " isOver: ", sample[3])
        if sample[1] or sample[3]:
            view_state(sample[0])

    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
        reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
        action = np.asarray([e[2] for e in batch_exp], dtype='int8')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver]

    # DataFlow method:
    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        while True:
            idx = self.rng.randint(
                self._populate_job_queue.maxsize * self.update_frequency,
                len(self.mem) - self.history_len - 1,
                size=self.batch_size)
            batch_exp = [self.mem.sample(i) for i in idx]

            yield self._process_batch(batch_exp)
            self._populate_job_queue.put(1)

    # Callback methods:
    def _setup_graph(self):
        self.predictor = self.trainer.get_predictor(*self.predictor_io_names)

    def _before_train(self):
        self._init_memory()
        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

    def _trigger(self):
        v = self._player_scores
        try:
            mean, max = v.average, v.max
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
        except Exception:
            logger.exception("Cannot log training scores.")
        v.reset()
Example #28
0
class ThorPlayer(RLEnvironment):
  """
  a wrapper for Thor environment.
  """
  def __init__(self, exe_path, json_path, actions=ACTIONS, height=HEIGHT, width=WIDTH, gray=False, record=False):
    super(ThorPlayer, self).__init__()
    assert os.path.isfile(exe_path), 'wrong path of executable binary for Thor'
    assert os.path.isfile(json_path), 'wrong path of target json file'
    self.height = height
    self.width = width
    self.gray = gray
    self.record = record
    # set Thor controller
    self.env = robosims.controller.ChallengeController(
                unity_path=exe_path,
                height=self.height,
                width=self.width,
                record_actions=self.record)
    
    # read targets from the json file
    with open(json_path) as f:
      self.targets = json.loads(f.read())
    self.num_targets = len(self.targets)
    
    self.rng = get_rng(self)
    self.actions = actions
    self.current_episode_score = StatCounter()
    self.env.start()
    self.restart_episode()

  def current_state(self):
    # image of current state, numpy array of (h, w, 3) in RGB order
    img = self.env.last_event.frame
    success = self.env.last_event.metadata['lastActionSuccess']
    found = self.env.target_found()
    if self.gray:
      img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)]
    
    return img, success, found

  def get_action_space(self):
    return DiscreteActionSpace(len(self.actions))

  def next_target(self):
    idx = self.rng.choice(range(self.num_targets))
    return self.targets[idx]

  def restart_episode(self):
    """
    reset the episode counter and
    initialize the env by a random selected target
    """
    self.current_episode_score.reset()
    target = self.next_target()
    self.env.initialize_target(target)

  def action(self, act):
    """
    Perform an action.
    Will automatically start a new episode if isOver
    """
    r = 0.0
    isOver = False
    event = self.env.step(action=dict(action=self.actions[act]))
    if not event.metadata['lastActionSuccess']:
      r -= 0.01
    if self.env.target_found():
      r += 100.0
      isOver = True
    self.current_episode_score.feed(r)
    if isOver:
      self.restart_episode()
    return (r, isOver)
Example #29
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.

    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).

    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """

    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency, history_len,
                 arg_type=None):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)


        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        ###############################################################################
        # HITL UPDATE
        self.hmem_full = False
        if self.update_frequency < 4:
            self.hmem = HumanDemReplayMemory(memory_size, state_shape, history_len, arg_type=arg_type)
            self.hmem.load_experience()
            self.hmem_full = True
            logger.info("HITL buffer full")

        ###############################################################################
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()
        self._player_distError = StatCounter()

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            ###############################################################################
            # HITL UPDATE
            # as self.update_frequency = 0 during pretraining, no workers will be initialized.
            ###############################################################################
            #logger.info("update_frequency: {}".format(self.update_frequency))

            for _ in range(int(self.update_frequency)):
                self._populate_exp()

        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()

    # quickly fill the memory for debug
    def _fake_init_memory(self):
        from copy import deepcopy
        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < 5:
                self._populate_exp()
                pbar.update()
            while len(self.mem) < self.init_memory_size:
                self.mem.append(deepcopy(self.mem._hist[0]))
                pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""


        old_s = self._current_ob

        # initialize q_values to zeros
        q_values = [0, ] * self.num_actions

        if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
            act = self.rng.choice(range(self.num_actions))
        else:
            # build a history state
            history = self.mem.recent_state()
            history.append(old_s)
            if np.ndim(history) == 4:  # 3d states
                history = np.stack(history, axis=3)
                # assume batched network - this is the bottleneck
                q_values = self.predictor(history[None, :, :, :, :])[0][0]
            else:
                history = np.stack(history, axis=2)
                # assume batched network - this is the bottleneck
                q_values = self.predictor(history[None, :, :, :])[0][0]

            act = np.argmax(q_values)

        self._current_ob, reward, isOver, info = self.player.step(act, q_values)

        if isOver:
            # if info['gameOver']:  # only record score when a whole game is over (not when an episode is over)
            #     self._player_scores.feed(info['score'])
            self._player_scores.feed(info['score'])
            self._player_distError.feed(info['distError'])
            self.player.reset()
        # As generated by AI human = False
        self.mem.append(Experience(old_s, act, reward, isOver, False))

    def _debug_sample(self, sample):
        import cv2

        def view_state(comb_state):
            state = comb_state[:, :, :-1]
            next_state = comb_state[:, :, 1:]
            r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1)
            r2 = np.concatenate([next_state[:, :, k] for k in range(self.history_len)], axis=1)
            r = np.concatenate([r, r2], axis=0)
            cv2.imshow("state", r)
            cv2.waitKey()

        print("Act: ", sample[2], " reward:", sample[1], " isOver: ", sample[3])
        if sample[1] or sample[3]:
            view_state(sample[0])

    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        ###############################################################################
        # HITL UPDATE
        # if self.update_frequency == 0:
        #     logger.info("logging update freq ...".format(self.update_frequency))
        while True:
            # Pretraining only sampling from HITL buffer
            if self.update_frequency == 0:
                idx = self.rng.randint(
                    self._populate_job_queue.maxsize * 4,
                    len(self.hmem)- self.history_len - 1,
                    size=self.batch_size)
                batch_exp = [self.hmem.sample(i) for i in idx]

                yield self._process_batch(batch_exp)
                logger.info("Human batch ...")
                self._populate_job_queue.put(1)
            # After pretraining sampling from both HITL and agent buffer
            elif self.hmem_full == True:
                ex_idx = self.rng.randint(
                    self._populate_job_queue.maxsize * self.update_frequency,
                    len(self.mem) - self.history_len - 1,
                    size=38)    #38
                hu_idx = self.rng.randint(
                    self._populate_job_queue.maxsize * 4,
                    len(self.hmem)- self.history_len - 1,
                    size=10)    #10


                batch_exp = [self.mem.sample(i) for i in ex_idx]
                for j in hu_idx:
                    batch_exp.append(self.hmem.sample(j))

                yield self._process_batch(batch_exp)
                logger.info("Mixed batch 0.8agent 0.2human ...")
                self._populate_job_queue.put(1)
            # HITL not implemented therefore only sample from agent buffer
            else:
                idx = self.rng.randint(
                    self._populate_job_queue.maxsize * self.update_frequency,
                    len(self.mem) - self.history_len - 1,
                    size=self.batch_size)
                batch_exp = [self.mem.sample(i) for i in idx]

                yield self._process_batch(batch_exp)
                self._populate_job_queue.put(1)





    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
        reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
        action = np.asarray([e[2] for e in batch_exp], dtype='int8')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        human = np.asarray([e[4] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver, human]

    def _setup_graph(self):
        self.predictor = self.trainer.get_predictor(*self.predictor_io_names)

    def _before_train(self):
        self._init_memory()
        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

    def _trigger(self):
        # log player statistics in training
        v = self._player_scores
        dist = self._player_distError
        try:
            mean, max = v.average, v.max
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
            mean, max = dist.average, dist.max
            self.trainer.monitors.put_scalar('expreplay/mean_dist', mean)
            self.trainer.monitors.put_scalar('expreplay/max_dist', max)
        except Exception:
            logger.exception("Cannot log training scores.")
        v.reset()
        dist.reset()

        # monitor number of played games and successes of reaching the target
        if self.player.num_games.count:
            self.trainer.monitors.put_scalar('n_games',
                                             np.asscalar(self.player.num_games.sum))
        else:
            self.trainer.monitors.put_scalar('n_games', 0)

        if self.player.num_success.count:
            self.trainer.monitors.put_scalar('n_success',
                                             np.asscalar(self.player.num_success.sum))
            self.trainer.monitors.put_scalar('n_success_ratio',
                                             self.player.num_success.sum / self.player.num_games.sum)
        else:
            self.trainer.monitors.put_scalar('n_success', 0)
            self.trainer.monitors.put_scalar('n_success_ratio', 0)
        # reset stats
        self.player.reset_stat()
class MedicalPlayer(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(self,
                 directory=None,
                 files_list=None,
                 viz=False,
                 train=False,
                 screen_dims=(27, 27, 27),
                 spacing=(1, 1, 1),
                 nullop_start=30,
                 history_length=30,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param history_length: consider lost of lives as end of
            episode (useful for training)
        """
        super(MedicalPlayer, self).__init__()

        self.reset_stat()
        self._supervised = False
        self._init_action_angle_step = 8
        self._init_action_dist_step = 4

        #######################################################################
        ## save results in csv file
        self.csvfile = 'dummy.csv'
        if not train:
            with open(self.csvfile, 'w') as outcsv:
                fields = ["filename", "dist_error", "angle_error"]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))
        #######################################################################

        # read files from directory - add your data loader here
        self.files = filesListCardioMRPlane(directory, files_list)

        # prepare file sampler
        self.sampled_files = self.files.sample_circular()
        self.filepath = None
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.train = train
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self._plane_size = screen_dims
        self.dims = len(self.screen_dims)
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims
        # plane sampling spacings
        self.init_spacing = np.array(spacing)
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(8)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims)
        # history buffer for storing last locations to check oscillations
        self._history_length = history_length
        # circular buffer to store plane parameters history [4,history_length]
        self._plane_history = deque(maxlen=self._history_length)
        self._bestq_history = deque(maxlen=self._history_length)
        self._dist_history = deque(maxlen=self._history_length)
        self._dist_history_params = deque(maxlen=self._history_length)
        self._dist_supervised_history = deque(maxlen=self._history_length)
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        self._qvalues = [
            0,
        ] * self.actions

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []

        self._restart_episode()

    # -------------------------------------------------------------------------

    def _reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
        restart current episoide
        """
        self.terminal = False
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset score stat counter
        self._plane_history.clear()
        self._bestq_history.clear()
        self._dist_history.clear()
        self._dist_history_params.clear()
        self._dist_supervised_history.clear()
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length

        self.new_random_game()

    # -------------------------------------------------------------------------

    def new_random_game(self):
        # print('\n============== new game ===============\n')
        self.terminal = False
        self.viewer = None

        # sample a new image
        (self.sitk_image, self.sitk_image_2ch, self.sitk_image_4ch,
         self.landmarks, self.filepath) = next(self.sampled_files)
        self.filename = os.path.basename(self.filepath)
        # image volume size
        self._image_dims = self.sitk_image.GetSize()
        self.action_angle_step = copy.deepcopy(self._init_action_angle_step)
        self.action_dist_step = copy.deepcopy(self._init_action_dist_step)
        self.spacing = self.init_spacing.copy()

        # find center point of the initial plane
        if self.train:
            # sample randomly ±10% around the center point
            skip_thickness = ((int)(self._image_dims[0] / 2.5),
                              (int)(self._image_dims[1] / 2.5),
                              (int)(self._image_dims[2] / 2.5))
            x = self.rng.randint(0 + skip_thickness[0],
                                 self._image_dims[0] - skip_thickness[0])
            y = self.rng.randint(0 + skip_thickness[1],
                                 self._image_dims[1] - skip_thickness[1])
            z = self.rng.randint(0 + skip_thickness[2],
                                 self._image_dims[2] - skip_thickness[2])
        else:
            # during testing start sample a plane around the center point
            x, y, z = (self._image_dims[0] / 2, self._image_dims[1] / 2,
                       self._image_dims[2] / 2)
        self._origin3d_point = (int(x), int(y), int(z))
        # Get ground truth plane
        # logger.info('filename {} '.format(self.filename))
        self._groundTruth_plane = Plane(
            *getGroundTruthPlane(self.sitk_image,
                                 self.sitk_image_4ch,
                                 self._origin3d_point,
                                 self._plane_size,
                                 spacing=self.spacing))
        # get an istropic 1mm groundtruth plane
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        self.groundTruth_plane_iso = Plane(
            *getGroundTruthPlane(self.sitk_image, self.sitk_image_4ch,
                                 self._origin3d_point, image_size, [1, 1, 1]))

        self.landmarks_gt, _ = zip(*[
            projectPointOnPlane(point, self._groundTruth_plane.norm,
                                self._groundTruth_plane.origin)
            for point in self.landmarks
        ])
        # logger.info('groundTruth {}'.format(self._groundTruth_plane.params))
        # Get initial plane and set current plane the same
        self._plane = self._init_plane = Plane(
            *getInitialPlane(self.sitk_image, self._plane_size,
                             self._origin3d_point, self.spacing))
        _, dist = zip(*[
            projectPointOnPlane(point, self._plane.norm, self._plane.origin)
            for point in self.landmarks_gt
        ])
        # calculate current distance between initial and ground truth planes
        # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
        #                                      self._plane.points)
        self.cur_dist = np.mean(np.abs(dist))
        self.cur_dist_params = calcDistTwoParams(
            self._groundTruth_plane.params,
            self._plane.params,
            scale_angle=self.action_angle_step,
            scale_dist=self.action_dist_step)

        self._screen = self._current_state()

    # -------------------------------------------------------------------------

    def step(self, act, qvalues):
        """The environment's step function returns exactly what we need.
        Args:
          action:
        Returns:
          observation (object):
            an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning.
        """
        self.terminal = False
        self._qvalues = qvalues
        # get current plane params
        current_plane_params = np.copy(self._plane.params)
        next_plane_params = current_plane_params.copy()
        # ---------------------------------------------------------------------
        # theta x+ (param a)
        if (act == 0): next_plane_params[0] += self.action_angle_step
        # theta y+ (param b)
        if (act == 1): next_plane_params[1] += self.action_angle_step
        # theta z+ (param c)
        if (act == 2): next_plane_params[2] += self.action_angle_step
        # dist d+
        if (act == 3): next_plane_params[3] += self.action_dist_step

        # theta x- (param a)
        if (act == 4): next_plane_params[0] -= self.action_angle_step
        # theta y- (param b)
        if (act == 5): next_plane_params[1] -= self.action_angle_step
        # theta z- (param c)
        if (act == 6): next_plane_params[2] -= self.action_angle_step
        # dist d-
        if (act == 7): next_plane_params[3] -= self.action_dist_step
        # ---------------------------------------------------------------------
        # self.reward = self._calc_reward_points(self._plane.points,
        #                                        next_plane.points)
        self.reward = self._calc_reward_params(current_plane_params,
                                               next_plane_params)
        # threshold reward between -1 and 1
        self.reward = np.sign(self.reward)
        go_out = False

        if self._supervised and self.train:
            ## supervised
            dist_queue = deque(maxlen=self.actions)
            plane_params_queue = deque(maxlen=self.actions)
            # theta x+ (param a)
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[0] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta y+ (param b) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[1] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta z+ (param c) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[2] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # dist d+ ---------------------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[3] += self.action_dist_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta x- (param a) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[0] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta y- (param b) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[1] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta z- (param c) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[2] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # dist d- ---------------------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[3] -= self.action_dist_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))

            # -----------------------------------------------------------------
            # get best plane based on lowest distance to the target
            next_plane_idx = np.argmin(dist_queue)
            next_plane = Plane(*getPlane(self.sitk_image,
                                         self._origin3d_point,
                                         plane_params_queue[next_plane_idx],
                                         self._plane_size,
                                         spacing=self.spacing))
            self._dist_supervised_history.append(np.min(dist_queue))

        else:
            ## unsupervised or testing
            # get the new plane using new params result from taking the action
            next_plane = Plane(*getPlane(self.sitk_image,
                                         self._origin3d_point,
                                         next_plane_params,
                                         self._plane_size,
                                         spacing=self.spacing))
            # -----------------------------------------------------------------
            # check if the screen is not full of zeros (background)
            go_out = checkBackgroundRatio(next_plane,
                                          min_pixel_val=0.5,
                                          ratio=0.8)
            # also check if go out (sampling from outside the volume)
            # by checking if the new origin
            if not go_out:
                go_out = checkOriginLocation(self.sitk_image,
                                             next_plane.origin)
            # also check if plane parameters got very high
            if not go_out:
                go_out = checkParamsBound(next_plane.params,
                                          self._groundTruth_plane.params)
            # punish lowest reward if the agent tries to go out and keep same plane
            if go_out:
                self.reward = -1  # lowest possible reward
                next_plane = copy.deepcopy(self._plane)
                if self.train: self.terminal = True  # end episode and restart

        # ---------------------------------------------------------------------
        # update current plane
        self._plane = copy.deepcopy(next_plane)
        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames: self.terminal = True
        # check oscillation and reduce action step or terminate if minimum
        if self._oscillate:
            if self.train and self._supervised:
                self._plane = self.getBestPlaneTrain()
            else:
                self._plane = self.getBestPlane()

            # find distance metrics
            # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, self._plane.points)
            _, dist = zip(*[
                projectPointOnPlane(point, self._plane.norm,
                                    self._plane.origin)
                for point in self.landmarks_gt
            ])
            self.cur_dist = np.max(np.abs(dist))
            self._update_heirarchical()
            self._clear_history()
            # terminate if distance steps are less than 1
            if self.action_dist_step < 1: self.terminal = True

        # ---------------------------------------------------------------------
        # find distance error
        _, dist = zip(*[
            projectPointOnPlane(point, self._plane.norm, self._plane.origin)
            for point in self.landmarks_gt
        ])
        self.cur_dist = np.mean(np.abs(dist))
        self.cur_dist_params = calcDistTwoParams(
            self._groundTruth_plane.params,
            self._plane.params,
            scale_angle=self.action_angle_step,
            scale_dist=self.action_dist_step)
        self.current_episode_score.feed(self.reward)
        self._update_history()  # store results in memory

        # terminate if distance between params are low during training
        if self.train and (self.cur_dist_params <= 1):
            self.terminal = True
            self.num_success.feed(1)

        # ---------------------------------------------------------------------

        # # supervised reward (for debuging)
        # reward_supervised = self._calc_reward_params(current_plane_params,
        #                                        self._plane.params)
        # # threshold reward between -1 and 1
        # self.reward = np.sign(np.around(reward_supervised,decimals=1))

        # ---------------------------------------------------------------------

        # render screen if viz is on
        if self.viz:
            if isinstance(self.viz, float):
                self.display()

        A = normalizeUnitVector(self._groundTruth_plane.norm)
        B = normalizeUnitVector(self._plane.norm)
        angle_between_norms = np.rad2deg(np.arccos(A.dot(B)))

        info = {
            'score': self.current_episode_score.sum,
            'gameOver': self.terminal,
            'distError': self.cur_dist,
            'distAngle': angle_between_norms,
            'filename': self.filename
        }

        if self.terminal:
            with open(self.csvfile, 'a') as outcsv:
                fields = [self.filename, self.cur_dist, angle_between_norms]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))

        return self._current_state(), self.reward, self.terminal, info

    # -------------------------------------------------------------------------

    def _update_heirarchical(self):
        self.action_angle_step = int(self.action_angle_step / 2)
        self.action_dist_step = self.action_dist_step - 1
        # self.spacing -= 1
        if (self.spacing[0] > 1): self.spacing -= 1
        self._groundTruth_plane = Plane(
            *getGroundTruthPlane(self.sitk_image,
                                 self.sitk_image_4ch,
                                 self._origin3d_point,
                                 self._plane_size,
                                 spacing=self.spacing))

    def getBestPlane(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        best_idx = np.argmin(self._bestq_history)
        # best_idx = np.argmax(self._bestq_history)
        return self._plane_history[best_idx]

    def getBestPlaneTrain(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        best_idx = np.argmin(self._dist_supervised_history)
        # best_idx = np.argmax(self._bestq_history)
        return self._plane_history[best_idx]

    def _current_state(self):
        """
        :returns: a gray-scale (h, w, d) float ###uint8 image
        """
        return self._plane.grid_smooth

    def _clear_history(self):
        self._plane_history.clear()
        self._bestq_history.clear()
        self._dist_history.clear()
        self._dist_history_params.clear()
        self._dist_supervised_history.clear()
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length

    def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        loc = self._plane.origin
        loc = self._plane.params
        # logger.info('loc {}'.format(loc))
        self._loc_history[-1] = (np.around(loc[0], decimals=2),
                                 np.around(loc[1], decimals=2),
                                 np.around(loc[2], decimals=2),
                                 np.around(loc[3], decimals=2))
        # update distance history
        self._dist_history.append(self.cur_dist)
        self._dist_history_params.append(self.cur_dist_params)
        # update params history
        self._plane_history.append(self._plane)
        self._bestq_history.append(np.max(self._qvalues))
        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues

    def _calc_reward_points(self, prev_points, next_points):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        prev_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
                                          prev_points)
        next_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
                                          next_points)

        return prev_dist - next_dist

    def _calc_reward_params(self, prev_params, next_params):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        # logger.info('prev_params {}'.format(np.around(prev_params,2)))
        # logger.info('next_params {}'.format(np.around(next_params,2)))
        prev_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                            prev_params,
                                            scale_angle=self.action_angle_step,
                                            scale_dist=self.action_dist_step)
        next_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                            next_params,
                                            scale_angle=self.action_angle_step,
                                            scale_dist=self.action_dist_step)

        return prev_dist - next_dist

    @property
    def _oscillate(self):
        ''' Return True if the agent is stuck and oscillating
        '''
        counter = Counter(self._loc_history)
        freq = counter.most_common()
        # return false is history is empty (begining of the game)
        if len(freq) < 2: return False
        # check frequency
        if freq[0][0] == (0, 0, 0, 0):
            if (freq[1][1] > 2):
                # logger.info('oscillating {}'.format(self._loc_history))
                return True
            else:
                return False
        elif (freq[0][1] > 2):
            # logger.info('oscillating {}'.format(self._loc_history))
            return True

    def get_action_meanings(self):
        ''' return array of integers for actions '''
        ACTION_MEANING = {
            0: "inc_x",  # increment +1 the norm angle in x-direction
            1: "inc_y",  # increment +1 the norm angle in y-direction
            2: "inc_z",  # increment +1 the norm angle in z-direction
            3: "inc_d",  # increment +1 the norm distance d to origin
            4: "dec_x",  # decrement -1 the norm angle in x-direction
            5: "dec_y",  # decrement -1 the norm angle in y-direction
            6: "dec_z",  # decrement -1 the norm angle in z-direction
            7: "dec_d",  # decrement -1 the norm distance d to origin
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self, return_rgb_array=False):
        # pass
        # --------------------------------------------------------------------
        ## planes seen by the agent
        # # get image and convert it to pyglet
        # plane = self._plane.grid[:,:,round(self.depth/2)] # z-plane
        # # concatenate groundtruth image
        # gt_plane = self._groundTruth_plane.grid[:,:,round(self.depth/2)]
        # --------------------------------------------------------------------
        ## whole plan
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        current_plane = Plane(*getPlane(self.sitk_image,
                                        self._origin3d_point,
                                        self._plane.params,
                                        image_size,
                                        spacing=[1, 1, 1]))

        # get image and convert it to pyglet
        plane = current_plane.grid[:, :, int(image_size[2] / 2)]  # z-plane
        # concatenate groundtruth image
        gt_plane = self.groundTruth_plane_iso.grid[:, :,
                                                   int(image_size[2] / 2)]
        # --------------------------------------------------------------------
        # concatenate two planes side by side
        plane = np.concatenate((plane, gt_plane), axis=1)
        #
        img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 5
        scale_y = 5
        # img = cv2.resize(img,
        #                  (int(scale_x*img.shape[1]),int(scale_y*img.shape[0])),
        #                  interpolation=cv2.INTER_LINEAR)
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)
        self.viewer.display_text('Current Plane',
                                 color=(0, 0, 204, 255),
                                 x=int(0.7 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)
        self.viewer.display_text('Ground Truth',
                                 color=(0, 0, 204, 255),
                                 x=int(4.3 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)

        # display info
        dist_color_flag = False
        if len(self._dist_history) > 1:
            dist_color_flag = self.cur_dist < self._dist_history[-2]

        color_dist = (0, 204, 0, 255) if dist_color_flag else (204, 0, 0, 255)
        text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=color_dist,
                                 x=int(3 * img.shape[1] / 8),
                                 y=5 * scale_y)

        dist_color_flag = False
        if len(self._dist_history_params) > 1:
            dist_color_flag = self.cur_dist_params < self._dist_history_params[
                -2]

        # color_dist = (0,255,0,255) if dist_color_flag else (255,0,0,255)
        # text = 'Params Error ' + str(round(self.cur_dist_params,3))
        # self.viewer.display_text(text, color=color_dist,
        # x=int(6*img.shape[1]/8), y=5*scale_y)
        text = 'Spacing ' + str(round(self.spacing[0], 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=(204, 204, 0, 255),
                                 x=int(6 * img.shape[1] / 8),
                                 y=5 * scale_y)

        color_reward = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0,
                                                                 255)
        text = 'Reward ' + "%+d" % round(self.reward, 3)
        self.viewer.display_text(text,
                                 color=color_reward,
                                 x=2 * scale_x,
                                 y=5 * scale_y)

        # render and wait (viz) time between frames
        self.viewer.render()

        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            # set_trace()
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr).convert('P')
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.savegif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if (self.cnt <= 1):
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '3', '-i', dirname + '/%04d.png', '-s', '1280x720',
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename + '.mp4'
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Example #31
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """

    def __init__(self,
                 # model,
                 agent_name,
                 player,
                 state_shape,
                 num_actions,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency,
                 encoding_file='../AutoEncoder/encoding.npy'):
        init_memory_size = int(init_memory_size)
        # self.model = model

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.agent_name = agent_name
        self.exploration = init_exploration
        self.num_actions = num_actions
        self.encoding = np.load(encoding_file)
        logger.info("Number of Legal actions: {}, {}".format(*self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        self.player.reset()
        self.player.prepare()
        self._comb_mask = True
        self._fine_mask = None
        self._current_ob, self._action_space = self.get_state_and_action_spaces()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()

    def get_combinations(self, curr_cards_char, last_cards_char):
        if len(curr_cards_char) > 10:
            card_mask = Card.char2onehot60(curr_cards_char).astype(np.uint8)
            mask = augment_action_space_onehot60
            a = np.expand_dims(1 - card_mask, 0) * mask
            invalid_row_idx = set(np.where(a > 0)[0])
            if len(last_cards_char) == 0:
                invalid_row_idx.add(0)

            valid_row_idx = [i for i in range(len(augment_action_space)) if i not in invalid_row_idx]

            mask = mask[valid_row_idx, :]
            idx_mapping = dict(zip(range(mask.shape[0]), valid_row_idx))

            # augment mask
            # TODO: known issue: 555444666 will not decompose into 5554 and 66644
            combs = get_combinations_nosplit(mask, card_mask)
            combs = [([] if len(last_cards_char) == 0 else [0]) + [clamp_action_idx(idx_mapping[idx]) for idx in comb] for comb in combs]

            if len(last_cards_char) > 0:
                idx_must_be_contained = set(
                    [idx for idx in valid_row_idx if CardGroup.to_cardgroup(augment_action_space[idx]). \
                        bigger_than(CardGroup.to_cardgroup(last_cards_char))])
                combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)]
                self._fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool)
                for i in range(len(combs)):
                    for j in range(len(combs[i])):
                        if combs[i][j] in idx_must_be_contained:
                            self._fine_mask[i][j] = True
            else:
                self._fine_mask = None
        else:
            mask = get_mask_onehot60(curr_cards_char, action_space, None).reshape(len(action_space), 15, 4).sum(-1).astype(
                np.uint8)
            valid = mask.sum(-1) > 0
            cards_target = Card.char2onehot60(curr_cards_char).reshape(-1, 4).sum(-1).astype(np.uint8)
            combs = get_combinations_recursive(mask[valid, :], cards_target)
            idx_mapping = dict(zip(range(valid.shape[0]), np.where(valid)[0]))

            combs = [([] if len(last_cards_char) == 0 else [0]) + [idx_mapping[idx] for idx in comb] for comb in combs]

            if len(last_cards_char) > 0:
                valid[0] = True
                idx_must_be_contained = set(
                    [idx for idx in range(len(action_space)) if valid[idx] and CardGroup.to_cardgroup(action_space[idx]). \
                        bigger_than(CardGroup.to_cardgroup(last_cards_char))])
                combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)]
                self._fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool)
                for i in range(len(combs)):
                    for j in range(len(combs[i])):
                        if combs[i][j] in idx_must_be_contained:
                            self._fine_mask[i][j] = True
            else:
                self._fine_mask = None
        return combs

    def subsample_combs_masks(self, combs, masks, num_sample):
        if masks is not None:
            assert len(combs) == masks.shape[0]
        idx = np.random.permutation(len(combs))[:num_sample]
        return [combs[i] for i in idx], (masks[idx] if masks is not None else None)

    def get_state_and_action_spaces(self, action=None):

        def cards_char2embedding(cards_char):
            test = (action_space_onehot60 == Card.char2onehot60(cards_char))
            test = np.all(test, axis=1)
            target = np.where(test)[0]
            return self.encoding[target[0]]

        last_two_cards_char = self.player.get_last_two_cards()
        last_cards_char = last_two_cards_char[0]
        if not last_cards_char:
            last_cards_char = last_two_cards_char[1]
        curr_cards_char = self.player.get_curr_handcards()
        if self._comb_mask:
            # print(curr_cards_char, last_cards_char)
            combs = self.get_combinations(curr_cards_char, last_cards_char)
            if len(combs) > self.num_actions[0]:
                combs, self._fine_mask = self.subsample_combs_masks(combs, self._fine_mask, self.num_actions[0])
            # TODO: utilize temporal relations to speedup
            available_actions = [[action_space[idx] for idx in comb] for comb in combs]
            # print(available_actions)
            # print('-------------------------------------------')
            assert len(combs) > 0
            if self._fine_mask is not None:
                self._fine_mask = self.pad_fine_mask(self._fine_mask)
            self.pad_action_space(available_actions)
            state = [np.stack([self.encoding[idx] for idx in comb]) for comb in combs]
            assert len(state) > 0
            prob_state = self.player.get_state_prob()
            # test = action_space_onehot60 == Card.char2onehot60(last_cards_char)
            # test = np.all(test, axis=1)
            # target = np.where(test)[0]
            # assert target.size == 1
            extra_state = np.concatenate([cards_char2embedding(last_two_cards_char[0]), cards_char2embedding(last_two_cards_char[1]), prob_state])
            for i in range(len(state)):
                state[i] = np.concatenate([state[i], np.tile(extra_state[None, :], [state[i].shape[0], 1])], axis=-1)
            state = self.pad_state(state)
            assert state.shape[0] == self.num_actions[0] and state.shape[1] == self.num_actions[1]
        else:
            assert action is not None
            if self._fine_mask is not None:
                self._fine_mask = self._fine_mask[action]
            available_actions = self._action_space[action]
            state = self._current_ob[action:action+1, :, :]
            state = np.repeat(state, self.num_actions[0], axis=0)
            assert state.shape[0] == self.num_actions[0] and state.shape[1] == self.num_actions[1]
        return state, available_actions

    def pad_fine_mask(self, mask):
        if mask.shape[0] < self.num_actions[0]:
            mask = np.concatenate([mask, np.repeat(mask[-1:], self.num_actions[0] - mask.shape[0], 0)], 0)
        return mask

    def pad_action_space(self, available_actions):
        # print(available_actions)
        for i in range(len(available_actions)):
            available_actions[i] += [available_actions[i][-1]] * (self.num_actions[1] - len(available_actions[i]))
        if len(available_actions) < self.num_actions[0]:
            available_actions.extend([available_actions[-1]] * (self.num_actions[0] - len(available_actions)))

    # input is a list of N * HIDDEN_STATE
    def pad_state(self, state):
        # since out net uses max operation, we just dup the last row and keep the result same
        newstates = []
        for s in state:
            assert s.shape[0] <= self.num_actions[1]
            s = np.concatenate([s, np.repeat(s[-1:, :], self.num_actions[1] - s.shape[0], axis=0)], axis=0)
            newstates.append(s)
        newstates = np.stack(newstates, axis=0)
        if len(state) < self.num_actions[0]:
            state = np.concatenate([newstates, np.repeat(newstates[-1:, :, :], self.num_actions[0] - newstates.shape[0], axis=0)], axis=0)
        else:
            state = newstates
        return state

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            for _ in range(self.update_frequency):
                self._populate_exp()
        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""
        old_s = self._current_ob
        comb_mask = self._comb_mask
        if not self._comb_mask and self._fine_mask is not None:
            fine_mask = self._fine_mask if self._fine_mask.shape[0] == max(self.num_actions[0], self.num_actions[1]) \
                else np.pad(self._fine_mask, (0, max(self.num_actions[0], self.num_actions[1]) - self._fine_mask.shape[0]), 'constant', constant_values=(0, 0))
        else:
            fine_mask = np.ones([max(self.num_actions[0], self.num_actions[1])], dtype=np.bool)
        last_cards_char = self.player.get_last_outcards()
        if self.rng.rand() <= self.exploration:
            if not self._comb_mask and self._fine_mask is not None:
                q_values = np.random.rand(self.num_actions[1])
                q_values[np.where(np.logical_not(self._fine_mask))[0]] = np.nan
                act = np.nanargmax(q_values)
                # print(q_values)
                # print(act)
            else:
                act = self.rng.choice(range(self.num_actions[0 if comb_mask else 1]))
        else:
            q_values = self.curr_predictor(old_s[None, :, :, :], np.array([comb_mask]), np.array([fine_mask]))[0][0]
            if not self._comb_mask and self._fine_mask is not None:
                q_values = q_values[:self.num_actions[1]]
                assert np.all(q_values[np.where(np.logical_not(self._fine_mask))[0]] < -100)
                q_values[np.where(np.logical_not(self._fine_mask))[0]] = np.nan
            act = np.nanargmax(q_values)
            assert act < self.num_actions[0 if comb_mask else 1]
            # print(q_values)
            # print(act)
            # clamp action to valid range
            act = min(act, self.num_actions[0 if comb_mask else 1] - 1)
        winner = -1
        reward = 0
        if comb_mask:
            isOver = False
        else:
            if len(last_cards_char) > 0:
                if act > 0:
                    if not CardGroup.to_cardgroup(self._action_space[act]).bigger_than(CardGroup.to_cardgroup(last_cards_char)):
                        print('warning, some error happened, ', self._action_space[act], last_cards_char)
                        raise Exception("card comparison error")
            winner, isOver = self.player.step(self._action_space[act])

        # step for AI farmers
        while not isOver and self.player.get_curr_agent_name() != self.agent_name:
            handcards = self.player.get_curr_handcards()
            last_two_cards = self.player.get_last_two_cards()
            prob_state = self.player.get_state_prob()
            action = self.predictors[self.player.get_curr_agent_name()].predict(handcards, last_two_cards, prob_state)
            winner, isOver = self.player.step(action)

        if isOver:
            if self.agent_name == winner:
                reward = 1
            else:
                if self.player.get_all_agent_names().index(winner) + self.player.get_all_agent_names().index(self.agent_name) == 3:
                    reward = 1
                else:
                    reward = -1
        self._current_game_score.feed(reward)

        if isOver:
            self._player_scores.feed(self._current_game_score.sum)
            self.player.reset()
            self.player.prepare()
            self._comb_mask = True
            self.prestart()
            self._current_game_score.reset()
        else:
            self._comb_mask = not self._comb_mask
        self._current_ob, self._action_space = self.get_state_and_action_spaces(act if not self._comb_mask else None)
        self.mem.append(Experience(old_s, act, reward, isOver, comb_mask, fine_mask))

    def prestart(self):
        while self.player.get_curr_agent_name() != self.agent_name:
            handcards = self.player.get_curr_handcards()
            last_two_cards = self.player.get_last_two_cards()
            prob_state = self.player.get_state_prob()
            action = self.predictors[self.player.get_curr_agent_name()].predict(handcards, last_two_cards, prob_state)

            self.player.step(action)
        self._current_ob, self._action_space = self.get_state_and_action_spaces()

    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        while True:
            idx = self.rng.randint(
                self._populate_job_queue.maxsize * self.update_frequency,
                len(self.mem) - 1,
                size=self.batch_size)
            batch_exp = [self.mem.sample(i) for i in idx]

            yield self._process_batch(batch_exp)
            self._populate_job_queue.put(1)

    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='float32')
        action = np.asarray([e[1] for e in batch_exp], dtype='int32')
        reward = np.asarray([e[2] for e in batch_exp], dtype='float32')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        comb_mask = np.asarray([e[4] for e in batch_exp], dtype='bool')
        fine_mask = np.asarray([e[5] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver, comb_mask, fine_mask]

    def _setup_graph(self):
        self.curr_predictor = self.trainer.get_predictor([self.agent_name + '/state:0', self.agent_name + '_comb_mask:0', self.agent_name + '/fine_mask:0'], [self.agent_name + '/Qvalue:0'])
        self.predictors = {n: Predictor(self.trainer.get_predictor([n + '/state:0', n + '_comb_mask:0', n + '/fine_mask:0'], [n + '/Qvalue:0'])) for n in self.player.get_all_agent_names()}

    def _before_train(self):
        self.prestart()

        self._init_memory()
        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

    def _trigger(self):
        v = self._player_scores
        try:
            mean, max = v.average, v.max
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
        except Exception:
            logger.exception(self.agent_name + " Cannot log training scores.")
        v.reset()
Example #32
0
class Brain_Env(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(
            self,
            directory=None,
            viz=False,
            task=False,
            files_list=None,
            observation_dims=(27, 27, 27),
            multiscale=False,  # FIXME automatic dimensions
            max_num_frames=20,
            saveGif=False,
            saveVideo=False):  # FIXME hardcoded max num frames!
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param observation_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum number of frames per episode.
        """
        super(Brain_Env, self).__init__()

        print(
            "warning! max num frames hard coded to {}!".format(max_num_frames),
            flush=True)

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.observation_dims = observation_dims
        self.dims = len(self.observation_dims)
        # multi-scale agent
        self.multiscale = multiscale
        # FIXME force multiscale false for now
        self.multiscale = False

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = observation_dims
        elif self.dims == 3:
            self.width, self.height, self.depth = observation_dims
        else:
            raise ValueError

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # TODO: understand this viz setup
            # visualization setup
            #     if isinstance(viz, six.string_types):  # check if viz is a string
            #         assert os.path.isdir(viz), viz
            #         viz = 0
            #     if isinstance(viz, int):
            #         viz = float(viz)
            self.viz = viz
        #     if self.viz and isinstance(self.viz, float):
        #         self.viewer = None
        #         self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.observation_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        self._history_length = max_num_frames
        # TODO initialize _observation_bounds limits from input image coordinates
        self._observation_bounds = ObservationBounds(0, 0, 0, 0, 0, 0)
        # add your data loader here
        # TODO: look into returnLandmarks
        # if self.task == 'play':
        #     self.files = filesListBrainMRLandmark(directory, files_list,
        #                                           returnLandmarks=False)
        # else:
        #     self.files = filesListBrainMRLandmark(directory, files_list,
        #                                           returnLandmarks=True)
        self.files = FilesListCubeNPY(directory, files_list)

        # self.files = filesListFetalUSLandmark(directory,files_list)
        # self.files = filesListCardioMRLandmark(directory,files_list)
        # prepare file sampler
        self.filepath = None
        self.file_sampler = self.files.sample_circular()  # returns generator
        # reset buffer, terminal, counters, and init new_random_game
        # we put this here so that init_player in DQN.py doesn't try to update_history
        self._clear_history()  # init arrays
        self._restart_episode()
        # self.viz = True  # FIXME viz should default False
        assert (np.shape(self._state) == self.observation_dims)
        assert np.isclose(jaccard(self.original_state, self.original_state), 1)

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._observe()

    def _restart_episode(self):
        """
        restart current episode
        """
        self.terminal = False
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset the stat counter
        self.new_random_game()

    def new_random_game(self):
        """
        load image,
        set dimensions,
        randomize start point,
        init _screen, qvals,
        calc distance to goal
        """
        self.terminal = False
        self.viewer = None

        # # sample a new image
        self.filepath, self.filename = next(self.file_sampler)
        self._state = np.load(self.filepath).astype(float)
        self.original_state = np.copy(self._state)

        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            raise NotImplementedError
            # ## brain
            # self.action_step = 9
            # self.xscale = 3
            # self.yscale = 3
            # self.zscale = 3
            ## cardiac
            # self.action_step = 6
            # self.xscale = 2
            # self.yscale = 2
            # self.zscale = 2
        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1
            self.zscale = 1
        # image volume size
        self._state_dims = np.shape(self._state)
        #######################################################################
        ## select random starting point
        # add padding to avoid start right on the border of the image
        if (self.task == 'train'):
            skip_thickness = (int(self._state_dims[0] / 5),
                              int(self._state_dims[1] / 5),
                              int(self._state_dims[2] / 5))
        else:  # TODO: wtf why different skip thickness
            skip_thickness = (int(self._state_dims[0] / 4),
                              int(self._state_dims[1] / 4),
                              int(self._state_dims[2] / 4))

        # FIXME randomly select one of the ground truth voxels as a starting point
        binary_grid = self.original_state.astype(bool)
        x_span, y_span, z_span = self.original_state.shape
        x, y, z = np.indices((x_span, y_span, z_span))
        positions = np.c_[x[binary_grid == 1], y[binary_grid == 1],
                          z[binary_grid == 1]]
        # pick a random row as starting position
        self._location = positions[np.random.choice(positions.shape[0],
                                                    1)].flatten()
        # print("starting location ", self._location)
        self._start_location = self._location

        # # randomly select the starting coords
        # x = self.rng.randint(0 + skip_thickness[0],
        #                      self._state_dims[0] - skip_thickness[0])
        # y = self.rng.randint(0 + skip_thickness[1],
        #                      self._state_dims[1] - skip_thickness[1])
        # z = self.rng.randint(0 + skip_thickness[2],
        #                      self._state_dims[2] - skip_thickness[2])
        #######################################################################

        # self._location = np.array([x, y, z])
        # self._start_location = np.array([x, y, z])
        self._qvalues = np.zeros(self.actions)
        self._observation = self._observe()
        self.curr_IOU = self.calc_IOU()
        print("first IOU ", self.curr_IOU)
        self.reward = self._calc_reward(False, False)
        self._update_history()
        # we've finished iteration 0. now, step begins with cnt = 1
        self.cnt += 1

    def calc_IOU(self):
        """ calculate the Intersection over Union AKA Jaccard Index
        between two images

        https://en.wikipedia.org/wiki/Jaccard_index
        """
        # flatten bc  jaccard_similarity_score expects 1D arrays
        state = self._state.ravel()
        state[state != -1] = 0  # mask out non-agent trajectory
        state = state.astype(bool)  # everything non-zero => True
        if not state.any():  # no agent trajectory
            print(" no state trajectory found")
            iou = 0.0
        else:
            iou = jaccard(state, self.original_state)
            # print("computed iou ", iou)
            # print("sum(agent) ", sum(state), "sum(original state)", sum(self.original_state), "computed iou ", iou)
        # print("agent \n", state.shape)
        # print("og \n", original_state.shape)
        # np.save("agent", state)
        # np.save("og", original_state)
        # assert isinstance(iou, )
        return iou

    def step(self, act, qvalues):
        """The environment's step function returns exactly what we need.
        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """
        self._qvalues = qvalues
        current_loc = self._location
        self.terminal = False
        go_out = False
        backtrack = False

        # UP Z+ -----------------------------------------------------------
        if (act == 0):
            proposed_location = current_loc + np.array([0, 0, 1
                                                        ]) * self.action_step
        # FORWARD Y+ ---------------------------------------------------------
        elif (act == 1):
            proposed_location = current_loc + np.array([0, 1, 0
                                                        ]) * self.action_step
        # RIGHT X+ -----------------------------------------------------------
        elif (act == 2):
            proposed_location = current_loc + np.array([1, 0, 0
                                                        ]) * self.action_step
        # LEFT X- -----------------------------------------------------------
        elif act == 3:
            proposed_location = current_loc + np.array([-1, 0, 0
                                                        ]) * self.action_step
        # BACKWARD Y- ---------------------------------------------------------
        elif act == 4:
            proposed_location = current_loc + np.array([0, -1, 0
                                                        ]) * self.action_step
        # DOWN Z- -----------------------------------------------------------
        elif act == 5:
            proposed_location = current_loc + np.array([0, 0, -1
                                                        ]) * self.action_step
        else:
            raise ValueError

        # print("action ", act, "loc ", self._location, "proposed ", proposed_location, "diff ", proposed_location-self._location)

        if not self._is_in_bounds(proposed_location):  # went out of bounds
            # do not update current_loc
            go_out = True
        else:  # in bounds
            transposed = proposed_location.T
            # https://stackoverflow.com/a/25823710/4212158
            if np.any(
                    np.isclose(np.unique(self._agent_nodes, axis=0),
                               transposed).all(axis=1)):
                # print("backtracking detected ", transposed, "hist ", np.unique(self._agent_nodes, axis=0), np.isclose(np.unique(self._agent_nodes, axis=0), transposed).all(axis=1))
                # we backtracked
                backtrack = True
            else:
                # we are in bounds, AND we didn't back track. accept new location
                self._location = proposed_location
                # only update state, iou if we've changed location
                self._observation = self._observe()
                self.curr_IOU = self.calc_IOU()

        # punish -1 reward if the agent tries to go out
        #if (self.task != 'play'):  # TODO: why is this necessary?
        self.reward = self._calc_reward(
            go_out, backtrack
        )  # TODO I think reward needs to be calculated after increment cnt
        # update screen, reward ,location, terminal
        self._update_history()

        # terminate if the distance is less than 1 during trainig
        if (self.task == 'train'):
            if self.curr_IOU >= 0.9:
                print("finishing episode, IOU = ", self.curr_IOU)
                self.terminal = True
                self.num_success.feed(1)
                self.display()

        # terminate if maximum number of steps is reached

        if self.cnt >= self.max_num_frames - 1:
            print("finishing episode, exceeded max_frames ",
                  self.max_num_frames, " IOU = ", self.curr_IOU)
            self.terminal = True
            # self.display()

        # update history buffer with new location and qvalues
        if (self.task != 'play'):
            self.curr_IOU = self.calc_IOU()

        # check if agent oscillates
        # if self._oscillate:
        # TODO: rewind history, recalculate IOU
        # self._location = self.get_best_node()  # TODO replace
        # self._observation = self._observe()
        # if (self.task != 'play'):
        # self.curr_IOU = self.calc_IOU()
        # multi-scale steps
        # if self.multiscale:
        #     if self.xscale > 1:
        #         self.xscale -= 1
        #         self.yscale -= 1
        #         self.zscale -= 1
        #         self.action_step = int(self.action_step / 3)
        #         self._clear_history()
        #     # terminate if scale is less than 1
        #     else:
        #         self.terminal = True
        #         if self.curr_IOU >= 0.9: self.num_success.feed(1)
        # else:
        # self.terminal = True
        # if self.curr_IOU >= 0.9: self.num_success.feed(1)

        # # render screen if viz is on  FIXME this displays at each step
        # with _ALE_LOCK:
        #     if self.viz:
        #         if isinstance(self.viz, float):
        #             self.display()

        self.current_episode_score.feed(self.reward)
        self.cnt += 1

        info = {
            'score': self.current_episode_score.sum,
            'gameOver': self.terminal,
            'IoU': self.curr_IOU,
            'filename': self.filename
        }

        return self._observe(), self.reward, self.terminal, info

    def get_best_node(self):
        ''' get best location with best qvalue from last for locations
        stored in history

        TODO: make sure nodes dont have overlap
        '''
        last_qvalues_history = self._qvalues_history[-4:]
        last_loc_history = self._agent_nodes[-4:]
        best_qvalues = np.max(last_qvalues_history, axis=1)
        # best_idx = best_qvalues.argmax()
        best_idx = best_qvalues.argmin()
        best_location = last_loc_history[best_idx]

        return best_location

    def _clear_history(self):
        """ clear history buffer with current state
        """
        # TODO: double check these np arrays work in place of the lists
        self._agent_nodes = np.zeros(
            (self._history_length,
             self.dims))  # [(0,) * self.dims] * self._history_length
        self._IOU_history = np.zeros((self._history_length, ))
        # list of q-value lists
        self._qvalues_history = np.zeros(
            (self._history_length,
             self.actions))  # [(0,) * self.actions] * self._history_length
        self.reward_history = np.zeros((self._history_length, ))

    def _update_history(self):
        """ update history buffer with current state
        """
        # update location history
        self._agent_nodes[self.cnt] = self._location
        # update jaccard index history
        self._IOU_history[self.cnt] = self.curr_IOU
        # and the reward
        self.reward_history[self.cnt] = self.reward
        # update q-value history
        self._qvalues_history[self.cnt] = self._qvalues

    def _observe(self):
        """
        crop image data around current location to update what network sees.
        update _observation_bounds

        :return: new state
        """
        # initialize screen with zeros - all background
        observation = np.zeros((self.observation_dims))

        # screen uses coordinate system relative to origin (0, 0, 0)
        screen_xmin, screen_ymin, screen_zmin = 0, 0, 0
        screen_xmax, screen_ymax, screen_zmax = self.observation_dims

        # extract boundary locations using coordinate system relative to "global" image
        # width, height, depth in terms of screen coord system
        if self.xscale % 2:
            xmin = self._location[0] - int(self.width * self.xscale / 2) - 1
            xmax = self._location[0] + int(self.width * self.xscale / 2)
            ymin = self._location[1] - int(self.height * self.yscale / 2) - 1
            ymax = self._location[1] + int(self.height * self.yscale / 2)
            zmin = self._location[2] - int(self.depth * self.zscale / 2) - 1
            zmax = self._location[2] + int(self.depth * self.zscale / 2)
        else:
            xmin = self._location[0] - round(self.width * self.xscale / 2)
            xmax = self._location[0] + round(self.width * self.xscale / 2)
            ymin = self._location[1] - round(self.height * self.yscale / 2)
            ymax = self._location[1] + round(self.height * self.yscale / 2)
            zmin = self._location[2] - round(self.depth * self.zscale / 2)
            zmax = self._location[2] + round(self.depth * self.zscale / 2)

        # check if they violate image boundary and fix it
        if xmin < 0:
            xmin = 0
            screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale))
        if ymin < 0:
            ymin = 0
            screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale))
        if zmin < 0:
            zmin = 0
            screen_zmin = screen_zmax - len(np.arange(zmin, zmax, self.zscale))
        if xmax > self._state_dims[0]:
            xmax = self._state_dims[0]
            screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale))
        if ymax > self._state_dims[1]:
            ymax = self._state_dims[1]
            screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale))
        if zmax > self._state_dims[2]:
            zmax = self._state_dims[2]
            screen_zmax = screen_zmin + len(np.arange(zmin, zmax, self.zscale))

        # take image, mask it w agent trajectory
        agent_trajectory = self.trajectory_to_branch()
        agent_trajectory *= -1  # agent frames are negative
        # paste agent trajectory ontop of original state, but only when vals are not 0
        agent_mask = agent_trajectory.astype(bool)
        if agent_mask.any():  # agent trajectory not empty
            np.copyto(self._state,
                      agent_trajectory,
                      casting='no',
                      where=agent_mask)
            assert self._state is not None

        # crop image data to update what network sees
        # image coordinate system becomes screen coordinates
        # scale can be thought of as a stride
        # TODO: check if we need to keep "stride" from upstream
        observation[screen_xmin:screen_xmax, screen_ymin:screen_ymax,
                    screen_zmin:screen_zmax] = self._state[xmin:xmax,
                                                           ymin:ymax,
                                                           zmin:zmax]

        # update _observation_bounds limits from input image coordinates
        # this is what the network sees
        self._observation_bounds = ObservationBounds(xmin, xmax, ymin, ymax,
                                                     zmin, zmax)

        return observation

    def trajectory_to_branch(self):
        """take location history, generate connected branches using Vaa3d plugin
        FIXME this function is horribly inefficient
        """
        locations = self._agent_nodes
        # print("og state shape ", np.shape(self.original_state))
        # print("self obs dims ", self.observation_dims)
        # if the agent hasn't drawn any nodes, then the branch is empty. skip pipeline, return empty arr.
        if not locations.any():  # if all zeros, evals to False
            return np.zeros_like(self.original_state)
        else:
            # TODO: make tmp files not collide when doing multiprocessing
            output_swc = save_branch_as_swc(locations,
                                            "agent_trajectory",
                                            output_dir="tmp",
                                            overwrite=True)
            # TODO: be explicit about bounds to swc_to_tiff
            output_tiff = swc_to_TIFF("agent_trajectory",
                                      output_swc,
                                      output_dir="tmp",
                                      overwrite=True)
            output_npy = TIFF_to_npy("agent_trajectory",
                                     output_tiff,
                                     output_dir="tmp",
                                     overwrite=True)
            output_npy = np.load(output_npy).astype(float)
            tiff_max = np.amax(np.fabs(output_npy))
            if not np.isclose(tiff_max, 0):  # normalize if tiff is not blank
                output_npy = output_npy / tiff_max
            return output_npy

        def crop_brain(self, xmin, xmax, ymin, ymax, zmin, zmax):
            return self.state[xmin:xmax, ymin:ymax, zmin:zmax]

    def _calc_reward(self, go_out, backtrack):
        """ Calculate the new reward based on the increase in IoU
        TODO: if current location is same as past location, always penalize (discourage retracing)
        """
        if go_out or backtrack:
            reward = -1
        else:
            # TODO, double check if indexes are correct
            if self.cnt == 0:
                previous_IOU = 0.
            else:
                previous_IOU = self._IOU_history[self.cnt - 1]
            IOU_difference = self.curr_IOU - previous_IOU
            print("curr IOU = ", self.curr_IOU, "prev IOU = ",
                  self._IOU_history[self.cnt - 1], "diff = ", IOU_difference)
            assert isinstance(IOU_difference, float)
            if IOU_difference > 0:
                reward = 1
            else:
                reward = -1
        return reward

    def _is_in_bounds(self, coords):
        x, y, z = coords
        bounds = self._observation_bounds
        return ((bounds.xmin <= x <= bounds.xmax - 1
                 and bounds.ymin <= y <= bounds.ymax - 1
                 and bounds.zmin <= z <= bounds.zmax - 1))

    @property
    def _oscillate(self):
        """ Return True if the agent is stuck and oscillating
        """
        # TODO reimplement
        # TODO: erase last few frames if oscillation is detected
        counter = Counter(self._agent_nodes)
        freq = counter.most_common()

        # TODO: wtF?
        if freq[0][0] == (0, 0, 0):
            if (freq[1][1] > 3):
                return True
            else:
                return False
        elif (freq[0][1] > 3):
            return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "UP",  # MOVE Z+
            2: "FORWARD",  # MOVE Y+
            3: "RIGHT",  # MOVE X+
            4: "LEFT",  # MOVE X-
            5: "BACKWARD",  # MOVE Y-
            6: "DOWN",  # MOVE Z-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self):
        """this is called at every step"""
        current_point = self._location
        # img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        # scale_x = 1
        # scale_y = 1

        # print("nodes ", self._agent_nodes)
        # print("ious", self._IOU_history)
        print("reward history ", np.unique(self.reward_history))
        print("IOU history ", np.unique(self._IOU_history))
        plotter = Viewer(self.original_state,
                         zip(self._agent_nodes, self._IOU_history),
                         filepath=self.filename)
        #
        # #
        # # from viewer import SimpleImageViewer
        # # self.viewer = SimpleImageViewer(self._state,
        # #                                 scale_x=1,
        # #                                 scale_y=1,
        # #                                 filepath=self.filename)
        #     self.gif_buffer = []
        #
        #
        # # render and wait (viz) time between frames
        # self.viewer.render()
        # # time.sleep(self.viz)
        # # save gif
        if self.saveGif:
            # if self.saveGif:
            # TODO make this a method of viewer
            raise NotImplementedError
            # image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            # data = image_data.get_data('RGB', image_data.width * 3)
            # arr = np.array(bytearray(data)).astype('uint8')
            # arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            # im = Image.fromarray(arr)
            # self.gif_buffer.append(im)
            #
            # if not self.terminal:
            #     gifname = self.filename.split('.')[0] + '.gif'
            #     self.viewer.saveGif(gifname, arr=self.gif_buffer,
            #                         duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            # if self.cnt <= 1:
            #     if os.path.isdir(dirname):
            #         logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
            #         act = input("select action: d (delete) / q (quit): ").lower().strip()
            #         if act == 'd':
            #             shutil.rmtree(dirname, ignore_errors=True)
            #         else:
            #             raise OSError("Directory {} exits!".format(dirname))
            #     os.mkdir(dirname)

            vid_fpath = self.filename + '.mp4'
            # vid_fpath = dirname + '/' + self.filename + '.mp4'
            plotter.save_vid(vid_fpath, self.max_num_frames - 1)
            # plotter.show_agent()

        if self.viz:  # show progress
            # plotter.show()
            # actually, let's just save the files for later
            output_dir = os.path.abspath("saved_trajectories/")
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            # outfile_fpath = os.path.join(output_dir, input_fname + ".npy")
            #
            # # don't overwrite
            # if not os.path.isfile(outfile_fpath) or overwrite:
            #     desired_len = 16
            #     img_array = tiff2array.imread(input_fpath)
            #     # make all arrays the same shape
            #     # format: ((top, bottom), (left, right))
            #     shp = img_array.shape
            #     # print(shp, flush=True)
            #     if shp != (desired_len, desired_len, desired_len):
            #         try:
            #             img_array = np.pad(img_array, (
            #             (0, desired_len - shp[0]), (0, desired_len - shp[1]), (0, desired_len - shp[2])),
            #                                'constant')
            #         except ValueError:
            #             raise
            #             # print(shp, flush=True)  # don't wait for all threads to finish before printing
            #
            np.savez(output_dir + self.filename,
                     locations=self._agent_nodes,
                     original_state=self.original_state,
                     reward_history=self.reward_history)
Example #33
0
class MedicalPlayer(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(self,
                 directory=None,
                 viz=False,
                 task=False,
                 files_list=None,
                 screen_dims=(27, 27, 27),
                 history_length=20,
                 multiscale=True,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False,
                 data_type=None):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum numbe0r of frames per episode.
        """
        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filename", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5,0.25,0.75]
        # y = [0.5,0.25,0.75]
        # z = [0.5,0.25,0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)
        # multi-scale agent
        self.multiscale = multiscale
        #Type of data
        self.data_type = data_type
        #directory is file for logging evaluation
        self.directory = directory

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        # initialize rectangle limits from input image coordinates
        self.rectangle = Rectangle(0, 0, 0, 0, 0, 0)
        # add your data loader here
        self.set_dataLoader(files_list)

        # prepare file sampler
        self.filepath = None
        self.HITL_logger = []
        self._loc_history = None
        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()

    def set_dataLoader(self, files_list):
        if self.data_type == 'BrainMRI':
            self.data_loader = filesListBrainMRLandmark
        elif self.data_type == 'CardiacMRI':
            self.data_loader = filesListCardioLandmark
        elif self.data_type == 'FetalUS':
            self.data_loader = filesListFetalUSLandmark
        elif self.data_type == "HITL":
            self.data_loader = fileHITL

        if self.task == 'play':
            self.files = self.data_loader(files_list, returnLandmarks=False)
        else:
            self.files = self.data_loader(files_list, returnLandmarks=True)

        self.sampled_files = self.files.sample_circular()

    def HITL_episode_log(self):
        """ Method to save episode info for HITL """
        log = {
            'states': self._loc_history,
            'rewards': self._reward_history,
            'actions': self._act_history,
            'target': self._target_loc,
            'img_name': self.filename,
            'is_over':
            [False for i in range(len(self._loc_history) - 1)] + [True],
            'resolution': self._res_history,
        }
        self.HITL_logger.append(log)

    def HITL_set_location(self, location, res):
        """ Method to set the location in the image to that specified in the logs """
        self._location = location
        self.xscale = res
        self.yscale = res
        self.zscale = res

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
        restart current episoide
        """
        if self.task == 'browse' and self._loc_history:
            self.HITL_episode_log()

        self.terminal = False
        self.reward = 0
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset the stat counter
        self._loc_history = [(0, ) * self.dims] * self._history_length
        # list of q-value lists
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        self._clear_history()
        self.new_random_game()

    def new_random_game(self):
        """
        load image,
        set dimensions,
        randomize start point,
        init _screen, qvals,
        calc distance to goal
        """
        self.terminal = False
        self.viewer = None
        # ######################################################################
        # ## generate evaluation results from 19 different points
        # if self.count_points ==0:
        #     print('\n============== new game ===============\n')
        #     # save results
        #     if self.total_loc:
        #         with open(self.csvfile, 'a') as outcsv:
        #             fields= [self.filename, self.cur_dist]
        #             writer = csv.writer(outcsv)
        #             writer.writerow(map(lambda x: x, fields))
        #         self.total_loc = []
        #     # sample a new image
        #     self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        #     scale = next(self.start_points)
        #     self.count_points +=1
        # else:
        #     self.count_points += 1
        #     logger.info('count_points {}'.format(self.count_points))
        #     scale = next(self.start_points)
        #
        # x = int(scale[0] * self._image.dims[0])
        # y = int(scale[1] * self._image.dims[1])
        # z = int(scale[2] * self._image.dims[2])
        # logger.info('starting point {}-{}-{}'.format(x,y,z))
        # ######################################################################

        # sample a new image
        self._image, self._target_loc, self.filepath, self.spacing = next(
            self.sampled_files)
        self.filename = os.path.basename(self.filepath)

        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            # #cardiac
            # if self.data_type == 'CardiacMRI':
            #     self.action_step = 6
            #     self.xscale = 2
            #     self.yscale = 2
            #     self.zscale = 2
            # #brain or fetal
            # else:
            #     self.action_step = 9
            #     self.xscale = 3
            #     self.yscale = 3
            #     self.zscale = 3
            self.action_step = 9
            self.xscale = 3
            self.yscale = 3
            self.zscale = 3

        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1
            self.zscale = 1
        # image volume size
        self._image_dims = self._image.dims

        #######################################################################
        ## select random starting point
        # add padding to avoid start right on the border of the image
        if (self.task == 'train'):
            skip_thickness = ((int)(self._image_dims[0] / 5),
                              (int)(self._image_dims[1] / 5),
                              (int)(self._image_dims[2] / 5))
        else:
            skip_thickness = (int(self._image_dims[0] / 4),
                              int(self._image_dims[1] / 4),
                              int(self._image_dims[2] / 4))

        x = self.rng.randint(0 + skip_thickness[0],
                             self._image_dims[0] - skip_thickness[0])
        y = self.rng.randint(0 + skip_thickness[1],
                             self._image_dims[1] - skip_thickness[1])
        z = self.rng.randint(0 + skip_thickness[2],
                             self._image_dims[2] - skip_thickness[2])
        #######################################################################

        self._location = (x, y, z)
        self._start_location = (x, y, z)
        self._qvalues = [
            0,
        ] * self.actions
        self._screen = self._current_state()

        if self.task == 'play':
            self.cur_dist = 0
        else:
            self.cur_dist = self.calcDistance(self._location, self._target_loc,
                                              self.spacing)

    def calcDistance(self, points1, points2, spacing=(1, 1, 1)):
        """ calculate the distance between two points in mm"""
        spacing = np.array(spacing)
        points1 = spacing * np.array(points1)
        points2 = spacing * np.array(points2)
        return np.linalg.norm(points1 - points2)

    def step(self, act, qvalues, viewer=None):
        """The environment's step function returns exactly what we need.
        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """
        self._qvalues = qvalues
        current_loc = self._location
        self.terminal = False
        go_out = False
        self.viewer = viewer

        # UP Z+ -----------------------------------------------------------
        if (act == 0):
            next_location = (current_loc[0], current_loc[1],
                             round(current_loc[2] + self.action_step))
            if (next_location[2] >= self._image_dims[2]):
                # print(' trying to go out the image Z+ ',)
                next_location = current_loc
                go_out = True

        # FORWARD Y+ ---------------------------------------------------------
        if (act == 1):
            next_location = (current_loc[0],
                             round(current_loc[1] + self.action_step),
                             current_loc[2])
            if (next_location[1] >= self._image_dims[1]):
                # print(' trying to go out the image Y+ ',)
                next_location = current_loc
                go_out = True
        # RIGHT X+ -----------------------------------------------------------
        if (act == 2):
            next_location = (round(current_loc[0] + self.action_step),
                             current_loc[1], current_loc[2])
            if next_location[0] >= self._image_dims[0]:
                # print(' trying to go out the image X+ ',)
                next_location = current_loc
                go_out = True
        # LEFT X- -----------------------------------------------------------
        if act == 3:
            next_location = (round(current_loc[0] - self.action_step),
                             current_loc[1], current_loc[2])
            if next_location[0] <= 0:
                # print(' trying to go out the image X- ',)
                next_location = current_loc
                go_out = True
        # BACKWARD Y- ---------------------------------------------------------
        if act == 4:
            next_location = (current_loc[0],
                             round(current_loc[1] - self.action_step),
                             current_loc[2])
            if next_location[1] <= 0:
                # print(' trying to go out the image Y- ',)
                next_location = current_loc
                go_out = True
        # DOWN Z- -----------------------------------------------------------
        if act == 5:
            next_location = (current_loc[0], current_loc[1],
                             round(current_loc[2] - self.action_step))
            if next_location[2] <= 0:
                # print(' trying to go out the image Z- ',)
                next_location = current_loc
                go_out = True
        # ---------------------------------------------------------------------
        # ---------------------------------------------------------------------
        # punish -1 reward if the agent tries to go out
        if (self.task != 'play'):
            if go_out:
                self.reward = -1
            else:
                self.reward = self._calc_reward(current_loc, next_location)
        # update screen, reward ,location, terminal
        self._location = next_location
        self._screen = self._current_state()

        # terminate if the distance is less than 1 during trainig
        if (self.task == 'train'):
            if self.cur_dist <= 1:
                # print('Terminal Condition DISTANCE')
                self.terminal = True
                self.num_success.feed(1)

        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames:
            # print('Terminal Condition NUMBER OF FRAMES')
            self.terminal = True

        # update history buffer with new location and qvalues
        if (self.task != 'play'):
            self.cur_dist = self.calcDistance(self._location, self._target_loc,
                                              self.spacing)
        self._update_history()

        # check if agent oscillates
        if self._oscillate:
            self._location = self.getBestLocation()
            self._screen = self._current_state()

            if self.task != 'play':
                self.cur_dist = self.calcDistance(self._location,
                                                  self._target_loc,
                                                  self.spacing)
            # multi-scale steps
            if self.multiscale:
                if self.xscale > 1:
                    self.adjustMultiScale()
                # terminate if scale is less than 1
                else:
                    self.terminal = True
                    # print("TERMINAL OCCILATE")
                    if self.cur_dist <= 1:
                        self.num_success.feed(1)
            else:
                self.terminal = True
                # print("TERMINAL OCCILATE")
                if self.cur_dist <= 1:
                    self.num_success.feed(1)

        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        distance_error = self.cur_dist
        self.current_episode_score.feed(self.reward)
        # print(self.reward) this is every step of the agent

        info = {
            'score': self.current_episode_score.sum,
            'gameOver': self.terminal,
            'distError': distance_error,
            'filename': self.filename
        }

        if self.terminal:
            # store results when batch evaluation
            if self.directory:
                path = self.directory
                with open(path, 'a') as outcsv:
                    fields = [
                        info['filename'], info['score'], info['distError']
                    ]
                    writer = csv.writer(outcsv)
                    writer.writerow(map(lambda x: x, fields))

        # #######################################################################
        # ## generate evaluation results from 19 different points
        # if self.terminal:
        #     logger.info(info)
        #     self.total_loc.append(self._location)
        #     if not(self.count_points == 19):
        #         self._restart_episode()
        #     else:
        #         mean_location = np.mean(self.total_loc,axis=0)
        #         logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location))
        #         self.cur_dist = self.calcDistance(mean_location,
        #                                           self._target_loc,
        #                                           self.spacing)
        #         logger.info('final distance error {} \n'.format(self.cur_dist))
        #         self.count_points = 0
        # #######################################################################

        return self._current_state(), self.reward, self.terminal, info

    def stepManual(self, act, viewer):
        """ Version of above for browse mode allowing the user to navigate
            through an uploaded img
        """
        # self._qvalues = qvalues
        current_loc = self._location
        self.terminal = False
        go_out = False
        self.viewer = viewer
        self._act = act

        # -1 passed during init so skip updating current location
        if act == -1:
            pass
        else:
            # UP Z+ -----------------------------------------------------------
            if (act == 0):
                next_location = (current_loc[0], current_loc[1],
                                 round(current_loc[2] + self.action_step))
                if (next_location[2] >= self._image_dims[2]):
                    # print(' trying to go out the image Z+ ',)
                    next_location = current_loc
                    go_out = True

            # FORWARD Y+ ---------------------------------------------------------
            if (act == 1):
                next_location = (current_loc[0],
                                 round(current_loc[1] + self.action_step),
                                 current_loc[2])
                if (next_location[1] >= self._image_dims[1]):
                    # print(' trying to go out the image Y+ ',)
                    next_location = current_loc
                    go_out = True
            # RIGHT X+ -----------------------------------------------------------
            if (act == 2):
                next_location = (round(current_loc[0] + self.action_step),
                                 current_loc[1], current_loc[2])
                if next_location[0] >= self._image_dims[0]:
                    # print(' trying to go out the image X+ ',)
                    next_location = current_loc
                    go_out = True
            # LEFT X- -----------------------------------------------------------
            if act == 3:
                next_location = (round(current_loc[0] - self.action_step),
                                 current_loc[1], current_loc[2])
                if next_location[0] <= 0:
                    # print(' trying to go out the image X- ',)
                    next_location = current_loc
                    go_out = True
            # BACKWARD Y- ---------------------------------------------------------
            if act == 4:
                next_location = (current_loc[0],
                                 round(current_loc[1] - self.action_step),
                                 current_loc[2])
                if next_location[1] <= 0:
                    # print(' trying to go out the image Y- ',)
                    next_location = current_loc
                    go_out = True
            # DOWN Z- -----------------------------------------------------------
            if act == 5:
                next_location = (current_loc[0], current_loc[1],
                                 round(current_loc[2] - self.action_step))
                if next_location[2] <= 0:
                    # print(' trying to go out the image Z- ',)
                    next_location = current_loc
                    go_out = True

            if go_out:
                self.reward = -1
            else:
                self.reward = self._calc_reward(current_loc, next_location)

            self._location = next_location

        self._screen = self._current_state()
        self.cur_dist = self.calcDistance(self._location, self._target_loc,
                                          self.spacing)

        self._update_history()

        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        return self._current_state()

    def getBestLocation(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        last_qvalues_history = self._qvalues_history[-4:]
        last_loc_history = self._loc_history[-4:]
        best_qvalues = np.max(last_qvalues_history, axis=1)
        # best_idx = best_qvalues.argmax()
        best_idx = best_qvalues.argmin()
        best_location = last_loc_history[best_idx]

        return best_location

    def adjustMultiScale(self, higherRes=True):
        '''Adjusts the agent's step size'''
        if higherRes:
            self.xscale -= 1
            self.yscale -= 1
            self.zscale -= 1
            self.action_step = int(self.action_step / 3)
        else:
            self.xscale += 1
            self.yscale += 1
            self.zscale += 1
            self.action_step = int(self.action_step * 3)

        self._clear_history()

    def _clear_history(self):
        ''' clear history buffer with current state
        '''
        if self.task == 'browse':
            self._loc_history = []
            self._act_history = []
            self._reward_history = []
            self._res_history = []
        else:
            self._loc_history = [(0, ) * self.dims] * self._history_length
            self._qvalues_history = [(0, ) * self.actions
                                     ] * self._history_length

    def _update_history(self):
        ''' update history buffer with current state
        '''
        if self.task == 'browse':
            self._loc_history.append(self._location)
            self._act_history.append(self._act)
            self._res_history.append(self.xscale)
            self._reward_history.append(self.reward)
        else:
            # update location history
            self._loc_history[:-1] = self._loc_history[1:]
            self._loc_history[-1] = self._location
            # update q-value history
            self._qvalues_history[:-1] = self._qvalues_history[1:]
            self._qvalues_history[-1] = self._qvalues

    def _current_state(self):
        """
        crop image data around current location to update what network sees.
        update rectangle

        :return: new state
        """
        # initialize screen with zeros - all background
        screen = np.zeros((self.screen_dims)).astype(self._image.data.dtype)

        # screen uses coordinate system relative to origin (0, 0, 0)
        screen_xmin, screen_ymin, screen_zmin = 0, 0, 0
        screen_xmax, screen_ymax, screen_zmax = self.screen_dims

        # extract boundary locations using coordinate system relative to "global" image
        # width, height, depth in terms of screen coord system
        if self.xscale % 2:
            xmin = self._location[0] - int(self.width * self.xscale / 2) - 1
            xmax = self._location[0] + int(self.width * self.xscale / 2)
            ymin = self._location[1] - int(self.height * self.yscale / 2) - 1
            ymax = self._location[1] + int(self.height * self.yscale / 2)
            zmin = self._location[2] - int(self.depth * self.zscale / 2) - 1
            zmax = self._location[2] + int(self.depth * self.zscale / 2)
        else:
            xmin = self._location[0] - round(self.width * self.xscale / 2)
            xmax = self._location[0] + round(self.width * self.xscale / 2)
            ymin = self._location[1] - round(self.height * self.yscale / 2)
            ymax = self._location[1] + round(self.height * self.yscale / 2)
            zmin = self._location[2] - round(self.depth * self.zscale / 2)
            zmax = self._location[2] + round(self.depth * self.zscale / 2)

        # check if they violate image boundary and fix it
        if xmin < 0:
            xmin = 0
            screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale))
        if ymin < 0:
            ymin = 0
            screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale))
        if zmin < 0:
            zmin = 0
            screen_zmin = screen_zmax - len(np.arange(zmin, zmax, self.zscale))
        if xmax > self._image_dims[0]:
            xmax = self._image_dims[0]
            screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale))
        if ymax > self._image_dims[1]:
            ymax = self._image_dims[1]
            screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale))
        if zmax > self._image_dims[2]:
            zmax = self._image_dims[2]
            screen_zmax = screen_zmin + len(np.arange(zmin, zmax, self.zscale))

        # crop image data to update what network sees
        # image coordinate system becomes screen coordinates
        # scale can be thought of as a stride
        screen[screen_xmin:screen_xmax, screen_ymin:screen_ymax,
               screen_zmin:screen_zmax] = self._image.data[
                   xmin:xmax:self.xscale, ymin:ymax:self.yscale,
                   zmin:zmax:self.zscale]

        # update rectangle limits from input image coordinates
        # this is what the network sees
        self.rectangle = Rectangle(xmin, xmax, ymin, ymax, zmin, zmax)

        return screen

    def get_plane_z(self, z=0):
        im = self._image.data[:, :, z]
        if self.data_type in ['BrainMRI', 'CardiacMRI']:
            im = np.rot90(im, 1)  # Rotate 90 degrees ccw
        return im

    def get_plane_x(self, x=0):
        im = self._image.data[x, :, :]
        im = np.rot90(im, 1)
        return im

    def get_plane_y(self, y=0):
        im = self._image.data[:, y, :]
        im = np.rot90(im, 1)
        return im

    def _calc_reward(self, current_loc, next_loc):
        """ Calculate the new reward based on the decrease in euclidean distance to the target location
        """
        curr_dist = self.calcDistance(current_loc, self._target_loc,
                                      self.spacing)
        next_dist = self.calcDistance(next_loc, self._target_loc, self.spacing)
        return curr_dist - next_dist

    @property
    def _oscillate(self):
        """ Return True if the agent is stuck and oscillating
        """
        counter = Counter(self._loc_history)
        freq = counter.most_common()

        if freq[0][0] == (0, 0, 0):
            if (freq[1][1] > 3):
                return True
            else:
                return False
        elif (freq[0][1] > 3):
            return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "UP",  # MOVE Z+
            2: "FORWARD",  # MOVE Y+
            3: "RIGHT",  # MOVE X+
            4: "LEFT",  # MOVE X-
            5: "BACKWARD",  # MOVE Y-
            6: "DOWN",  # MOVE Z-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self, return_rgb_array=False):
        # get dimensions
        current_point = self._location
        target_point = self._target_loc
        # get image and convert it to pyglet

        plane = self.get_plane_z(current_point[2])
        plane_x = self.get_plane_x(current_point[0])
        plane_y = self.get_plane_y(current_point[1])

        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 2
        scale_y = 2
        scale_z = 2
        current_point = (current_point[0] * scale_x,
                         current_point[1] * scale_y,
                         current_point[2] * scale_z)
        if target_point is not None:
            target_point = (target_point[0] * scale_x,
                            target_point[1] * scale_y,
                            target_point[2] * scale_z)
        self.rectangle = (self.rectangle[0] * scale_x, self.rectangle[1] *
                          scale_x, self.rectangle[2] * scale_y,
                          self.rectangle[3] * scale_y, self.rectangle[4] *
                          scale_z, self.rectangle[5] * scale_z)
        img = cv2.resize(
            plane,
            (int(scale_x * plane.shape[1]), int(scale_y * plane.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img_x = cv2.resize(
            plane_x,
            (int(scale_x * plane_x.shape[1]), int(scale_y * plane_x.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img_y = cv2.resize(
            plane_y,
            (int(scale_y * plane_y.shape[1]), int(scale_y * plane_y.shape[0])),
            interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        img_x = cv2.cvtColor(img_x, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        img_y = cv2.cvtColor(img_y, cv2.COLOR_GRAY2RGB)  # congvert to rgb

        ########################################################################
        # PyQt GUI Code Section

        # Section of code to get initial value to be stored in a pickle object
        # (Uncomment if you wish to modify default_data.pickle)
        # viewer_param = {
        #     "arrs": (img, img_x, img_y),
        #     "filepath": self.filename
        # }
        # with open("default_data.pickle", "wb") as f:
        #     viewer_param = pickle.dump(viewer_param, f)
        #     exit()

        # Sleep until resume (for browse mode)
        if self.task != 'browse':
            while self.viewer.right_widget.automatic_mode.thread.pause:
                time.sleep(0.5)

                # Check whether thread should be killed (pause)
                if self.viewer.right_widget.automatic_mode.thread.terminate:
                    exit()

            # Check whether thread should be killed (general)
            if self.viewer.right_widget.automatic_mode.thread.terminate:
                exit()

        # Need to emit signal here (to draw images)
        self.viewer.widget.agent_signal.emit({
            "arrs": (img, img_x, img_y),
            "agent_loc": current_point,
            "target": target_point,
            "error": self.cur_dist,
            "scale": self.xscale,
            "rect": self.rectangle,
            "task": self.task,
            "is_terminal": self.terminal,
            "cnt": self.cnt
        })

        if self.task != 'browse':
            # Control agent speed
            if self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.FAST:
                time.sleep(0)
            elif self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.MEDIUM:
                time.sleep(0.5)
            else:
                time.sleep(1.5)

            ########################################################################

            # save gif
            if self.saveGif:
                image_data = pyglet.image.get_buffer_manager(
                ).get_color_buffer().get_image_data()
                data = image_data.get_data('RGB', image_data.width * 3)
                arr = np.array(bytearray(data)).astype('uint8')
                arr = np.flip(
                    np.reshape(arr, (image_data.height, image_data.width, -1)),
                    0)
                im = Image.fromarray(arr)
                self.gif_buffer.append(im)

                if not self.terminal:
                    gifname = self.filename.split('.')[0] + '.gif'
                    self.viewer.saveGif(gifname,
                                        arr=self.gif_buffer,
                                        duration=self.viz)
            if self.saveVideo:
                dirname = 'tmp_video'
                if self.cnt <= 1:
                    if os.path.isdir(dirname):
                        logger.warn(
                            """Log directory {} exists! Use 'd' to delete it. """
                            .format(dirname))
                        act = input("select action: d (delete) / q (quit): "
                                    ).lower().strip()
                        if act == 'd':
                            shutil.rmtree(dirname, ignore_errors=True)
                        else:
                            raise OSError(
                                "Directory {} exits!".format(dirname))
                    os.mkdir(dirname)

                frame = dirname + '/' + '%04d' % self.cnt + '.png'
                pyglet.image.get_buffer_manager().get_color_buffer().save(
                    frame)
                if self.terminal:
                    resolution = str(3 * self.viewer.img_width) + 'x' + str(
                        3 * self.viewer.img_height)
                    save_cmd = [
                        'ffmpeg', '-f', 'image2', '-framerate', '30',
                        '-pattern_type', 'sequence', '-start_number', '0',
                        '-r', '6', '-i', dirname + '/%04d.png', '-s',
                        resolution, '-vcodec', 'libx264', '-b:v', '2567k',
                        self.filename + '.mp4'
                    ]
                    subprocess.check_output(save_cmd)
                    shutil.rmtree(dirname, ignore_errors=True)
Example #34
0
 def reset_stat(self):
     """ Reset all statistics counter"""
     self.stats = defaultdict(list)
     self.num_games = StatCounter()
     self.num_success = StatCounter()
Example #35
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """
    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 num_actions,
                 batch_size,
                 memory_size,
                 init_memory_size,
                 init_exploration,
                 update_frequency,
                 encoding_file='../AutoEncoder/encoding.npy'):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = num_actions
        self.encoding = np.load(encoding_file)
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event(
        )  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape)
        self.player.reset()
        # init_cards = np.arange(36)
        # self.player.prepare_manual(init_cards)
        self.player.prepare()
        # self._current_ob = self.player.get_state_prob()
        self._current_ob = self.get_state()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()

    def get_state(self):
        def cards_char2embedding(cards_char):
            test = (action_space_onehot60 == Card.char2onehot60(cards_char))
            test = np.all(test, axis=1)
            target = np.where(test)[0]
            return self.encoding[target[0]]

        s = self.player.get_state_prob()
        s = np.concatenate(
            [Card.val2onehot60(self.player.get_curr_handcards()), s])
        last_two_cards_char = self.player.get_last_two_cards()
        last_two_cards_char = [to_char(c) for c in last_two_cards_char]
        return np.concatenate([
            s,
            cards_char2embedding(last_two_cards_char[0]),
            cards_char2embedding(last_two_cards_char[1])
        ])

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            for _ in range(self.update_frequency):
                self._populate_exp()

        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(
            self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""
        old_s = self._current_ob
        if self.rng.rand() <= self.exploration:
            act = self.rng.choice(range(self.num_actions))
        else:
            mask = get_mask(to_char(self.player.get_curr_handcards()),
                            action_space,
                            to_char(self.player.get_last_outcards()))
            q_values = self.predictor(old_s[None, ...])[0][0]
            q_values[mask == 0] = np.nan
            act = np.nanargmax(q_values)
            assert act < self.num_actions
        reward, isOver, _ = self.player.step_manual(to_value(
            action_space[act]))

        # step for AI
        while not isOver and self.player.get_role_ID() != ROLE_ID_TO_TRAIN:
            _, reward, _ = self.player.step_auto()
            isOver = (reward != 0)
        if ROLE_ID_TO_TRAIN == 2:
            reward = -reward
        self._current_game_score.feed(reward)

        if isOver:
            # print('lord wins' if reward > 0 else 'farmer wins')
            self._player_scores.feed(self._current_game_score.sum)
            # print(self._current_game_score.sum)
            while True:
                self.player.reset()
                # init_cards = np.arange(36)
                # self.player.prepare_manual(init_cards)
                self.player.prepare()
                early_stop = False
                while self.player.get_role_ID() != ROLE_ID_TO_TRAIN:
                    _, reward, _ = self.player.step_auto()
                    isOver = (reward != 0)
                    if isOver:
                        print('prestart ends too early! now resetting env')
                        early_stop = True
                        break
                if early_stop:
                    continue
                self._current_ob = self.get_state()
                break
            self._current_game_score.reset()
        self._current_ob = self.get_state()
        self.mem.append(Experience(old_s, act, reward, isOver))

    def debug(self, cnt=100000):
        with get_tqdm(total=cnt) as pbar:
            for i in range(cnt):
                self.mem.append(
                    Experience(
                        np.zeros(
                            [self.num_actions[0], self.num_actions[1], 256]),
                        0, 0))
                # self._current_ob, self._action_space = self.get_state_and_action_spaces(None)
                pbar.update()

    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        while True:
            idx = self.rng.randint(self._populate_job_queue.maxsize *
                                   self.update_frequency,
                                   len(self.mem) - 1,
                                   size=self.batch_size)
            batch_exp = [self.mem.sample(i) for i in idx]

            yield self._process_batch(batch_exp)
            self._populate_job_queue.put(1)

    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='float32')
        action = np.asarray([e[1] for e in batch_exp], dtype='int32')
        reward = np.asarray([e[2] for e in batch_exp], dtype='float32')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver]

    def _setup_graph(self):
        self.predictor = self.trainer.get_predictor(*self.predictor_io_names)

    def _before_train(self):
        while self.player.get_role_ID() != ROLE_ID_TO_TRAIN:
            self.player.step_auto()
            self._current_ob, self._action_space = self.get_state_and_action_spaces(
            )
        self._init_memory()
        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

    def _trigger(self):
        v = self._player_scores
        try:
            mean, max = v.average, v.max
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
        except Exception:
            logger.exception("Cannot log training scores.")
        v.reset()
Example #36
0
class ExpReplay(DataFlow, Callback):
    """
    Implement experience replay in the paper
    `Human-level control through deep reinforcement learning
    <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.

    This implementation provides the interface as a :class:`DataFlow`.
    This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).

    This implementation assumes that state is
    batch-able, and the network takes batched inputs.
    """
    def __init__(self, predictor_io_names, player, state_shape, batch_size,
                 memory_size, init_memory_size, init_exploration,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event(
        )  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()

    def get_simulator_thread(self):
        # spawn a separate thread to run policy
        def populate_job_func():
            self._populate_job_queue.get()
            for _ in range(self.update_frequency):
                self._populate_exp()

        th = ShareSessionThread(LoopThread(populate_job_func, pausable=False))
        th.name = "SimulatorThread"
        return th

    def _init_memory(self):
        logger.info("Populating replay memory with epsilon={} ...".format(
            self.exploration))

        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < self.init_memory_size:
                self._populate_exp()
                pbar.update()
        self._init_memory_flag.set()

    # quickly fill the memory for debug
    def _fake_init_memory(self):
        from copy import deepcopy
        with get_tqdm(total=self.init_memory_size) as pbar:
            while len(self.mem) < 5:
                self._populate_exp()
                pbar.update()
            while len(self.mem) < self.init_memory_size:
                self.mem.append(deepcopy(self.mem._hist[0]))
                pbar.update()
        self._init_memory_flag.set()

    def _populate_exp(self):
        """ populate a transition by epsilon-greedy"""
        old_s = self._current_ob
        if self.rng.rand() <= self.exploration or (len(self.mem) <=
                                                   self.history_len):
            act = self.rng.choice(range(self.num_actions))
        else:
            # build a history state
            history = self.mem.recent_state()
            history.append(old_s)
            history = np.stack(history, axis=2)

            # assume batched network
            q_values = self.predictor([[history]
                                       ])[0][0]  # this is the bottleneck
            act = np.argmax(q_values)
        self._current_ob, reward, isOver, info = self.player.step(act)
        if isOver:
            if info['gameOver']:  # only record score when a whole game is over (not when an episode is over)
                self._player_scores.feed(info['score'])
            self.player.reset()
        self.mem.append(Experience(old_s, act, reward, isOver))

    def _debug_sample(self, sample):
        import cv2

        def view_state(comb_state):
            state = comb_state[:, :, :-1]
            next_state = comb_state[:, :, 1:]
            r = np.concatenate(
                [state[:, :, k] for k in range(self.history_len)], axis=1)
            r2 = np.concatenate(
                [next_state[:, :, k] for k in range(self.history_len)], axis=1)
            r = np.concatenate([r, r2], axis=0)
            cv2.imshow("state", r)
            cv2.waitKey()

        print("Act: ", sample[2], " reward:", sample[1], " isOver: ",
              sample[3])
        if sample[1] or sample[3]:
            view_state(sample[0])

    def get_data(self):
        # wait for memory to be initialized
        self._init_memory_flag.wait()

        while True:
            idx = self.rng.randint(self._populate_job_queue.maxsize *
                                   self.update_frequency,
                                   len(self.mem) - self.history_len - 1,
                                   size=self.batch_size)
            batch_exp = [self.mem.sample(i) for i in idx]

            yield self._process_batch(batch_exp)
            self._populate_job_queue.put(1)

    def _process_batch(self, batch_exp):
        state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
        reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
        action = np.asarray([e[2] for e in batch_exp], dtype='int8')
        isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
        return [state, action, reward, isOver]

    def _setup_graph(self):
        self.predictor = self.trainer.get_predictor(*self.predictor_io_names)

    def _before_train(self):
        self._init_memory()
        self._simulator_th = self.get_simulator_thread()
        self._simulator_th.start()

    def _trigger(self):
        v = self._player_scores
        try:
            mean, max = v.average, v.max
            self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
            self.trainer.monitors.put_scalar('expreplay/max_score', max)
        except:
            logger.exception("Cannot log training scores.")
        v.reset()
Example #37
0
class AgentBase(GymEnv):
    def __init__(self,
                 agentIdent,
                 is_train=False,
                 auto_restart=True,
                 **kwargs):
        # super(AgentBase, self).__init__(name='torcs')
        self.auto_restart = auto_restart
        self._isTrain = is_train
        self._agentIdent = agentIdent
        self._kwargs = kwargs
        self._init()

    def _init(self):
        logger.info("[{}]: agent init, isTrain={}".format(
            self._agentIdent, self._isTrain))
        self._episodeCount = -1
        from tensorpack.utils.utils import get_rng
        self._rng = get_rng(self)
        from tensorpack.utils.stats import StatCounter
        self.reset_stat()
        self.rwd_counter = StatCounter()
        self._memorySaver = None
        save_dir = self._kwargs.pop('save_dir', None)
        if save_dir is not None:
            self._memorySaver = MemorySaver(
                save_dir,
                self._kwargs.pop('max_save_item', 3),
                self._kwargs.pop('min_save_score', None),
            )
        self.restart_episode()
        pass

    def restart_episode(self):
        self.rwd_counter.reset()
        self.__ob = self.reset()

    def finish_episode(self):
        score = self.rwd_counter.sum
        self.stats['score'].append(score)
        logger.info(
            "episode finished, rewards = {:.3f}, episode = {}, steps = {}".
            format(score, self._episodeCount, self._episodeSteps))

    def current_state(self):
        return self.__ob

    def reset(self):
        self._episodeCount += 1
        ret = self._reset()
        self._episodeRewards = 0.
        self._episodeSteps = 0
        if self._memorySaver:
            self._memorySaver.createMemory(self._episodeCount)
        logger.info("restart, episode={}".format(self._episodeCount))
        return ret

    @abc.abstractmethod
    def _reset(self):
        pass

    def action(self, pred):
        ob, act, r, isOver, info = self._step(pred)
        self.rwd_counter.feed(r)
        if self._memorySaver:
            self._memorySaver.addCurrent(ob, act, r, isOver)
        self.__ob = ob
        self._episodeSteps += 1
        self._episodeRewards += r
        if isOver:
            self.finish_episode()
            if self.auto_restart:
                self.restart_episode()
        return act, r, isOver

    @abc.abstractmethod
    def _step(self, action):
        raise NotImplementedError

    def get_action_space(self):
        raise NotImplementedError