Пример #1
0
class PreproWrapper(gym.Wrapper):
    """
    Wrapper for Pong to apply preprocessing
    Stores the state into variable self.obs
    """
    def __init__(self, env, prepro, shape, overwrite_render=True, high=255):
        """
        Args:
            env: (gym env)
            prepro: (function) to apply to a state for preprocessing
            shape: (list) shape of obs after prepro
            overwrite_render: (bool) if True, render is overwriten to vizualise effect of prepro
            grey_scale: (bool) if True, assume grey scale, else black and white
            high: (int) max value of state after prepro
        """
        super(PreproWrapper, self).__init__(env)
        self.overwrite_render = overwrite_render
        self.viewer = None
        self.prepro = prepro
        self.observation_space = spaces.Box(low=0,
                                            high=high,
                                            shape=shape,
                                            dtype=np.uint8)
        self.high = high

    def step(self, action):
        """
        Overwrites _step function from environment to apply preprocess
        """
        obs, reward, done, info = self.env.step(action)
        self.obs = self.prepro(obs)
        return self.obs, reward, done, info

    def reset(self):
        self.obs = self.prepro(self.env.reset())
        return self.obs

    def _render(self, mode='human', close=False):
        """
        Overwrite _render function to vizualize preprocessing
        """

        if self.overwrite_render:
            if close:
                if self.viewer is not None:
                    self.viewer.close()
                    self.viewer = None
                return
            img = self.obs
            if mode == 'rgb_array':
                return img
            elif mode == 'human':
                from gym.envs.classic_control import rendering
                if self.viewer is None:
                    self.viewer = SimpleImageViewer()
                self.viewer.imshow(img)

        else:
            super(PongWrapper, self)._render(mode, close)
 def _render(self, mode='human', close=False):
     """
     Overwrite _render function to vizualize preprocessing
     """
     if close:
         if self.viewer is not None:
             self.viewer.close()
             self.viewer = None
         return
     img = self.obs
     if mode == 'rgb_array':
         return img
     elif mode == 'human':
         from gym.envs.classic_control import rendering
         if self.viewer is None:
             self.viewer = SimpleImageViewer()
         self.viewer.imshow(img)
Пример #3
0
    def display(self):
        """this is called at every step"""
        current_point = self._location
        # img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        # scale_x = 1
        # scale_y = 1

        # print("nodes ", self._agent_nodes)
        # print("ious", self._IOU_history)
        print("reward history ", np.unique(self.reward_history))
        print("IOU history ", np.unique(self._IOU_history))
        plotter = Viewer(self.original_state,
                         zip(self._agent_nodes, self._IOU_history),
                         filepath=self.filename)
        #
        # #
        # # from viewer import SimpleImageViewer
        # # self.viewer = SimpleImageViewer(self._state,
        # #                                 scale_x=1,
        # #                                 scale_y=1,
        # #                                 filepath=self.filename)
        #     self.gif_buffer = []
        #
        #
        # # render and wait (viz) time between frames
        # self.viewer.render()
        # # time.sleep(self.viz)
        # # save gif
        if self.saveGif:
            # if self.saveGif:
            # TODO make this a method of viewer
            raise NotImplementedError
            # image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            # data = image_data.get_data('RGB', image_data.width * 3)
            # arr = np.array(bytearray(data)).astype('uint8')
            # arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            # im = Image.fromarray(arr)
            # self.gif_buffer.append(im)
            #
            # if not self.terminal:
            #     gifname = self.filename.split('.')[0] + '.gif'
            #     self.viewer.saveGif(gifname, arr=self.gif_buffer,
            #                         duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            # if self.cnt <= 1:
            #     if os.path.isdir(dirname):
            #         logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
            #         act = input("select action: d (delete) / q (quit): ").lower().strip()
            #         if act == 'd':
            #             shutil.rmtree(dirname, ignore_errors=True)
            #         else:
            #             raise OSError("Directory {} exits!".format(dirname))
            #     os.mkdir(dirname)

            vid_fpath = self.filename + '.mp4'
            # vid_fpath = dirname + '/' + self.filename + '.mp4'
            plotter.save_vid(vid_fpath, self.max_num_frames - 1)
            # plotter.show_agent()

        if self.viz:  # show progress
            # plotter.show()
            # actually, let's just save the files for later
            output_dir = os.path.abspath("saved_trajectories/")
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)

            # outfile_fpath = os.path.join(output_dir, input_fname + ".npy")
            #
            # # don't overwrite
            # if not os.path.isfile(outfile_fpath) or overwrite:
            #     desired_len = 16
            #     img_array = tiff2array.imread(input_fpath)
            #     # make all arrays the same shape
            #     # format: ((top, bottom), (left, right))
            #     shp = img_array.shape
            #     # print(shp, flush=True)
            #     if shp != (desired_len, desired_len, desired_len):
            #         try:
            #             img_array = np.pad(img_array, (
            #             (0, desired_len - shp[0]), (0, desired_len - shp[1]), (0, desired_len - shp[2])),
            #                                'constant')
            #         except ValueError:
            #             raise
            #             # print(shp, flush=True)  # don't wait for all threads to finish before printing
            #
            np.savez(output_dir + self.filename,
                     locations=self._agent_nodes,
                     original_state=self.original_state,
                     reward_history=self.reward_history)
