Beispiel #1
0
 def _init(self):
     logger.info("[{}]: agent init, isTrain={}".format(self._agentIdent, self._isTrain))
     self._episodeCount = -1
     from tensorpack.utils.utils import get_rng
     self._rng = get_rng(self)
     from tensorpack.utils.stats import StatCounter
     self.reset_stat()
     self.rwd_counter = StatCounter()
     self._memorySaver = None
     save_dir = self._kwargs.pop('save_dir', None)
     if save_dir is not None:
         self._memorySaver = MemorySaver(save_dir,
                                         self._kwargs.pop('max_save_item', 3),
                                         self._kwargs.pop('min_save_score', None),
                                         )
     self.restart_episode()
     pass
Beispiel #2
0
    def __init__(self,
                 predictor_io_names,
                 player,
                 state_shape,
                 batch_size,
                 memory_size, init_memory_size,
                 init_exploration,
                 update_frequency, history_len):
        """
        Args:
            predictor_io_names (tuple of list of str): input/output names to
                predict Q value from state.
            player (RLEnvironment): the player.
            state_shape (tuple): h, w, c
            history_len (int): length of history frames to concat. Zero-filled
                initial frames.
            update_frequency (int): number of new transitions to add to memory
                after sampling a batch of transitions for training.
        """
        assert len(state_shape) == 3, state_shape
        init_memory_size = int(init_memory_size)

        for k, v in locals().items():
            if k != 'self':
                setattr(self, k, v)
        self.exploration = init_exploration
        self.num_actions = player.action_space.n
        logger.info("Number of Legal actions: {}".format(self.num_actions))

        self.rng = get_rng(self)
        self._init_memory_flag = threading.Event()  # tell if memory has been initialized

        # a queue to receive notifications to populate memory
        self._populate_job_queue = queue.Queue(maxsize=5)

        self.mem = ReplayMemory(memory_size, state_shape, history_len)
        self._current_ob = self.player.reset()
        self._player_scores = StatCounter()
        self._current_game_score = StatCounter()
    def __init__(self,
                 directory=None,
                 viz=False,
                 task=False,
                 files_list=None,
                 screen_dims=(27, 27, 27),
                 history_length=20,
                 multiscale=True,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum numbe0r of frames per episode.
        """
        # ######################################################################
        # ## generate evaluation results from 19 different points
        # ## save results in csv file
        # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv'
        # if not train:
        #     with open(self.csvfile, 'w') as outcsv:
        #         fields = ["filename", "dist_error"]
        #         writer = csv.writer(outcsv)
        #         writer.writerow(map(lambda x: x, fields))
        #
        # x = [0.5,0.25,0.75]
        # y = [0.5,0.25,0.75]
        # z = [0.5,0.25,0.75]
        # self.start_points = []
        # for combination in itertools.product(x, y, z):
        #     if 0.5 in combination: self.start_points.append(combination)
        # self.start_points = itertools.cycle(self.start_points)
        # self.count_points = 0
        # self.total_loc = []
        # ######################################################################

        super(MedicalPlayer, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)
        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [(0, ) * self.dims] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        # initialize rectangle limits from input image coordinates
        self.rectangle = Rectangle(0, 0, 0, 0, 0, 0)
        # add your data loader here
        if self.task == 'play':
            self.files = filesListBrainMRLandmark(files_list,
                                                  returnLandmarks=False)
        else:
            self.files = filesListBrainMRLandmark(files_list,
                                                  returnLandmarks=True)

        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular()
        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()
Beispiel #4
0
 def reset_state(self):
     """ Reset the RNG """
     self.rng = get_rng(self)
     super(lymphomaBase, self).reset_state()
Beispiel #5
0
 def __init__(self, pipe_c2s, pipe_s2c, model):
     super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
     self.M = model
     self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
     from tensorpack.utils.utils import get_rng
     self._rng = get_rng(self)
Beispiel #6
0
 def __init__(self):
     self.rng = get_rng(self)
Beispiel #7
0
 def __init__(self, pipe_c2s, pipe_s2c, model):
     super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
     self.M = model
     self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
     from tensorpack.utils.utils import get_rng
     self._rng = get_rng(self)
Beispiel #8
0
    def __init__(self,
                 data_dir='/home/Pearl/quantm/QuadRand/data/SNEMI/',
                 viz=0,
                 frame_skip=8,
                 nullop_start=30,
                 max_num_frames=0):
        """
        Args:
            data_dir: path to the image 
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(ImagePlayer, self).__init__()

        # Read the image here
        import glob, skimage.io, cv2
        from natsort import natsorted
        self.imageFiles = natsorted(
            glob.glob(os.path.join(data_dir, 'images/*.tif')))
        self.labelFiles = natsorted(
            glob.glob(os.path.join(data_dir, 'labels/*.tif')))

        # self.images = cv2.imread(self.imageFiles[idx], cv2.IMREAD_GRAYSCALE).astype(np.uint8)
        # self.labels = cv2.imread(self.labelFiles[idx], cv2.IMREAD_GRAYSCALE).astype(np.uint8)
        print(self.imageFiles)
        print(self.labelFiles)
        self.images = []
        self.labels = []
        for imageFile in self.imageFiles:
            self.images.append(skimage.io.imread(imageFile))
        for labelFile in self.labelFiles:
            self.labels.append(skimage.io.imread(labelFile))  # 3x100x1024x1024
        # self.images = np.array(self.images)
        # self.labels = np.array(self.labels)
        self.image = None
        self.label = None
        self.estim = None
        self.heap = None
        self.root = None

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _SAFE_LOCK:
            self.rng = get_rng(self)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(data_dir)
                cv2.namedWindow(self.windowname)

        self.width = DIMX
        self.height = DIMY
        self.actions = [0, 1, 2, 3, 4, 5,
                        6]  # Modify here {0, 1, 2, 3, 4, 5, 6: bg, fg, split}
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width, 1),
                                            dtype=np.uint8)
        self._restart_episode()
