class Unity3DPlayer(RLEnvironment): ACTION_TABLE = [(0.5, 0.0), # Forward (-0.5, 0.0), # Backward (0.5, 1.0), # Forward-Right (-0.5, 1.0), # Backward-Right (0.5, -1.0), # Forward-Left (-0.5, -1.0) ] # Backward-Left def __init__(self, connection, skip=1, dumpdir=None, viz=False, auto_restart=True): if connection != None: with _ENV_LOCK: self.gymenv = Unity3DEnvironment(server_address=connection) self.use_dir = dumpdir self.skip = skip self.reset_stat() self.rwd_counter = StatCounter() self.restart_episode() self.auto_restart = auto_restart self.viz = viz def restart_episode(self): self.rwd_counter.reset() self._ob = self.gymenv.reset() def finish_episode(self): self.stats['score'].append(self.rwd_counter.sum) def current_state(self): if self.viz: self.gymenv.render() time.sleep(self.viz) return self._ob def action(self, act): env_act = self.ACTION_TABLE[act] for i in range(self.skip): self._ob, r, isOver, info = self.gymenv.step(env_act) if r <= -1.0: isOver = True if isOver: break self.rwd_counter.feed(r) if isOver: self.finish_episode() if self.auto_restart: self.restart_episode() return r, isOver def get_action_space(self): return DiscreteActionSpace(len(self.ACTION_TABLE)) def close(self): self.gymenv.close()
class ExpReplay(DataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, predictor_io_names, player, state_shape, batch_size, memory_size, init_memory_size, init_exploration, update_frequency, history_len): """ Args: predictor_io_names (tuple of list of str): input/output names to predict Q value from state. player (gym.Env): the player. state_shape (tuple): h, w, c history_len (int): length of history frames to concat. Zero-filled initial frames. update_frequency (int): number of new transitions to add to memory after sampling a batch of transitions for training. """ assert len(state_shape) == 3, state_shape init_memory_size = int(init_memory_size) for k, v in locals().items(): if k != 'self': setattr(self, k, v) self.exploration = init_exploration self.num_actions = player.action_space.n logger.info("Number of Legal actions: {}".format(self.num_actions)) self.rng = get_rng(self) self._init_memory_flag = threading.Event() # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape, history_len) self._current_ob = self.player.reset() self._player_scores = StatCounter() self._current_game_score = StatCounter() def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: self._populate_exp() pbar.update() self._init_memory_flag.set() # quickly fill the memory for debug def _fake_init_memory(self): from copy import deepcopy with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < 5: self._populate_exp() pbar.update() while len(self.mem) < self.init_memory_size: self.mem.append(deepcopy(self.mem._hist[0])) pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" old_s = self._current_ob if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len): act = self.rng.choice(range(self.num_actions)) else: # build a history state history = self.mem.recent_state() history.append(old_s) history = np.concatenate(history, axis=-1) # H,W,HistxC history = np.expand_dims(history, axis=0) # assume batched network q_values = self.predictor(history)[0][0] # this is the bottleneck act = np.argmax(q_values) self._current_ob, reward, isOver, info = self.player.step(act) self._current_game_score.feed(reward) if isOver: if info['ale.lives'] == 0: # only record score when a whole game is over (not when an episode is over) self._player_scores.feed(self._current_game_score.sum) self._current_game_score.reset() self.player.reset() self.mem.append(Experience(old_s, act, reward, isOver)) def _debug_sample(self, sample): import cv2 def view_state(comb_state): state = comb_state[:, :, :-1] next_state = comb_state[:, :, 1:] r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1) r2 = np.concatenate([next_state[:, :, k] for k in range(self.history_len)], axis=1) r = np.concatenate([r, r2], axis=0) cv2.imshow("state", r) cv2.waitKey() print("Act: ", sample[2], " reward:", sample[1], " isOver: ", sample[3]) if sample[1] or sample[3]: view_state(sample[0]) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='uint8') reward = np.asarray([e[1] for e in batch_exp], dtype='float32') action = np.asarray([e[2] for e in batch_exp], dtype='int8') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') return [state, action, reward, isOver] # DataFlow method: def __iter__(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint( self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - self.history_len - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] yield self._process_batch(batch_exp) self._populate_job_queue.put(1) # Callback methods: def _setup_graph(self): self.predictor = self.trainer.get_predictor(*self.predictor_io_names) def _before_train(self): self._init_memory() self._simulator_th = self.get_simulator_thread() self._simulator_th.start() def _trigger(self): v = self._player_scores try: mean, max = v.average, v.max self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception("Cannot log training scores.") v.reset()
class MedicalPlayer(gym.Env): """ Class that provides 3D medical image environment. This is just an implementation of the classic "agent-environment loop". Each time-step, the agent chooses an action, and the environment returns an observation and a reward. """ def __init__(self, directory=None, viz=False, task=False, files_list=None, screen_dims=(27, 27), history_length=20, multiscale=True, max_num_frames=0, saveGif=False, saveVideo=False): """ :param train_directory: environment or game name :param viz: visualization set to 0 to disable set to +ve number to be the delay between frames to show set to a string to be the directory for storing frames :param screen_dims: shape of the frame cropped from the image to feed it to dqn (d, w) - defaults (27, 27) :param nullop_start: start with random number of null ops :param location_history_length: consider lost of lives as end of episode (useful for training) :max_num_frames: maximum numbe0r of frames per episode. """ # ###################################################################### # ## generate evaluation results from 19 different points # ## save results in csv file # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv' # if not train: # with open(self.csvfile, 'w') as outcsv: # fields = ["filenames", "dist_error"] # writer = csv.writer(outcsv) # writer.writerow(map(lambda x: x, fields)) # # x = [0.5, 0.25, 0.75] # y = [0.5, 0.25, 0.75] # z = [0.5, 0.25, 0.75] # self.start_points = [] # for combination in itertools.product(x, y, z): # if 0.5 in combination: self.start_points.append(combination) # self.start_points = itertools.cycle(self.start_points) # self.count_points = 0 # self.total_loc = [] # ###################################################################### super(MedicalPlayer, self).__init__() # inits stat counters self.reset_stat() # counter to limit number of steps per episodes self.cnt = 0 # maximum number of frames (steps) per episodes self.max_num_frames = max_num_frames # stores information: terminal, score, distError self.info = None # option to save display as gif self.saveGif = saveGif self.saveVideo = saveVideo # training flag self.task = task # image dimension (2D/3D) self.screen_dims = screen_dims self.dims = len(self.screen_dims) # multi-scale agent self.multiscale = multiscale # init env dimensions if self.dims == 2: self.width, self.height = screen_dims else: self.width, self.height, self.depth = screen_dims with _ALE_LOCK: self.rng = get_rng(self) # visualization setup if isinstance(viz, six.string_types): # check if viz is a string assert os.path.isdir(viz), viz viz = 0 if isinstance(viz, int): # check if viz is an int number viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): # check if viz is an float number self.viewer = None self.gif_buffer = [] # stat counter to store current score or accumulated reward self.current_episode_score = StatCounter() # get action space and minimal action set self.action_space = spaces.Discrete(4) # change number actions here self.actions = self.action_space.n # 都是需要配备observation space的 self.observation_space = spaces.Box(low=0, high=255, shape=self.screen_dims, dtype=np.uint8) # history buffer for storing last locations to check oscilations self._history_length = history_length self._loc_history = [(0,) * self.dims] * self._history_length self._qvalues_history = [(0,) * self.actions] * self._history_length # initialize rectangle limits from input image coordinates self.rectangle = Rectangle(0, 0, 0, 0) # add your data loader here # play is returnLandmarks = False and returnLandmarks = True if self.task == 'play': self.files = filesListLVCardiacMRLandmark(files_list, returnLandmarks=False) print("self.files:", self.files) else: self.files = filesListLVCardiacMRLandmark(files_list, returnLandmarks=True) # prepare file sampler self.filepath = None self.sampled_files = self.files.sample_circular() # reset buffer, terminal, counters, and init new_random_game self._restart_episode() def reset(self): # with _ALE_LOCK: self._restart_episode() return self._current_state() def _restart_episode(self): """ restart current episoide """ self.terminal = False self.reward = 0 self.cnt = 0 # counter to limit number of steps per episodes self.num_games.feed(1) self.current_episode_score.reset() # reset the stat counter self._loc_history = [(0,) * self.dims] * self._history_length # list of q-value lists self._qvalues_history = [(0,) * self.actions] * self._history_length self.new_random_game() def new_random_game(self): """ 1.load image, 2.set dimensions, 3.randomize start point, 4.init _screen, qvals, 5.calc distance to goal """ self.terminal = False self.viewer = None # ###################################################################### # ## generate evaluation results from 19 different points # if self.count_points ==0: # print('\n============== new game ===============\n') # # save results # if self.total_loc: # with open(self.csvfile, 'a') as outcsv: # fields= [self.filenames, self.cur_dist] # writer = csv.writer(outcsv) # writer.writerow(map(lambda x: x, fields)) # self.total_loc = [] # # sample a new image # self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files) # scale = next(self.start_points) # self.count_points +=1 # else: # self.count_points += 1 # logger.info('count_points {}'.format(self.count_points)) # scale = next(self.start_points) # # x = int(scale[0] * self._image.dims[0]) # y = int(scale[1] * self._image.dims[1]) # z = int(scale[2] * self._image.dims[2]) # logger.info('starting point {}-{}-{}'.format(x,y,z)) # ###################################################################### # # sample a new image self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files) self.filename = os.path.basename(self.filepath) # multiscale (e.g. start with 3 -> 2 -> 1) # scale can be thought of as sampling stride if self.multiscale: ## brain self.action_step = 9 self.xscale = 3 self.yscale = 3 ## cardiac # self.action_step = 6 # self.xscale = 2 # self.yscale = 2 else: self.action_step = 1 self.xscale = 1 self.yscale = 1 # image volume size self._image_dims = self._image.dims self._image.data = np.array(self._image) print("image_dim:", self._image_dims) ####################################################################### # select random starting point # add padding to avoid start right on the border of the image if (self.task == 'train'): skip_thickness = ((int)(self._image_dims[0] / 5), (int)(self._image_dims[1] / 5)) else: skip_thickness = (int(self._image_dims[0] / 4), int(self._image_dims[1] / 4)) x = self.rng.randint(0 + skip_thickness[0], self._image_dims[0] - skip_thickness[0]) y = self.rng.randint(0 + skip_thickness[1], self._image_dims[1] - skip_thickness[1]) ####################################################################### # 都是用的(x, y) self._location = (x, y) self._start_location = (x, y) self._qvalues = [0, ] * self.actions self._screen = self._current_state() if self.task == 'play': self.cur_dist = 0 else: self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) def calcDistance(self, points1, points2, spacing=(1, 1)): """ calculate the distance between two points in mm """ spacing = np.array(spacing) points1 = spacing * np.array(points1) points2 = spacing * np.array(points2) return np.linalg.norm(points1 - points2) # 这个还是基于gym的 --> 要不要将 gym 中搞清楚呢? def step(self, act, qvalues): """ The environment's step function returns exactly what we need. Args: act: Returns: observation (object): an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game. reward (float): amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward. done (boolean): whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.) info (dict): diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning. """ self._qvalues = qvalues current_loc = self._location self.terminal = False go_out = False # 这种还不是我们真正所说的 连续的 action # FORWARD Y+ --------------------------------------------------------- if (act == 1): next_location = (current_loc[0], round(current_loc[1] + self.action_step), current_loc[2]) if (next_location[1] >= self._image_dims[1]): # print(' trying to go out the image Y+ ',) next_location = current_loc go_out = True # RIGHT X+ ----------------------------------------------------------- if (act == 2): next_location = (round(current_loc[0] + self.action_step), current_loc[1], current_loc[2]) if next_location[0] >= self._image_dims[0]: # print(' trying to go out the image X+ ',) next_location = current_loc go_out = True # LEFT X- ----------------------------------------------------------- if act == 3: next_location = (round(current_loc[0] - self.action_step), current_loc[1], current_loc[2]) if next_location[0] <= 0: # print(' trying to go out the image X- ',) next_location = current_loc go_out = True # BACKWARD Y- --------------------------------------------------------- if act == 4: next_location = (current_loc[0], round(current_loc[1] - self.action_step), current_loc[2]) if next_location[1] <= 0: # print(' trying to go out the image Y- ',) next_location = current_loc go_out = True # --------------------------------------------------------------------- # --------------------------------------------------------------------- # punish -1 reward if the agent tries to go out if (self.task!='play'): if go_out: self.reward = -1 else: self.reward = self._calc_reward(current_loc, next_location) # update screen, reward ,location, terminal self._location = next_location self._screen = self._current_state() # terminate if the distance is less than 1 during trainig if (self.task == 'train'): if self.cur_dist <= 1: self.terminal = True self.num_success.feed(1) # terminate if maximum number of steps is reached self.cnt += 1 if self.cnt >= self.max_num_frames: self.terminal = True # update history buffer with new location and qvalues if (self.task != 'play'): self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) self._update_history() # check if agent oscillates -- 摆动 # 检测agent是否摆动 if self._oscillate: self._location = self.getBestLocation() self._screen = self._current_state() if (self.task != 'play'): self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) # multi-scale steps if self.multiscale: if self.xscale > 1: self.xscale -= 1 self.yscale -= 1 self.action_step = int(self.action_step / 3) self._clear_history() # terminate if scale is less than 1 else: self.terminal = True if self.cur_dist <= 1: self.num_success.feed(1) else: self.terminal = True if self.cur_dist <= 1: self.num_success.feed(1) # render screen if viz is on with _ALE_LOCK: if self.viz: if isinstance(self.viz, float): self.display() distance_error = self.cur_dist self.current_episode_score.feed(self.reward) info = {'score': self.current_episode_score.sum, 'gameOver': self.terminal, 'distError': distance_error, 'filenames': self.filename} # ####################################################################### # ## generate evaluation results from 19 different points # if self.terminal: # logger.info(info) # self.total_loc.append(self._location) # if not(self.count_points == 19): # self._restart_episode() # else: # mean_location = np.mean(self.total_loc,axis=0) # logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location)) # self.cur_dist = self.calcDistance(mean_location, # self._target_loc, # self.spacing) # logger.info('final distance error {} \n'.format(self.cur_dist)) # self.count_points = 0 # ####################################################################### return self._current_state(), self.reward, self.terminal, info def getBestLocation(self): ''' get best location with best qvalue from last for locations stored in history ''' last_qvalues_history = self._qvalues_history[-4:] last_loc_history = self._loc_history[-4:] best_qvalues = np.max(last_qvalues_history, axis=1) # best_idx = best_qvalues.argmax() best_idx = best_qvalues.argmin() best_location = last_loc_history[best_idx] return best_location def _clear_history(self): ''' clear history buffer with current state ''' self._loc_history = [(0,) * self.dims] * self._history_length self._qvalues_history = [(0,) * self.actions] * self._history_length def _update_history(self): ''' update history buffer with current state ''' # update location history self._loc_history[:-1] = self._loc_history[1:] self._loc_history[-1] = self._location # update q-value history self._qvalues_history[:-1] = self._qvalues_history[1:] self._qvalues_history[-1] = self._qvalues def _current_state(self): """ crop image data around current location to update what network sees. update rectangle :return: new state """ # initialize screen with zeros - all background screen = np.zeros((self.screen_dims)).astype('float32') # screen uses coordinate system relative to origin (0, 0, 0) screen_xmin, screen_ymin = 0, 0 screen_xmax, screen_ymax = self.screen_dims print("image_data:", self._image) # extract boundary locations using coordinate system relative to "global" image # width, height, depth in terms of screen coord system if self.xscale % 2: xmin = self._location[0] - int(self.width * self.xscale / 2) - 1 xmax = self._location[0] + int(self.width * self.xscale / 2) ymin = self._location[1] - int(self.height * self.yscale / 2) - 1 ymax = self._location[1] + int(self.height * self.yscale / 2) else: xmin = self._location[0] - round(self.width * self.xscale / 2) xmax = self._location[0] + round(self.width * self.xscale / 2) ymin = self._location[1] - round(self.height * self.yscale / 2) ymax = self._location[1] + round(self.height * self.yscale / 2) # check if they violate image boundary and fix it if xmin < 0: xmin = 0 screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale)) if ymin < 0: ymin = 0 screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale)) if xmax > self._image_dims[0]: xmax = self._image_dims[0] screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale)) if ymax>self._image_dims[1]: ymax = self._image_dims[1] screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale)) # crop image data to update what network sees # image coordinate system becomes screen coordinates # scale can be thought of as a stride screen[screen_xmin:screen_xmax, screen_ymin:screen_ymax] = self._image.data[ xmin:xmax:self.xscale, ymin:ymax:self.yscale] # update rectangle limits from input image coordinates # this is what the network sees self.rectangle = Rectangle(xmin, xmax, ymin, ymax) return screen def get_plane(self): return self._image.data[:, :] def _calc_reward(self, current_loc, next_loc): """ Calculate the new reward based on the decrease in euclidean distance to the target location """ curr_dist = self.calcDistance(current_loc, self._target_loc, self.spacing) next_dist = self.calcDistance(next_loc, self._target_loc, self.spacing) return curr_dist - next_dist @property def _oscillate(self): """ Return True if the agent is stuck and oscillating """ counter = Counter(self._loc_history) freq = counter.most_common() if freq[0][0] == (0, 0): if (freq[1][1] > 3): return True else: return False elif (freq[0][1] > 3): return True def get_action_meanings(self): """ return array of integers for actions""" ACTION_MEANING = { 1: "FORWARD", # MOVE Y+ 2: "RIGHT", # MOVE X+ 3: "LEFT", # MOVE X- 4: "BACKWARD", # MOVE Y- } return [ACTION_MEANING[i] for i in self.actions] @property def getScreenDims(self): """ return screen dimensions """ return (self.width, self.height) def lives(self): return None def reset_stat(self): """ Reset all statistics counter """ self.stats = defaultdict(list) self.num_games = StatCounter() self.num_success = StatCounter() def display(self, return_rgb_array=False): # pass # get dimensions current_point = self._location target_point = self._target_loc # get image and convert it to pyglet plane = self.get_plane(current_point[2]) # z-plane # plane = np.squeeze(self._current_state()[:,:,13]) # rescale image # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4 scale_x = 1 scale_y = 1 img = cv2.resize(plane, (int(scale_x*plane.shape[1]), int(scale_y*plane.shape[0])), interpolation=cv2.INTER_LINEAR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # congvert to rgb # skip if there is a viewer open if (not self.viewer) and self.viz: from viewer import SimpleImageViewer self.viewer = SimpleImageViewer(arr=img, scale_x=1, scale_y=1, filepath=self.filename) self.gif_buffer = [] # display image self.viewer.draw_image(img) # draw current point self.viewer.draw_circle(radius=scale_x * 1, pos_x=scale_x * current_point[0], pos_y=scale_y * current_point[1], color=(0.0, 0.0, 1.0, 1.0)) # draw a box around the agent - what the network sees ROI self.viewer.draw_rect(scale_x*self.rectangle.xmin, scale_y*self.rectangle.ymin, scale_x*self.rectangle.xmax, scale_y*self.rectangle.ymax) self.viewer.display_text('Agent ', color=(204, 204, 0, 255), x=self.rectangle.xmin - 15, y=self.rectangle.ymin) # display info text = 'Spacing ' + str(self.xscale) self.viewer.display_text(text, color=(204, 204, 0, 255), x=10, y=self._image_dims[1]-80) # --------------------------------------------------------------------- if (self.task != 'play'): # draw a transparent circle around target point with variable radius # based on the difference z-direction diff_z = scale_x * abs(current_point[2]-target_point[2]) self.viewer.draw_circle(radius=diff_z, pos_x=scale_x*target_point[0], pos_y=scale_y*target_point[1], color=(1.0, 0.0, 0.0, 0.2)) # draw target point self.viewer.draw_circle(radius=scale_x * 1, pos_x=scale_x*target_point[0], pos_y=scale_y*target_point[1], color=(1.0, 0.0, 0.0, 1.0)) # display info color = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255) text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm' self.viewer.display_text(text, color=color, x=10, y=20) # --------------------------------------------------------------------- # render and wait (viz) time between frames self.viewer.render() # time.sleep(self.viz) # save gif if self.saveGif: image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() data = image_data.get_data('RGB', image_data.width * 3) arr = np.array(bytearray(data)).astype('uint8') arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0) im = Image.fromarray(arr) self.gif_buffer.append(im) if not self.terminal: gifname = self.filename.split('.')[0] + '.gif' self.viewer.saveGif(gifname, arr=self.gif_buffer, duration=self.viz) if self.saveVideo: dirname = 'tmp_video' if self.cnt <= 1: if os.path.isdir(dirname): logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname)) act = input("select action: d (delete) / q (quit): ").lower().strip() if act == 'd': shutil.rmtree(dirname, ignore_errors=True) else: raise OSError("Directory {} exits!".format(dirname)) os.mkdir(dirname) frame = dirname + '/' + '%04d' % self.cnt + '.png' pyglet.image.get_buffer_manager().get_color_buffer().save(frame) if self.terminal: resolution = str(3 * self.viewer.img_width) + 'x' + str(3 * self.viewer.img_height) save_cmd = ['ffmpeg', '-f', 'image2', '-framerate', '30', '-pattern_type', 'sequence', '-start_number', '0', '-r', '6', '-i', dirname + '/%04d.png', '-s', resolution, '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4'] subprocess.check_output(save_cmd) shutil.rmtree(dirname, ignore_errors=True)
class ExpReplay(DataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, predictor_io_names, player, state_shape, num_actions, batch_size, memory_size, init_memory_size, init_exploration, update_frequency, encoding_file='../AutoEncoder/encoding.npy'): """ Args: predictor_io_names (tuple of list of str): input/output names to predict Q value from state. player (RLEnvironment): the player. history_len (int): length of history frames to concat. Zero-filled initial frames. update_frequency (int): number of new transitions to add to memory after sampling a batch of transitions for training. """ init_memory_size = int(init_memory_size) for k, v in locals().items(): if k != 'self': setattr(self, k, v) self.exploration = init_exploration self.num_actions = num_actions self.encoding = np.load(encoding_file) logger.info("Number of Legal actions: {}".format(self.num_actions)) self.rng = get_rng(self) self._init_memory_flag = threading.Event( ) # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape) self.player.reset() # init_cards = np.arange(36) # self.player.prepare_manual(init_cards) self.player.prepare() # self._current_ob = self.player.get_state_prob() self._current_ob = self.get_state() self._player_scores = StatCounter() self._current_game_score = StatCounter() def get_state(self): def cards_char2embedding(cards_char): test = (action_space_onehot60 == Card.char2onehot60(cards_char)) test = np.all(test, axis=1) target = np.where(test)[0] return self.encoding[target[0]] s = self.player.get_state_prob() s = np.concatenate( [Card.val2onehot60(self.player.get_curr_handcards()), s]) last_two_cards_char = self.player.get_last_two_cards() last_two_cards_char = [to_char(c) for c in last_two_cards_char] return np.concatenate([ s, cards_char2embedding(last_two_cards_char[0]), cards_char2embedding(last_two_cards_char[1]) ]) def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("Populating replay memory with epsilon={} ...".format( self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: self._populate_exp() pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" old_s = self._current_ob if self.rng.rand() <= self.exploration: act = self.rng.choice(range(self.num_actions)) else: mask = get_mask(to_char(self.player.get_curr_handcards()), action_space, to_char(self.player.get_last_outcards())) q_values = self.predictor(old_s[None, ...])[0][0] q_values[mask == 0] = np.nan act = np.nanargmax(q_values) assert act < self.num_actions reward, isOver, _ = self.player.step_manual(to_value( action_space[act])) # step for AI while not isOver and self.player.get_role_ID() != ROLE_ID_TO_TRAIN: _, reward, _ = self.player.step_auto() isOver = (reward != 0) if ROLE_ID_TO_TRAIN == 2: reward = -reward self._current_game_score.feed(reward) if isOver: # print('lord wins' if reward > 0 else 'farmer wins') self._player_scores.feed(self._current_game_score.sum) # print(self._current_game_score.sum) while True: self.player.reset() # init_cards = np.arange(36) # self.player.prepare_manual(init_cards) self.player.prepare() early_stop = False while self.player.get_role_ID() != ROLE_ID_TO_TRAIN: _, reward, _ = self.player.step_auto() isOver = (reward != 0) if isOver: print('prestart ends too early! now resetting env') early_stop = True break if early_stop: continue self._current_ob = self.get_state() break self._current_game_score.reset() self._current_ob = self.get_state() self.mem.append(Experience(old_s, act, reward, isOver)) def debug(self, cnt=100000): with get_tqdm(total=cnt) as pbar: for i in range(cnt): self.mem.append( Experience( np.zeros( [self.num_actions[0], self.num_actions[1], 256]), 0, 0)) # self._current_ob, self._action_space = self.get_state_and_action_spaces(None) pbar.update() def get_data(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint(self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] yield self._process_batch(batch_exp) self._populate_job_queue.put(1) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='float32') action = np.asarray([e[1] for e in batch_exp], dtype='int32') reward = np.asarray([e[2] for e in batch_exp], dtype='float32') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') return [state, action, reward, isOver] def _setup_graph(self): self.predictor = self.trainer.get_predictor(*self.predictor_io_names) def _before_train(self): while self.player.get_role_ID() != ROLE_ID_TO_TRAIN: self.player.step_auto() self._current_ob, self._action_space = self.get_state_and_action_spaces( ) self._init_memory() self._simulator_th = self.get_simulator_thread() self._simulator_th.start() def _trigger(self): v = self._player_scores try: mean, max = v.average, v.max self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception("Cannot log training scores.") v.reset()
class MedicalPlayer(gym.Env): """Class that provides 3D medical image environment. This is just an implementation of the classic "agent-environment loop". Each time-step, the agent chooses an action, and the environment returns an observation and a reward.""" def __init__(self, directory=None, viz=False, task=False, files_list=None, screen_dims=(27, 27, 27), history_length=20, multiscale=True, max_num_frames=0, saveGif=False, saveVideo=False, data_type=None): """ :param train_directory: environment or game name :param viz: visualization set to 0 to disable set to +ve number to be the delay between frames to show set to a string to be the directory for storing frames :param screen_dims: shape of the frame cropped from the image to feed it to dqn (d,w,h) - defaults (27,27,27) :param nullop_start: start with random number of null ops :param location_history_length: consider lost of lives as end of episode (useful for training) :max_num_frames: maximum numbe0r of frames per episode. """ # ###################################################################### # ## generate evaluation results from 19 different points # ## save results in csv file # self.csvfile = 'DuelDoubleDQN_multiscale_brain_mri_point_pc_ROI_45_45_45_midl2018.csv' # if not train: # with open(self.csvfile, 'w') as outcsv: # fields = ["filename", "dist_error"] # writer = csv.writer(outcsv) # writer.writerow(map(lambda x: x, fields)) # # x = [0.5,0.25,0.75] # y = [0.5,0.25,0.75] # z = [0.5,0.25,0.75] # self.start_points = [] # for combination in itertools.product(x, y, z): # if 0.5 in combination: self.start_points.append(combination) # self.start_points = itertools.cycle(self.start_points) # self.count_points = 0 # self.total_loc = [] # ###################################################################### super(MedicalPlayer, self).__init__() # inits stat counters self.reset_stat() # counter to limit number of steps per episodes self.cnt = 0 # maximum number of frames (steps) per episodes self.max_num_frames = max_num_frames # stores information: terminal, score, distError self.info = None # option to save display as gif self.saveGif = saveGif self.saveVideo = saveVideo # training flag self.task = task # image dimension (2D/3D) self.screen_dims = screen_dims self.dims = len(self.screen_dims) # multi-scale agent self.multiscale = multiscale #Type of data self.data_type = data_type #directory is file for logging evaluation self.directory = directory # init env dimensions if self.dims == 2: self.width, self.height = screen_dims else: self.width, self.height, self.depth = screen_dims with _ALE_LOCK: self.rng = get_rng(self) # visualization setup if isinstance(viz, six.string_types): # check if viz is a string assert os.path.isdir(viz), viz viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.viewer = None self.gif_buffer = [] # stat counter to store current score or accumlated reward self.current_episode_score = StatCounter() # get action space and minimal action set self.action_space = spaces.Discrete(6) # change number actions here self.actions = self.action_space.n self.observation_space = spaces.Box(low=0, high=255, shape=self.screen_dims, dtype=np.uint8) # history buffer for storing last locations to check oscilations self._history_length = history_length # initialize rectangle limits from input image coordinates self.rectangle = Rectangle(0, 0, 0, 0, 0, 0) # add your data loader here self.set_dataLoader(files_list) # prepare file sampler self.filepath = None self.HITL_logger = [] self._loc_history = None # reset buffer, terminal, counters, and init new_random_game self._restart_episode() def set_dataLoader(self, files_list): if self.data_type == 'BrainMRI': self.data_loader = filesListBrainMRLandmark elif self.data_type == 'CardiacMRI': self.data_loader = filesListCardioLandmark elif self.data_type == 'FetalUS': self.data_loader = filesListFetalUSLandmark elif self.data_type == "HITL": self.data_loader = fileHITL if self.task == 'play': self.files = self.data_loader(files_list, returnLandmarks=False) else: self.files = self.data_loader(files_list, returnLandmarks=True) self.sampled_files = self.files.sample_circular() def HITL_episode_log(self): """ Method to save episode info for HITL """ log = { 'states': self._loc_history, 'rewards': self._reward_history, 'actions': self._act_history, 'target': self._target_loc, 'img_name': self.filename, 'is_over': [False for i in range(len(self._loc_history) - 1)] + [True], 'resolution': self._res_history, } self.HITL_logger.append(log) def HITL_set_location(self, location, res): """ Method to set the location in the image to that specified in the logs """ self._location = location self.xscale = res self.yscale = res self.zscale = res def reset(self): # with _ALE_LOCK: self._restart_episode() return self._current_state() def _restart_episode(self): """ restart current episoide """ if self.task == 'browse' and self._loc_history: self.HITL_episode_log() self.terminal = False self.reward = 0 self.cnt = 0 # counter to limit number of steps per episodes self.num_games.feed(1) self.current_episode_score.reset() # reset the stat counter self._loc_history = [(0, ) * self.dims] * self._history_length # list of q-value lists self._qvalues_history = [(0, ) * self.actions] * self._history_length self._clear_history() self.new_random_game() def new_random_game(self): """ load image, set dimensions, randomize start point, init _screen, qvals, calc distance to goal """ self.terminal = False self.viewer = None # ###################################################################### # ## generate evaluation results from 19 different points # if self.count_points ==0: # print('\n============== new game ===============\n') # # save results # if self.total_loc: # with open(self.csvfile, 'a') as outcsv: # fields= [self.filename, self.cur_dist] # writer = csv.writer(outcsv) # writer.writerow(map(lambda x: x, fields)) # self.total_loc = [] # # sample a new image # self._image, self._target_loc, self.filepath, self.spacing = next(self.sampled_files) # scale = next(self.start_points) # self.count_points +=1 # else: # self.count_points += 1 # logger.info('count_points {}'.format(self.count_points)) # scale = next(self.start_points) # # x = int(scale[0] * self._image.dims[0]) # y = int(scale[1] * self._image.dims[1]) # z = int(scale[2] * self._image.dims[2]) # logger.info('starting point {}-{}-{}'.format(x,y,z)) # ###################################################################### # sample a new image self._image, self._target_loc, self.filepath, self.spacing = next( self.sampled_files) self.filename = os.path.basename(self.filepath) # multiscale (e.g. start with 3 -> 2 -> 1) # scale can be thought of as sampling stride if self.multiscale: # #cardiac # if self.data_type == 'CardiacMRI': # self.action_step = 6 # self.xscale = 2 # self.yscale = 2 # self.zscale = 2 # #brain or fetal # else: # self.action_step = 9 # self.xscale = 3 # self.yscale = 3 # self.zscale = 3 self.action_step = 9 self.xscale = 3 self.yscale = 3 self.zscale = 3 else: self.action_step = 1 self.xscale = 1 self.yscale = 1 self.zscale = 1 # image volume size self._image_dims = self._image.dims ####################################################################### ## select random starting point # add padding to avoid start right on the border of the image if (self.task == 'train'): skip_thickness = ((int)(self._image_dims[0] / 5), (int)(self._image_dims[1] / 5), (int)(self._image_dims[2] / 5)) else: skip_thickness = (int(self._image_dims[0] / 4), int(self._image_dims[1] / 4), int(self._image_dims[2] / 4)) x = self.rng.randint(0 + skip_thickness[0], self._image_dims[0] - skip_thickness[0]) y = self.rng.randint(0 + skip_thickness[1], self._image_dims[1] - skip_thickness[1]) z = self.rng.randint(0 + skip_thickness[2], self._image_dims[2] - skip_thickness[2]) ####################################################################### self._location = (x, y, z) self._start_location = (x, y, z) self._qvalues = [ 0, ] * self.actions self._screen = self._current_state() if self.task == 'play': self.cur_dist = 0 else: self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) def calcDistance(self, points1, points2, spacing=(1, 1, 1)): """ calculate the distance between two points in mm""" spacing = np.array(spacing) points1 = spacing * np.array(points1) points2 = spacing * np.array(points2) return np.linalg.norm(points1 - points2) def step(self, act, qvalues, viewer=None): """The environment's step function returns exactly what we need. Args: act: Returns: observation (object): an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game. reward (float): amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward. done (boolean): whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.) info (dict): diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning. """ self._qvalues = qvalues current_loc = self._location self.terminal = False go_out = False self.viewer = viewer # UP Z+ ----------------------------------------------------------- if (act == 0): next_location = (current_loc[0], current_loc[1], round(current_loc[2] + self.action_step)) if (next_location[2] >= self._image_dims[2]): # print(' trying to go out the image Z+ ',) next_location = current_loc go_out = True # FORWARD Y+ --------------------------------------------------------- if (act == 1): next_location = (current_loc[0], round(current_loc[1] + self.action_step), current_loc[2]) if (next_location[1] >= self._image_dims[1]): # print(' trying to go out the image Y+ ',) next_location = current_loc go_out = True # RIGHT X+ ----------------------------------------------------------- if (act == 2): next_location = (round(current_loc[0] + self.action_step), current_loc[1], current_loc[2]) if next_location[0] >= self._image_dims[0]: # print(' trying to go out the image X+ ',) next_location = current_loc go_out = True # LEFT X- ----------------------------------------------------------- if act == 3: next_location = (round(current_loc[0] - self.action_step), current_loc[1], current_loc[2]) if next_location[0] <= 0: # print(' trying to go out the image X- ',) next_location = current_loc go_out = True # BACKWARD Y- --------------------------------------------------------- if act == 4: next_location = (current_loc[0], round(current_loc[1] - self.action_step), current_loc[2]) if next_location[1] <= 0: # print(' trying to go out the image Y- ',) next_location = current_loc go_out = True # DOWN Z- ----------------------------------------------------------- if act == 5: next_location = (current_loc[0], current_loc[1], round(current_loc[2] - self.action_step)) if next_location[2] <= 0: # print(' trying to go out the image Z- ',) next_location = current_loc go_out = True # --------------------------------------------------------------------- # --------------------------------------------------------------------- # punish -1 reward if the agent tries to go out if (self.task != 'play'): if go_out: self.reward = -1 else: self.reward = self._calc_reward(current_loc, next_location) # update screen, reward ,location, terminal self._location = next_location self._screen = self._current_state() # terminate if the distance is less than 1 during trainig if (self.task == 'train'): if self.cur_dist <= 1: # print('Terminal Condition DISTANCE') self.terminal = True self.num_success.feed(1) # terminate if maximum number of steps is reached self.cnt += 1 if self.cnt >= self.max_num_frames: # print('Terminal Condition NUMBER OF FRAMES') self.terminal = True # update history buffer with new location and qvalues if (self.task != 'play'): self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) self._update_history() # check if agent oscillates if self._oscillate: self._location = self.getBestLocation() self._screen = self._current_state() if self.task != 'play': self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) # multi-scale steps if self.multiscale: if self.xscale > 1: self.adjustMultiScale() # terminate if scale is less than 1 else: self.terminal = True # print("TERMINAL OCCILATE") if self.cur_dist <= 1: self.num_success.feed(1) else: self.terminal = True # print("TERMINAL OCCILATE") if self.cur_dist <= 1: self.num_success.feed(1) # render screen if viz is on with _ALE_LOCK: if self.viz: if isinstance(self.viz, float): self.display() distance_error = self.cur_dist self.current_episode_score.feed(self.reward) # print(self.reward) this is every step of the agent info = { 'score': self.current_episode_score.sum, 'gameOver': self.terminal, 'distError': distance_error, 'filename': self.filename } if self.terminal: # store results when batch evaluation if self.directory: path = self.directory with open(path, 'a') as outcsv: fields = [ info['filename'], info['score'], info['distError'] ] writer = csv.writer(outcsv) writer.writerow(map(lambda x: x, fields)) # ####################################################################### # ## generate evaluation results from 19 different points # if self.terminal: # logger.info(info) # self.total_loc.append(self._location) # if not(self.count_points == 19): # self._restart_episode() # else: # mean_location = np.mean(self.total_loc,axis=0) # logger.info('total_loc {} \n mean_location{}'.format(self.total_loc, mean_location)) # self.cur_dist = self.calcDistance(mean_location, # self._target_loc, # self.spacing) # logger.info('final distance error {} \n'.format(self.cur_dist)) # self.count_points = 0 # ####################################################################### return self._current_state(), self.reward, self.terminal, info def stepManual(self, act, viewer): """ Version of above for browse mode allowing the user to navigate through an uploaded img """ # self._qvalues = qvalues current_loc = self._location self.terminal = False go_out = False self.viewer = viewer self._act = act # -1 passed during init so skip updating current location if act == -1: pass else: # UP Z+ ----------------------------------------------------------- if (act == 0): next_location = (current_loc[0], current_loc[1], round(current_loc[2] + self.action_step)) if (next_location[2] >= self._image_dims[2]): # print(' trying to go out the image Z+ ',) next_location = current_loc go_out = True # FORWARD Y+ --------------------------------------------------------- if (act == 1): next_location = (current_loc[0], round(current_loc[1] + self.action_step), current_loc[2]) if (next_location[1] >= self._image_dims[1]): # print(' trying to go out the image Y+ ',) next_location = current_loc go_out = True # RIGHT X+ ----------------------------------------------------------- if (act == 2): next_location = (round(current_loc[0] + self.action_step), current_loc[1], current_loc[2]) if next_location[0] >= self._image_dims[0]: # print(' trying to go out the image X+ ',) next_location = current_loc go_out = True # LEFT X- ----------------------------------------------------------- if act == 3: next_location = (round(current_loc[0] - self.action_step), current_loc[1], current_loc[2]) if next_location[0] <= 0: # print(' trying to go out the image X- ',) next_location = current_loc go_out = True # BACKWARD Y- --------------------------------------------------------- if act == 4: next_location = (current_loc[0], round(current_loc[1] - self.action_step), current_loc[2]) if next_location[1] <= 0: # print(' trying to go out the image Y- ',) next_location = current_loc go_out = True # DOWN Z- ----------------------------------------------------------- if act == 5: next_location = (current_loc[0], current_loc[1], round(current_loc[2] - self.action_step)) if next_location[2] <= 0: # print(' trying to go out the image Z- ',) next_location = current_loc go_out = True if go_out: self.reward = -1 else: self.reward = self._calc_reward(current_loc, next_location) self._location = next_location self._screen = self._current_state() self.cur_dist = self.calcDistance(self._location, self._target_loc, self.spacing) self._update_history() # render screen if viz is on with _ALE_LOCK: if self.viz: if isinstance(self.viz, float): self.display() return self._current_state() def getBestLocation(self): ''' get best location with best qvalue from last for locations stored in history ''' last_qvalues_history = self._qvalues_history[-4:] last_loc_history = self._loc_history[-4:] best_qvalues = np.max(last_qvalues_history, axis=1) # best_idx = best_qvalues.argmax() best_idx = best_qvalues.argmin() best_location = last_loc_history[best_idx] return best_location def adjustMultiScale(self, higherRes=True): '''Adjusts the agent's step size''' if higherRes: self.xscale -= 1 self.yscale -= 1 self.zscale -= 1 self.action_step = int(self.action_step / 3) else: self.xscale += 1 self.yscale += 1 self.zscale += 1 self.action_step = int(self.action_step * 3) self._clear_history() def _clear_history(self): ''' clear history buffer with current state ''' if self.task == 'browse': self._loc_history = [] self._act_history = [] self._reward_history = [] self._res_history = [] else: self._loc_history = [(0, ) * self.dims] * self._history_length self._qvalues_history = [(0, ) * self.actions ] * self._history_length def _update_history(self): ''' update history buffer with current state ''' if self.task == 'browse': self._loc_history.append(self._location) self._act_history.append(self._act) self._res_history.append(self.xscale) self._reward_history.append(self.reward) else: # update location history self._loc_history[:-1] = self._loc_history[1:] self._loc_history[-1] = self._location # update q-value history self._qvalues_history[:-1] = self._qvalues_history[1:] self._qvalues_history[-1] = self._qvalues def _current_state(self): """ crop image data around current location to update what network sees. update rectangle :return: new state """ # initialize screen with zeros - all background screen = np.zeros((self.screen_dims)).astype(self._image.data.dtype) # screen uses coordinate system relative to origin (0, 0, 0) screen_xmin, screen_ymin, screen_zmin = 0, 0, 0 screen_xmax, screen_ymax, screen_zmax = self.screen_dims # extract boundary locations using coordinate system relative to "global" image # width, height, depth in terms of screen coord system if self.xscale % 2: xmin = self._location[0] - int(self.width * self.xscale / 2) - 1 xmax = self._location[0] + int(self.width * self.xscale / 2) ymin = self._location[1] - int(self.height * self.yscale / 2) - 1 ymax = self._location[1] + int(self.height * self.yscale / 2) zmin = self._location[2] - int(self.depth * self.zscale / 2) - 1 zmax = self._location[2] + int(self.depth * self.zscale / 2) else: xmin = self._location[0] - round(self.width * self.xscale / 2) xmax = self._location[0] + round(self.width * self.xscale / 2) ymin = self._location[1] - round(self.height * self.yscale / 2) ymax = self._location[1] + round(self.height * self.yscale / 2) zmin = self._location[2] - round(self.depth * self.zscale / 2) zmax = self._location[2] + round(self.depth * self.zscale / 2) # check if they violate image boundary and fix it if xmin < 0: xmin = 0 screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale)) if ymin < 0: ymin = 0 screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale)) if zmin < 0: zmin = 0 screen_zmin = screen_zmax - len(np.arange(zmin, zmax, self.zscale)) if xmax > self._image_dims[0]: xmax = self._image_dims[0] screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale)) if ymax > self._image_dims[1]: ymax = self._image_dims[1] screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale)) if zmax > self._image_dims[2]: zmax = self._image_dims[2] screen_zmax = screen_zmin + len(np.arange(zmin, zmax, self.zscale)) # crop image data to update what network sees # image coordinate system becomes screen coordinates # scale can be thought of as a stride screen[screen_xmin:screen_xmax, screen_ymin:screen_ymax, screen_zmin:screen_zmax] = self._image.data[ xmin:xmax:self.xscale, ymin:ymax:self.yscale, zmin:zmax:self.zscale] # update rectangle limits from input image coordinates # this is what the network sees self.rectangle = Rectangle(xmin, xmax, ymin, ymax, zmin, zmax) return screen def get_plane_z(self, z=0): im = self._image.data[:, :, z] if self.data_type in ['BrainMRI', 'CardiacMRI']: im = np.rot90(im, 1) # Rotate 90 degrees ccw return im def get_plane_x(self, x=0): im = self._image.data[x, :, :] im = np.rot90(im, 1) return im def get_plane_y(self, y=0): im = self._image.data[:, y, :] im = np.rot90(im, 1) return im def _calc_reward(self, current_loc, next_loc): """ Calculate the new reward based on the decrease in euclidean distance to the target location """ curr_dist = self.calcDistance(current_loc, self._target_loc, self.spacing) next_dist = self.calcDistance(next_loc, self._target_loc, self.spacing) return curr_dist - next_dist @property def _oscillate(self): """ Return True if the agent is stuck and oscillating """ counter = Counter(self._loc_history) freq = counter.most_common() if freq[0][0] == (0, 0, 0): if (freq[1][1] > 3): return True else: return False elif (freq[0][1] > 3): return True def get_action_meanings(self): """ return array of integers for actions""" ACTION_MEANING = { 1: "UP", # MOVE Z+ 2: "FORWARD", # MOVE Y+ 3: "RIGHT", # MOVE X+ 4: "LEFT", # MOVE X- 5: "BACKWARD", # MOVE Y- 6: "DOWN", # MOVE Z- } return [ACTION_MEANING[i] for i in self.actions] @property def getScreenDims(self): """ return screen dimensions """ return (self.width, self.height, self.depth) def lives(self): return None def reset_stat(self): """ Reset all statistics counter""" self.stats = defaultdict(list) self.num_games = StatCounter() self.num_success = StatCounter() def display(self, return_rgb_array=False): # get dimensions current_point = self._location target_point = self._target_loc # get image and convert it to pyglet plane = self.get_plane_z(current_point[2]) plane_x = self.get_plane_x(current_point[0]) plane_y = self.get_plane_y(current_point[1]) # rescale image # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4 scale_x = 2 scale_y = 2 scale_z = 2 current_point = (current_point[0] * scale_x, current_point[1] * scale_y, current_point[2] * scale_z) if target_point is not None: target_point = (target_point[0] * scale_x, target_point[1] * scale_y, target_point[2] * scale_z) self.rectangle = (self.rectangle[0] * scale_x, self.rectangle[1] * scale_x, self.rectangle[2] * scale_y, self.rectangle[3] * scale_y, self.rectangle[4] * scale_z, self.rectangle[5] * scale_z) img = cv2.resize( plane, (int(scale_x * plane.shape[1]), int(scale_y * plane.shape[0])), interpolation=cv2.INTER_LINEAR) img_x = cv2.resize( plane_x, (int(scale_x * plane_x.shape[1]), int(scale_y * plane_x.shape[0])), interpolation=cv2.INTER_LINEAR) img_y = cv2.resize( plane_y, (int(scale_y * plane_y.shape[1]), int(scale_y * plane_y.shape[0])), interpolation=cv2.INTER_LINEAR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # congvert to rgb img_x = cv2.cvtColor(img_x, cv2.COLOR_GRAY2RGB) # congvert to rgb img_y = cv2.cvtColor(img_y, cv2.COLOR_GRAY2RGB) # congvert to rgb ######################################################################## # PyQt GUI Code Section # Section of code to get initial value to be stored in a pickle object # (Uncomment if you wish to modify default_data.pickle) # viewer_param = { # "arrs": (img, img_x, img_y), # "filepath": self.filename # } # with open("default_data.pickle", "wb") as f: # viewer_param = pickle.dump(viewer_param, f) # exit() # Sleep until resume (for browse mode) if self.task != 'browse': while self.viewer.right_widget.automatic_mode.thread.pause: time.sleep(0.5) # Check whether thread should be killed (pause) if self.viewer.right_widget.automatic_mode.thread.terminate: exit() # Check whether thread should be killed (general) if self.viewer.right_widget.automatic_mode.thread.terminate: exit() # Need to emit signal here (to draw images) self.viewer.widget.agent_signal.emit({ "arrs": (img, img_x, img_y), "agent_loc": current_point, "target": target_point, "error": self.cur_dist, "scale": self.xscale, "rect": self.rectangle, "task": self.task, "is_terminal": self.terminal, "cnt": self.cnt }) if self.task != 'browse': # Control agent speed if self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.FAST: time.sleep(0) elif self.viewer.right_widget.automatic_mode.thread.speed == WorkerThread.MEDIUM: time.sleep(0.5) else: time.sleep(1.5) ######################################################################## # save gif if self.saveGif: image_data = pyglet.image.get_buffer_manager( ).get_color_buffer().get_image_data() data = image_data.get_data('RGB', image_data.width * 3) arr = np.array(bytearray(data)).astype('uint8') arr = np.flip( np.reshape(arr, (image_data.height, image_data.width, -1)), 0) im = Image.fromarray(arr) self.gif_buffer.append(im) if not self.terminal: gifname = self.filename.split('.')[0] + '.gif' self.viewer.saveGif(gifname, arr=self.gif_buffer, duration=self.viz) if self.saveVideo: dirname = 'tmp_video' if self.cnt <= 1: if os.path.isdir(dirname): logger.warn( """Log directory {} exists! Use 'd' to delete it. """ .format(dirname)) act = input("select action: d (delete) / q (quit): " ).lower().strip() if act == 'd': shutil.rmtree(dirname, ignore_errors=True) else: raise OSError( "Directory {} exits!".format(dirname)) os.mkdir(dirname) frame = dirname + '/' + '%04d' % self.cnt + '.png' pyglet.image.get_buffer_manager().get_color_buffer().save( frame) if self.terminal: resolution = str(3 * self.viewer.img_width) + 'x' + str( 3 * self.viewer.img_height) save_cmd = [ 'ffmpeg', '-f', 'image2', '-framerate', '30', '-pattern_type', 'sequence', '-start_number', '0', '-r', '6', '-i', dirname + '/%04d.png', '-s', resolution, '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4' ] subprocess.check_output(save_cmd) shutil.rmtree(dirname, ignore_errors=True)
class ExpReplay(DataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, # model, agent_name, player, state_shape, num_actions, batch_size, memory_size, init_memory_size, init_exploration, update_frequency, encoding_file='../AutoEncoder/encoding.npy'): init_memory_size = int(init_memory_size) # self.model = model for k, v in locals().items(): if k != 'self': setattr(self, k, v) self.agent_name = agent_name self.exploration = init_exploration self.num_actions = num_actions self.encoding = np.load(encoding_file) logger.info("Number of Legal actions: {}, {}".format(*self.num_actions)) self.rng = get_rng(self) self._init_memory_flag = threading.Event() # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape) self.player.reset() self.player.prepare() self._comb_mask = True self._fine_mask = None self._current_ob, self._action_space = self.get_state_and_action_spaces() self._player_scores = StatCounter() self._current_game_score = StatCounter() def get_combinations(self, curr_cards_char, last_cards_char): if len(curr_cards_char) > 10: card_mask = Card.char2onehot60(curr_cards_char).astype(np.uint8) mask = augment_action_space_onehot60 a = np.expand_dims(1 - card_mask, 0) * mask invalid_row_idx = set(np.where(a > 0)[0]) if len(last_cards_char) == 0: invalid_row_idx.add(0) valid_row_idx = [i for i in range(len(augment_action_space)) if i not in invalid_row_idx] mask = mask[valid_row_idx, :] idx_mapping = dict(zip(range(mask.shape[0]), valid_row_idx)) # augment mask # TODO: known issue: 555444666 will not decompose into 5554 and 66644 combs = get_combinations_nosplit(mask, card_mask) combs = [([] if len(last_cards_char) == 0 else [0]) + [clamp_action_idx(idx_mapping[idx]) for idx in comb] for comb in combs] if len(last_cards_char) > 0: idx_must_be_contained = set( [idx for idx in valid_row_idx if CardGroup.to_cardgroup(augment_action_space[idx]). \ bigger_than(CardGroup.to_cardgroup(last_cards_char))]) combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] self._fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) for i in range(len(combs)): for j in range(len(combs[i])): if combs[i][j] in idx_must_be_contained: self._fine_mask[i][j] = True else: self._fine_mask = None else: mask = get_mask_onehot60(curr_cards_char, action_space, None).reshape(len(action_space), 15, 4).sum(-1).astype( np.uint8) valid = mask.sum(-1) > 0 cards_target = Card.char2onehot60(curr_cards_char).reshape(-1, 4).sum(-1).astype(np.uint8) combs = get_combinations_recursive(mask[valid, :], cards_target) idx_mapping = dict(zip(range(valid.shape[0]), np.where(valid)[0])) combs = [([] if len(last_cards_char) == 0 else [0]) + [idx_mapping[idx] for idx in comb] for comb in combs] if len(last_cards_char) > 0: valid[0] = True idx_must_be_contained = set( [idx for idx in range(len(action_space)) if valid[idx] and CardGroup.to_cardgroup(action_space[idx]). \ bigger_than(CardGroup.to_cardgroup(last_cards_char))]) combs = [comb for comb in combs if not idx_must_be_contained.isdisjoint(comb)] self._fine_mask = np.zeros([len(combs), self.num_actions[1]], dtype=np.bool) for i in range(len(combs)): for j in range(len(combs[i])): if combs[i][j] in idx_must_be_contained: self._fine_mask[i][j] = True else: self._fine_mask = None return combs def subsample_combs_masks(self, combs, masks, num_sample): if masks is not None: assert len(combs) == masks.shape[0] idx = np.random.permutation(len(combs))[:num_sample] return [combs[i] for i in idx], (masks[idx] if masks is not None else None) def get_state_and_action_spaces(self, action=None): def cards_char2embedding(cards_char): test = (action_space_onehot60 == Card.char2onehot60(cards_char)) test = np.all(test, axis=1) target = np.where(test)[0] return self.encoding[target[0]] last_two_cards_char = self.player.get_last_two_cards() last_cards_char = last_two_cards_char[0] if not last_cards_char: last_cards_char = last_two_cards_char[1] curr_cards_char = self.player.get_curr_handcards() if self._comb_mask: # print(curr_cards_char, last_cards_char) combs = self.get_combinations(curr_cards_char, last_cards_char) if len(combs) > self.num_actions[0]: combs, self._fine_mask = self.subsample_combs_masks(combs, self._fine_mask, self.num_actions[0]) # TODO: utilize temporal relations to speedup available_actions = [[action_space[idx] for idx in comb] for comb in combs] # print(available_actions) # print('-------------------------------------------') assert len(combs) > 0 if self._fine_mask is not None: self._fine_mask = self.pad_fine_mask(self._fine_mask) self.pad_action_space(available_actions) state = [np.stack([self.encoding[idx] for idx in comb]) for comb in combs] assert len(state) > 0 prob_state = self.player.get_state_prob() # test = action_space_onehot60 == Card.char2onehot60(last_cards_char) # test = np.all(test, axis=1) # target = np.where(test)[0] # assert target.size == 1 extra_state = np.concatenate([cards_char2embedding(last_two_cards_char[0]), cards_char2embedding(last_two_cards_char[1]), prob_state]) for i in range(len(state)): state[i] = np.concatenate([state[i], np.tile(extra_state[None, :], [state[i].shape[0], 1])], axis=-1) state = self.pad_state(state) assert state.shape[0] == self.num_actions[0] and state.shape[1] == self.num_actions[1] else: assert action is not None if self._fine_mask is not None: self._fine_mask = self._fine_mask[action] available_actions = self._action_space[action] state = self._current_ob[action:action+1, :, :] state = np.repeat(state, self.num_actions[0], axis=0) assert state.shape[0] == self.num_actions[0] and state.shape[1] == self.num_actions[1] return state, available_actions def pad_fine_mask(self, mask): if mask.shape[0] < self.num_actions[0]: mask = np.concatenate([mask, np.repeat(mask[-1:], self.num_actions[0] - mask.shape[0], 0)], 0) return mask def pad_action_space(self, available_actions): # print(available_actions) for i in range(len(available_actions)): available_actions[i] += [available_actions[i][-1]] * (self.num_actions[1] - len(available_actions[i])) if len(available_actions) < self.num_actions[0]: available_actions.extend([available_actions[-1]] * (self.num_actions[0] - len(available_actions))) # input is a list of N * HIDDEN_STATE def pad_state(self, state): # since out net uses max operation, we just dup the last row and keep the result same newstates = [] for s in state: assert s.shape[0] <= self.num_actions[1] s = np.concatenate([s, np.repeat(s[-1:, :], self.num_actions[1] - s.shape[0], axis=0)], axis=0) newstates.append(s) newstates = np.stack(newstates, axis=0) if len(state) < self.num_actions[0]: state = np.concatenate([newstates, np.repeat(newstates[-1:, :, :], self.num_actions[0] - newstates.shape[0], axis=0)], axis=0) else: state = newstates return state def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: self._populate_exp() pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" old_s = self._current_ob comb_mask = self._comb_mask if not self._comb_mask and self._fine_mask is not None: fine_mask = self._fine_mask if self._fine_mask.shape[0] == max(self.num_actions[0], self.num_actions[1]) \ else np.pad(self._fine_mask, (0, max(self.num_actions[0], self.num_actions[1]) - self._fine_mask.shape[0]), 'constant', constant_values=(0, 0)) else: fine_mask = np.ones([max(self.num_actions[0], self.num_actions[1])], dtype=np.bool) last_cards_char = self.player.get_last_outcards() if self.rng.rand() <= self.exploration: if not self._comb_mask and self._fine_mask is not None: q_values = np.random.rand(self.num_actions[1]) q_values[np.where(np.logical_not(self._fine_mask))[0]] = np.nan act = np.nanargmax(q_values) # print(q_values) # print(act) else: act = self.rng.choice(range(self.num_actions[0 if comb_mask else 1])) else: q_values = self.curr_predictor(old_s[None, :, :, :], np.array([comb_mask]), np.array([fine_mask]))[0][0] if not self._comb_mask and self._fine_mask is not None: q_values = q_values[:self.num_actions[1]] assert np.all(q_values[np.where(np.logical_not(self._fine_mask))[0]] < -100) q_values[np.where(np.logical_not(self._fine_mask))[0]] = np.nan act = np.nanargmax(q_values) assert act < self.num_actions[0 if comb_mask else 1] # print(q_values) # print(act) # clamp action to valid range act = min(act, self.num_actions[0 if comb_mask else 1] - 1) winner = -1 reward = 0 if comb_mask: isOver = False else: if len(last_cards_char) > 0: if act > 0: if not CardGroup.to_cardgroup(self._action_space[act]).bigger_than(CardGroup.to_cardgroup(last_cards_char)): print('warning, some error happened, ', self._action_space[act], last_cards_char) raise Exception("card comparison error") winner, isOver = self.player.step(self._action_space[act]) # step for AI farmers while not isOver and self.player.get_curr_agent_name() != self.agent_name: handcards = self.player.get_curr_handcards() last_two_cards = self.player.get_last_two_cards() prob_state = self.player.get_state_prob() action = self.predictors[self.player.get_curr_agent_name()].predict(handcards, last_two_cards, prob_state) winner, isOver = self.player.step(action) if isOver: if self.agent_name == winner: reward = 1 else: if self.player.get_all_agent_names().index(winner) + self.player.get_all_agent_names().index(self.agent_name) == 3: reward = 1 else: reward = -1 self._current_game_score.feed(reward) if isOver: self._player_scores.feed(self._current_game_score.sum) self.player.reset() self.player.prepare() self._comb_mask = True self.prestart() self._current_game_score.reset() else: self._comb_mask = not self._comb_mask self._current_ob, self._action_space = self.get_state_and_action_spaces(act if not self._comb_mask else None) self.mem.append(Experience(old_s, act, reward, isOver, comb_mask, fine_mask)) def prestart(self): while self.player.get_curr_agent_name() != self.agent_name: handcards = self.player.get_curr_handcards() last_two_cards = self.player.get_last_two_cards() prob_state = self.player.get_state_prob() action = self.predictors[self.player.get_curr_agent_name()].predict(handcards, last_two_cards, prob_state) self.player.step(action) self._current_ob, self._action_space = self.get_state_and_action_spaces() def get_data(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint( self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] yield self._process_batch(batch_exp) self._populate_job_queue.put(1) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='float32') action = np.asarray([e[1] for e in batch_exp], dtype='int32') reward = np.asarray([e[2] for e in batch_exp], dtype='float32') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') comb_mask = np.asarray([e[4] for e in batch_exp], dtype='bool') fine_mask = np.asarray([e[5] for e in batch_exp], dtype='bool') return [state, action, reward, isOver, comb_mask, fine_mask] def _setup_graph(self): self.curr_predictor = self.trainer.get_predictor([self.agent_name + '/state:0', self.agent_name + '_comb_mask:0', self.agent_name + '/fine_mask:0'], [self.agent_name + '/Qvalue:0']) self.predictors = {n: Predictor(self.trainer.get_predictor([n + '/state:0', n + '_comb_mask:0', n + '/fine_mask:0'], [n + '/Qvalue:0'])) for n in self.player.get_all_agent_names()} def _before_train(self): self.prestart() self._init_memory() self._simulator_th = self.get_simulator_thread() self._simulator_th.start() def _trigger(self): v = self._player_scores try: mean, max = v.average, v.max self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception(self.agent_name + " Cannot log training scores.") v.reset()
class SoccerPlayer(RLEnvironment): """ A wrapper for pygame_soccer emulator. Will automatically restart when a real episode ends (isOver might be just lost of lives but not game over). """ SOCCER_WIDTH = 288 SOCCER_HEIGHT = 192 def __init__(self, viz=0, height_range=(None, None), field='large', partial=False, radius=2, frame_skip=4, image_shape=(84, 84), nullop_start=30, mode=None, team_size=2, ai_frame_skip=1): super(SoccerPlayer, self).__init__() self.mode = mode self.field = field self.partial = partial self.viz = viz assert mode == None, 'Not impl' assert field == 'large', 'No small 2vs2' if self.viz: self.renderer_options = soccer_renderer.RendererOptions( show_display=True, max_fps=10, enable_key_events=True) else: self.renderer_options = None map_path = file_util.resolve_path(__file__, '../data/map/soccer_large.tmx') self.team_size = team_size self.env_options = soccer_environment.SoccerEnvironmentOptions( team_size=self.team_size, map_path=map_path, ai_frame_skip=ai_frame_skip) self.env = soccer_environment.SoccerEnvironment( env_options=self.env_options, renderer_options=self.renderer_options) self.computer_team_name = self.env.team_names[1] self.player_team_name = self.env.team_names[0] # Partial if self.partial: self.radius = radius self.player_agent_index = self.env.get_agent_index( self.player_team_name, 0) self.width, self.height = self.SOCCER_WIDTH, self.SOCCER_HEIGHT self.actions = self.env.actions self.frame_skip = frame_skip self.nullop_start = nullop_start self.height_range = height_range self.image_shape = image_shape self.last_info = {} self.agent_actions = ['STAND'] * (self.team_size * 2) self.current_episode_score = StatCounter() self.restart_episode() def _get_computer_actions(self): # Collaborator for i in range(self.team_size): index = self.env.get_agent_index(self.player_team_name, i) action = self.env.state.get_agent_action(index) self.agent_actions[self.team_size * 0 + i] = action # Opponent for i in range(self.team_size): index = self.env.get_agent_index(self.computer_team_name, i) action = self.env.state.get_agent_action(index) self.agent_actions[self.team_size * 1 + i] = action return np.asarray([ self.env.actions.index(act if act else 'STAND') for act in self.agent_actions ]) def _grab_raw_image(self): """ :returns: the current 3-channel image """ self.env.render() if self.partial: screenshot = self.env.renderer.get_po_screenshot( self.player_agent_index, self.radius) else: screenshot = self.env.renderer.get_screenshot() return screenshot def current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen #ret = np.maximum(ret, self.last_raw_screen) ''' if self.viz: if isinstance(self.viz, float): cv2.imshow('soccer', ret) cv2.waitKey(1) ''' #ret = ret[self.height_range[0]:self.height_range[1], :].astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.resize(ret, self.image_shape) return ret.astype('uint8') # to save some memory def get_action_space(self): return DiscreteActionSpace(len(self.actions)) def finish_episode(self): self.stats['score'].append(self.current_episode_score.sum) def restart_episode(self): self.current_episode_score.reset() self.env.reset() self.last_raw_screen = self._grab_raw_image() def action(self, act): """ :param act: an index of the action :returns: (reward, isOver) """ r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() ret = self.env.take_action(self.env.actions[act]) if k == 0: self.last_info['agent_actions'] = self._get_computer_actions() r += ret.reward if self.env.state.is_terminal(): break self.current_episode_score.feed(r) isOver = self.env.state.is_terminal() if isOver: self.finish_episode() self.restart_episode() return (r, isOver) def get_internal_state(self): return self.last_info
class Brain_Env(gym.Env): """Class that provides 3D medical image environment. This is just an implementation of the classic "agent-environment loop". Each time-step, the agent chooses an action, and the environment returns an observation and a reward.""" def __init__( self, directory=None, viz=False, task=False, files_list=None, observation_dims=(27, 27, 27), multiscale=False, # FIXME automatic dimensions max_num_frames=20, saveGif=False, saveVideo=False): # FIXME hardcoded max num frames! """ :param train_directory: environment or game name :param viz: visualization set to 0 to disable set to +ve number to be the delay between frames to show set to a string to be the directory for storing frames :param observation_dims: shape of the frame cropped from the image to feed it to dqn (d,w,h) - defaults (27,27,27) :param nullop_start: start with random number of null ops :param location_history_length: consider lost of lives as end of episode (useful for training) :max_num_frames: maximum number of frames per episode. """ super(Brain_Env, self).__init__() print( "warning! max num frames hard coded to {}!".format(max_num_frames), flush=True) # inits stat counters self.reset_stat() # counter to limit number of steps per episodes self.cnt = 0 # maximum number of frames (steps) per episodes self.max_num_frames = max_num_frames # stores information: terminal, score, distError self.info = None # option to save display as gif self.saveGif = saveGif self.saveVideo = saveVideo # training flag self.task = task # image dimension (2D/3D) self.observation_dims = observation_dims self.dims = len(self.observation_dims) # multi-scale agent self.multiscale = multiscale # FIXME force multiscale false for now self.multiscale = False # init env dimensions if self.dims == 2: self.width, self.height = observation_dims elif self.dims == 3: self.width, self.height, self.depth = observation_dims else: raise ValueError with _ALE_LOCK: self.rng = get_rng(self) # TODO: understand this viz setup # visualization setup # if isinstance(viz, six.string_types): # check if viz is a string # assert os.path.isdir(viz), viz # viz = 0 # if isinstance(viz, int): # viz = float(viz) self.viz = viz # if self.viz and isinstance(self.viz, float): # self.viewer = None # self.gif_buffer = [] # stat counter to store current score or accumlated reward self.current_episode_score = StatCounter() # get action space and minimal action set self.action_space = spaces.Discrete(6) # change number actions here self.actions = self.action_space.n self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_dims, dtype=np.uint8) # history buffer for storing last locations to check oscilations self._history_length = max_num_frames # TODO initialize _observation_bounds limits from input image coordinates self._observation_bounds = ObservationBounds(0, 0, 0, 0, 0, 0) # add your data loader here # TODO: look into returnLandmarks # if self.task == 'play': # self.files = filesListBrainMRLandmark(directory, files_list, # returnLandmarks=False) # else: # self.files = filesListBrainMRLandmark(directory, files_list, # returnLandmarks=True) self.files = FilesListCubeNPY(directory, files_list) # self.files = filesListFetalUSLandmark(directory,files_list) # self.files = filesListCardioMRLandmark(directory,files_list) # prepare file sampler self.filepath = None self.file_sampler = self.files.sample_circular() # returns generator # reset buffer, terminal, counters, and init new_random_game # we put this here so that init_player in DQN.py doesn't try to update_history self._clear_history() # init arrays self._restart_episode() # self.viz = True # FIXME viz should default False assert (np.shape(self._state) == self.observation_dims) assert np.isclose(jaccard(self.original_state, self.original_state), 1) def reset(self): # with _ALE_LOCK: self._restart_episode() return self._observe() def _restart_episode(self): """ restart current episode """ self.terminal = False self.cnt = 0 # counter to limit number of steps per episodes self.num_games.feed(1) self.current_episode_score.reset() # reset the stat counter self.new_random_game() def new_random_game(self): """ load image, set dimensions, randomize start point, init _screen, qvals, calc distance to goal """ self.terminal = False self.viewer = None # # sample a new image self.filepath, self.filename = next(self.file_sampler) self._state = np.load(self.filepath).astype(float) self.original_state = np.copy(self._state) # multiscale (e.g. start with 3 -> 2 -> 1) # scale can be thought of as sampling stride if self.multiscale: raise NotImplementedError # ## brain # self.action_step = 9 # self.xscale = 3 # self.yscale = 3 # self.zscale = 3 ## cardiac # self.action_step = 6 # self.xscale = 2 # self.yscale = 2 # self.zscale = 2 else: self.action_step = 1 self.xscale = 1 self.yscale = 1 self.zscale = 1 # image volume size self._state_dims = np.shape(self._state) ####################################################################### ## select random starting point # add padding to avoid start right on the border of the image if (self.task == 'train'): skip_thickness = (int(self._state_dims[0] / 5), int(self._state_dims[1] / 5), int(self._state_dims[2] / 5)) else: # TODO: wtf why different skip thickness skip_thickness = (int(self._state_dims[0] / 4), int(self._state_dims[1] / 4), int(self._state_dims[2] / 4)) # FIXME randomly select one of the ground truth voxels as a starting point binary_grid = self.original_state.astype(bool) x_span, y_span, z_span = self.original_state.shape x, y, z = np.indices((x_span, y_span, z_span)) positions = np.c_[x[binary_grid == 1], y[binary_grid == 1], z[binary_grid == 1]] # pick a random row as starting position self._location = positions[np.random.choice(positions.shape[0], 1)].flatten() # print("starting location ", self._location) self._start_location = self._location # # randomly select the starting coords # x = self.rng.randint(0 + skip_thickness[0], # self._state_dims[0] - skip_thickness[0]) # y = self.rng.randint(0 + skip_thickness[1], # self._state_dims[1] - skip_thickness[1]) # z = self.rng.randint(0 + skip_thickness[2], # self._state_dims[2] - skip_thickness[2]) ####################################################################### # self._location = np.array([x, y, z]) # self._start_location = np.array([x, y, z]) self._qvalues = np.zeros(self.actions) self._observation = self._observe() self.curr_IOU = self.calc_IOU() print("first IOU ", self.curr_IOU) self.reward = self._calc_reward(False, False) self._update_history() # we've finished iteration 0. now, step begins with cnt = 1 self.cnt += 1 def calc_IOU(self): """ calculate the Intersection over Union AKA Jaccard Index between two images https://en.wikipedia.org/wiki/Jaccard_index """ # flatten bc jaccard_similarity_score expects 1D arrays state = self._state.ravel() state[state != -1] = 0 # mask out non-agent trajectory state = state.astype(bool) # everything non-zero => True if not state.any(): # no agent trajectory print(" no state trajectory found") iou = 0.0 else: iou = jaccard(state, self.original_state) # print("computed iou ", iou) # print("sum(agent) ", sum(state), "sum(original state)", sum(self.original_state), "computed iou ", iou) # print("agent \n", state.shape) # print("og \n", original_state.shape) # np.save("agent", state) # np.save("og", original_state) # assert isinstance(iou, ) return iou def step(self, act, qvalues): """The environment's step function returns exactly what we need. Args: act: Returns: observation (object): an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game. reward (float): amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward. done (boolean): whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.) info (dict): diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning. """ self._qvalues = qvalues current_loc = self._location self.terminal = False go_out = False backtrack = False # UP Z+ ----------------------------------------------------------- if (act == 0): proposed_location = current_loc + np.array([0, 0, 1 ]) * self.action_step # FORWARD Y+ --------------------------------------------------------- elif (act == 1): proposed_location = current_loc + np.array([0, 1, 0 ]) * self.action_step # RIGHT X+ ----------------------------------------------------------- elif (act == 2): proposed_location = current_loc + np.array([1, 0, 0 ]) * self.action_step # LEFT X- ----------------------------------------------------------- elif act == 3: proposed_location = current_loc + np.array([-1, 0, 0 ]) * self.action_step # BACKWARD Y- --------------------------------------------------------- elif act == 4: proposed_location = current_loc + np.array([0, -1, 0 ]) * self.action_step # DOWN Z- ----------------------------------------------------------- elif act == 5: proposed_location = current_loc + np.array([0, 0, -1 ]) * self.action_step else: raise ValueError # print("action ", act, "loc ", self._location, "proposed ", proposed_location, "diff ", proposed_location-self._location) if not self._is_in_bounds(proposed_location): # went out of bounds # do not update current_loc go_out = True else: # in bounds transposed = proposed_location.T # https://stackoverflow.com/a/25823710/4212158 if np.any( np.isclose(np.unique(self._agent_nodes, axis=0), transposed).all(axis=1)): # print("backtracking detected ", transposed, "hist ", np.unique(self._agent_nodes, axis=0), np.isclose(np.unique(self._agent_nodes, axis=0), transposed).all(axis=1)) # we backtracked backtrack = True else: # we are in bounds, AND we didn't back track. accept new location self._location = proposed_location # only update state, iou if we've changed location self._observation = self._observe() self.curr_IOU = self.calc_IOU() # punish -1 reward if the agent tries to go out #if (self.task != 'play'): # TODO: why is this necessary? self.reward = self._calc_reward( go_out, backtrack ) # TODO I think reward needs to be calculated after increment cnt # update screen, reward ,location, terminal self._update_history() # terminate if the distance is less than 1 during trainig if (self.task == 'train'): if self.curr_IOU >= 0.9: print("finishing episode, IOU = ", self.curr_IOU) self.terminal = True self.num_success.feed(1) self.display() # terminate if maximum number of steps is reached if self.cnt >= self.max_num_frames - 1: print("finishing episode, exceeded max_frames ", self.max_num_frames, " IOU = ", self.curr_IOU) self.terminal = True # self.display() # update history buffer with new location and qvalues if (self.task != 'play'): self.curr_IOU = self.calc_IOU() # check if agent oscillates # if self._oscillate: # TODO: rewind history, recalculate IOU # self._location = self.get_best_node() # TODO replace # self._observation = self._observe() # if (self.task != 'play'): # self.curr_IOU = self.calc_IOU() # multi-scale steps # if self.multiscale: # if self.xscale > 1: # self.xscale -= 1 # self.yscale -= 1 # self.zscale -= 1 # self.action_step = int(self.action_step / 3) # self._clear_history() # # terminate if scale is less than 1 # else: # self.terminal = True # if self.curr_IOU >= 0.9: self.num_success.feed(1) # else: # self.terminal = True # if self.curr_IOU >= 0.9: self.num_success.feed(1) # # render screen if viz is on FIXME this displays at each step # with _ALE_LOCK: # if self.viz: # if isinstance(self.viz, float): # self.display() self.current_episode_score.feed(self.reward) self.cnt += 1 info = { 'score': self.current_episode_score.sum, 'gameOver': self.terminal, 'IoU': self.curr_IOU, 'filename': self.filename } return self._observe(), self.reward, self.terminal, info def get_best_node(self): ''' get best location with best qvalue from last for locations stored in history TODO: make sure nodes dont have overlap ''' last_qvalues_history = self._qvalues_history[-4:] last_loc_history = self._agent_nodes[-4:] best_qvalues = np.max(last_qvalues_history, axis=1) # best_idx = best_qvalues.argmax() best_idx = best_qvalues.argmin() best_location = last_loc_history[best_idx] return best_location def _clear_history(self): """ clear history buffer with current state """ # TODO: double check these np arrays work in place of the lists self._agent_nodes = np.zeros( (self._history_length, self.dims)) # [(0,) * self.dims] * self._history_length self._IOU_history = np.zeros((self._history_length, )) # list of q-value lists self._qvalues_history = np.zeros( (self._history_length, self.actions)) # [(0,) * self.actions] * self._history_length self.reward_history = np.zeros((self._history_length, )) def _update_history(self): """ update history buffer with current state """ # update location history self._agent_nodes[self.cnt] = self._location # update jaccard index history self._IOU_history[self.cnt] = self.curr_IOU # and the reward self.reward_history[self.cnt] = self.reward # update q-value history self._qvalues_history[self.cnt] = self._qvalues def _observe(self): """ crop image data around current location to update what network sees. update _observation_bounds :return: new state """ # initialize screen with zeros - all background observation = np.zeros((self.observation_dims)) # screen uses coordinate system relative to origin (0, 0, 0) screen_xmin, screen_ymin, screen_zmin = 0, 0, 0 screen_xmax, screen_ymax, screen_zmax = self.observation_dims # extract boundary locations using coordinate system relative to "global" image # width, height, depth in terms of screen coord system if self.xscale % 2: xmin = self._location[0] - int(self.width * self.xscale / 2) - 1 xmax = self._location[0] + int(self.width * self.xscale / 2) ymin = self._location[1] - int(self.height * self.yscale / 2) - 1 ymax = self._location[1] + int(self.height * self.yscale / 2) zmin = self._location[2] - int(self.depth * self.zscale / 2) - 1 zmax = self._location[2] + int(self.depth * self.zscale / 2) else: xmin = self._location[0] - round(self.width * self.xscale / 2) xmax = self._location[0] + round(self.width * self.xscale / 2) ymin = self._location[1] - round(self.height * self.yscale / 2) ymax = self._location[1] + round(self.height * self.yscale / 2) zmin = self._location[2] - round(self.depth * self.zscale / 2) zmax = self._location[2] + round(self.depth * self.zscale / 2) # check if they violate image boundary and fix it if xmin < 0: xmin = 0 screen_xmin = screen_xmax - len(np.arange(xmin, xmax, self.xscale)) if ymin < 0: ymin = 0 screen_ymin = screen_ymax - len(np.arange(ymin, ymax, self.yscale)) if zmin < 0: zmin = 0 screen_zmin = screen_zmax - len(np.arange(zmin, zmax, self.zscale)) if xmax > self._state_dims[0]: xmax = self._state_dims[0] screen_xmax = screen_xmin + len(np.arange(xmin, xmax, self.xscale)) if ymax > self._state_dims[1]: ymax = self._state_dims[1] screen_ymax = screen_ymin + len(np.arange(ymin, ymax, self.yscale)) if zmax > self._state_dims[2]: zmax = self._state_dims[2] screen_zmax = screen_zmin + len(np.arange(zmin, zmax, self.zscale)) # take image, mask it w agent trajectory agent_trajectory = self.trajectory_to_branch() agent_trajectory *= -1 # agent frames are negative # paste agent trajectory ontop of original state, but only when vals are not 0 agent_mask = agent_trajectory.astype(bool) if agent_mask.any(): # agent trajectory not empty np.copyto(self._state, agent_trajectory, casting='no', where=agent_mask) assert self._state is not None # crop image data to update what network sees # image coordinate system becomes screen coordinates # scale can be thought of as a stride # TODO: check if we need to keep "stride" from upstream observation[screen_xmin:screen_xmax, screen_ymin:screen_ymax, screen_zmin:screen_zmax] = self._state[xmin:xmax, ymin:ymax, zmin:zmax] # update _observation_bounds limits from input image coordinates # this is what the network sees self._observation_bounds = ObservationBounds(xmin, xmax, ymin, ymax, zmin, zmax) return observation def trajectory_to_branch(self): """take location history, generate connected branches using Vaa3d plugin FIXME this function is horribly inefficient """ locations = self._agent_nodes # print("og state shape ", np.shape(self.original_state)) # print("self obs dims ", self.observation_dims) # if the agent hasn't drawn any nodes, then the branch is empty. skip pipeline, return empty arr. if not locations.any(): # if all zeros, evals to False return np.zeros_like(self.original_state) else: # TODO: make tmp files not collide when doing multiprocessing output_swc = save_branch_as_swc(locations, "agent_trajectory", output_dir="tmp", overwrite=True) # TODO: be explicit about bounds to swc_to_tiff output_tiff = swc_to_TIFF("agent_trajectory", output_swc, output_dir="tmp", overwrite=True) output_npy = TIFF_to_npy("agent_trajectory", output_tiff, output_dir="tmp", overwrite=True) output_npy = np.load(output_npy).astype(float) tiff_max = np.amax(np.fabs(output_npy)) if not np.isclose(tiff_max, 0): # normalize if tiff is not blank output_npy = output_npy / tiff_max return output_npy def crop_brain(self, xmin, xmax, ymin, ymax, zmin, zmax): return self.state[xmin:xmax, ymin:ymax, zmin:zmax] def _calc_reward(self, go_out, backtrack): """ Calculate the new reward based on the increase in IoU TODO: if current location is same as past location, always penalize (discourage retracing) """ if go_out or backtrack: reward = -1 else: # TODO, double check if indexes are correct if self.cnt == 0: previous_IOU = 0. else: previous_IOU = self._IOU_history[self.cnt - 1] IOU_difference = self.curr_IOU - previous_IOU print("curr IOU = ", self.curr_IOU, "prev IOU = ", self._IOU_history[self.cnt - 1], "diff = ", IOU_difference) assert isinstance(IOU_difference, float) if IOU_difference > 0: reward = 1 else: reward = -1 return reward def _is_in_bounds(self, coords): x, y, z = coords bounds = self._observation_bounds return ((bounds.xmin <= x <= bounds.xmax - 1 and bounds.ymin <= y <= bounds.ymax - 1 and bounds.zmin <= z <= bounds.zmax - 1)) @property def _oscillate(self): """ Return True if the agent is stuck and oscillating """ # TODO reimplement # TODO: erase last few frames if oscillation is detected counter = Counter(self._agent_nodes) freq = counter.most_common() # TODO: wtF? if freq[0][0] == (0, 0, 0): if (freq[1][1] > 3): return True else: return False elif (freq[0][1] > 3): return True def get_action_meanings(self): """ return array of integers for actions""" ACTION_MEANING = { 1: "UP", # MOVE Z+ 2: "FORWARD", # MOVE Y+ 3: "RIGHT", # MOVE X+ 4: "LEFT", # MOVE X- 5: "BACKWARD", # MOVE Y- 6: "DOWN", # MOVE Z- } return [ACTION_MEANING[i] for i in self.actions] @property def getScreenDims(self): """ return screen dimensions """ return (self.width, self.height, self.depth) def lives(self): return None def reset_stat(self): """ Reset all statistics counter""" self.stats = defaultdict(list) self.num_games = StatCounter() self.num_success = StatCounter() def display(self): """this is called at every step""" current_point = self._location # img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB) # congvert to rgb # rescale image # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4 # scale_x = 1 # scale_y = 1 # print("nodes ", self._agent_nodes) # print("ious", self._IOU_history) print("reward history ", np.unique(self.reward_history)) print("IOU history ", np.unique(self._IOU_history)) plotter = Viewer(self.original_state, zip(self._agent_nodes, self._IOU_history), filepath=self.filename) # # # # # from viewer import SimpleImageViewer # # self.viewer = SimpleImageViewer(self._state, # # scale_x=1, # # scale_y=1, # # filepath=self.filename) # self.gif_buffer = [] # # # # render and wait (viz) time between frames # self.viewer.render() # # time.sleep(self.viz) # # save gif if self.saveGif: # if self.saveGif: # TODO make this a method of viewer raise NotImplementedError # image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() # data = image_data.get_data('RGB', image_data.width * 3) # arr = np.array(bytearray(data)).astype('uint8') # arr = np.flip(np.reshape(arr, (image_data.height, image_data.width, -1)), 0) # im = Image.fromarray(arr) # self.gif_buffer.append(im) # # if not self.terminal: # gifname = self.filename.split('.')[0] + '.gif' # self.viewer.saveGif(gifname, arr=self.gif_buffer, # duration=self.viz) if self.saveVideo: dirname = 'tmp_video' # if self.cnt <= 1: # if os.path.isdir(dirname): # logger.warn("""Log directory {} exists! Use 'd' to delete it. """.format(dirname)) # act = input("select action: d (delete) / q (quit): ").lower().strip() # if act == 'd': # shutil.rmtree(dirname, ignore_errors=True) # else: # raise OSError("Directory {} exits!".format(dirname)) # os.mkdir(dirname) vid_fpath = self.filename + '.mp4' # vid_fpath = dirname + '/' + self.filename + '.mp4' plotter.save_vid(vid_fpath, self.max_num_frames - 1) # plotter.show_agent() if self.viz: # show progress # plotter.show() # actually, let's just save the files for later output_dir = os.path.abspath("saved_trajectories/") if not os.path.exists(output_dir): os.mkdir(output_dir) # outfile_fpath = os.path.join(output_dir, input_fname + ".npy") # # # don't overwrite # if not os.path.isfile(outfile_fpath) or overwrite: # desired_len = 16 # img_array = tiff2array.imread(input_fpath) # # make all arrays the same shape # # format: ((top, bottom), (left, right)) # shp = img_array.shape # # print(shp, flush=True) # if shp != (desired_len, desired_len, desired_len): # try: # img_array = np.pad(img_array, ( # (0, desired_len - shp[0]), (0, desired_len - shp[1]), (0, desired_len - shp[2])), # 'constant') # except ValueError: # raise # # print(shp, flush=True) # don't wait for all threads to finish before printing # np.savez(output_dir + self.filename, locations=self._agent_nodes, original_state=self.original_state, reward_history=self.reward_history)
class AtariPlayer(gym.Env): """ A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. Info: score: the accumulated reward in the current game gameOver: True when the current game is Over """ def __init__(self, rom_file, viz=0, frame_skip=4, nullop_start=30, live_lost_as_eoe=True, max_num_frames=0): """ Args: rom_file: path to the rom frame_skip: skip every k frames and repeat the action viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. nullop_start: start with random number of null ops. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. max_num_frames: maximum number of frames per episode. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "rom {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Error) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) cv2.startWindowThread() cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.current_episode_score = StatCounter() self.action_space = spaces.Discrete(len(self.actions)) self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width)) self._restart_episode() def get_action_meanings(self): return [ACTION_MEANING[i] for i in self.actions] def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def _current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): cv2.imshow(self.windowname, ret) time.sleep(self.viz) ret = ret.astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) return ret.astype('uint8') # to save some memory def _restart_episode(self): self.current_episode_score.reset() with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = self.rng.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def _reset(self): if self.ale.game_over(): self._restart_episode() return self._current_state() def _step(self, act): oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break self.current_episode_score.feed(r) trueIsOver = isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives info = { 'score': self.current_episode_score.sum, 'gameOver': trueIsOver } return self._current_state(), r, isOver, info
class AtariPlayer(RLEnvironment): """ A wrapper for atari emulator. Will automatically restart when a real episode ends (isOver might be just lost of lives but not game over). """ def __init__(self, rom_file, viz=0, height_range=(None, None), frame_skip=4, image_shape=(84, 84), nullop_start=30, live_lost_as_eoe=True): """ :param rom_file: path to the rom :param frame_skip: skip every k frames and repeat the action :param image_shape: (w, h) :param height_range: (h1, h2) to cut :param viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. :param nullop_start: start with random number of null ops :param live_losts_as_eoe: consider lost of lives as end of episode. useful for training. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "rom {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) cv2.startWindowThread() cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.height_range = height_range self.image_shape = image_shape self.current_episode_score = StatCounter() self.restart_episode() def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): cv2.imshow(self.windowname, ret) time.sleep(self.viz) ret = ret[self.height_range[0]:self.height_range[1], :].astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.resize(ret, self.image_shape) return ret.astype('uint8') # to save some memory def get_action_space(self): return DiscreteActionSpace(len(self.actions)) def finish_episode(self): self.stats['score'].append(self.current_episode_score.sum) def restart_episode(self): self.current_episode_score.reset() with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = self.rng.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def action(self, act): """ :param act: an index of the action :returns: (reward, isOver) """ oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break self.current_episode_score.feed(r) isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives if isOver: self.finish_episode() if self.ale.game_over(): self.restart_episode() return (r, isOver)
class ExpReplay(DataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, predictor_io_names, player, state_shape, batch_size, memory_size, init_memory_size, init_exploration, update_frequency, history_len): """ Args: predictor_io_names (tuple of list of str): input/output names to predict Q value from state. player (RLEnvironment): the player. history_len (int): length of history frames to concat. Zero-filled initial frames. update_frequency (int): number of new transitions to add to memory after sampling a batch of transitions for training. """ init_memory_size = int(init_memory_size) for k, v in locals().items(): if k != 'self': setattr(self, k, v) self.exploration = init_exploration self.num_actions = player.action_space.n logger.info("Number of Legal actions: {}".format(self.num_actions)) self.rng = get_rng(self) self._init_memory_flag = threading.Event() # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape, history_len) self._current_ob = self.player.reset() self._player_scores = StatCounter() self._current_game_score = StatCounter() def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: self._populate_exp() pbar.update() self._init_memory_flag.set() # quickly fill the memory for debug def _fake_init_memory(self): from copy import deepcopy with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < 5: self._populate_exp() pbar.update() while len(self.mem) < self.init_memory_size: self.mem.append(deepcopy(self.mem._hist[0])) pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" old_s = self._current_ob if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len): act = self.rng.choice(range(self.num_actions)) else: # build a history state history = self.mem.recent_state() history.append(old_s) history = np.stack(history, axis=2) # assume batched network q_values = self.predictor(history[None, :, :, :])[0][0] # this is the bottleneck act = np.argmax(q_values) self._current_ob, reward, isOver, info = self.player.step(act) self._current_game_score.feed(reward) if isOver: if info['ale.lives'] == 0: # only record score when a whole game is over (not when an episode is over) self._player_scores.feed(self._current_game_score.sum) self._current_game_score.reset() self.player.reset() self.mem.append(Experience(old_s, act, reward, isOver)) def _debug_sample(self, sample): import cv2 def view_state(comb_state): state = comb_state[:, :, :-1] next_state = comb_state[:, :, 1:] r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1) r2 = np.concatenate([next_state[:, :, k] for k in range(self.history_len)], axis=1) r = np.concatenate([r, r2], axis=0) cv2.imshow("state", r) cv2.waitKey() print("Act: ", sample[2], " reward:", sample[1], " isOver: ", sample[3]) if sample[1] or sample[3]: view_state(sample[0]) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='uint8') reward = np.asarray([e[1] for e in batch_exp], dtype='float32') action = np.asarray([e[2] for e in batch_exp], dtype='int8') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') return [state, action, reward, isOver] # DataFlow method: def get_data(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint( self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - self.history_len - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] yield self._process_batch(batch_exp) self._populate_job_queue.put(1) # Callback methods: def _setup_graph(self): self.predictor = self.trainer.get_predictor(*self.predictor_io_names) def _before_train(self): self._init_memory() self._simulator_th = self.get_simulator_thread() self._simulator_th.start() def _trigger(self): v = self._player_scores try: mean, max = v.average, v.max self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception("Cannot log training scores.") v.reset()
class AgentBase(GymEnv): def __init__(self, agentIdent, is_train=False, auto_restart = True, **kwargs): # super(AgentBase, self).__init__(name='torcs') self.auto_restart = auto_restart self._isTrain = is_train self._agentIdent = agentIdent self._kwargs = kwargs self._init() def _init(self): logger.info("[{}]: agent init, isTrain={}".format(self._agentIdent, self._isTrain)) self._episodeCount = -1 from tensorpack.utils.utils import get_rng self._rng = get_rng(self) from tensorpack.utils.stats import StatCounter self.reset_stat() self.rwd_counter = StatCounter() self._memorySaver = None save_dir = self._kwargs.pop('save_dir', None) if save_dir is not None: self._memorySaver = MemorySaver(save_dir, self._kwargs.pop('max_save_item', 3), self._kwargs.pop('min_save_score', None), ) self.restart_episode() pass def restart_episode(self): self.rwd_counter.reset() self.__ob = self.reset() def finish_episode(self): score = self.rwd_counter.sum self.stats['score'].append(score) logger.info("episode finished, rewards = {:.3f}, episode = {}, steps = {}" .format(score, self._episodeCount, self._episodeSteps)) def current_state(self): return self.__ob def reset(self): self._episodeCount += 1 ret = self._reset() self._episodeRewards = 0. self._episodeSteps = 0 if self._memorySaver: self._memorySaver.createMemory(self._episodeCount) logger.info("restart, episode={}".format(self._episodeCount)) return ret @abc.abstractmethod def _reset(self): pass def action(self, pred): ob, act, r, isOver, info = self._step(pred) self.rwd_counter.feed(r) if self._memorySaver: self._memorySaver.addCurrent(ob, act, r, isOver) self.__ob = ob self._episodeSteps += 1 self._episodeRewards += r if isOver: self.finish_episode() if self.auto_restart: self.restart_episode() return act, r, isOver @abc.abstractmethod def _step(self, action): raise NotImplementedError def get_action_space(self): raise NotImplementedError
class AgentBase(GymEnv): def __init__(self, agentIdent, is_train=False, auto_restart=True, **kwargs): # super(AgentBase, self).__init__(name='torcs') self.auto_restart = auto_restart self._isTrain = is_train self._agentIdent = agentIdent self._kwargs = kwargs self._init() def _init(self): logger.info("[{}]: agent init, isTrain={}".format( self._agentIdent, self._isTrain)) self._episodeCount = -1 from tensorpack.utils.utils import get_rng self._rng = get_rng(self) from tensorpack.utils.stats import StatCounter self.reset_stat() self.rwd_counter = StatCounter() self._memorySaver = None save_dir = self._kwargs.pop('save_dir', None) if save_dir is not None: self._memorySaver = MemorySaver( save_dir, self._kwargs.pop('max_save_item', 3), self._kwargs.pop('min_save_score', None), ) self.restart_episode() pass def restart_episode(self): self.rwd_counter.reset() self.__ob = self.reset() def finish_episode(self): score = self.rwd_counter.sum self.stats['score'].append(score) logger.info( "episode finished, rewards = {:.3f}, episode = {}, steps = {}". format(score, self._episodeCount, self._episodeSteps)) def current_state(self): return self.__ob def reset(self): self._episodeCount += 1 ret = self._reset() self._episodeRewards = 0. self._episodeSteps = 0 if self._memorySaver: self._memorySaver.createMemory(self._episodeCount) logger.info("restart, episode={}".format(self._episodeCount)) return ret @abc.abstractmethod def _reset(self): pass def action(self, pred): ob, act, r, isOver, info = self._step(pred) self.rwd_counter.feed(r) if self._memorySaver: self._memorySaver.addCurrent(ob, act, r, isOver) self.__ob = ob self._episodeSteps += 1 self._episodeRewards += r if isOver: self.finish_episode() if self.auto_restart: self.restart_episode() return act, r, isOver @abc.abstractmethod def _step(self, action): raise NotImplementedError def get_action_space(self): raise NotImplementedError
class SoccerPlayer(RLEnvironment): """ A wrapper for pygame_soccer emulator. Will automatically restart when a real episode ends (isOver might be just lost of lives but not game over). """ SOCCER_WIDTH = 288 SOCCER_HEIGHT = 192 def __init__(self, viz=0, field=None, partial=False, radius=2, frame_skip=4, image_shape=(84, 84), mode=None, team_size=1, ai_frame_skip=1, raw_env=soccer_environment.SoccerEnvironment): super(SoccerPlayer, self).__init__() if team_size > 1 and mode != None: self.mode = mode.split(',') else: self.mode = [mode] self.field = field self.partial = partial self.viz = viz if self.viz: self.renderer_options = soccer_renderer.RendererOptions( show_display=True, max_fps=10, enable_key_events=True) else: self.renderer_options = None if self.field == 'large': map_path = file_util.resolve_path(__file__, '../data/map/soccer_large.tmx') else: map_path = None self.team_size = team_size self.env_options = soccer_environment.SoccerEnvironmentOptions( team_size=self.team_size, map_path=map_path, ai_frame_skip=ai_frame_skip) self.env = raw_env(env_options=self.env_options, renderer_options=self.renderer_options) self.computer_team_name = self.env.team_names[1] self.player_team_name = self.env.team_names[0] # Partial if self.partial: self.radius = radius self.player_agent_index = self.env.get_agent_index( self.player_team_name, 0) self.actions = self.env.actions self.frame_skip = frame_skip self.image_shape = image_shape self.last_info = {} self.agent_actions = ['STAND'] * (self.team_size * 2) self.changing_counter = 0 self.timestep = 0 self.current_episode_score = StatCounter() self.restart_episode() def _grab_raw_image(self): self.env.render() if self.partial: screenshot = self.env.renderer.get_po_screenshot( self.player_agent_index, self.radius) else: screenshot = self.env.renderer.get_screenshot() return screenshot def _get_computer_actions(self): # Collaborator for i in range(self.team_size): index = self.env.get_agent_index(self.player_team_name, i) action = self.env.state.get_agent_action(index) self.agent_actions[self.team_size * 0 + i] = action # Opponent for i in range(self.team_size): index = self.env.get_agent_index(self.computer_team_name, i) action = self.env.state.get_agent_action(index) self.agent_actions[self.team_size * 1 + i] = action return np.asarray([ self.env.actions.index(act if act else 'STAND') for act in self.agent_actions ]) def _set_opponent_mode(self, mode): for i in range(self.team_size): index = self.env.get_agent_index(self.computer_team_name, i) m = mode[i] self.env.state.set_agent_mode(index, m) def _set_collaborator_mode(self, mode): for i in range(1, self.team_size): index = self.env.get_agent_index(self.player_team_name, i) m = mode[i - 1] self.env.state.set_agent_mode(index, m) def _set_computer_mode(self, mode): if mode[0] == None or len(mode) < self.team_size * 2 - 1: return if mode[0] in ['OFFENVIE', 'DFFENSIVE']: # Collaborator if self.team_size >= 2: self._set_collaborator_mode(mode[:(self.team_size - 1)]) # Opponent self._set_opponent_mode(mode[(self.team_size - 1):]) def current_state(self): ret = self._grab_raw_image() ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.resize(ret, self.image_shape) return ret.astype('uint8') # to save some memory def get_action_space(self): return DiscreteActionSpace(len(self.actions)) def finish_episode(self): self.stats['score'].append(self.current_episode_score.sum) def restart_episode(self): self.current_episode_score.reset() self.env.reset() self._set_computer_mode(self.mode) self.last_raw_screen = self._grab_raw_image() self.changing_counter = 0 self.timestep = 0 def action(self, act): ball_pos_agent_old = self.env.state.get_ball_possession() r = 0 ball_poss_old = self.env.state.get_ball_possession()['team_name'] for k in range(self.frame_skip): self.timestep += 1 if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() if self.mode[0] == 'WEAKCOOP': actions = {} for team_name in self.env.team_names: for team_agent_index in range(self.env.options.team_size): agent_index = self.env.get_agent_index( team_name, team_agent_index) agent_action = self.env._get_ai_action( team_name, team_agent_index) print(team_name + self.env.state.get_agent_mode(agent_index)) actions[agent_index] = agent_action player_index = self.env.get_agent_index( self.player_team_name, 0) coop_index = self.env.get_agent_index(self.player_team_name, 1) actions[player_index] = self.env.actions[act] if random.random() < 0.5: actions[coop_index] = random.choice(self.env.actions) ret = self.env.take_action(actions) elif self.mode[0] == 'ALL_RANDOM': if self.team_size == 1: player_index = self.env.get_agent_index( self.player_team_name, 0) opponent_index = self.env.get_agent_index( self.computer_team_name, 0) actions = { player_index: self.env.actions[act], opponent_index: random.choice(self.env.actions) } else: actions = {} for team_name in [ self.player_team_name, self.computer_team_name ]: for team_index in range(self.team_size): agent_index = self.env.get_agent_index( team_name, team_index) actions[agent_index] = random.choice( self.env.actions) player_index = self.env.get_agent_index( self.player_team_name, 0) actions[player_index] = self.env.actions[act] ret = self.env.take_action(actions) # else: # print(self.env.actions[act]) # ret = self.env.take_action(self.env.actions[act]) else: if self.mode[0] == 'OPPONENT_DYNAMIC': choices = ['OFFENSIVE', 'DEFENSIVE'] if self.timestep % random.randint(4, 10) == 0: new_modes = [ random.choice(choices) for i in range(self.team_size) ] self._set_opponent_mode(new_modes) if self.mode[0] == 'COOP_DYNAMIC': choices = ['OFFENSIVE', 'DEFENSIVE'] if self.timestep % random.randint(4, 10) == 0: new_modes = [ random.choice(choices) for i in range(self.team_size - 1) ] self._set_collaborator_mode(new_modes) actions = {} for team_name in self.env.team_names: for team_agent_index in range(self.env.options.team_size): agent_index = self.env.get_agent_index( team_name, team_agent_index) agent_action = self.env._get_ai_action( team_name, team_agent_index) # print(team_name + self.env.state.get_agent_mode(agent_index)) actions[agent_index] = agent_action player_index = self.env.get_agent_index( self.player_team_name, 0) actions[player_index] = self.env.actions[act] ret = self.env.take_action(actions) if k == 0: self.last_info['agent_actions'] = self._get_computer_actions() r += ret.reward if self.env.state.is_terminal(): break self.current_episode_score.feed(r) isOver = self.env.state.is_terminal() ball_pos_agent_new = self.env.state.get_ball_possession() if ball_pos_agent_old['team_name'] == ball_pos_agent_new[ 'team_name'] and ball_pos_agent_new['team_name'] == 'PLAYER': if ball_pos_agent_old['team_agent_index'] != ball_pos_agent_new[ 'team_agent_index']: self.changing_counter += 1 if isOver: self.finish_episode() self.restart_episode() return (r, isOver) def get_internal_state(self): return self.last_info def get_changing_counter(self): return self.changing_counter
class ThorPlayer(RLEnvironment): """ a wrapper for Thor environment. """ def __init__(self, exe_path, json_path, actions=ACTIONS, height=HEIGHT, width=WIDTH, gray=False, record=False): super(ThorPlayer, self).__init__() assert os.path.isfile(exe_path), 'wrong path of executable binary for Thor' assert os.path.isfile(json_path), 'wrong path of target json file' self.height = height self.width = width self.gray = gray self.record = record # set Thor controller self.env = robosims.controller.ChallengeController( unity_path=exe_path, height=self.height, width=self.width, record_actions=self.record) # read targets from the json file with open(json_path) as f: self.targets = json.loads(f.read()) self.num_targets = len(self.targets) self.rng = get_rng(self) self.actions = actions self.current_episode_score = StatCounter() self.env.start() self.restart_episode() def current_state(self): # image of current state, numpy array of (h, w, 3) in RGB order img = self.env.last_event.frame success = self.env.last_event.metadata['lastActionSuccess'] found = self.env.target_found() if self.gray: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)] return img, success, found def get_action_space(self): return DiscreteActionSpace(len(self.actions)) def next_target(self): idx = self.rng.choice(range(self.num_targets)) return self.targets[idx] def restart_episode(self): """ reset the episode counter and initialize the env by a random selected target """ self.current_episode_score.reset() target = self.next_target() self.env.initialize_target(target) def action(self, act): """ Perform an action. Will automatically start a new episode if isOver """ r = 0.0 isOver = False event = self.env.step(action=dict(action=self.actions[act])) if not event.metadata['lastActionSuccess']: r -= 0.01 if self.env.target_found(): r += 100.0 isOver = True self.current_episode_score.feed(r) if isOver: self.restart_episode() return (r, isOver)
class ExpReplay(DataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, # model, agent_name, state_shape, num_actions, batch_size, memory_size, init_memory_size, init_exploration, update_frequency, pipe_exp2sim, pipe_sim2exp): logger.info('starting expreplay {}'.format(agent_name)) self.init_memory_size = int(init_memory_size) self.context = zmq.Context() # no reply for now # self.exp2sim_socket = self.context.socket(zmq.ROUTER) # self.exp2sim_socket.set_hwm(20) # self.exp2sim_socket.bind(pipe_exp2sim) self.sim2exp_socket = self.context.socket(zmq.PULL) self.sim2exp_socket.set_hwm(2) self.sim2exp_socket.bind(pipe_sim2exp) self.queue = queue.Queue(maxsize=1000) # self.model = model for k, v in locals().items(): if k != 'self': setattr(self, k, v) self.agent_name = agent_name self.exploration = init_exploration self.num_actions = num_actions logger.info("Number of Legal actions: {}, {}".format(*self.num_actions)) self.rng = get_rng(self) self._init_memory_flag = threading.Event() # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape) # self._current_ob, self._action_space = self.get_state_and_action_spaces() self._player_scores = StatCounter() self._current_game_score = StatCounter() def get_recv_thread(self): def f(): msg = self.sim2exp_socket.recv(copy=False).bytes msg = loads(msg) print('{}: received msg'.format(self.agent_name)) try: self.queue.put_nowait(msg) except Exception: logger.info('put queue failed!') # send response or not? recv_thread = LoopThread(f, pausable=False) # recv_thread.daemon = True recv_thread.name = "recv thread" return recv_thread def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() i = 0 # synchronous training while i < self.update_frequency: if self._populate_exp(): i += 1 time.sleep(0.1) # for _ in range(self.update_frequency): # self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("{} populating replay memory with epsilon={} ...".format(self.agent_name, self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: if self._populate_exp(): pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" try: # do not wait for an update, this may cause some agents have old replay buffer trained more times before new buffer comes in state, action, reward, isOver, comb_mask, fine_mask = self.queue.get_nowait() self._current_game_score.feed(reward) # print(reward) if isOver: self._player_scores.feed(self._current_game_score.sum) self._current_game_score.reset() self.mem.append(Experience(np.stack(state), action, reward, isOver, comb_mask, np.stack(fine_mask))) return True except queue.Empty: return False def get_data(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint( self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] yield self._process_batch(batch_exp) self._populate_job_queue.put(1) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='float32') action = np.asarray([e[1] for e in batch_exp], dtype='int32') reward = np.asarray([e[2] for e in batch_exp], dtype='float32') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') comb_mask = np.asarray([e[4] for e in batch_exp], dtype='bool') fine_mask = np.asarray([e[5] for e in batch_exp], dtype='bool') return [state, action, reward, isOver, comb_mask, fine_mask] def _setup_graph(self): self._recv_th = self.get_recv_thread() self._recv_th.start() # self.curr_predictor = self.trainer.get_predictor([self.agent_name + '/state:0', self.agent_name + '_comb_mask:0', self.agent_name + '/fine_mask:0'], [self.agent_name + '/Qvalue:0']) def _before_train(self): logger.info('{}-receive thread started'.format(self.agent_name)) self._simulator_th = self.get_simulator_thread() self._simulator_th.start() self._init_memory() def _trigger(self): from simulator.tools import mean_score_logger v = self._player_scores try: mean, max = v.average, v.max logger.info('{} mean_score: {}'.format(self.agent_name, mean)) mean_score_logger('{} mean_score: {}\n'.format(self.agent_name, mean)) self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception(self.agent_name + " Cannot log training scores.") v.reset()
class MedicalPlayer(gym.Env): """Class that provides 3D medical image environment. This is just an implementation of the classic "agent-environment loop". Each time-step, the agent chooses an action, and the environment returns an observation and a reward.""" def __init__(self, directory=None, files_list=None, viz=False, train=False, screen_dims=(27, 27, 27), spacing=(1, 1, 1), nullop_start=30, history_length=30, max_num_frames=0, saveGif=False, saveVideo=False): """ :param train_directory: environment or game name :param viz: visualization set to 0 to disable set to +ve number to be the delay between frames to show set to a string to be the directory for storing frames :param screen_dims: shape of the frame cropped from the image to feed it to dqn (d,w,h) - defaults (27,27,27) :param nullop_start: start with random number of null ops :param history_length: consider lost of lives as end of episode (useful for training) """ super(MedicalPlayer, self).__init__() self.reset_stat() self._supervised = False self._init_action_angle_step = 8 self._init_action_dist_step = 4 ####################################################################### ## save results in csv file self.csvfile = 'dummy.csv' if not train: with open(self.csvfile, 'w') as outcsv: fields = ["filename", "dist_error", "angle_error"] writer = csv.writer(outcsv) writer.writerow(map(lambda x: x, fields)) ####################################################################### # read files from directory - add your data loader here self.files = filesListCardioMRPlane(directory, files_list) # prepare file sampler self.sampled_files = self.files.sample_circular() self.filepath = None # maximum number of frames (steps) per episodes self.max_num_frames = max_num_frames # stores information: terminal, score, distError self.info = None # option to save display as gif self.saveGif = saveGif self.saveVideo = saveVideo # training flag self.train = train # image dimension (2D/3D) self.screen_dims = screen_dims self._plane_size = screen_dims self.dims = len(self.screen_dims) if self.dims == 2: self.width, self.height = screen_dims else: self.width, self.height, self.depth = screen_dims # plane sampling spacings self.init_spacing = np.array(spacing) # stat counter to store current score or accumlated reward self.current_episode_score = StatCounter() # get action space and minimal action set self.action_space = spaces.Discrete(8) # change number actions here self.actions = self.action_space.n self.observation_space = spaces.Box(low=0, high=255, shape=self.screen_dims) # history buffer for storing last locations to check oscillations self._history_length = history_length # circular buffer to store plane parameters history [4,history_length] self._plane_history = deque(maxlen=self._history_length) self._bestq_history = deque(maxlen=self._history_length) self._dist_history = deque(maxlen=self._history_length) self._dist_history_params = deque(maxlen=self._history_length) self._dist_supervised_history = deque(maxlen=self._history_length) # self._loc_history = [(0,) * self.dims] * self._history_length self._loc_history = [(0, ) * 4] * self._history_length self._qvalues_history = [(0, ) * self.actions] * self._history_length self._qvalues = [ 0, ] * self.actions with _ALE_LOCK: self.rng = get_rng(self) # visualization setup if isinstance(viz, six.string_types): # check if viz is a string assert os.path.isdir(viz), viz viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.viewer = None self.gif_buffer = [] self._restart_episode() # ------------------------------------------------------------------------- def _reset(self): # with _ALE_LOCK: self._restart_episode() return self._current_state() def _restart_episode(self): """ restart current episoide """ self.terminal = False self.cnt = 0 # counter to limit number of steps per episodes self.num_games.feed(1) self.current_episode_score.reset() # reset score stat counter self._plane_history.clear() self._bestq_history.clear() self._dist_history.clear() self._dist_history_params.clear() self._dist_supervised_history.clear() # self._loc_history = [(0,) * self.dims] * self._history_length self._loc_history = [(0, ) * 4] * self._history_length self._qvalues_history = [(0, ) * self.actions] * self._history_length self.new_random_game() # ------------------------------------------------------------------------- def new_random_game(self): # print('\n============== new game ===============\n') self.terminal = False self.viewer = None # sample a new image (self.sitk_image, self.sitk_image_2ch, self.sitk_image_4ch, self.landmarks, self.filepath) = next(self.sampled_files) self.filename = os.path.basename(self.filepath) # image volume size self._image_dims = self.sitk_image.GetSize() self.action_angle_step = copy.deepcopy(self._init_action_angle_step) self.action_dist_step = copy.deepcopy(self._init_action_dist_step) self.spacing = self.init_spacing.copy() # find center point of the initial plane if self.train: # sample randomly ±10% around the center point skip_thickness = ((int)(self._image_dims[0] / 2.5), (int)(self._image_dims[1] / 2.5), (int)(self._image_dims[2] / 2.5)) x = self.rng.randint(0 + skip_thickness[0], self._image_dims[0] - skip_thickness[0]) y = self.rng.randint(0 + skip_thickness[1], self._image_dims[1] - skip_thickness[1]) z = self.rng.randint(0 + skip_thickness[2], self._image_dims[2] - skip_thickness[2]) else: # during testing start sample a plane around the center point x, y, z = (self._image_dims[0] / 2, self._image_dims[1] / 2, self._image_dims[2] / 2) self._origin3d_point = (int(x), int(y), int(z)) # Get ground truth plane # logger.info('filename {} '.format(self.filename)) self._groundTruth_plane = Plane( *getGroundTruthPlane(self.sitk_image, self.sitk_image_4ch, self._origin3d_point, self._plane_size, spacing=self.spacing)) # get an istropic 1mm groundtruth plane # image_size = (int(min(self._image_dims)),)*3 image_size = self._image_dims self.groundTruth_plane_iso = Plane( *getGroundTruthPlane(self.sitk_image, self.sitk_image_4ch, self._origin3d_point, image_size, [1, 1, 1])) self.landmarks_gt, _ = zip(*[ projectPointOnPlane(point, self._groundTruth_plane.norm, self._groundTruth_plane.origin) for point in self.landmarks ]) # logger.info('groundTruth {}'.format(self._groundTruth_plane.params)) # Get initial plane and set current plane the same self._plane = self._init_plane = Plane( *getInitialPlane(self.sitk_image, self._plane_size, self._origin3d_point, self.spacing)) _, dist = zip(*[ projectPointOnPlane(point, self._plane.norm, self._plane.origin) for point in self.landmarks_gt ]) # calculate current distance between initial and ground truth planes # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, # self._plane.points) self.cur_dist = np.mean(np.abs(dist)) self.cur_dist_params = calcDistTwoParams( self._groundTruth_plane.params, self._plane.params, scale_angle=self.action_angle_step, scale_dist=self.action_dist_step) self._screen = self._current_state() # ------------------------------------------------------------------------- def step(self, act, qvalues): """The environment's step function returns exactly what we need. Args: action: Returns: observation (object): an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game. reward (float): amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward. done (boolean): whether it's time to reset the environment again. Most (but not all) tasks are divided up into well-defined episodes, and done being True indicates the episode has terminated. (For example, perhaps the pole tipped too far, or you lost your last life.) info (dict): diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environment's last state change). However, official evaluations of your agent are not allowed to use this for learning. """ self.terminal = False self._qvalues = qvalues # get current plane params current_plane_params = np.copy(self._plane.params) next_plane_params = current_plane_params.copy() # --------------------------------------------------------------------- # theta x+ (param a) if (act == 0): next_plane_params[0] += self.action_angle_step # theta y+ (param b) if (act == 1): next_plane_params[1] += self.action_angle_step # theta z+ (param c) if (act == 2): next_plane_params[2] += self.action_angle_step # dist d+ if (act == 3): next_plane_params[3] += self.action_dist_step # theta x- (param a) if (act == 4): next_plane_params[0] -= self.action_angle_step # theta y- (param b) if (act == 5): next_plane_params[1] -= self.action_angle_step # theta z- (param c) if (act == 6): next_plane_params[2] -= self.action_angle_step # dist d- if (act == 7): next_plane_params[3] -= self.action_dist_step # --------------------------------------------------------------------- # self.reward = self._calc_reward_points(self._plane.points, # next_plane.points) self.reward = self._calc_reward_params(current_plane_params, next_plane_params) # threshold reward between -1 and 1 self.reward = np.sign(self.reward) go_out = False if self._supervised and self.train: ## supervised dist_queue = deque(maxlen=self.actions) plane_params_queue = deque(maxlen=self.actions) # theta x+ (param a) next_plane_params = np.copy(self._plane.params) next_plane_params[0] += self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # theta y+ (param b) ---------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[1] += self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # theta z+ (param c) ---------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[2] += self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # dist d+ --------------------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[3] += self.action_dist_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # theta x- (param a) ---------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[0] -= self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # theta y- (param b) ---------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[1] -= self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # theta z- (param c) ---------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[2] -= self.action_angle_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # dist d- --------------------------------------------------------- next_plane_params = np.copy(self._plane.params) next_plane_params[3] -= self.action_dist_step plane_params_queue.append(next_plane_params) # dist_queue.append(calcMeanDistTwoPlanes(self._groundTruth_plane.points, plane_params_queue[-1].points)) dist_queue.append( calcDistTwoParams(self._groundTruth_plane.params, plane_params_queue[-1], scale_angle=self.action_angle_step, scale_dist=self.action_dist_step)) # ----------------------------------------------------------------- # get best plane based on lowest distance to the target next_plane_idx = np.argmin(dist_queue) next_plane = Plane(*getPlane(self.sitk_image, self._origin3d_point, plane_params_queue[next_plane_idx], self._plane_size, spacing=self.spacing)) self._dist_supervised_history.append(np.min(dist_queue)) else: ## unsupervised or testing # get the new plane using new params result from taking the action next_plane = Plane(*getPlane(self.sitk_image, self._origin3d_point, next_plane_params, self._plane_size, spacing=self.spacing)) # ----------------------------------------------------------------- # check if the screen is not full of zeros (background) go_out = checkBackgroundRatio(next_plane, min_pixel_val=0.5, ratio=0.8) # also check if go out (sampling from outside the volume) # by checking if the new origin if not go_out: go_out = checkOriginLocation(self.sitk_image, next_plane.origin) # also check if plane parameters got very high if not go_out: go_out = checkParamsBound(next_plane.params, self._groundTruth_plane.params) # punish lowest reward if the agent tries to go out and keep same plane if go_out: self.reward = -1 # lowest possible reward next_plane = copy.deepcopy(self._plane) if self.train: self.terminal = True # end episode and restart # --------------------------------------------------------------------- # update current plane self._plane = copy.deepcopy(next_plane) # terminate if maximum number of steps is reached self.cnt += 1 if self.cnt >= self.max_num_frames: self.terminal = True # check oscillation and reduce action step or terminate if minimum if self._oscillate: if self.train and self._supervised: self._plane = self.getBestPlaneTrain() else: self._plane = self.getBestPlane() # find distance metrics # self.cur_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, self._plane.points) _, dist = zip(*[ projectPointOnPlane(point, self._plane.norm, self._plane.origin) for point in self.landmarks_gt ]) self.cur_dist = np.max(np.abs(dist)) self._update_heirarchical() self._clear_history() # terminate if distance steps are less than 1 if self.action_dist_step < 1: self.terminal = True # --------------------------------------------------------------------- # find distance error _, dist = zip(*[ projectPointOnPlane(point, self._plane.norm, self._plane.origin) for point in self.landmarks_gt ]) self.cur_dist = np.mean(np.abs(dist)) self.cur_dist_params = calcDistTwoParams( self._groundTruth_plane.params, self._plane.params, scale_angle=self.action_angle_step, scale_dist=self.action_dist_step) self.current_episode_score.feed(self.reward) self._update_history() # store results in memory # terminate if distance between params are low during training if self.train and (self.cur_dist_params <= 1): self.terminal = True self.num_success.feed(1) # --------------------------------------------------------------------- # # supervised reward (for debuging) # reward_supervised = self._calc_reward_params(current_plane_params, # self._plane.params) # # threshold reward between -1 and 1 # self.reward = np.sign(np.around(reward_supervised,decimals=1)) # --------------------------------------------------------------------- # render screen if viz is on if self.viz: if isinstance(self.viz, float): self.display() A = normalizeUnitVector(self._groundTruth_plane.norm) B = normalizeUnitVector(self._plane.norm) angle_between_norms = np.rad2deg(np.arccos(A.dot(B))) info = { 'score': self.current_episode_score.sum, 'gameOver': self.terminal, 'distError': self.cur_dist, 'distAngle': angle_between_norms, 'filename': self.filename } if self.terminal: with open(self.csvfile, 'a') as outcsv: fields = [self.filename, self.cur_dist, angle_between_norms] writer = csv.writer(outcsv) writer.writerow(map(lambda x: x, fields)) return self._current_state(), self.reward, self.terminal, info # ------------------------------------------------------------------------- def _update_heirarchical(self): self.action_angle_step = int(self.action_angle_step / 2) self.action_dist_step = self.action_dist_step - 1 # self.spacing -= 1 if (self.spacing[0] > 1): self.spacing -= 1 self._groundTruth_plane = Plane( *getGroundTruthPlane(self.sitk_image, self.sitk_image_4ch, self._origin3d_point, self._plane_size, spacing=self.spacing)) def getBestPlane(self): ''' get best location with best qvalue from last for locations stored in history ''' best_idx = np.argmin(self._bestq_history) # best_idx = np.argmax(self._bestq_history) return self._plane_history[best_idx] def getBestPlaneTrain(self): ''' get best location with best qvalue from last for locations stored in history ''' best_idx = np.argmin(self._dist_supervised_history) # best_idx = np.argmax(self._bestq_history) return self._plane_history[best_idx] def _current_state(self): """ :returns: a gray-scale (h, w, d) float ###uint8 image """ return self._plane.grid_smooth def _clear_history(self): self._plane_history.clear() self._bestq_history.clear() self._dist_history.clear() self._dist_history_params.clear() self._dist_supervised_history.clear() # self._loc_history = [(0,) * self.dims] * self._history_length self._loc_history = [(0, ) * 4] * self._history_length self._qvalues_history = [(0, ) * self.actions] * self._history_length def _update_history(self): ''' update history buffer with current state ''' # update location history self._loc_history[:-1] = self._loc_history[1:] loc = self._plane.origin loc = self._plane.params # logger.info('loc {}'.format(loc)) self._loc_history[-1] = (np.around(loc[0], decimals=2), np.around(loc[1], decimals=2), np.around(loc[2], decimals=2), np.around(loc[3], decimals=2)) # update distance history self._dist_history.append(self.cur_dist) self._dist_history_params.append(self.cur_dist_params) # update params history self._plane_history.append(self._plane) self._bestq_history.append(np.max(self._qvalues)) # update q-value history self._qvalues_history[:-1] = self._qvalues_history[1:] self._qvalues_history[-1] = self._qvalues def _calc_reward_points(self, prev_points, next_points): ''' Calculate the new reward based on the euclidean distance to the target plane ''' prev_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, prev_points) next_dist = calcMeanDistTwoPlanes(self._groundTruth_plane.points, next_points) return prev_dist - next_dist def _calc_reward_params(self, prev_params, next_params): ''' Calculate the new reward based on the euclidean distance to the target plane ''' # logger.info('prev_params {}'.format(np.around(prev_params,2))) # logger.info('next_params {}'.format(np.around(next_params,2))) prev_dist = calcScaledDistTwoParams(self._groundTruth_plane.params, prev_params, scale_angle=self.action_angle_step, scale_dist=self.action_dist_step) next_dist = calcScaledDistTwoParams(self._groundTruth_plane.params, next_params, scale_angle=self.action_angle_step, scale_dist=self.action_dist_step) return prev_dist - next_dist @property def _oscillate(self): ''' Return True if the agent is stuck and oscillating ''' counter = Counter(self._loc_history) freq = counter.most_common() # return false is history is empty (begining of the game) if len(freq) < 2: return False # check frequency if freq[0][0] == (0, 0, 0, 0): if (freq[1][1] > 2): # logger.info('oscillating {}'.format(self._loc_history)) return True else: return False elif (freq[0][1] > 2): # logger.info('oscillating {}'.format(self._loc_history)) return True def get_action_meanings(self): ''' return array of integers for actions ''' ACTION_MEANING = { 0: "inc_x", # increment +1 the norm angle in x-direction 1: "inc_y", # increment +1 the norm angle in y-direction 2: "inc_z", # increment +1 the norm angle in z-direction 3: "inc_d", # increment +1 the norm distance d to origin 4: "dec_x", # decrement -1 the norm angle in x-direction 5: "dec_y", # decrement -1 the norm angle in y-direction 6: "dec_z", # decrement -1 the norm angle in z-direction 7: "dec_d", # decrement -1 the norm distance d to origin } return [ACTION_MEANING[i] for i in self.actions] @property def getScreenDims(self): """ return screen dimensions """ return (self.width, self.height, self.depth) def lives(self): return None def reset_stat(self): """ Reset all statistics counter""" self.stats = defaultdict(list) self.num_games = StatCounter() self.num_success = StatCounter() def display(self, return_rgb_array=False): # pass # -------------------------------------------------------------------- ## planes seen by the agent # # get image and convert it to pyglet # plane = self._plane.grid[:,:,round(self.depth/2)] # z-plane # # concatenate groundtruth image # gt_plane = self._groundTruth_plane.grid[:,:,round(self.depth/2)] # -------------------------------------------------------------------- ## whole plan # image_size = (int(min(self._image_dims)),)*3 image_size = self._image_dims current_plane = Plane(*getPlane(self.sitk_image, self._origin3d_point, self._plane.params, image_size, spacing=[1, 1, 1])) # get image and convert it to pyglet plane = current_plane.grid[:, :, int(image_size[2] / 2)] # z-plane # concatenate groundtruth image gt_plane = self.groundTruth_plane_iso.grid[:, :, int(image_size[2] / 2)] # -------------------------------------------------------------------- # concatenate two planes side by side plane = np.concatenate((plane, gt_plane), axis=1) # img = cv2.cvtColor(plane, cv2.COLOR_GRAY2RGB) # congvert to rgb # rescale image # INTER_NEAREST, INTER_LINEAR, INTER_AREA, INTER_CUBIC, INTER_LANCZOS4 scale_x = 5 scale_y = 5 # img = cv2.resize(img, # (int(scale_x*img.shape[1]),int(scale_y*img.shape[0])), # interpolation=cv2.INTER_LINEAR) # skip if there is a viewer open if (not self.viewer) and self.viz: from viewer import SimpleImageViewer self.viewer = SimpleImageViewer(arr=img, scale_x=1, scale_y=1, filepath=self.filename) self.gif_buffer = [] # display image self.viewer.draw_image(img) self.viewer.display_text('Current Plane', color=(0, 0, 204, 255), x=int(0.7 * img.shape[1] / 7), y=img.shape[0] - 3) self.viewer.display_text('Ground Truth', color=(0, 0, 204, 255), x=int(4.3 * img.shape[1] / 7), y=img.shape[0] - 3) # display info dist_color_flag = False if len(self._dist_history) > 1: dist_color_flag = self.cur_dist < self._dist_history[-2] color_dist = (0, 204, 0, 255) if dist_color_flag else (204, 0, 0, 255) text = 'Error ' + str(round(self.cur_dist, 3)) + 'mm' self.viewer.display_text(text, color=color_dist, x=int(3 * img.shape[1] / 8), y=5 * scale_y) dist_color_flag = False if len(self._dist_history_params) > 1: dist_color_flag = self.cur_dist_params < self._dist_history_params[ -2] # color_dist = (0,255,0,255) if dist_color_flag else (255,0,0,255) # text = 'Params Error ' + str(round(self.cur_dist_params,3)) # self.viewer.display_text(text, color=color_dist, # x=int(6*img.shape[1]/8), y=5*scale_y) text = 'Spacing ' + str(round(self.spacing[0], 3)) + 'mm' self.viewer.display_text(text, color=(204, 204, 0, 255), x=int(6 * img.shape[1] / 8), y=5 * scale_y) color_reward = (0, 204, 0, 255) if self.reward > 0 else (204, 0, 0, 255) text = 'Reward ' + "%+d" % round(self.reward, 3) self.viewer.display_text(text, color=color_reward, x=2 * scale_x, y=5 * scale_y) # render and wait (viz) time between frames self.viewer.render() # save gif if self.saveGif: image_data = pyglet.image.get_buffer_manager().get_color_buffer( ).get_image_data() data = image_data.get_data('RGB', image_data.width * 3) # set_trace() arr = np.array(bytearray(data)).astype('uint8') arr = np.flip( np.reshape(arr, (image_data.height, image_data.width, -1)), 0) im = Image.fromarray(arr).convert('P') self.gif_buffer.append(im) if not self.terminal: gifname = self.filename.split('.')[0] + '.gif' self.viewer.savegif(gifname, arr=self.gif_buffer, duration=self.viz) if self.saveVideo: dirname = 'tmp_video' if (self.cnt <= 1): if os.path.isdir(dirname): logger.warn( """Log directory {} exists! Use 'd' to delete it. """. format(dirname)) act = input("select action: d (delete) / q (quit): " ).lower().strip() if act == 'd': shutil.rmtree(dirname, ignore_errors=True) else: raise OSError("Directory {} exits!".format(dirname)) os.mkdir(dirname) frame = dirname + '/' + '%04d' % self.cnt + '.png' pyglet.image.get_buffer_manager().get_color_buffer().save(frame) if self.terminal: save_cmd = [ 'ffmpeg', '-f', 'image2', '-framerate', '30', '-pattern_type', 'sequence', '-start_number', '0', '-r', '3', '-i', dirname + '/%04d.png', '-s', '1280x720', '-vcodec', 'libx264', '-b:v', '2567k', self.filename + '.mp4' ] subprocess.check_output(save_cmd) shutil.rmtree(dirname, ignore_errors=True)
class AtariPlayer(RLEnvironment): """ A wrapper for atari emulator. Will automatically restart when a real episode ends (isOver might be just lost of lives but not game over). """ def __init__(self, rom_file, viz=0, height_range=(None, None), frame_skip=4, image_shape=(84, 84), nullop_start=30, live_lost_as_eoe=True): """ :param rom_file: path to the rom :param frame_skip: skip every k frames and repeat the action :param image_shape: (w, h) :param height_range: (h1, h2) to cut :param viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. :param nullop_start: start with random number of null ops :param live_losts_as_eoe: consider lost of lives as end of episode. useful for training. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "rom {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) cv2.startWindowThread() cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.height_range = height_range self.image_shape = image_shape self.current_episode_score = StatCounter() self.restart_episode() def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): cv2.imshow(self.windowname, ret) time.sleep(self.viz) ret = ret[self.height_range[0]:self.height_range[1], :].astype( 'float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.resize(ret, self.image_shape) return ret.astype('uint8') # to save some memory def get_action_space(self): return DiscreteActionSpace(len(self.actions)) def finish_episode(self): self.stats['score'].append(self.current_episode_score.sum) def restart_episode(self): self.current_episode_score.reset() with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = self.rng.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def action(self, act): """ :param act: an index of the action :returns: (reward, isOver) """ oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break self.current_episode_score.feed(r) isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives if isOver: self.finish_episode() if self.ale.game_over(): self.restart_episode() return (r, isOver)
class EnvRunner(object): """ A class which is responsible for stepping the environment with epsilon-greedy, and fill the results to experience replay buffer. """ def __init__(self, player, predictor, memory, history_len): """ Args: player (gym.Env) predictor (callable): the model forward function which takes a state and returns the prediction. memory (ReplayMemory): the replay memory to store experience to. history_len (int): """ self.player = player self.num_actions = player.action_space.n self.predictor = predictor self.memory = memory self.state_shape = memory.state_shape self.dtype = memory.dtype self.history_len = history_len self._current_episode = [] self._current_ob = player.reset() self._current_game_score = StatCounter() # store per-step reward self.total_scores = [] # store per-game total score self.rng = get_rng(self) def step(self, exploration): """ Run the environment for one step. If the episode ends, store the entire episode to the replay memory. """ old_s = self._current_ob if self.rng.rand() <= exploration: act = self.rng.choice(range(self.num_actions)) else: history = self.recent_state() history.append(old_s) history = np.stack(history, axis=-1) # state_shape + (Hist,) # assume batched network history = np.expand_dims(history, axis=0) q_values = self.predictor(history)[0][0] # this is the bottleneck act = np.argmax(q_values) self._current_ob, reward, isOver, info = self.player.step(act) self._current_game_score.feed(reward) self._current_episode.append(Experience(old_s, act, reward, isOver)) if isOver: flush_experience = True if 'ale.lives' in info: # if running Atari, do something special if info['ale.lives'] != 0: # only record score and flush experience # when a whole game is over (not when an episode is over) flush_experience = False self.player.reset() if flush_experience: self.total_scores.append(self._current_game_score.sum) self._current_game_score.reset() # Ensure that the whole episode of experience is continuous in the replay buffer with self.memory.writer_lock: for exp in self._current_episode: self.memory.append(exp) self._current_episode.clear() def recent_state(self): """ Get the recent state (with stacked history) of the environment. Returns: a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """ expected_len = self.history_len - 1 if len(self._current_episode) >= expected_len: return [k.state for k in self._current_episode[-expected_len:]] else: states = [np.zeros(self.state_shape, dtype=self.dtype)] * (expected_len - len(self._current_episode)) states.extend([k.state for k in self._current_episode]) return states
class ExpReplay(RNGDataFlow, Callback): """ Implement experience replay in the paper `Human-level control through deep reinforcement learning <http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_. This implementation provides the interface as a :class:`DataFlow`. This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This implementation assumes that state is batch-able, and the network takes batched inputs. """ def __init__(self, predictor_io_names, predictor_refine_io_names, env, state_shape, batch_size, memory_size, init_memory_size, init_exploration, update_frequency): """ Args: predictor_io_names (tuple of list of str): input/output names to predict Q value from state. player (RLEnvironment): the player. history_len (int): length of history frames to concat. Zero-filled initial frames. update_frequency (int): number of new transitions to add to memory after sampling a batch of transitions for training. """ init_memory_size = int(init_memory_size) items = locals().items() for k, v in items: if k != 'self': setattr(self, k, v) self.exploration = init_exploration self.env = env self.rng = get_rng(self) # print('RNG------------------------------------------', self.rng.randint(10)) self._init_memory_flag = threading.Event( ) # tell if memory has been initialized # a queue to receive notifications to populate memory self._populate_job_queue = queue.Queue(maxsize=5) self.mem = ReplayMemory(memory_size, state_shape) self.mem_refine = ReplayMemoryRefine(memory_size, state_shape) self.env.reset() self._current_ob, self._current_history = self.env.focus_image, self.env.history # stage 1 ar actions self._action_space = self.env.action_space # stage 2 actions self._action_space_refine = self.env.action_space_refine logger.info( "Number of Legal actions: stage-1-ar {}, stage-2 {}".format( len(self._action_space), len(self._action_space_refine))) self._player_scores = StatCounter() self._current_game_score = StatCounter() self.state_shape = state_shape def get_simulator_thread(self): # spawn a separate thread to run policy def populate_job_func(): self._populate_job_queue.get() for _ in range(self.update_frequency): self._populate_exp() th = ShareSessionThread(LoopThread(populate_job_func, pausable=False)) th.name = "SimulatorThread" return th def _init_memory(self): logger.info("Populating replay memory with epsilon={} ...".format( self.exploration)) with get_tqdm(total=self.init_memory_size) as pbar: while len(self.mem) < self.init_memory_size: self._populate_exp() pbar.update() self._init_memory_flag.set() def _populate_exp(self): """ populate a transition by epsilon-greedy""" old_s, old_history = self._current_ob, self._current_history # forced termination if self.env.iou > 0.5: act = -1 else: if self.rng.rand() <= self.exploration: act = self.rng.choice(range(len(self._action_space))) else: q_values = self.predictor(old_s[None, ...], old_history.reshape(1, -1))[0][0] act = np.argmax(q_values) # stage 2 if self._action_space[act] != 'trigger': self.env.step(self._action_space[act]) refine_stop = False while not refine_stop: state_refine, history_refine = self.env.focus_image, self.env.history_refine if self.rng.rand() <= self.exploration: act_refine = self.rng.choice( range(len(self._action_space_refine))) else: q_values = self.predictor_refine( state_refine[None, ...], history_refine.reshape(1, -1))[0][0] act_refine = np.argmax(q_values) reward_refine, refine_stop = self.env.step_refine( self._action_space_refine[act_refine]) self.mem_refine.append( Experience(state_refine, act_refine, reward_refine, refine_stop, history_refine.reshape(-1))) reward, isOver = self.env.step_post() else: reward, isOver = self.env.step(self._action_space[act]) self._current_game_score.feed(reward) if isOver: # print('lord wins' if reward > 0 else 'farmer wins') # print(self._current_game_score.sum) self._player_scores.feed(self._current_game_score.sum) self.env.reset() self._current_game_score.reset() self._current_ob, self._current_history = self.env.focus_image, self.env.history self.mem.append( Experience(old_s, act, reward, isOver, old_history.reshape(-1))) def get_data(self): # wait for memory to be initialized self._init_memory_flag.wait() while True: idx = self.rng.randint(self._populate_job_queue.maxsize * self.update_frequency, len(self.mem) - 1, size=self.batch_size) batch_exp = [self.mem.sample(i) for i in idx] batch_exp_refine = [self.mem_refine.sample(i) for i in idx] yield self._process_batch(batch_exp) + self._process_batch( batch_exp_refine) self._populate_job_queue.put(1) def _process_batch(self, batch_exp): state = np.asarray([e[0] for e in batch_exp], dtype='float32') action = np.asarray([e[1] for e in batch_exp], dtype='int32') reward = np.asarray([e[2] for e in batch_exp], dtype='float32') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool') history = np.asarray([e[4] for e in batch_exp], dtype='float32') return [state, action, reward, isOver, history] def _setup_graph(self): self.predictor = self.trainer.get_predictor(*self.predictor_io_names) self.predictor_refine = self.trainer.get_predictor( *self.predictor_refine_io_names) def _before_train(self): self._init_memory() self._simulator_th = self.get_simulator_thread() self._simulator_th.start() def _trigger(self): v = self._player_scores try: mean, max = v.average, v.max self.trainer.monitors.put_scalar('expreplay/mean_score', mean) self.trainer.monitors.put_scalar('expreplay/max_score', max) except Exception: logger.exception("Cannot log training scores.") v.reset()