Пример #4
0
    def display(self, return_rgb_array=False):
        # pass
        # --------------------------------------------------------------------
        ## planes seen by the agent
        # # get image and convert it to pyglet
        # plane = self._plane.grid[:,:,round(self.depth/2)] # z-plane
        # # concatenate groundtruth image
        # gt_plane = self._groundTruth_plane.grid[:,:,round(self.depth/2)]
        # --------------------------------------------------------------------
        ## whole plan
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        current_plane = Plane(*getPlane(self.sitk_image,
                                        self._origin3d_point,
                                        self._plane.params,
                                        image_size,
                                        spacing=[1, 1, 1]))

        # get image and convert it to pyglet
        plane = current_plane.grid[:, :, int(image_size[2] / 2)]  # z-plane
        # concatenate groundtruth image
        gt_plane = self.groundTruth_plane_iso.grid[:, :,
                                                   int(image_size[2] / 2)]
        # --------------------------------------------------------------------
        # concatenate two planes side by side
        plane = np.concatenate((plane, gt_plane), axis=1)
        #
        img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 5
        scale_y = 5
        # img = cv2.resize(img,
        #                  (int(scale_x*img.shape[1]),int(scale_y*img.shape[0])),
        #                  interpolation=cv2.INTER_LINEAR)
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)
        self.viewer.display_text('Current Plane',
                                 color=(0, 0, 204, 255),
                                 x=int(0.7 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)
        self.viewer.display_text('Ground Truth',
                                 color=(0, 0, 204, 255),
                                 x=int(4.3 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)

        # display info
        dist_color_flag = False
        if len(self._dist_history) > 1:
            dist_color_flag = self.cur_dist < self._dist_history[-2]

        color_dist = (0, 204, 0, 255) if dist_color_flag else (204, 0, 0, 255)
        text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=color_dist,
                                 x=int(3 * img.shape[1] / 8),
                                 y=5 * scale_y)

        dist_color_flag = False
        if len(self._dist_history_params) > 1:
            dist_color_flag = self.cur_dist_params < self._dist_history_params[
                -2]

        # color_dist = (0,255,0,255) if dist_color_flag else (255,0,0,255)
        # text = 'Params Error ' + str(round(self.cur_dist_params,3))
        # self.viewer.display_text(text, color=color_dist,
        # x=int(6*img.shape[1]/8), y=5*scale_y)
        text = 'Spacing ' + str(round(self.spacing[0], 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=(204, 204, 0, 255),
                                 x=int(6 * img.shape[1] / 8),
                                 y=5 * scale_y)

        color_reward = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0,
                                                                 255)
        text = 'Reward ' + "%+d" % round(self.reward, 3)
        self.viewer.display_text(text,
                                 color=color_reward,
                                 x=2 * scale_x,
                                 y=5 * scale_y)

        # render and wait (viz) time between frames
        self.viewer.render()

        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            # set_trace()
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr).convert('P')
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.savegif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if (self.cnt <= 1):
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '3', '-i', dirname + '/%04d.png', '-s', '1280x720',
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename + '.mp4'
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #5
0
class MedicalPlayer(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(self,
                 directory=None,
                 files_list=None,
                 viz=False,
                 train=False,
                 screen_dims=(27, 27, 27),
                 spacing=(1, 1, 1),
                 nullop_start=30,
                 history_length=30,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param history_length: consider lost of lives as end of
            episode (useful for training)
        """
        super(MedicalPlayer, self).__init__()

        self.reset_stat()
        self._supervised = False
        self._init_action_angle_step = 8
        self._init_action_dist_step = 4

        #######################################################################
        ## save results in csv file
        self.csvfile = 'dummy.csv'
        if not train:
            with open(self.csvfile, 'w') as outcsv:
                fields = ["filename", "dist_error", "angle_error"]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))
        #######################################################################

        # read files from directory - add your data loader here
        self.files = filesListCardioMRPlane(directory, files_list)

        # prepare file sampler
        self.sampled_files = self.files.sample_circular()
        self.filepath = None
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.train = train
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self._plane_size = screen_dims
        self.dims = len(self.screen_dims)
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims
        # plane sampling spacings
        self.init_spacing = np.array(spacing)
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(8)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims)
        # history buffer for storing last locations to check oscillations
        self._history_length = history_length
        # circular buffer to store plane parameters history [4,history_length]
        self._plane_history = deque(maxlen=self._history_length)
        self._bestq_history = deque(maxlen=self._history_length)
        self._dist_history = deque(maxlen=self._history_length)
        self._dist_history_params = deque(maxlen=self._history_length)
        self._dist_supervised_history = deque(maxlen=self._history_length)
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        self._qvalues = [
            0,
        ] * self.actions

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []

        self._restart_episode()

    # -------------------------------------------------------------------------

    def _reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
        restart current episoide
        """
        self.terminal = False
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset score stat counter
        self._plane_history.clear()
        self._bestq_history.clear()
        self._dist_history.clear()
        self._dist_history_params.clear()
        self._dist_supervised_history.clear()
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length

        self.new_random_game()

    # -------------------------------------------------------------------------

    def new_random_game(self):
        # print('\n============== new game ===============\n')
        self.terminal = False
        self.viewer = None

        # sample a new image
        (self.sitk_image, self.sitk_image_2ch, self.sitk_image_4ch,
         self.landmarks, self.filepath) = next(self.sampled_files)
        self.filename = os.path.basename(self.filepath)
        # image volume size
        self._image_dims = self.sitk_image.GetSize()
        self.action_angle_step = copy.deepcopy(self._init_action_angle_step)
        self.action_dist_step = copy.deepcopy(self._init_action_dist_step)
        self.spacing = self.init_spacing.copy()

        # find center point of the initial plane
        if self.train:
            # sample randomly ±10% around the center point
            skip_thickness = ((int)(self._image_dims[0] / 2.5),
                              (int)(self._image_dims[1] / 2.5),
                              (int)(self._image_dims[2] / 2.5))
            x = self.rng.randint(0 + skip_thickness[0],
                                 self._image_dims[0] - skip_thickness[0])
            y = self.rng.randint(0 + skip_thickness[1],
                                 self._image_dims[1] - skip_thickness[1])
            z = self.rng.randint(0 + skip_thickness[2],
                                 self._image_dims[2] - skip_thickness[2])
        else:
            # during testing start sample a plane around the center point
            x, y, z = (self._image_dims[0] / 2, self._image_dims[1] / 2,
                       self._image_dims[2] / 2)
        self._origin3d_point = (int(x), int(y), int(z))
        # Get ground truth plane
        # logger.info('filename {} '.format(self.filename))
        self._groundTruth_plane = Plane(
            *getGroundTruthPlane(self.sitk_image,
                                 self.sitk_image_4ch,
                                 self._origin3d_point,
                                 self._plane_size,
                                 spacing=self.spacing))
        # get an istropic 1mm groundtruth plane
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        self.groundTruth_plane_iso = Plane(
            *getGroundTruthPlane(self.sitk_image, self.sitk_image_4ch,
                                 self._origin3d_point, image_size, [1, 1, 1]))

        self.landmarks_gt, _ = zip(*[
            projectPointOnPlane(point, self._groundTruth_plane.norm,
                                self._groundTruth_plane.origin)
            for point in self.landmarks
        ])
        # logger.info('groundTruth {}'.format(self._groundTruth_plane.params))
        # Get initial plane and set current plane the same
        self._plane = self._init_plane = Plane(
            *getInitialPlane(self.sitk_image, self._plane_size,
                             self._origin3d_point, self.spacing))
        _, dist = zip(*[
            projectPointOnPlane(point, self._plane.norm, self._plane.origin)
            for point in self.landmarks_gt
        ])
        # calculate current distance between initial and ground truth planes
        # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
        #                                      self._plane.points)
        self.cur_dist = np.mean(np.abs(dist))
        self.cur_dist_params = calcDistTwoParams(
            self._groundTruth_plane.params,
            self._plane.params,
            scale_angle=self.action_angle_step,
            scale_dist=self.action_dist_step)

        self._screen = self._current_state()

    # -------------------------------------------------------------------------

    def step(self, act, qvalues):
        """The environment's step function returns exactly what we need.
        Args:
          action:
        Returns:
          observation (object):
            an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning.
        """
        self.terminal = False
        self._qvalues = qvalues
        # get current plane params
        current_plane_params = np.copy(self._plane.params)
        next_plane_params = current_plane_params.copy()
        # ---------------------------------------------------------------------
        # theta x+ (param a)
        if (act == 0): next_plane_params[0] += self.action_angle_step
        # theta y+ (param b)
        if (act == 1): next_plane_params[1] += self.action_angle_step
        # theta z+ (param c)
        if (act == 2): next_plane_params[2] += self.action_angle_step
        # dist d+
        if (act == 3): next_plane_params[3] += self.action_dist_step

        # theta x- (param a)
        if (act == 4): next_plane_params[0] -= self.action_angle_step
        # theta y- (param b)
        if (act == 5): next_plane_params[1] -= self.action_angle_step
        # theta z- (param c)
        if (act == 6): next_plane_params[2] -= self.action_angle_step
        # dist d-
        if (act == 7): next_plane_params[3] -= self.action_dist_step
        # ---------------------------------------------------------------------
        # self.reward = self._calc_reward_points(self._plane.points,
        #                                        next_plane.points)
        self.reward = self._calc_reward_params(current_plane_params,
                                               next_plane_params)
        # threshold reward between -1 and 1
        self.reward = np.sign(self.reward)
        go_out = False

        if self._supervised and self.train:
            ## supervised
            dist_queue = deque(maxlen=self.actions)
            plane_params_queue = deque(maxlen=self.actions)
            # theta x+ (param a)
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[0] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta y+ (param b) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[1] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta z+ (param c) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[2] += self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # dist d+ ---------------------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[3] += self.action_dist_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta x- (param a) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[0] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta y- (param b) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[1] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # theta z- (param c) ----------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[2] -= self.action_angle_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))
            # dist d- ---------------------------------------------------------
            next_plane_params = np.copy(self._plane.params)
            next_plane_params[3] -= self.action_dist_step
            plane_params_queue.append(next_plane_params)
            # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points))
            dist_queue.append(
                calcDistTwoParams(self._groundTruth_plane.params,
                                  plane_params_queue[-1],
                                  scale_angle=self.action_angle_step,
                                  scale_dist=self.action_dist_step))

            # -----------------------------------------------------------------
            # get best plane based on lowest distance to the target
            next_plane_idx = np.argmin(dist_queue)
            next_plane = Plane(*getPlane(self.sitk_image,
                                         self._origin3d_point,
                                         plane_params_queue[next_plane_idx],
                                         self._plane_size,
                                         spacing=self.spacing))
            self._dist_supervised_history.append(np.min(dist_queue))

        else:
            ## unsupervised or testing
            # get the new plane using new params result from taking the action
            next_plane = Plane(*getPlane(self.sitk_image,
                                         self._origin3d_point,
                                         next_plane_params,
                                         self._plane_size,
                                         spacing=self.spacing))
            # -----------------------------------------------------------------
            # check if the screen is not full of zeros (background)
            go_out = checkBackgroundRatio(next_plane,
                                          min_pixel_val=0.5,
                                          ratio=0.8)
            # also check if go out (sampling from outside the volume)
            # by checking if the new origin
            if not go_out:
                go_out = checkOriginLocation(self.sitk_image,
                                             next_plane.origin)
            # also check if plane parameters got very high
            if not go_out:
                go_out = checkParamsBound(next_plane.params,
                                          self._groundTruth_plane.params)
            # punish lowest reward if the agent tries to go out and keep same plane
            if go_out:
                self.reward = -1  # lowest possible reward
                next_plane = copy.deepcopy(self._plane)
                if self.train: self.terminal = True  # end episode and restart

        # ---------------------------------------------------------------------
        # update current plane
        self._plane = copy.deepcopy(next_plane)
        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames: self.terminal = True
        # check oscillation and reduce action step or terminate if minimum
        if self._oscillate:
            if self.train and self._supervised:
                self._plane = self.getBestPlaneTrain()
            else:
                self._plane = self.getBestPlane()

            # find distance metrics
            # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, self._plane.points)
            _, dist = zip(*[
                projectPointOnPlane(point, self._plane.norm,
                                    self._plane.origin)
                for point in self.landmarks_gt
            ])
            self.cur_dist = np.max(np.abs(dist))
            self._update_heirarchical()
            self._clear_history()
            # terminate if distance steps are less than 1
            if self.action_dist_step < 1: self.terminal = True

        # ---------------------------------------------------------------------
        # find distance error
        _, dist = zip(*[
            projectPointOnPlane(point, self._plane.norm, self._plane.origin)
            for point in self.landmarks_gt
        ])
        self.cur_dist = np.mean(np.abs(dist))
        self.cur_dist_params = calcDistTwoParams(
            self._groundTruth_plane.params,
            self._plane.params,
            scale_angle=self.action_angle_step,
            scale_dist=self.action_dist_step)
        self.current_episode_score.feed(self.reward)
        self._update_history()  # store results in memory

        # terminate if distance between params are low during training
        if self.train and (self.cur_dist_params <= 1):
            self.terminal = True
            self.num_success.feed(1)

        # ---------------------------------------------------------------------

        # # supervised reward (for debuging)
        # reward_supervised = self._calc_reward_params(current_plane_params,
        #                                        self._plane.params)
        # # threshold reward between -1 and 1
        # self.reward = np.sign(np.around(reward_supervised,decimals=1))

        # ---------------------------------------------------------------------

        # render screen if viz is on
        if self.viz:
            if isinstance(self.viz, float):
                self.display()

        A = normalizeUnitVector(self._groundTruth_plane.norm)
        B = normalizeUnitVector(self._plane.norm)
        angle_between_norms = np.rad2deg(np.arccos(A.dot(B)))

        info = {
            'score': self.current_episode_score.sum,
            'gameOver': self.terminal,
            'distError': self.cur_dist,
            'distAngle': angle_between_norms,
            'filename': self.filename
        }

        if self.terminal:
            with open(self.csvfile, 'a') as outcsv:
                fields = [self.filename, self.cur_dist, angle_between_norms]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))

        return self._current_state(), self.reward, self.terminal, info

    # -------------------------------------------------------------------------

    def _update_heirarchical(self):
        self.action_angle_step = int(self.action_angle_step / 2)
        self.action_dist_step = self.action_dist_step - 1
        # self.spacing -= 1
        if (self.spacing[0] > 1): self.spacing -= 1
        self._groundTruth_plane = Plane(
            *getGroundTruthPlane(self.sitk_image,
                                 self.sitk_image_4ch,
                                 self._origin3d_point,
                                 self._plane_size,
                                 spacing=self.spacing))

    def getBestPlane(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        best_idx = np.argmin(self._bestq_history)
        # best_idx = np.argmax(self._bestq_history)
        return self._plane_history[best_idx]

    def getBestPlaneTrain(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        best_idx = np.argmin(self._dist_supervised_history)
        # best_idx = np.argmax(self._bestq_history)
        return self._plane_history[best_idx]

    def _current_state(self):
        """
        :returns: a gray-scale (h, w, d) float ###uint8 image
        """
        return self._plane.grid_smooth

    def _clear_history(self):
        self._plane_history.clear()
        self._bestq_history.clear()
        self._dist_history.clear()
        self._dist_history_params.clear()
        self._dist_supervised_history.clear()
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length

    def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        loc = self._plane.origin
        loc = self._plane.params
        # logger.info('loc {}'.format(loc))
        self._loc_history[-1] = (np.around(loc[0], decimals=2),
                                 np.around(loc[1], decimals=2),
                                 np.around(loc[2], decimals=2),
                                 np.around(loc[3], decimals=2))
        # update distance history
        self._dist_history.append(self.cur_dist)
        self._dist_history_params.append(self.cur_dist_params)
        # update params history
        self._plane_history.append(self._plane)
        self._bestq_history.append(np.max(self._qvalues))
        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues

    def _calc_reward_points(self, prev_points, next_points):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        prev_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
                                          prev_points)
        next_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points,
                                          next_points)

        return prev_dist - next_dist

    def _calc_reward_params(self, prev_params, next_params):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        # logger.info('prev_params {}'.format(np.around(prev_params,2)))
        # logger.info('next_params {}'.format(np.around(next_params,2)))
        prev_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                            prev_params,
                                            scale_angle=self.action_angle_step,
                                            scale_dist=self.action_dist_step)
        next_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                            next_params,
                                            scale_angle=self.action_angle_step,
                                            scale_dist=self.action_dist_step)

        return prev_dist - next_dist

    @property
    def _oscillate(self):
        ''' Return True if the agent is stuck and oscillating
        '''
        counter = Counter(self._loc_history)
        freq = counter.most_common()
        # return false is history is empty (begining of the game)
        if len(freq) < 2: return False
        # check frequency
        if freq[0][0] == (0, 0, 0, 0):
            if (freq[1][1] > 2):
                # logger.info('oscillating {}'.format(self._loc_history))
                return True
            else:
                return False
        elif (freq[0][1] > 2):
            # logger.info('oscillating {}'.format(self._loc_history))
            return True

    def get_action_meanings(self):
        ''' return array of integers for actions '''
        ACTION_MEANING = {
            0: "inc_x",  # increment +1 the norm angle in x-direction
            1: "inc_y",  # increment +1 the norm angle in y-direction
            2: "inc_z",  # increment +1 the norm angle in z-direction
            3: "inc_d",  # increment +1 the norm distance d to origin
            4: "dec_x",  # decrement -1 the norm angle in x-direction
            5: "dec_y",  # decrement -1 the norm angle in y-direction
            6: "dec_z",  # decrement -1 the norm angle in z-direction
            7: "dec_d",  # decrement -1 the norm distance d to origin
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self, return_rgb_array=False):
        # pass
        # --------------------------------------------------------------------
        ## planes seen by the agent
        # # get image and convert it to pyglet
        # plane = self._plane.grid[:,:,round(self.depth/2)] # z-plane
        # # concatenate groundtruth image
        # gt_plane = self._groundTruth_plane.grid[:,:,round(self.depth/2)]
        # --------------------------------------------------------------------
        ## whole plan
        # image_size = (int(min(self._image_dims)),)*3
        image_size = self._image_dims
        current_plane = Plane(*getPlane(self.sitk_image,
                                        self._origin3d_point,
                                        self._plane.params,
                                        image_size,
                                        spacing=[1, 1, 1]))

        # get image and convert it to pyglet
        plane = current_plane.grid[:, :, int(image_size[2] / 2)]  # z-plane
        # concatenate groundtruth image
        gt_plane = self.groundTruth_plane_iso.grid[:, :,
                                                   int(image_size[2] / 2)]
        # --------------------------------------------------------------------
        # concatenate two planes side by side
        plane = np.concatenate((plane, gt_plane), axis=1)
        #
        img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 5
        scale_y = 5
        # img = cv2.resize(img,
        #                  (int(scale_x*img.shape[1]),int(scale_y*img.shape[0])),
        #                  interpolation=cv2.INTER_LINEAR)
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)
        self.viewer.display_text('Current Plane',
                                 color=(0, 0, 204, 255),
                                 x=int(0.7 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)
        self.viewer.display_text('Ground Truth',
                                 color=(0, 0, 204, 255),
                                 x=int(4.3 * img.shape[1] / 7),
                                 y=img.shape[0] - 3)

        # display info
        dist_color_flag = False
        if len(self._dist_history) > 1:
            dist_color_flag = self.cur_dist < self._dist_history[-2]

        color_dist = (0, 204, 0, 255) if dist_color_flag else (204, 0, 0, 255)
        text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=color_dist,
                                 x=int(3 * img.shape[1] / 8),
                                 y=5 * scale_y)

        dist_color_flag = False
        if len(self._dist_history_params) > 1:
            dist_color_flag = self.cur_dist_params < self._dist_history_params[
                -2]

        # color_dist = (0,255,0,255) if dist_color_flag else (255,0,0,255)
        # text = 'Params Error ' + str(round(self.cur_dist_params,3))
        # self.viewer.display_text(text, color=color_dist,
        # x=int(6*img.shape[1]/8), y=5*scale_y)
        text = 'Spacing ' + str(round(self.spacing[0], 3)) + 'mm'
        self.viewer.display_text(text,
                                 color=(204, 204, 0, 255),
                                 x=int(6 * img.shape[1] / 8),
                                 y=5 * scale_y)

        color_reward = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0,
                                                                 255)
        text = 'Reward ' + "%+d" % round(self.reward, 3)
        self.viewer.display_text(text,
                                 color=color_reward,
                                 x=2 * scale_x,
                                 y=5 * scale_y)

        # render and wait (viz) time between frames
        self.viewer.render()

        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            # set_trace()
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr).convert('P')
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.savegif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if (self.cnt <= 1):
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '3', '-i', dirname + '/%04d.png', '-s', '1280x720',
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename + '.mp4'
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #6
0
    def display(self, return_rgb_array=False):
        # pass
        # get dimensions
        current_point = self._location
        target_point = self._target_loc

        # get image and convert it to pyglet
        plane = self.get_plane(current_point[2])  # z-plane

        # plane = np.squeeze(self._current_state()[:,:,13])
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 1
        scale_y = 1
        img = cv2.resize(plane,
                         (int(scale_x*plane.shape[1]), int(scale_y*plane.shape[0])),
                         interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)

        # draw current point
        self.viewer.draw_circle(radius=scale_x * 1,
                                pos_x=scale_x * current_point[0],
                                pos_y=scale_y * current_point[1],
                                color=(0.0, 0.0, 1.0, 1.0))

        # draw a box around the agent - what the network sees ROI
        self.viewer.draw_rect(scale_x*self.rectangle.xmin, scale_y*self.rectangle.ymin,
                              scale_x*self.rectangle.xmax, scale_y*self.rectangle.ymax)

        self.viewer.display_text('Agent ', color=(204, 204, 0, 255),
                                 x=self.rectangle.xmin - 15,
                                 y=self.rectangle.ymin)
        # display info
        text = 'Spacing ' + str(self.xscale)
        self.viewer.display_text(text, color=(204, 204, 0, 255),
                                 x=10, y=self._image_dims[1]-80)

        # ---------------------------------------------------------------------
        if (self.task != 'play'):
            # draw a transparent circle around target point with variable radius
            # based on the difference z-direction
            diff_z = scale_x * abs(current_point[2]-target_point[2])

            self.viewer.draw_circle(radius=diff_z,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 0.2))
            # draw target point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 1.0))
            # display info
            color = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255)
            text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
            self.viewer.display_text(text, color=color, x=10, y=20)

        # ---------------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()

        # time.sleep(self.viz)
        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)

            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)

            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.saveGif(gifname, arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
                    act = input("select action: d (delete) / q (quit): ").lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                resolution = str(3 * self.viewer.img_width) + 'x' + str(3 * self.viewer.img_height)
                save_cmd = ['ffmpeg', '-f', 'image2', '-framerate', '30',
                            '-pattern_type', 'sequence', '-start_number', '0', '-r',
                            '6', '-i', dirname + '/%04d.png', '-s', resolution,
                            '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4']
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #7
0
class MedicalPlayer(gym.Env):
    """
        Class that provides 3D medical image environment.
        This is just an implementation of the classic "agent-environment loop".
        Each time-step, the agent chooses an action, and the environment returns
        an observation and a reward.
    """

    def __init__(self, directory=None, viz=False, task=False, files_list=None,
                 screen_dims=(27, 27), history_length=20, multiscale=True,
                 max_num_frames=0, saveGif=False, saveVideo=False):
        """
        :param train_directory: environment or game name

        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames

        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d, w) - defaults (27, 27)

        :param nullop_start: start with random number of null ops

        :param location_history_length: consider lost of lives as end of
            episode (useful for training)

        :max_num_frames: maximum numbe0r of frames per episode.
        """

        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filenames", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5, 0.25, 0.75]
        # y = [0.5, 0.25, 0.75]
        # z = [0.5, 0.25, 0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0

        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames

        # stores information: terminal, score, distError
        self.info = None

        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo

        # training flag
        self.task = task

        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)

        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0

            if isinstance(viz, int):  # check if viz is an int number
                viz = float(viz)
            self.viz = viz

            if self.viz and isinstance(self.viz, float): # check if viz is an float number
                self.viewer = None
                self.gif_buffer = []

        # stat counter to store current score or accumulated reward
        self.current_episode_score = StatCounter()

        # get action space and minimal action set
        self.action_space = spaces.Discrete(4)  # change number actions here
        self.actions = self.action_space.n

        # 都是需要配备observation space的
        self.observation_space = spaces.Box(low=0, high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)

        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [(0,) * self.dims] * self._history_length
        self._qvalues_history = [(0,) * self.actions] * self._history_length

        # initialize rectangle limits from input image coordinates
        self.rectangle = Rectangle(0, 0, 0, 0)

        # add your data loader here
        # play is returnLandmarks = False and returnLandmarks = True
        if self.task == 'play':
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=False)
            print("self.files:", self.files)
        else:
            self.files = filesListLVCardiacMRLandmark(files_list,
                                                  returnLandmarks=True)


        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular()

        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
            restart current episoide
        """
        self.terminal = False
        self.reward = 0
        self.cnt = 0 # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self.current_episode_score.reset()  # reset the stat counter
        self._loc_history = [(0,) * self.dims] * self._history_length

        # list of q-value lists
        self._qvalues_history = [(0,) * self.actions] * self._history_length
        self.new_random_game()

    def new_random_game(self):
        """
            1.load image,
            2.set dimensions,
            3.randomize start point,
            4.init _screen, qvals,
            5.calc distance to goal
        """
        self.terminal = False
        self.viewer = None

        # ######################################################################
        # ## generate evaluation results from 19 different points
        # if self.count_points ==0:
        #     print('\n============== new game ===============\n')
        #     # save results
        #     if self.total_loc:
        #         with open(self.csvfile, 'a') as outcsv:
        #             fields= [self.filenames, self.cur_dist]
        #             writer = csv.writer(outcsv)
        #             writer.writerow(map(lambda x: x, fields))
        #         self.total_loc = []
        #     # sample a new image
        #     self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        #     scale = next(self.start_points)
        #     self.count_points +=1
        # else:
        #     self.count_points += 1
        #     logger.info('count_points {}'.format(self.count_points))
        #     scale = next(self.start_points)
        #
        # x = int(scale[0] * self._image.dims[0])
        # y = int(scale[1] * self._image.dims[1])
        # z = int(scale[2] * self._image.dims[2])
        # logger.info('starting point {}-{}-{}'.format(x,y,z))
        # ######################################################################

        # # sample a new image
        self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        self.filename = os.path.basename(self.filepath)

        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            ## brain
            self.action_step = 9
            self.xscale = 3
            self.yscale = 3

            ## cardiac
            # self.action_step = 6
            # self.xscale = 2
            # self.yscale = 2
        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1

        # image volume size
        self._image_dims = self._image.dims
        self._image.data =  np.array(self._image)

        print("image_dim:", self._image_dims)

        #######################################################################
        # select random starting point
        # add padding to avoid start right on the border of the image
        if (self.task == 'train'):
            skip_thickness = ((int)(self._image_dims[0] / 5),
                              (int)(self._image_dims[1] / 5))
        else:
            skip_thickness = (int(self._image_dims[0] / 4),
                              int(self._image_dims[1] / 4))

        x = self.rng.randint(0 + skip_thickness[0],
                             self._image_dims[0] - skip_thickness[0])

        y = self.rng.randint(0 + skip_thickness[1],
                             self._image_dims[1] - skip_thickness[1])
        #######################################################################

        # 都是用的(x, y)
        self._location = (x, y)
        self._start_location = (x, y)

        self._qvalues = [0, ] * self.actions
        self._screen = self._current_state()

        if self.task == 'play':
            self.cur_dist = 0
        else:
            self.cur_dist = self.calcDistance(self._location,
                                              self._target_loc,
                                              self.spacing)

    def calcDistance(self, points1, points2, spacing=(1, 1)):
        """
            calculate the distance between two points in mm
        """
        spacing = np.array(spacing)
        points1 = spacing * np.array(points1)
        points2 = spacing * np.array(points2)

        return np.linalg.norm(points1 - points2)

    # 这个还是基于gym的 --> 要不要将 gym 中搞清楚呢?
    def step(self, act, qvalues):
        """
        The environment's step function returns exactly what we need.

        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.

          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.

          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)

          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """

        self._qvalues = qvalues
        current_loc = self._location
        self.terminal = False
        go_out = False

        # 这种还不是我们真正所说的 连续的 action
        # FORWARD Y+ ---------------------------------------------------------
        if (act == 1):
            next_location = (current_loc[0],
                             round(current_loc[1] + self.action_step),
                             current_loc[2])
            if (next_location[1] >= self._image_dims[1]):
                # print(' trying to go out the image Y+ ',)
                next_location = current_loc
                go_out = True

        # RIGHT X+ -----------------------------------------------------------
        if (act == 2):
            next_location = (round(current_loc[0] + self.action_step),
                             current_loc[1],
                             current_loc[2])
            if next_location[0] >= self._image_dims[0]:
                # print(' trying to go out the image X+ ',)
                next_location = current_loc
                go_out = True

        # LEFT X- -----------------------------------------------------------
        if act == 3:
            next_location = (round(current_loc[0] - self.action_step),
                             current_loc[1],
                             current_loc[2])
            if next_location[0] <= 0:
                # print(' trying to go out the image X- ',)
                next_location = current_loc
                go_out = True

        # BACKWARD Y- ---------------------------------------------------------
        if act == 4:
            next_location = (current_loc[0],
                             round(current_loc[1] - self.action_step),
                             current_loc[2])
            if next_location[1] <= 0:
                # print(' trying to go out the image Y- ',)
                next_location = current_loc
                go_out = True

        # ---------------------------------------------------------------------
        # ---------------------------------------------------------------------
        # punish -1 reward if the agent tries to go out
        if (self.task!='play'):
            if go_out:
                self.reward = -1
            else:
                self.reward = self._calc_reward(current_loc, next_location)

        # update screen, reward ,location, terminal
        self._location = next_location
        self._screen = self._current_state()

        # terminate if the distance is less than 1 during trainig
        if (self.task == 'train'):
            if self.cur_dist <= 1:
                self.terminal = True
                self.num_success.feed(1)

        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames: self.terminal = True

        # update history buffer with new location and qvalues
        if (self.task != 'play'):
            self.cur_dist = self.calcDistance(self._location,
                                              self._target_loc,
                                              self.spacing)
        self._update_history()

        # check if agent oscillates -- 摆动
        # 检测agent是否摆动
        if self._oscillate:
            self._location = self.getBestLocation()
            self._screen = self._current_state()

            if (self.task != 'play'):
                self.cur_dist = self.calcDistance(self._location,
                                                  self._target_loc,
                                                  self.spacing)
            # multi-scale steps
            if self.multiscale:
                if self.xscale > 1:
                    self.xscale -= 1
                    self.yscale -= 1
                    self.action_step = int(self.action_step / 3)
                    self._clear_history()

                # terminate if scale is less than 1
                else:
                    self.terminal = True
                    if self.cur_dist <= 1: self.num_success.feed(1)
            else:
                self.terminal = True
                if self.cur_dist <= 1: self.num_success.feed(1)

        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        distance_error = self.cur_dist
        self.current_episode_score.feed(self.reward)

        info = {'score': self.current_episode_score.sum, 'gameOver': self.terminal,
                'distError': distance_error, 'filenames': self.filename}

        # #######################################################################
        # ## generate evaluation results from 19 different points
        # if self.terminal:
        #     logger.info(info)
        #     self.total_loc.append(self._location)
        #     if not(self.count_points == 19):
        #         self._restart_episode()
        #     else:
        #         mean_location = np.mean(self.total_loc,axis=0)
        #         logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location))
        #         self.cur_dist = self.calcDistance(mean_location,
        #                                           self._target_loc,
        #                                           self.spacing)
        #         logger.info('final distance error {} \n'.format(self.cur_dist))
        #         self.count_points = 0
        # #######################################################################

        return self._current_state(), self.reward, self.terminal, info

    def getBestLocation(self):
        '''
            get best location with best qvalue from last for locations
            stored in history
        '''
        last_qvalues_history = self._qvalues_history[-4:]
        last_loc_history = self._loc_history[-4:]
        best_qvalues = np.max(last_qvalues_history, axis=1)

        # best_idx = best_qvalues.argmax()
        best_idx = best_qvalues.argmin()
        best_location = last_loc_history[best_idx]

        return best_location

    def _clear_history(self):
        '''
            clear history buffer with current state
        '''
        self._loc_history = [(0,) * self.dims] * self._history_length
        self._qvalues_history = [(0,) * self.actions] * self._history_length

    def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        self._loc_history[-1] = self._location

        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues

    def _current_state(self):
        """
        crop image data around current location to update what network sees.
        update rectangle

        :return: new state
        """
        # initialize screen with zeros - all background
        screen = np.zeros((self.screen_dims)).astype('float32')

        # screen uses coordinate system relative to origin (0, 0, 0)
        screen_xmin, screen_ymin = 0, 0
        screen_xmax, screen_ymax = self.screen_dims

        print("image_data:", self._image)

        # extract boundary locations using coordinate system relative to "global" image
        # width, height, depth in terms of screen coord system
        if self.xscale % 2:
            xmin = self._location[0] - int(self.width * self.xscale / 2) - 1
            xmax = self._location[0] + int(self.width * self.xscale / 2)
            ymin = self._location[1] - int(self.height * self.yscale / 2) - 1
            ymax = self._location[1] + int(self.height * self.yscale / 2)
        else:
            xmin = self._location[0] - round(self.width * self.xscale / 2)
            xmax = self._location[0] + round(self.width * self.xscale / 2)
            ymin = self._location[1] - round(self.height * self.yscale / 2)
            ymax = self._location[1] + round(self.height * self.yscale / 2)

        # check if they violate image boundary and fix it
        if xmin < 0:
            xmin = 0
            screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale))
        if ymin < 0:
            ymin = 0
            screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale))

        if xmax > self._image_dims[0]:
            xmax = self._image_dims[0]
            screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale))
        if ymax>self._image_dims[1]:
            ymax = self._image_dims[1]
            screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale))

        # crop image data to update what network sees
        # image coordinate system becomes screen coordinates
        # scale can be thought of as a stride
        screen[screen_xmin:screen_xmax, screen_ymin:screen_ymax] = self._image.data[
                                                                            xmin:xmax:self.xscale,
                                                                            ymin:ymax:self.yscale]

        # update rectangle limits from input image coordinates
        # this is what the network sees
        self.rectangle = Rectangle(xmin, xmax,
                                   ymin, ymax)

        return screen

    def get_plane(self):
        return self._image.data[:, :]

    def _calc_reward(self, current_loc, next_loc):
        """ Calculate the new reward based on the decrease in euclidean distance to the target location
        """
        curr_dist = self.calcDistance(current_loc, self._target_loc,
                                      self.spacing)
        next_dist = self.calcDistance(next_loc, self._target_loc,
                                      self.spacing)
        return curr_dist - next_dist

    @property
    def _oscillate(self):
        """
            Return True if the agent is stuck and oscillating
        """
        counter = Counter(self._loc_history)
        freq = counter.most_common()

        if freq[0][0] == (0, 0):
            if (freq[1][1] > 3):
                return True
            else:
                return False
        elif (freq[0][1] > 3):
            return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "FORWARD",  # MOVE Y+
            2: "RIGHT",  # MOVE X+
            3: "LEFT",  # MOVE X-
            4: "BACKWARD",  # MOVE Y-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
            return screen dimensions
        """
        return (self.width, self.height)

    def lives(self):
        return None

    def reset_stat(self):
        """
            Reset all statistics counter
        """
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = StatCounter()

    def display(self, return_rgb_array=False):
        # pass
        # get dimensions
        current_point = self._location
        target_point = self._target_loc

        # get image and convert it to pyglet
        plane = self.get_plane(current_point[2])  # z-plane

        # plane = np.squeeze(self._current_state()[:,:,13])
        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_x = 1
        scale_y = 1
        img = cv2.resize(plane,
                         (int(scale_x*plane.shape[1]), int(scale_y*plane.shape[0])),
                         interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # congvert to rgb
        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_x=1,
                                            scale_y=1,
                                            filepath=self.filename)
            self.gif_buffer = []

        # display image
        self.viewer.draw_image(img)

        # draw current point
        self.viewer.draw_circle(radius=scale_x * 1,
                                pos_x=scale_x * current_point[0],
                                pos_y=scale_y * current_point[1],
                                color=(0.0, 0.0, 1.0, 1.0))

        # draw a box around the agent - what the network sees ROI
        self.viewer.draw_rect(scale_x*self.rectangle.xmin, scale_y*self.rectangle.ymin,
                              scale_x*self.rectangle.xmax, scale_y*self.rectangle.ymax)

        self.viewer.display_text('Agent ', color=(204, 204, 0, 255),
                                 x=self.rectangle.xmin - 15,
                                 y=self.rectangle.ymin)
        # display info
        text = 'Spacing ' + str(self.xscale)
        self.viewer.display_text(text, color=(204, 204, 0, 255),
                                 x=10, y=self._image_dims[1]-80)

        # ---------------------------------------------------------------------
        if (self.task != 'play'):
            # draw a transparent circle around target point with variable radius
            # based on the difference z-direction
            diff_z = scale_x * abs(current_point[2]-target_point[2])

            self.viewer.draw_circle(radius=diff_z,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 0.2))
            # draw target point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_x=scale_x*target_point[0],
                                    pos_y=scale_y*target_point[1],
                                    color=(1.0, 0.0, 0.0, 1.0))
            # display info
            color = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255)
            text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm'
            self.viewer.display_text(text, color=color, x=10, y=20)

        # ---------------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()

        # time.sleep(self.viz)
        # save gif
        if self.saveGif:
            image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)

            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0)

            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if not self.terminal:
                gifname = self.filename.split('.')[0] + '.gif'
                self.viewer.saveGif(gifname, arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname))
                    act = input("select action: d (delete) / q (quit): ").lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if self.terminal:
                resolution = str(3 * self.viewer.img_width) + 'x' + str(3 * self.viewer.img_height)
                save_cmd = ['ffmpeg', '-f', 'image2', '-framerate', '30',
                            '-pattern_type', 'sequence', '-start_number', '0', '-r',
                            '6', '-i', dirname + '/%04d.png', '-s', resolution,
                            '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4']
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #8
0
    def display(self, return_rgb_array=False):
        # Initializations
        planes = np.flipud(
            np.transpose(self.get_plane(self._location[0][2], agent=0)))
        shape = np.shape(planes)

        target_points = []
        current_points = []

        for i in range(self.agents):
            # get landmarks
            current_points.append(self._location[i])
            if self.task != 'play':
                target_points.append(self._target_loc[i])
            else:
                target_points.append(None)
            # get current plane
            current_plane = np.flipud(
                np.transpose(self.get_plane(current_points[i][2], agent=i)))

            if i > 0:
                # get image in z-axis
                planes = np.hstack((planes, current_plane))

        shifts_x = [np.shape(current_plane)[1] * i for i in range(self.agents)]
        shifts_y = [0] * self.agents

        # get image and convert it to pyglet + convert to rgb
        # # horizontal concat
        # planes = np.array(planes)#.ravel(order='C') # C for cardiac
        # np.transpose(planes, (2,1,0))
        # img = cv2.cvtColor(np.flipud(planes.reshape((shape[1],
        #                                   shape[0]*shape[2]),
        #                                   order='C')), # F for cardiac
        #                    cv2.COLOR_GRAY2RGB)
        # # vertical concat
        # planes = np.array(planes)
        # img = cv2.cvtColor(planes.reshape(shape[0]*shape[1], shape[2]),
        #                    cv2.COLOR_GRAY2RGB)

        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_y = 1
        scale_x = 1
        img = cv2.resize(
            planes,
            (int(scale_y * planes.shape[1]), int(scale_x * planes.shape[0])),
            interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_y=1,
                                            scale_x=1,
                                            filepath=self.filename[i] + str(i))
            self.gif_buffer = []
        # display image
        self.viewer.draw_image(img)

        # plot landmarks
        for i in range(self.agents):
            # get landmarks - correct location if image is flipped and tranposed
            current_point = (shape[0] - current_points[i][1] + shifts_y[i],
                             current_points[i][0] + shifts_x[i],
                             current_points[i][2])
            if self.task != 'play':
                target_point = (shape[0] - target_points[i][1] + shifts_y[i],
                                target_points[i][0] + shifts_x[i],
                                target_points[i][2])
            # draw current point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_y=scale_y * current_point[1],
                                    pos_x=scale_x * current_point[0],
                                    color=(0.0, 0.0, 1.0, 1.0))
            # draw a box around the agent - what the network sees ROI
            # - correct location if image is flipped
            self.viewer.draw_rect(
                scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                scale_x * (self.rectangle[i].xmin + shifts_x[i]),
                scale_y * (shape[0] - self.rectangle[i].ymax + shifts_y[i]),
                scale_x * (self.rectangle[i].xmax + shifts_x[i])),
            self.viewer.display_text(
                'Agent ' + str(i),
                color=(204, 204, 0, 255),
                x=scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                y=scale_x * (self.rectangle[i].xmin + shifts_x[i]))
            # display info
            text = 'Spacing ' + str(self.xscale)
            self.viewer.display_text(text, color=(204, 204, 0, 255), x=8, y=8)
            #self._image_dims[1]-(int)(0.2*self._image_dims[1])-5)

            # -----------------------------------------------------------------
            if self.task != 'play':
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(radius=diff_z,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 0.2))
                # draw target point
                self.viewer.draw_circle(radius=scale_x * 1,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 1.0))
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = 'Error - ' + 'Agent ' + str(i) + ' - ' + str(
                    round(self.cur_dist[i], 3)) + 'mm'
                self.viewer.display_text(
                    text,
                    color=color,
                    x=scale_y * (int(1.0 * shape[0]) - 15 + shifts_y[i]),
                    y=scale_x * (8 + shifts_x[i]))

        # -----------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()
        # time.sleep(self.viz)
        # save gif
        if self.saveGif:

            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if all(self.terminal):
                gifname = self.filename[0].split('.')[0] + '_{}.gif'.format(i)
                self.viewer.saveGif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video_cardiac'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if all(self.terminal):
                resolution = str(3 * self.viewer.img_width) + 'x' + str(
                    3 * self.viewer.img_height)
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '6', '-i', dirname + '/%04d.png', '-s', resolution,
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename[0] + '_{}_agents.mp4'.format(i + 1)
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #9
0
class MedicalPlayer(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(self,
                 directory=None,
                 viz=False,
                 task=False,
                 files_list=None,
                 screen_dims=(27, 27, 27),
                 history_length=28,
                 multiscale=True,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False,
                 agents=1,
                 reward_strategy=1):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum numbe0r of frames per episode.
        """
        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filename", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5,0.25,0.75]
        # y = [0.5,0.25,0.75]
        # z = [0.5,0.25,0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()
        self.agents = agents
        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)
        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = [StatCounter()] * self.agents
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)
        self.reward_strategy = reward_strategy

        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [[(0, ) * self.dims
                              for _ in range(self._history_length)]
                             for _ in range(self.agents)]
        self._qvalues_history = [[(0, ) * self.actions
                                  for _ in range(self._history_length)]
                                 for _ in range(self.agents)]
        # initialize rectangle limits from input image coordinates
        self.rectangle = [Rectangle(0, 0, 0, 0, 0, 0)] * int(self.agents)

        # add your data loader here
        if self.task == 'play':
            self.files = filesListBrainMRLandmark(files_list,
                                                  returnLandmarks=False,
                                                  agents=self.agents)
        else:
            self.files = filesListBrainMRLandmark(files_list,
                                                  returnLandmarks=True,
                                                  agents=self.agents)

        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular()
        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
        restart current episode
        """
        logger.info("Medical Player restarting episode")
        self.terminal = [False] * self.agents
        self.reward = np.zeros((self.agents, ))
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self._loc_history = [[(0, ) * self.dims
                              for _ in range(self._history_length)]
                             for _ in range(self.agents)]
        # list of q-value lists
        self._qvalues_history = [[(0, ) * self.actions
                                  for _ in range(self._history_length)]
                                 for _ in range(self.agents)]
        for i in range(0, self.agents):
            self.current_episode_score[i].reset()
        self.new_random_game()

    def new_random_game(self):
        """
        load image,
        set dimensions,
        randomize start point,
        init _screen, qvals,
        calc distance to goal
        """
        self.terminal = [False] * self.agents
        self.viewer = None
        # ######################################################################
        # ## generate evaluation results from 19 different points
        # if self.count_points ==0:
        #     print('\n============== new game ===============\n')
        #     # save results
        #     if self.total_loc:
        #         with open(self.csvfile, 'a') as outcsv:
        #             fields= [self.filename, self.cur_dist]
        #             writer = csv.writer(outcsv)
        #             writer.writerow(map(lambda x: x, fields))
        #         self.total_loc = []
        #     # sample a new image
        #     self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        #     scale = next(self.start_points)
        #     self.count_points +=1
        # else:
        #     self.count_points += 1
        #     logger.info('count_points {}'.format(self.count_points))
        #     scale = next(self.start_points)
        #
        # x = int(scale[0] * self._image.dims[0])
        # y = int(scale[1] * self._image.dims[1])
        # z = int(scale[2] * self._image.dims[2])
        # logger.info('starting point {}-{}-{}'.format(x,y,z))
        # ######################################################################

        # # sample a new image
        self._image, self._target_loc, self.filepath, self.spacing = next(
            self.sampled_files)
        self.filename = [
            os.path.basename(self.filepath[i]) for i in range(self.agents)
        ]

        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            ## brain
            self.action_step = 9
            self.xscale = 3
            self.yscale = 3
            self.zscale = 3
            ## cardiac
            # self.action_step = 6
            # self.xscale = 2
            # self.yscale = 2
            # self.zscale = 2
        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1
            self.zscale = 1
        # image volume size
        self._image_dims = self._image[0].dims

        #######################################################################
        ## select random starting point
        # add padding to avoid start right on the border of the image
        if self.task == 'train':
            skip_thickness = ((int)(self._image_dims[0] / 5),
                              (int)(self._image_dims[1] / 5),
                              (int)(self._image_dims[2] / 5))
        else:
            skip_thickness = (int(self._image_dims[0] / 4),
                              int(self._image_dims[1] / 4),
                              int(self._image_dims[2] / 4))

        # TODO: should agents start at the same random points, agents get stuck
        #x=[self.rng.randint(0 + skip_thickness[0], self._image_dims[0] - skip_thickness[0])] * self.agents
        #y=[self.rng.randint(0 + skip_thickness[1], self._image_dims[1] - skip_thickness[1])] * self.agents
        #z=[self.rng.randint(0 + skip_thickness[2], self._image_dims[2] - skip_thickness[2])] * self.agents

        x = [
            self.rng.randint(0 + skip_thickness[0],
                             self._image_dims[0] - skip_thickness[0])
            for _ in range(self.agents)
        ]
        y = [
            self.rng.randint(0 + skip_thickness[1],
                             self._image_dims[1] - skip_thickness[1])
            for _ in range(self.agents)
        ]
        z = [
            self.rng.randint(0 + skip_thickness[2],
                             self._image_dims[2] - skip_thickness[2])
            for _ in range(self.agents)
        ]

        #######################################################################

        self._location = [(x[i], y[i], z[i]) for i in range(self.agents)]
        self._start_location = [(x[i], y[i], z[i]) for i in range(self.agents)]
        self._qvalues = [[
            0,
        ] * self.actions] * self.agents
        self._screen = self._current_state()

        if self.task == 'play':
            self.cur_dist = [
                0,
            ] * self.agents
        else:
            self.cur_dist = [
                self.calcDistance(self._location[i], self._target_loc[i],
                                  self.spacing) for i in range(self.agents)
            ]
        logger.info("Current distance is " + str(self.cur_dist))

    def calcDistance(self, points1, points2, spacing=(1, 1, 1)):
        """ calculate the distance between two points in mm"""
        spacing = np.array(spacing)
        points1 = spacing * np.array(points1)
        points2 = spacing * np.array(points2)
        return np.linalg.norm(points1 - points2)

    def step(self, act, q_values, isOver):
        """The environment's step function returns exactly what we need.
        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """
        for i in range(self.agents):
            if isOver[i]: act[i] = 10
        self._qvalues = q_values
        current_loc = self._location
        next_location = copy.deepcopy(current_loc)

        self.terminal = [False] * self.agents
        go_out = [False] * self.agents

        ######################## agent i movement #############################
        for i in range(self.agents):
            # UP Z+ -----------------------------------------------------------
            if (act[i] == 0):
                next_location[i] = (current_loc[i][0], current_loc[i][1],
                                    round(current_loc[i][2] +
                                          self.action_step))
                if (next_location[i][2] >= self._image_dims[2]):
                    # print(' trying to go out the image Z+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True

            # FORWARD Y+ ---------------------------------------------------------
            if (act[i] == 1):
                next_location[i] = (current_loc[i][0],
                                    round(current_loc[i][1] +
                                          self.action_step), current_loc[i][2])
                if (next_location[i][1] >= self._image_dims[1]):
                    # print(' trying to go out the image Y+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # RIGHT X+ -----------------------------------------------------------
            if (act[i] == 2):
                next_location[i] = (round(current_loc[i][0] +
                                          self.action_step), current_loc[i][1],
                                    current_loc[i][2])
                if next_location[i][0] >= self._image_dims[0]:
                    # print(' trying to go out the image X+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # LEFT X- -----------------------------------------------------------
            if act[i] == 3:
                next_location[i] = (round(current_loc[i][0] -
                                          self.action_step), current_loc[i][1],
                                    current_loc[i][2])
                if next_location[i][0] <= 0:
                    # print(' trying to go out the image X- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # BACKWARD Y- ---------------------------------------------------------
            if act[i] == 4:
                next_location[i] = (current_loc[i][0],
                                    round(current_loc[i][1] -
                                          self.action_step), current_loc[i][2])
                if next_location[i][1] <= 0:
                    # print(' trying to go out the image Y- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # DOWN Z- -----------------------------------------------------------
            if act[i] == 5:
                next_location[i] = (current_loc[i][0], current_loc[i][1],
                                    round(current_loc[i][2] -
                                          self.action_step))
                if next_location[i][2] <= 0:
                    # print(' trying to go out the image Z- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # -----------------------------------------------------------------

        #######################################################################

        # punish -1 reward if the agent tries to go out
        if self.task != 'play':
            for i in range(0, self.agents):
                if go_out[i]:
                    self.reward[i] = -1
                else:
                    # if self.task=='train' or self.task=='eval':
                    if self.reward_strategy == 1:
                        self.reward[i] = self._calc_reward(current_loc[i],
                                                           next_location[i],
                                                           agent=i)
                    elif self.reward_strategy == 2:
                        self.reward[i] = self._calc_reward_geometric(
                            current_loc[i], next_location[i], agent=i)
                    elif self.reward_strategy == 3:
                        self.reward[i] = self._distance_to_other_agents(
                            current_loc, next_location, agent=i)
                    elif self.reward_strategy == 4:
                        self.reward[
                            i] = self._distance_to_other_agents_and_line(
                                current_loc, next_location, agent=i)
                    elif self.reward_strategy == 5:
                        self.reward[
                            i] = self._distance_to_other_agents_and_line_no_point(
                                current_loc, next_location, agent=i)
                    elif self.reward_strategy == 6:
                        self.reward[
                            i] = self._calc_reward_geometric_generalized(
                                current_loc[i], next_location[i], agent=i)
                    # else:
                    #     self.reward[i]= self._calc_reward(current_loc[i], next_location[i],agent=i)

        # update screen, reward ,location, terminal
        self._location = next_location
        self._screen = self._current_state()

        # terminate if the distance is less than 1 during trainig
        if self.task == 'train':
            for i in range(self.agents):
                if self.cur_dist[i] <= 1:
                    self.terminal[i] = True
                    self.num_success[i].feed(1)

        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames:
            for i in range(self.agents):
                self.terminal[i] = True

        # update history buffer with new location and qvalues
        if self.task != 'play':
            for i in range(self.agents):
                self.cur_dist[i] = self.calcDistance(self._location[i],
                                                     self._target_loc[i],
                                                     self.spacing)

        self._update_history()
        # check if agent oscillates
        if self._oscillate:
            self._location = self.getBestLocation()
            # self._location=[item for sublist in temp for item in sublist]
            self._screen = self._current_state()

            if self.task != 'play':
                for i in range(self.agents):
                    self.cur_dist[i] = self.calcDistance(
                        self._location[i], self._target_loc[i], self.spacing)

            # multi-scale steps
            if self.multiscale:
                if self.xscale > 1:
                    self.xscale -= 1
                    self.yscale -= 1
                    self.zscale -= 1
                    self.action_step = int(self.action_step / 3)
                    self._clear_history()
                # terminate if scale is less than 1
                else:
                    for i in range(self.agents):
                        self.terminal[i] = True
                        if self.cur_dist[i] <= 1:
                            self.num_success[i].feed(1)
            else:
                for i in range(self.agents):
                    self.terminal[i] = True
                    if self.cur_dist[i] <= 1:
                        self.num_success[i].feed(1)
        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        distance_error = self.cur_dist
        for i in range(self.agents):
            self.current_episode_score[i].feed(self.reward[i])

        info = {}
        for i in range(self.agents):
            info['score_{}'.format(i)] = self.current_episode_score[i].sum
            info['gameOver_{}'.format(i)] = self.terminal[i]
            info['distError_{}'.format(i)] = distance_error[i]
            info['filename_{}'.format(i)] = self.filename[i]

        # #######################################################################
        # ## generate evaluation results from 19 different points
        # if self.terminal:
        #     logger.info(info)
        #     self.total_loc.append(self._location)
        #     if not(self.count_points == 19):
        #         self._restart_episode()
        #     else:
        #         mean_location = np.mean(self.total_loc,axis=0)
        #         logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location))
        #         self.cur_dist = self.calcDistance(mean_location,
        #                                           self._target_loc,
        #                                           self.spacing)
        #         logger.info('final distance error {} \n'.format(self.cur_dist))
        #         self.count_points = 0
        # #######################################################################
        return self._current_state(), self.reward, self.terminal, info

    def getBestLocation(self):
        ''' get best location with best qvalue from last for locations
        stored in history
        '''
        best_location = []
        for i in range(self.agents):
            last_qvalues_history = self._qvalues_history[i][-4:]
            last_loc_history = self._loc_history[i][-4:]
            best_qvalues = np.max(last_qvalues_history, axis=1)
            best_idx = best_qvalues.argmin()
            best_location.append(last_loc_history[best_idx])

        return best_location

    def _clear_history(self):
        ''' clear history buffer with current states
        '''
        self._loc_history = [[(0, ) * self.dims
                              for _ in range(self._history_length)]
                             for _ in range(self.agents)]
        self._qvalues_history = [[(0, ) * self.actions
                                  for _ in range(self._history_length)]
                                 for _ in range(self.agents)]

    def _update_history(self):
        ''' update history buffer with current states
        '''
        for i in range(self.agents):
            # update location history
            self._loc_history[i].pop(0)
            self._loc_history[i].insert(len(self._loc_history[i]),
                                        self._location[i])

            # update q-value history
            # self._qvalues_history[i][:-1] = self._qvalues_history[i][1:]
            # self._qvalues_history[i][-1] = np.ravel(self._qvalues[i])
            self._qvalues_history[i].pop(0)
            self._qvalues_history[i].insert(len(self._qvalues_history[i]),
                                            self._qvalues[i])

    def _current_state(self):
        """
        crop image data around current location to update what network sees.
        update rectangle

        :return: new state
        """
        # initialize screen with zeros - all background
        screen = np.zeros(
            (self.agents, self.screen_dims[0], self.screen_dims[1],
             self.screen_dims[2])).astype(self._image[0].data.dtype)

        for i in range(self.agents):
            # screen uses coordinate system relative to origin (0, 0, 0)
            screen_xmin, screen_ymin, screen_zmin = 0, 0, 0
            screen_xmax, screen_ymax, screen_zmax = self.screen_dims

            # extract boundary locations using coordinate system relative to "global" image
            # width, height, depth in terms of screen coord system

            if self.xscale % 2:
                xmin = self._location[i][0] - int(
                    self.width * self.xscale / 2) - 1
                xmax = self._location[i][0] + int(self.width * self.xscale / 2)
                ymin = self._location[i][1] - int(
                    self.height * self.yscale / 2) - 1
                ymax = self._location[i][1] + int(
                    self.height * self.yscale / 2)
                zmin = self._location[i][2] - int(
                    self.depth * self.zscale / 2) - 1
                zmax = self._location[i][2] + int(self.depth * self.zscale / 2)
            else:
                xmin = self._location[i][0] - round(
                    self.width * self.xscale / 2)
                xmax = self._location[i][0] + round(
                    self.width * self.xscale / 2)
                ymin = self._location[i][1] - round(
                    self.height * self.yscale / 2)
                ymax = self._location[i][1] + round(
                    self.height * self.yscale / 2)
                zmin = self._location[i][2] - round(
                    self.depth * self.zscale / 2)
                zmax = self._location[i][2] + round(
                    self.depth * self.zscale / 2)

            ###########################################################

            # check if they violate image boundary and fix it
            if xmin < 0:
                xmin = 0
                screen_xmin = screen_xmax - len(
                    np.arange(xmin, xmax, self.xscale))
            if ymin < 0:
                ymin = 0
                screen_ymin = screen_ymax - len(
                    np.arange(ymin, ymax, self.yscale))
            if zmin < 0:
                zmin = 0
                screen_zmin = screen_zmax - len(
                    np.arange(zmin, zmax, self.zscale))
            if xmax > self._image_dims[0]:
                xmax = self._image_dims[0]
                screen_xmax = screen_xmin + len(
                    np.arange(xmin, xmax, self.xscale))
            if ymax > self._image_dims[1]:
                ymax = self._image_dims[1]
                screen_ymax = screen_ymin + len(
                    np.arange(ymin, ymax, self.yscale))
            if zmax > self._image_dims[2]:
                zmax = self._image_dims[2]
                screen_zmax = screen_zmin + len(
                    np.arange(zmin, zmax, self.zscale))

            # crop image data to update what network sees
            # image coordinate system becomes screen coordinates
            # scale can be thought of as a stride
            screen[i, screen_xmin:screen_xmax, screen_ymin:screen_ymax,
                   screen_zmin:screen_zmax] = self._image[i].data[
                       xmin:xmax:self.xscale, ymin:ymax:self.yscale,
                       zmin:zmax:self.zscale]

            ###########################################################
            # update rectangle limits from input image coordinates
            # this is what the network sees
            self.rectangle[i] = Rectangle(xmin, xmax, ymin, ymax, zmin, zmax)
        return screen

    # Should the argument agent not be renamed to image rather?
    def get_plane(self, z=0, agent=0):
        return self._image[agent].data[:, :, z]

    def _calc_reward(self, current_loc, next_loc, agent):
        """ Calculate the new reward based on the decrease in euclidean distance to the target location
        """
        curr_dist = self.calcDistance(current_loc, self._target_loc[agent],
                                      self.spacing)
        next_dist = self.calcDistance(next_loc, self._target_loc[agent],
                                      self.spacing)
        return curr_dist - next_dist

    #TODO: does this not return the oscillation for the first agent only?
    @property
    def _oscillate(self):
        """ Return True if all agents are stuck and oscillating
        """
        for i in range(self.agents):
            counter = Counter(self._loc_history[i])
            freq = counter.most_common()
            # At beginning of episodes, history is prefilled with (0, 0, 0), thus do not count their frequency
            if freq[0][0] == (0, 0, 0):
                if len(freq) < 2:
                    return False
                if freq[1][1] < 2:
                    return False
            elif freq[0][1] < 2:
                return False
        return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "UP",  # MOVE Z+
            2: "FORWARD",  # MOVE Y+
            3: "RIGHT",  # MOVE X+
            4: "LEFT",  # MOVE X-
            5: "BACKWARD",  # MOVE Y-
            6: "DOWN",  # MOVE Z-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = [StatCounter()] * int(self.agents)

    def display(self, return_rgb_array=False):
        # Initializations
        planes = np.flipud(
            np.transpose(self.get_plane(self._location[0][2], agent=0)))
        shape = np.shape(planes)

        target_points = []
        current_points = []

        for i in range(self.agents):
            # get landmarks
            current_points.append(self._location[i])
            if self.task != 'play':
                target_points.append(self._target_loc[i])
            else:
                target_points.append(None)
            # get current plane
            current_plane = np.flipud(
                np.transpose(self.get_plane(current_points[i][2], agent=i)))

            if i > 0:
                # get image in z-axis
                planes = np.hstack((planes, current_plane))

        shifts_x = [np.shape(current_plane)[1] * i for i in range(self.agents)]
        shifts_y = [0] * self.agents

        # get image and convert it to pyglet + convert to rgb
        # # horizontal concat
        # planes = np.array(planes)#.ravel(order='C') # C for cardiac
        # np.transpose(planes, (2,1,0))
        # img = cv2.cvtColor(np.flipud(planes.reshape((shape[1],
        #                                   shape[0]*shape[2]),
        #                                   order='C')), # F for cardiac
        #                    cv2.COLOR_GRAY2RGB)
        # # vertical concat
        # planes = np.array(planes)
        # img = cv2.cvtColor(planes.reshape(shape[0]*shape[1], shape[2]),
        #                    cv2.COLOR_GRAY2RGB)

        # rescale image
        # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
        scale_y = 1
        scale_x = 1
        img = cv2.resize(
            planes,
            (int(scale_y * planes.shape[1]), int(scale_x * planes.shape[0])),
            interpolation=cv2.INTER_LINEAR)

        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # skip if there is a viewer open
        if (not self.viewer) and self.viz:
            from viewer import SimpleImageViewer
            self.viewer = SimpleImageViewer(arr=img,
                                            scale_y=1,
                                            scale_x=1,
                                            filepath=self.filename[i] + str(i))
            self.gif_buffer = []
        # display image
        self.viewer.draw_image(img)

        # plot landmarks
        for i in range(self.agents):
            # get landmarks - correct location if image is flipped and tranposed
            current_point = (shape[0] - current_points[i][1] + shifts_y[i],
                             current_points[i][0] + shifts_x[i],
                             current_points[i][2])
            if self.task != 'play':
                target_point = (shape[0] - target_points[i][1] + shifts_y[i],
                                target_points[i][0] + shifts_x[i],
                                target_points[i][2])
            # draw current point
            self.viewer.draw_circle(radius=scale_x * 1,
                                    pos_y=scale_y * current_point[1],
                                    pos_x=scale_x * current_point[0],
                                    color=(0.0, 0.0, 1.0, 1.0))
            # draw a box around the agent - what the network sees ROI
            # - correct location if image is flipped
            self.viewer.draw_rect(
                scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                scale_x * (self.rectangle[i].xmin + shifts_x[i]),
                scale_y * (shape[0] - self.rectangle[i].ymax + shifts_y[i]),
                scale_x * (self.rectangle[i].xmax + shifts_x[i])),
            self.viewer.display_text(
                'Agent ' + str(i),
                color=(204, 204, 0, 255),
                x=scale_y * (shape[0] - self.rectangle[i].ymin + shifts_y[i]),
                y=scale_x * (self.rectangle[i].xmin + shifts_x[i]))
            # display info
            text = 'Spacing ' + str(self.xscale)
            self.viewer.display_text(text, color=(204, 204, 0, 255), x=8, y=8)
            #self._image_dims[1]-(int)(0.2*self._image_dims[1])-5)

            # -----------------------------------------------------------------
            if self.task != 'play':
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(radius=diff_z,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 0.2))
                # draw target point
                self.viewer.draw_circle(radius=scale_x * 1,
                                        pos_x=scale_x * target_point[0],
                                        pos_y=scale_y * target_point[1],
                                        color=(1.0, 0.0, 0.0, 1.0))
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = 'Error - ' + 'Agent ' + str(i) + ' - ' + str(
                    round(self.cur_dist[i], 3)) + 'mm'
                self.viewer.display_text(
                    text,
                    color=color,
                    x=scale_y * (int(1.0 * shape[0]) - 15 + shifts_y[i]),
                    y=scale_x * (8 + shifts_x[i]))

        # -----------------------------------------------------------------

        # render and wait (viz) time between frames
        self.viewer.render()
        # time.sleep(self.viz)
        # save gif
        if self.saveGif:

            image_data = pyglet.image.get_buffer_manager().get_color_buffer(
            ).get_image_data()
            data = image_data.get_data('RGB', image_data.width * 3)
            arr = np.array(bytearray(data)).astype('uint8')
            arr = np.flip(
                np.reshape(arr, (image_data.height, image_data.width, -1)), 0)
            im = Image.fromarray(arr)
            self.gif_buffer.append(im)

            if all(self.terminal):
                gifname = self.filename[0].split('.')[0] + '_{}.gif'.format(i)
                self.viewer.saveGif(gifname,
                                    arr=self.gif_buffer,
                                    duration=self.viz)
        if self.saveVideo:
            dirname = 'tmp_video_cardiac'
            if self.cnt <= 1:
                if os.path.isdir(dirname):
                    logger.warn(
                        """Log directory {} exists! Use 'd' to delete it. """.
                        format(dirname))
                    act = input("select action: d (delete) / q (quit): "
                                ).lower().strip()
                    if act == 'd':
                        shutil.rmtree(dirname, ignore_errors=True)
                    else:
                        raise OSError("Directory {} exits!".format(dirname))
                os.mkdir(dirname)

            frame = dirname + '/' + '%04d' % self.cnt + '.png'
            pyglet.image.get_buffer_manager().get_color_buffer().save(frame)
            if all(self.terminal):
                resolution = str(3 * self.viewer.img_width) + 'x' + str(
                    3 * self.viewer.img_height)
                save_cmd = [
                    'ffmpeg', '-f', 'image2', '-framerate', '30',
                    '-pattern_type', 'sequence', '-start_number', '0', '-r',
                    '6', '-i', dirname + '/%04d.png', '-s', resolution,
                    '-vcodec', 'libx264', '-b:v', '2567k',
                    self.filename[0] + '_{}_agents.mp4'.format(i + 1)
                ]
                subprocess.check_output(save_cmd)
                shutil.rmtree(dirname, ignore_errors=True)