Beispiel #9
0
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width),
                                            dtype=np.uint8)
        self._restart_episode()
 def reset_state(self):
     self.rng = get_rng(self)
     print self.is_training
Beispiel #11
0
    def __init__(self,
                 directory=None,
                 viz=False,
                 task=False,
                 files_list=None,
                 file_type="brain",
                 landmark_ids=None,
                 screen_dims=(27, 27, 27),
                 history_length=28,
                 multiscale=True,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False,
                 agents=1,
                 oscillations_allowed=4,
                 logger=None):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum numbe0r of frames per episode.
        """
        super(MedicalPlayer, self).__init__()
        self.agents = agents
        self.oscillations_allowed = oscillations_allowed
        self.logger = logger
        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.screen_dims = screen_dims
        self.dims = len(self.screen_dims)
        # multi-scale agent
        self.multiscale = multiscale

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = screen_dims
        else:
            self.width, self.height, self.depth = screen_dims

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []
        # stat counter to store current score or accumlated reward
        self.current_episode_score = [StatCounter()] * self.agents
        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self.screen_dims,
                                            dtype=np.uint8)

        # history buffer for storing last locations to check oscilations
        self._history_length = history_length
        self._loc_history = [[(0, ) * self.dims
                              for _ in range(self._history_length)]
                             for _ in range(self.agents)]
        self._qvalues_history = [[(0, ) * self.actions
                                  for _ in range(self._history_length)]
                                 for _ in range(self.agents)]
        # initialize rectangle limits from input image coordinates
        self.rectangle = [Rectangle(0, 0, 0, 0, 0, 0)] * int(self.agents)

        returnLandmarks = (self.task != 'play')

        # add your data loader here
        if file_type == "brain":
            self.files = filesListBrainMRLandmark(files_list, returnLandmarks,
                                                  self.agents)
        elif file_type == "cardiac":
            self.files = filesListCardioLandmark(files_list, returnLandmarks,
                                                 self.agents)
        elif file_type == "fetal":
            self.files = filesListFetalUSLandmark(files_list, returnLandmarks,
                                                  self.agents)

        # prepare file sampler
        self.filepath = None
        self.sampled_files = self.files.sample_circular(landmark_ids)
        # reset buffer, terminal, counters, and init new_random_game
        self._restart_episode()
Beispiel #12
0
 def __init__(self, num):
     super(DiscreteActionSpace, self).__init__()
     self.num = num
     self.rng = get_rng(self)
Beispiel #13
0
    def __init__(self,
                 directory=None,
                 files_list=None,
                 viz=False,
                 train=False,
                 screen_dims=(27, 27, 27),
                 spacing=(1, 1, 1),
                 nullop_start=30,
                 history_length=30,
                 max_num_frames=0,
                 saveGif=False,
                 saveVideo=False):
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param screen_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param history_length: consider lost of lives as end of
            episode (useful for training)
        """
        super(MedicalPlayer, self).__init__()

        self.reset_stat()
        self._supervised = False
        self._init_action_angle_step = 8
        self._init_action_dist_step = 4

        ## save results in csv file
        # brain_adult_spacing_multi_3_actionh_8_4_unsupervised_duel_double_smooth_batch_32_layers_8_fold_1_model_2000000
        self.csvfile = 'dummy.csv'
        if not train:
            with open(self.csvfile, 'w') as outcsv:
                fields = ["filename", "dist_error", "angle_error"]
                writer = csv.writer(outcsv)
                writer.writerow(map(lambda x: x, fields))

        # read files from directory - add your data loader here
        self.files = filesListBrainMRPlane(directory, files_list)

        # prepare file sampler
        self.sampled_files = self.files.sample_circular()
        self.filepath = None
        # maximum number of frames (steps) per episodes
        self.cnt = 0
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.train = train
        # image dimension (2D/3D)
        self._plane_size = screen_dims
        self.dims = len(self._plane_size)
        if self.dims == 2:
            self.width, self.height = self._plane_size
        else:
            self.width, self.height, self.depth = self._plane_size
        # plane sampling spacings
        self.init_spacing = np.array(spacing)
        # stat counter to store current score or accumlated reward
        self.current_episode_score = StatCounter()
        # get action space and minimal action set
        self.action_space = spaces.Discrete(8)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=self._plane_size)
        # history buffer for storing last locations to check oscillations
        self._history_length = history_length
        # circular buffer to store plane parameters history [4,history_length]
        self._plane_history = deque(maxlen=self._history_length)
        self._bestq_history = deque(maxlen=self._history_length)
        self._dist_history = deque(maxlen=self._history_length)
        self._dist_history_params = deque(maxlen=self._history_length)
        self._dist_supervised_history = deque(maxlen=self._history_length)
        # self._loc_history = [(0,) * self.dims] * self._history_length
        self._loc_history = [(0, ) * 4] * self._history_length
        self._qvalues_history = [(0, ) * self.actions] * self._history_length
        self._qvalues = [
            0,
        ] * self.actions

        with _ALE_LOCK:
            self.rng = get_rng(self)
            # visualization setup
            if isinstance(viz, six.string_types):  # check if viz is a string
                assert os.path.isdir(viz), viz
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.viewer = None
                self.gif_buffer = []

        self._restart_episode()