Пример #10
0
    def display(self, return_rgb_array=False):
        # pass
        for i in range(0, self.agents):
            # get dimensions
            current_point = self._location[i]
            target_point = None
            if self.task != "play":
                target_point = self._target_loc[i]
            # print("_location", self._location)
            # print("_target_loc", self._target_loc)
            # print("current_point", current_point)
            # print("target_point", target_point)
            # get image and convert it to pyglet
            plane = self.get_plane(current_point[2], agent=i)  # z-plane
            # plane = np.squeeze(self._current_state()[:,:,13])
            img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
            # rescale image
            # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
            scale_x = 2
            scale_y = 2
            #
            img = cv2.resize(
                img,
                (int(scale_x * img.shape[1]), int(scale_y * img.shape[0])),
                interpolation=cv2.INTER_LINEAR,
            )
            # skip if there is a viewer open
            if (not self.viewer) and self.viz:
                from viewer import SimpleImageViewer

                self.viewer = SimpleImageViewer(arr=img,
                                                scale_x=1,
                                                scale_y=1,
                                                filepath=self.filepath[i] +
                                                str(i))
                self.gif_buffer = []
            # display image
            self.viewer.draw_image(img)
            # draw current point
            self.viewer.draw_circle(
                radius=scale_x * 1,
                pos_x=scale_x * current_point[0],
                pos_y=scale_y * current_point[1],
                color=(0.0, 0.0, 1.0, 1.0),
            )
            # draw a box around the agent - what the network sees ROI
            self.viewer.draw_rect(
                scale_x * self.rectangle[i].xmin,
                scale_y * self.rectangle[i].ymin,
                scale_x * self.rectangle[i].xmax,
                scale_y * self.rectangle[i].ymax,
            )
            self.viewer.display_text(
                "Agent " + str(i),
                color=(204, 204, 0, 255),
                x=scale_x * self.rectangle[i].xmin - 15,
                y=scale_y * self.rectangle[i].ymin,
            )
            # display info
            text = "Spacing " + str(self.xscale)
            self.viewer.display_text(text,
                                     color=(204, 204, 0, 255),
                                     x=10,
                                     y=self._image_dims[1] - 80)

            # ---------------------------------------------------------------------

            if self.task != "play":
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(
                    radius=diff_z,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 0.2),
                )
                # draw target point
                self.viewer.draw_circle(
                    radius=scale_x * 1,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 1.0),
                )
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = "Error " + str(round(self.cur_dist[i], 3)) + "mm"
                self.viewer.display_text(text, color=color, x=10, y=20)

            # ---------------------------------------------------------------------

            # render and wait (viz) time between frames
            self.viewer.render()
            # time.sleep(self.viz)
            # save gif
            if self.saveGif:
                image_data = (pyglet.image.get_buffer_manager().
                              get_color_buffer().get_image_data())
                data = image_data.get_data("RGB", image_data.width * 3)
                arr = np.array(bytearray(data)).astype("uint8")
                arr = np.flip(
                    np.reshape(arr, (image_data.height, image_data.width, -1)),
                    0)
                im = Image.fromarray(arr)
                self.gif_buffer.append(im)

                if not self.terminal[i]:
                    gifname = self.filepath[0] + ".gif"
                    self.viewer.saveGif(gifname,
                                        arr=self.gif_buffer,
                                        duration=self.viz)
            if self.saveVideo:
                dirname = "tmp_video"
                if self.cnt <= 1:
                    if os.path.isdir(dirname):
                        logger.warn(
                            """Log directory {} exists! Use 'd' to delete it. """
                            .format(dirname))
                        act = (input("select action: d (delete) / q (quit): ").
                               lower().strip())
                        if act == "d":
                            shutil.rmtree(dirname, ignore_errors=True)
                        else:
                            raise OSError(
                                "Directory {} exits!".format(dirname))
                    os.mkdir(dirname)

                frame = dirname + "/" + "%04d" % self.cnt + ".png"
                pyglet.image.get_buffer_manager().get_color_buffer().save(
                    frame)
                if self.terminal[i]:
                    resolution = (str(3 * self.viewer.img_width) + "x" +
                                  str(3 * self.viewer.img_height))
                    save_cmd = [
                        "ffmpeg",
                        "-f",
                        "image2",
                        "-framerate",
                        "30",
                        "-pattern_type",
                        "sequence",
                        "-start_number",
                        "0",
                        "-r",
                        "6",
                        "-i",
                        dirname + "/%04d.png",
                        "-s",
                        resolution,
                        "-vcodec",
                        "libx264",
                        "-b:v",
                        "2567k",
                        self.filepath[i] + ".mp4",
                    ]
                    subprocess.check_output(save_cmd)
                    shutil.rmtree(dirname, ignore_errors=True)