Beispiel #14
0
    def __init__(self, rom_file, viz=0,
                 frame_skip=4, nullop_start=30,
                 live_lost_as_eoe=True, max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(self.height, self.width))
        self._restart_episode()
Beispiel #15
0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info


if __name__ == '__main__':
    import sys

    a = AtariPlayer(sys.argv[1], viz=0.03)
    num = a.action_space.n
    rng = get_rng(num)
    while True:
        act = rng.choice(range(num))
        state, reward, isOver, info = a.step(act)
        if isOver:
            print(info)
            a.reset()
        print("Reward:", reward)
Beispiel #16
0
 def reset_state(self):
     self.rng = get_rng(self)
Beispiel #17
0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info


if __name__ == '__main__':
    import sys

    a = AtariPlayer(sys.argv[1], viz=0.03)
    num = a.action_space.n
    rng = get_rng(num)
    while True:
        act = rng.choice(range(num))
        state, reward, isOver, info = a.step(act)
        if isOver:
            print(info)
            a.reset()
        print("Reward:", reward)
Beispiel #18
0
    for i in xrange(x):
        for j in xrange(y):
            im = image[i:i + xo, j:j + yo]
            if im.shape == obj.shape:
                if method == 'absolute':
                    diff = float(np.sum(im != obj)) / obj.size
                elif method == 'relative':
                    diff = np.sum(abs(im - obj) / 255.0) / obj.size
            if diff < max_diff:
                answers.append((i, j))
    return answers


if __name__ == '__main__':
    player = get_player()
    rng = get_rng()

    task = 'detect'

    import matplotlib.pyplot as plt
    if task == 'save':
        for i in xrange(50):
            #random_action = rng.choice(range(NUM_ACTIONS))
            player.action(1)
            # Original image: (210 * 160 * 3) by print player.current_state().shape
            # print player.current_state().shape
            #if i == 25:
            #    player.restart_episode()
            file_name = 'freeway/' + ENV_NAME + '_' + str(i)
            X = player.current_state()
            np.save(file_name, X)
Beispiel #19
0
 def __init__(self, wmin, hmin, wmax=None, hmax=None):
     self.rng = get_rng()
     super().__init__(wmin, hmin, wmax, hmax)
Beispiel #20
0
    def __init__(self, directory=None, viz=False, task=False, files_list=None,
                 observation_dims=(27, 27, 27), multiscale=False,  # FIXME automatic dimensions
                 max_num_frames=0, saveGif=False, saveVideo=False):  # FIXME hardcoded max num frames!
        """
        :param train_directory: environment or game name
        :param viz: visualization
            set to 0 to disable
            set to +ve number to be the delay between frames to show
            set to a string to be the directory for storing frames
        :param observation_dims: shape of the frame cropped from the image to feed
            it to dqn (d,w,h) - defaults (27,27,27)
        :param nullop_start: start with random number of null ops
        :param location_history_length: consider lost of lives as end of
            episode (useful for training)
        :max_num_frames: maximum number of frames per episode.
        """
        super(Brain_Env, self).__init__()

        # inits stat counters
        self.reset_stat()

        # counter to limit number of steps per episodes
        self.cnt = 0
        # maximum number of frames (steps) per episodes
        self.max_num_frames = max_num_frames
        # stores information: terminal, score, distError
        self.info = None
        # option to save display as gif
        self.saveGif = saveGif
        self.saveVideo = saveVideo
        # training flag
        self.task = task
        # image dimension (2D/3D)
        self.observation_dims = observation_dims
        self.dims = len(self.observation_dims)
        # multi-scale agent
        self.multiscale = multiscale
        # FIXME force multiscale false for now
        self.multiscale = False

        # init env dimensions
        if self.dims == 2:
            self.width, self.height = observation_dims
        elif self.dims == 3:
            self.width, self.height, self.depth = observation_dims
        else:
            raise ValueError

        with THREAD_LOCKER:
            self.rng = get_rng(self)
        self.viz = viz

        print("viz {} gif {} video {}".format(self.viz, self.saveGif, self.saveVideo))


        # get action space and minimal action set
        self.action_space = spaces.Discrete(6)  # change number actions here
        self.actions = self.action_space.n
        self.observation_space = spaces.Box(low=-1., high=1.,
                                            shape=self.observation_dims,
                                            dtype=np.uint8)
        # history buffer for storing last locations to check oscilations
        # TODO initialize _observation_bounds limits from input image coordinates
        # -1 to compensate for 0 indexing
        self._observation_bounds = ObservationBounds(0,
                                                     self.observation_dims[0] - 1,
                                                     0,
                                                     self.observation_dims[1] - 1,
                                                     0,
                                                     self.observation_dims[2] - 1)

        self.files = FilesListCubeNPY(directory, files_list)

        # self.files = filesListFetalUSLandmark(directory,files_list)
        # self.files = filesListCardioMRLandmark(directory,files_list)
        # prepare file sampler
        self.filepath = None
        self.file_sampler = self.files.sample_circular()  # returns generator
        # reset buffer, terminal, counters, and init new_random_game
        # we put this here so that init_player in DQN.py doesn't try to update_history
        self._clear_history()  # init arrays
        self._restart_episode()
        assert (np.shape(self._state) == self.observation_dims)
        # test jaccard
        assert np.isclose(jaccard(self.original_state, self.original_state)[0], 1)