Пример #11
0
class MedicalPlayer(gym.Env):
    """Class that provides 3D medical image environment.
    This is just an implementation of the classic "agent-environment loop".
    Each time-step, the agent chooses an action, and the environment returns
    an observation and a reward."""
    def __init__(
        self,
        directory=None,
        viz=False,
        task=False,
        files_list=None,
        screen_dims=(27, 27, 27),
        history_length=20,
        multiscale=True,
        max_num_frames=0,
        saveGif=False,
        saveVideo=False,
        agents=2,
        fiducials=None,
        infDir="../inference",
    ):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum numbe0r of frames per episode.
        """
        # self.csvfile = 'DQN_fetal_US_agents_2_400k_RC_LC_CRP.csv'
        #
        # # if os.path.exists(self.csvfile): sys.exit('csv file exists')
        #
        # if task!='train':
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filename", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5, 0.25, 0.75]
        # y = [0.5, 0.25, 0.75]
        # z = [0.5, 0.25, 0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        ######################################################################

        super(MedicalPlayer, self).__init__()
        # number of agents
        self.agents = agents
        self.fiducials = fiducials
        # inits stat counters
        self.reset_stat()
        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)
        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []

        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = []
        self._qvalues_history = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = []
        self.rectangle = []
        for i in range(0, self.agents):
            self.current_episode_score.append(StatCounter())
            self._loc_history.append([(0, ) * self.dims] *
                                     self._history_length)
            self._qvalues_history.append([(0, ) * self.actions] *
                                         self._history_length)
            self.rectangle.append(Rectangle(
                0, 0, 0, 0, 0,
                0))  # initialize rectangle limits from input image coordinates

        # add your data loader here
        if self.task == "play":
            self.files = filesListBrainMRLandmark(
                files_list,
                returnLandmarks=False,
                eval=True,
                fiducials=fiducials,
                infDir=infDir,
                agents=self.agents,
            )
        else:
            if self.task == "eval":
                self.files = filesListBrainMRLandmark(
                    files_list,
                    returnLandmarks=True,
                    fiducials=fiducials,
                    eval=True,
                    infDir=infDir,
                    agents=self.agents,
                )
            else:
                self.files = filesListBrainMRLandmark(
                    files_list,
                    returnLandmarks=True,
                    fiducials=fiducials,
                    eval=False,
                    infDir=infDir,
                    agents=self.agents,
                )

        # prepare file sampler
        self.filepath = None
        self._image = None
        self._target_loc = None
        self.spacing = None
        self.sampled_files = self.files.sample_circular()
        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()

    def reset(self):
        # with _ALE_LOCK:
        self._restart_episode()
        return self._current_state()

    def _restart_episode(self):
        """
        restart current episoide
        """
        self.terminal = [False] * self.agents
        self.reward = np.zeros((self.agents, ))
        self.cnt = 0  # counter to limit number of steps per episodes
        self.num_games.feed(1)
        self._loc_history = []
        self._qvalues_history = []
        for i in range(0, self.agents):
            self.current_episode_score[i].reset()

            self._loc_history.append([(0, ) * self.dims] *
                                     self._history_length)
            # list of q-value lists
            self._qvalues_history.append([(0, ) * self.actions] *
                                         self._history_length)

        self.new_random_game()

    def new_random_game(self):
        """
        load image,
        set dimensions,
        randomize start point,
        init _screen, qvals,
        calc distance to goal
        """
        self.terminal = [False] * self.agents

        self.viewer = None

        #
        # if self.task!='train':
        #     #######################################################################
        #     ## generate results for yuwanwei landmark miccai2018 paper
        #     ## save results in csv file
        #     if self.count_points == 0:
        #         print('\n============== new game ===============\n')
        #         # save results
        #         if self.total_loc:
        #             with open(self.csvfile, 'a') as outcsv:
        #                 fields = [self.filename, self.cur_dist]
        #                 writer = csv.writer(outcsv)
        #                 writer.writerow(map(lambda x: x, fields))
        #             self.total_loc = []
        #         # sample a new image
        #         self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files)
        #         scale = next(self.start_points)
        #         self.count_points += 1
        #     else:
        #         self.count_points += 1
        #         logger.info('count_points {}'.format(self.count_points))
        #         scale = next(self.start_points)
        #
        #     x_temp = int(scale[0] * self._image[0].dims[0])
        #     y_temp = int(scale[1] * self._image[0].dims[1])
        #     z_temp = int(scale[2] * self._image[0].dims[2])
        #     logger.info('starting point {}-{}-{}'.format(x_temp, y_temp, z_temp))
        #     #######################################################################
        # else:
        self._image, self._target_loc, self.filepath, self.spacing = next(
            self.sampled_files)
        # multiscale (e.g. start with 3 -> 2 -> 1)
        # scale can be thought of as sampling stride
        if self.multiscale:
            ## brain
            self.action_step = 9
            self.xscale = 3
            self.yscale = 3
            self.zscale = 3
            ## cardiac
            # self.action_step =   6
            # self.xscale = 2
            # self.yscale = 2
            # self.zscale = 2
        else:
            self.action_step = 1
            self.xscale = 1
            self.yscale = 1
            self.zscale = 1
        # image volume size
        self._image_dims = self._image[0].dims

        #######################################################################
        # ## select random starting point
        # # add padding to avoid start right on the border of the image
        # if self.task == "train":
        #     skip_thickness = (
        #         (int)(self._image_dims[0] / 5),
        #         (int)(self._image_dims[1] / 5),
        #         (int)(self._image_dims[2] / 5),
        #     )
        # else:
        #     skip_thickness = (
        #         int(self._image_dims[0] / 4),
        #         int(self._image_dims[1] / 4),
        #         int(self._image_dims[2] / 4),
        #     )
        #
        # # if self.task == 'train':
        # x = []
        # y = []
        # z = []
        # for i in range(0, self.agents):
        #     x.append(
        #         self.rng.randint(
        #             0 + skip_thickness[0], self._image_dims[0] - skip_thickness[0]
        #         )
        #     )
        #     y.append(
        #         self.rng.randint(
        #             0 + skip_thickness[1], self._image_dims[1] - skip_thickness[1]
        #         )
        #     )
        #     z.append(
        #         self.rng.randint(
        #             0 + skip_thickness[2], self._image_dims[2] - skip_thickness[2]
        #         )
        #     )
        # # else:
        # #     x=[]
        # #     y=[]
        # #     z=[]
        # #     for i in range(0,self.agents):
        # #         x.append(x_temp)
        # #         y.append(y_temp)
        # #         z.append(z_temp)
        #
        #######################################################################

        self._location = []
        self._start_location = []
        for i in self.fiducials:
            self._location.append(tuple(meanFiducialLocations[i]))
            self._start_location.append(tuple(meanFiducialLocations[i]))
        self._qvalues = [[0] * self.actions] * self.agents
        self._screen = self._current_state()

        if self.task == "play":
            self.cur_dist = [0] * self.agents
        else:
            self.cur_dist = []
            for i in range(0, self.agents):
                self.cur_dist.append(
                    self.calcDistance(self._location[i], self._target_loc[i],
                                      self.spacing))

    def calcDistance(self, points1, points2, spacing=(1, 1, 1)):
        """ calculate the distance between two points in mm"""
        spacing = np.array(spacing)
        points1 = spacing * np.array(points1)
        points2 = spacing * np.array(points2)
        return np.linalg.norm(points1 - points2)

    def step(self, act, q_values, isOver):
        """The environment's step function returns exactly what we need.
        Args:
          act:
        Returns:
          observation (object):
            an environment-specific object representing your observation of
            the environment. For example, pixel data from a camera, joint angles
            and joint velocities of a robot, or the board state in a board game.
          reward (float):
            amount of reward achieved by the previous action. The scale varies
            between environments, but the goal is always to increase your total
            reward.
          done (boolean):
            whether it's time to reset the environment again. Most (but not all)
            tasks are divided up into well-defined episodes, and done being True
            indicates the episode has terminated. (For example, perhaps the pole
            tipped too far, or you lost your last life.)
          info (dict):
            diagnostic information useful for debugging. It can sometimes be
            useful for learning (for example, it might contain the raw
            probabilities behind the environment's last state change). However,
            official evaluations of your agent are not allowed to use this for
            learning.
        """
        for i in range(0, self.agents):
            if isOver[i]:
                act[i] = 10
        self._qvalues = q_values
        current_loc = self._location
        next_location = copy.deepcopy(current_loc)

        self.terminal = [False] * self.agents
        go_out = [False] * self.agents
        ###################### agent 1 movement #####################################
        for i in range(0, self.agents):
            # UP Z+ -----------------------------------------------------------
            if act[i] == 0:
                next_location[i] = (
                    current_loc[i][0],
                    current_loc[i][1],
                    round(current_loc[i][2] + self.action_step),
                )
                if next_location[i][2] >= self._image_dims[2]:
                    # print(' trying to go out the image Z+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True

            # FORWARD Y+ ---------------------------------------------------------
            if act[i] == 1:
                next_location[i] = (
                    current_loc[i][0],
                    round(current_loc[i][1] + self.action_step),
                    current_loc[i][2],
                )
                if next_location[i][1] >= self._image_dims[1]:
                    # print(' trying to go out the image Y+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # RIGHT X+ -----------------------------------------------------------
            if act[i] == 2:
                next_location[i] = (
                    round(current_loc[i][0] + self.action_step),
                    current_loc[i][1],
                    current_loc[i][2],
                )
                if next_location[i][0] >= self._image_dims[0]:
                    # print(' trying to go out the image X+ ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # LEFT X- -----------------------------------------------------------
            if act[i] == 3:
                next_location[i] = (
                    round(current_loc[i][0] - self.action_step),
                    current_loc[i][1],
                    current_loc[i][2],
                )
                if next_location[i][0] <= 0:
                    # print(' trying to go out the image X- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # BACKWARD Y- ---------------------------------------------------------
            if act[i] == 4:
                next_location[i] = (
                    current_loc[i][0],
                    round(current_loc[i][1] - self.action_step),
                    current_loc[i][2],
                )
                if next_location[i][1] <= 0:
                    # print(' trying to go out the image Y- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # DOWN Z- -----------------------------------------------------------
            if act[i] == 5:
                next_location[i] = (
                    current_loc[i][0],
                    current_loc[i][1],
                    round(current_loc[i][2] - self.action_step),
                )
                if next_location[i][2] <= 0:
                    # print(' trying to go out the image Z- ',)
                    next_location[i] = current_loc[i]
                    go_out[i] = True
            # ---------------------------------------------------------------------
        #############################################################################

        # ---------------------------------------------------------------------
        # punish -1 reward if the agent tries to go out
        if self.task != "play":
            for i in range(0, self.agents):
                if go_out[i]:
                    self.reward[i] = -1
                else:
                    self.reward[i] = self._calc_reward(current_loc[i],
                                                       next_location[i],
                                                       agent=i)

        # update screen, reward ,location, terminal
        self._location = next_location
        self._screen = self._current_state()

        # terminate if the distance is less than 1 during trainig
        if self.task == "train":
            for i in range(0, self.agents):
                if self.cur_dist[i] <= 1:
                    self.terminal[i] = True
                    self.num_success[i].feed(1)

        # terminate if maximum number of steps is reached
        self.cnt += 1
        if self.cnt >= self.max_num_frames:
            for i in range(0, self.agents):
                self.terminal[i] = True

        # update history buffer with new location and qvalues
        if self.task != "play":
            for i in range(0, self.agents):
                self.cur_dist[i] = self.calcDistance(self._location[i],
                                                     self._target_loc[i],
                                                     self.spacing)

        self._update_history()

        # check if agent oscillates
        if self._oscillate:
            self._location = self.getBestLocation()
            # self._location=[item for sublist in temp for item in sublist]
            self._screen = self._current_state()

            if self.task != "play":
                for i in range(0, self.agents):
                    self.cur_dist[i] = self.calcDistance(
                        self._location[i], self._target_loc[i], self.spacing)

            # multi-scale steps
            if self.multiscale:
                if self.xscale > 1:
                    self.xscale -= 1
                    self.yscale -= 1
                    self.zscale -= 1
                    self.action_step = int(self.action_step / 3)
                    self._clear_history()
                # terminate if scale is less than 1
                else:

                    for i in range(0, self.agents):
                        self.terminal[i] = True
                        if self.cur_dist[i] <= 1:
                            self.num_success[i].feed(1)

            else:

                for i in range(0, self.agents):
                    self.terminal[i] = True
                    if self.cur_dist[i] <= 1:
                        self.num_success[i].feed(1)
        # render screen if viz is on
        with _ALE_LOCK:
            if self.viz:
                if isinstance(self.viz, float):
                    self.display()

        distance_error = self.cur_dist
        for i in range(0, self.agents):
            self.current_episode_score[i].feed(self.reward[i])

        info = {}
        for i in range(0, self.agents):
            info["score_{}".format(i)] = self.current_episode_score[i].sum
            info["gameOver_{}".format(i)] = self.terminal[i]
            info["distError_{}".format(i)] = distance_error[i]
            info["filename_{}".format(i)] = self.filepath[i]
            info["location_{}".format(i)] = self._location[i]

            #######################################################################
            ## generate results for yuwanwei landmark miccai2018 paper

        # if all(self.terminal):
        #     logger.info(info)
        #     self.total_loc.append(self._location)
        #     if not (self.count_points == 5):
        #         self._restart_episode()
        #     else:
        #         mean_location = np.mean(self.total_loc, axis=0)
        #         for i in range(0,self.agents):
        #             logger.info('agent {}  \n mean_location{}'.format(i, mean_location[i]))
        #             if self.task != 'play':
        #                 self.cur_dist[i] = self.calcDistance(mean_location[i],
        #                                               self._target_loc[i],
        #                                               self.spacing[i])
        #                 logger.info('agent {} , final distance error {} \n'.format(i,self.cur_dist[i]))
        #         self.count_points = 0
        #######################################################################

        return self._current_state(), self.reward, self.terminal, info

    def getBestLocation(self):
        """ get best location with best qvalue from last for locations
        stored in history
        """
        best_location = []
        for i in range(0, self.agents):
            last_qvalues_history = self._qvalues_history[i][-4:]
            last_loc_history = self._loc_history[i][-4:]
            best_qvalues = np.max(last_qvalues_history, axis=1)
            best_idx = best_qvalues.argmax()
            best_location.append(last_loc_history[best_idx])
        #
        # last_qvalues_history=[]
        # last_loc_history=[]
        # best_qvalues=[]
        # best_idx=[]
        #
        # for i in range(0,self.agents):
        #     last_qvalues_history.append(self._qvalues_history[i][-4:])
        #     last_loc_history.append( self._loc_history[i][-4:])
        #     best_qvalues.append(np.max(last_qvalues_history[i], axis=1))
        #     best_idx.append(best_qvalues[i].argmin())
        #     best_location.append(last_loc_history[best_idx[i]])

        return best_location

    def _clear_history(self):
        """ clear history buffer with current state
        """
        self._loc_history = []
        self._qvalues_history = []
        for i in range(0, self.agents):
            self._loc_history.append([(0, ) * self.dims] *
                                     self._history_length)
            self._qvalues_history.append([(0, ) * self.actions] *
                                         self._history_length)

    def _update_history(self):
        """ update history buffer with current state
        """
        # update location history
        for i in range(0, self.agents):
            self._loc_history[i][:-1] = self._loc_history[i][1:]
            self._loc_history[i][-1] = self._location[i]

            # update q-value history
            self._qvalues_history[i][:-1] = self._qvalues_history[i][1:]
            self._qvalues_history[i][-1] = np.ravel(self._qvalues[i])

    def _current_state(self):
        """
        crop image data around current location to update what network sees.
        update rectangle

        :return: new state
        """
        # initialize screen with zeros - all background

        screen = np.zeros(
            (self.agents, self.screen_dims[0], self.screen_dims[1],
             self.screen_dims[2])).astype(self._image[0].data.dtype)

        for i in range(0, self.agents):
            # screen uses coordinate system relative to origin (0, 0, 0)
            screen_xmin, screen_ymin, screen_zmin = 0, 0, 0
            screen_xmax, screen_ymax, screen_zmax = self.screen_dims

            # extract boundary locations using coordinate system relative to "global" image
            # width, height, depth in terms of screen coord system

            if self.xscale % 2:
                xmin = self._location[i][0] - int(
                    self.width * self.xscale / 2) - 1
                xmax = self._location[i][0] + int(self.width * self.xscale / 2)
                ymin = self._location[i][1] - int(
                    self.height * self.yscale / 2) - 1
                ymax = self._location[i][1] + int(
                    self.height * self.yscale / 2)
                zmin = self._location[i][2] - int(
                    self.depth * self.zscale / 2) - 1
                zmax = self._location[i][2] + int(self.depth * self.zscale / 2)
            else:
                xmin = self._location[i][0] - round(
                    self.width * self.xscale / 2)
                xmax = self._location[i][0] + round(
                    self.width * self.xscale / 2)
                ymin = self._location[i][1] - round(
                    self.height * self.yscale / 2)
                ymax = self._location[i][1] + round(
                    self.height * self.yscale / 2)
                zmin = self._location[i][2] - round(
                    self.depth * self.zscale / 2)
                zmax = self._location[i][2] + round(
                    self.depth * self.zscale / 2)

            ###########################################################

            # check if they violate image boundary and fix it
            if xmin < 0:
                xmin = 0
                screen_xmin = screen_xmax - len(
                    np.arange(xmin, xmax, self.xscale))
            if ymin < 0:
                ymin = 0
                screen_ymin = screen_ymax - len(
                    np.arange(ymin, ymax, self.yscale))
            if zmin < 0:
                zmin = 0
                screen_zmin = screen_zmax - len(
                    np.arange(zmin, zmax, self.zscale))
            if xmax > self._image_dims[0]:
                xmax = self._image_dims[0]
                screen_xmax = screen_xmin + len(
                    np.arange(xmin, xmax, self.xscale))
            if ymax > self._image_dims[1]:
                ymax = self._image_dims[1]
                screen_ymax = screen_ymin + len(
                    np.arange(ymin, ymax, self.yscale))
            if zmax > self._image_dims[2]:
                zmax = self._image_dims[2]
                screen_zmax = screen_zmin + len(
                    np.arange(zmin, zmax, self.zscale))

            # crop image data to update what network sees
            # image coordinate system becomes screen coordinates
            # scale can be thought of as a stride
            screen[i, screen_xmin:screen_xmax, screen_ymin:screen_ymax,
                   screen_zmin:screen_zmax, ] = self._image[i].data[
                       xmin:xmax:self.xscale, ymin:ymax:self.yscale,
                       zmin:zmax:self.zscale, ]

            ###########################################################
            # update rectangle limits from input image coordinates
            # this is what the network sees
            self.rectangle[i] = Rectangle(xmin, xmax, ymin, ymax, zmin, zmax)

        return screen

    def get_plane(self, z=0, agent=0):
        return self._image[agent].data[:, :, z]

    def _calc_reward(self, current_loc, next_loc, agent):
        """ Calculate the new reward based on the decrease in euclidean distance to the target location
        """

        curr_dist = self.calcDistance(current_loc, self._target_loc[agent],
                                      self.spacing)
        next_dist = self.calcDistance(next_loc, self._target_loc[agent],
                                      self.spacing)
        dist = curr_dist - next_dist

        return dist

    @property
    def _oscillate(self):
        """ Return True if the agent is stuck and oscillating
        """
        counter = []
        freq = []
        for i in range(0, self.agents):
            counter.append(Counter(self._loc_history[i]))
            freq.append(counter[i].most_common())

            if freq[i][0][0] == (0, 0, 0):
                if freq[i][1][1] > 3:
                    return True
                else:
                    return False
            elif freq[i][0][1] > 3:
                return True

    def get_action_meanings(self):
        """ return array of integers for actions"""
        ACTION_MEANING = {
            1: "UP",  # MOVE Z+
            2: "FORWARD",  # MOVE Y+
            3: "RIGHT",  # MOVE X+
            4: "LEFT",  # MOVE X-
            5: "BACKWARD",  # MOVE Y-
            6: "DOWN",  # MOVE Z-
        }
        return [ACTION_MEANING[i] for i in self.actions]

    @property
    def getScreenDims(self):
        """
        return screen dimensions
        """
        return (self.width, self.height, self.depth)

    def lives(self):
        return None

    def reset_stat(self):
        """ Reset all statistics counter"""
        self.stats = defaultdict(list)
        self.num_games = StatCounter()
        self.num_success = [StatCounter()] * int(self.agents)

    def display(self, return_rgb_array=False):
        # pass
        for i in range(0, self.agents):
            # get dimensions
            current_point = self._location[i]
            target_point = None
            if self.task != "play":
                target_point = self._target_loc[i]
            # print("_location", self._location)
            # print("_target_loc", self._target_loc)
            # print("current_point", current_point)
            # print("target_point", target_point)
            # get image and convert it to pyglet
            plane = self.get_plane(current_point[2], agent=i)  # z-plane
            # plane = np.squeeze(self._current_state()[:,:,13])
            img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB)  # congvert to rgb
            # rescale image
            # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4
            scale_x = 2
            scale_y = 2
            #
            img = cv2.resize(
                img,
                (int(scale_x * img.shape[1]), int(scale_y * img.shape[0])),
                interpolation=cv2.INTER_LINEAR,
            )
            # skip if there is a viewer open
            if (not self.viewer) and self.viz:
                from viewer import SimpleImageViewer

                self.viewer = SimpleImageViewer(arr=img,
                                                scale_x=1,
                                                scale_y=1,
                                                filepath=self.filepath[i] +
                                                str(i))
                self.gif_buffer = []
            # display image
            self.viewer.draw_image(img)
            # draw current point
            self.viewer.draw_circle(
                radius=scale_x * 1,
                pos_x=scale_x * current_point[0],
                pos_y=scale_y * current_point[1],
                color=(0.0, 0.0, 1.0, 1.0),
            )
            # draw a box around the agent - what the network sees ROI
            self.viewer.draw_rect(
                scale_x * self.rectangle[i].xmin,
                scale_y * self.rectangle[i].ymin,
                scale_x * self.rectangle[i].xmax,
                scale_y * self.rectangle[i].ymax,
            )
            self.viewer.display_text(
                "Agent " + str(i),
                color=(204, 204, 0, 255),
                x=scale_x * self.rectangle[i].xmin - 15,
                y=scale_y * self.rectangle[i].ymin,
            )
            # display info
            text = "Spacing " + str(self.xscale)
            self.viewer.display_text(text,
                                     color=(204, 204, 0, 255),
                                     x=10,
                                     y=self._image_dims[1] - 80)

            # ---------------------------------------------------------------------

            if self.task != "play":
                # draw a transparent circle around target point with variable radius
                # based on the difference z-direction
                diff_z = scale_x * abs(current_point[2] - target_point[2])
                self.viewer.draw_circle(
                    radius=diff_z,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 0.2),
                )
                # draw target point
                self.viewer.draw_circle(
                    radius=scale_x * 1,
                    pos_x=scale_x * target_point[0],
                    pos_y=scale_y * target_point[1],
                    color=(1.0, 0.0, 0.0, 1.0),
                )
                # display info
                color = (0, 204, 0, 255) if self.reward[i] > 0 else (204, 0, 0,
                                                                     255)
                text = "Error " + str(round(self.cur_dist[i], 3)) + "mm"
                self.viewer.display_text(text, color=color, x=10, y=20)

            # ---------------------------------------------------------------------

            # render and wait (viz) time between frames
            self.viewer.render()
            # time.sleep(self.viz)
            # save gif
            if self.saveGif:
                image_data = (pyglet.image.get_buffer_manager().
                              get_color_buffer().get_image_data())
                data = image_data.get_data("RGB", image_data.width * 3)
                arr = np.array(bytearray(data)).astype("uint8")
                arr = np.flip(
                    np.reshape(arr, (image_data.height, image_data.width, -1)),
                    0)
                im = Image.fromarray(arr)
                self.gif_buffer.append(im)

                if not self.terminal[i]:
                    gifname = self.filepath[0] + ".gif"
                    self.viewer.saveGif(gifname,
                                        arr=self.gif_buffer,
                                        duration=self.viz)
            if self.saveVideo:
                dirname = "tmp_video"
                if self.cnt <= 1:
                    if os.path.isdir(dirname):
                        logger.warn(
                            """Log directory {} exists! Use 'd' to delete it. """
                            .format(dirname))
                        act = (input("select action: d (delete) / q (quit): ").
                               lower().strip())
                        if act == "d":
                            shutil.rmtree(dirname, ignore_errors=True)
                        else:
                            raise OSError(
                                "Directory {} exits!".format(dirname))
                    os.mkdir(dirname)

                frame = dirname + "/" + "%04d" % self.cnt + ".png"
                pyglet.image.get_buffer_manager().get_color_buffer().save(
                    frame)
                if self.terminal[i]:
                    resolution = (str(3 * self.viewer.img_width) + "x" +
                                  str(3 * self.viewer.img_height))
                    save_cmd = [
                        "ffmpeg",
                        "-f",
                        "image2",
                        "-framerate",
                        "30",
                        "-pattern_type",
                        "sequence",
                        "-start_number",
                        "0",
                        "-r",
                        "6",
                        "-i",
                        dirname + "/%04d.png",
                        "-s",
                        resolution,
                        "-vcodec",
                        "libx264",
                        "-b:v",
                        "2567k",
                        self.filepath[i] + ".mp4",
                    ]
                    subprocess.check_output(save_cmd)
                    shutil.rmtree(dirname, ignore_errors=True)