Exemplo n.º 1
0
 def setup_net(self):
     self.net = Unet(1, [8, 16, 32, 64], 1).to(self.device)
     checkpoint_path = 'seg_net/checkpoints/checkpoint_364500.pth.tar'
     checkpoint = torch.load(checkpoint_path)
     self.net.load_state_dict(checkpoint['state_dict'])
Exemplo n.º 2
0
class Environment:
    def __init__(self, raw_list, lbl_list):
        self.action_space = ActionSpace(NUM_ACTIONS)
        self.observation_space = ObservationSpace()

        self.obs_shape = RL_NET_SIZE
        self.mask_shape = RL_NET_SIZE[:3]
        self.raw_list = raw_list
        self.lbl_list = lbl_list
        self.viewed = {}
        self.tree_base = TREE_BASE

        self.base_start = (80, 80, 80)
        self.base_size = ORIGIN_SIZE

        self.device = torch.device("cuda:1")
        self.setup_net()

    def setup_net(self):
        self.net = Unet(1, [8, 16, 32, 64], 1).to(self.device)
        checkpoint_path = 'seg_net/checkpoints/checkpoint_364500.pth.tar'
        checkpoint = torch.load(checkpoint_path)
        self.net.load_state_dict(checkpoint['state_dict'])

    def calculate_score(self, gt_lbl, mask, cell_id):
        gt_mask = (gt_lbl == cell_id).astype(np.int32)
        mask = (mask > 128).astype(np.int32)
        dice_score = np.sum(
            mask[gt_mask == 1]) * 2.0 / (np.sum(mask) + np.sum(gt_mask))
        return dice_score

    def get_cell_id(self, gt_lbl):
        ids, count = np.unique(gt_lbl, return_counts=True)
        count[0] = -1
        biggest_cell_id = ids[np.argmax(count)]
        return biggest_cell_id

    def predict(self):
        # print ("predict")
        st = self.state.node.start
        sz = self.state.node.size
        padding = 30
        with torch.no_grad():
            raw_patch = self.raw_list[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                                      st[2]:st[2] + sz[2]]
            raw_patch = resize(raw_patch,
                               SEG_NET_SIZE[:3],
                               order=0,
                               mode='reflect',
                               preserve_range=True)
            raw_patch = np.expand_dims(np.expand_dims(raw_patch, 0), 0).astype(
                np.float32) / 255.0
            new_mask = self.net(torch.tensor(raw_patch).to(self.device))
            new_mask = np.squeeze(new_mask.cpu().numpy()) * 255

        new_mask = resize(new_mask,
                          sz,
                          order=0,
                          mode='reflect',
                          preserve_range=True).astype(np.uint8)
        cur_mask = self.mask[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                             st[2]:st[2] + sz[2]]
        self.mask[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                  st[2]:st[2] + sz[2]] = np.maximum(cur_mask, new_mask)
        gt_lbl = self.lbl_list[st[0] - padding:st[0] + sz[0] + padding,
                               st[1] - padding:st[1] + sz[1] + padding,
                               st[2] - padding:st[2] + sz[2] + padding]

        new_mask = np.pad(new_mask,
                          padding,
                          mode='constant',
                          constant_values=0)
        cell_id = self.get_cell_id(self.lbl_list[st[0]:st[0] + sz[0],
                                                 st[1]:st[1] + sz[1],
                                                 st[2]:st[2] + sz[2]])
        # Hanle revisit segmented cell
        if cell_id in self.viewed:
            return 0
        self.viewed[cell_id] = True
        score = self.calculate_score(gt_lbl, new_mask, cell_id)
        if score < 0.6:
            score = 0
        return score

    def update_taboo(self):
        self.taboo[self.state.node.id] = True

    def step(self, action):
        done = False
        reward = 0
        state = copy.deepcopy(self.state)
        self.set_state(state)
        prev_state = copy.deepcopy(state)

        action = Action(action)
        info = {'ale.lives': 1}

        # Handle revisiting a visited node
        taboo_action = False
        if (0 <= action.val < self.tree_base):
            if (self.state.node.id * TREE_BASE + action.val) in self.taboo:
                taboo_action = True
                action.val = self.tree_base + 1

        if action.val == self.tree_base:
            reward = self.predict()
        if action.val >= self.tree_base:
            self.update_taboo()
        if (action.val < self.tree_base
                and state.node.level == MAX_LV) or taboo_action:
            self.update_taboo()

        if (action.val >= self.tree_base and state.node.level == 0):
            done = True
            info = {'ale.lives': 0}
            return self.observation(), reward, done, info

        state.node.step(action.val)

        self.set_state(state)
        ret = self.observation(), reward, done, info
        return ret

    def sample_action(self):
        return self.action_space.sample_action()

    def reset(self):
        self.state = State(Node())
        self.mask = np.zeros(self.mask_shape, dtype=np.uint8)
        self.mask = np.pad(self.mask, 80, mode='constant', constant_values=0)
        self.taboo = {}
        self.viewed = {}
        return self.observation()

    def set_state(self, state):
        self.state = copy.deepcopy(state)

    def location_mask(self):
        location = np.zeros(ORIGIN_SIZE, dtype=np.uint8)
        st = self.state.node.start
        sz = self.state.node.size

        location[st[0] - 80:st[0] - 80 + sz[0], st[1] - 80:st[1] - 80 + sz[1],
                 st[2] - 80:st[2] - 80 + sz[2]] = 255
        return location

    def observation(self):
        '''
			Observation of size (D, H, W, 5)
		'''
        st = self.state.node.start
        sz = self.state.node.size
        raw_patch = self.raw_list[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                                  st[2]:st[2] + sz[2]]
        mask_patch = self.mask[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                               st[2]:st[2] + sz[2]]
        raw_patch = resize(raw_patch,
                           RL_NET_SIZE[:3],
                           order=0,
                           mode='wrap',
                           preserve_range=True)
        mask_patch = resize(mask_patch,
                            RL_NET_SIZE[:3],
                            order=0,
                            mode='wrap',
                            preserve_range=True)
        raw_patch = np.expand_dims(raw_patch, -1)
        mask_patch = np.expand_dims(mask_patch, -1)

        # print ('base', self.base_start, self.base_size, '\\base')

        full_raw = self.raw_list[self.base_start[0]:self.base_start[0] +
                                 self.base_size[0],
                                 self.base_start[1]:self.base_start[1] +
                                 self.base_size[1],
                                 self.base_start[2]:self.base_start[2] +
                                 self.base_size[2]]

        full_raw = resize(full_raw,
                          RL_NET_SIZE[:3],
                          order=0,
                          mode='wrap',
                          preserve_range=True)

        full_raw = np.expand_dims(full_raw, -1)

        full_mask = self.mask[self.base_start[0]:self.base_start[0] +
                              self.base_size[0],
                              self.base_start[1]:self.base_start[1] +
                              self.base_size[1],
                              self.base_start[2]:self.base_start[2] +
                              self.base_size[2]]
        full_mask = resize(full_mask,
                           RL_NET_SIZE[:3],
                           order=0,
                           mode='wrap',
                           preserve_range=True)

        full_mask = np.expand_dims(full_mask, -1)

        location_mask = self.location_mask()
        location_mask = resize(location_mask,
                               RL_NET_SIZE[:3],
                               order=0,
                               mode='wrap',
                               preserve_range=True)
        location_mask = np.expand_dims(location_mask, -1)

        ret = np.concatenate(
            [full_raw, full_mask, location_mask, raw_patch, mask_patch], -1)
        return ret

    def concat_last_dim_2_x(self, arr):
        # Arr of HxWxC
        # Ret of HxW*C
        assert (len(arr.shape) == 3)
        ret = []
        for i in range(arr.shape[-1]):
            ret.append(arr[..., i])
        ret = np.concatenate(ret, -1)
        return ret

    def render(self):
        ret = self.observation()

        st = self.state.node.start
        sz = self.state.node.size

        z = int(1.0 * (st[0] + sz[0] // 2 - 80) / ORIGIN_SIZE[0] *
                SEG_NET_SIZE[0])
        y = int(1.0 * (st[1] + sz[1] // 2 - 80) / ORIGIN_SIZE[0] *
                SEG_NET_SIZE[0])
        x = int(1.0 * (st[2] + sz[2] // 2 - 80) / ORIGIN_SIZE[0] *
                SEG_NET_SIZE[0])
        yx = self.concat_last_dim_2_x(ret[z, :, :, :])
        zx = self.concat_last_dim_2_x(ret[:, y, :, :])
        zy = self.concat_last_dim_2_x(ret[:, :, x, :])
        # Concatenate to Y dim
        ret = np.concatenate([yx, zx, zy], 0)
        ret = np.repeat(np.expand_dims(ret, -1), 3, axis=-1)

        return ret.astype(np.uint8)
Exemplo n.º 3
0
class Environment:
    def __init__(self, raw, lbl, SEG_checkpoints_path):
        self.action_space = ActionSpace(NUM_ACTIONS)
        self.observation_space = ObservationSpace(RL_NET_SIZE)
        self.obs_shape = RL_NET_SIZE
        self.raw = raw
        self.lbl = lbl
        self.raw = np.pad(self.raw,
                          BORDER_SIZE,
                          mode='constant',
                          constant_values=0)
        self.lbl = np.pad(self.lbl,
                          BORDER_SIZE,
                          mode='constant',
                          constant_values=0)

        self.base_start = (0, 0, 0)
        self.base_size = ORIGIN_SIZE
        self.thres = 128
        self.metric = rand_score
        self.vol_size = self.raw.shape

        self.rng = np.random.RandomState(time_seed())
        self.device = torch.device("cuda:1")
        self.setup_nets(SEG_checkpoints_path)

    def setup_nets(self, path):
        self.seg_net_size = SEG_NET_SIZE
        self.net = Unet(SEG_CHANNELS, FEATURES, 1).to(self.device)
        checkpoint = torch.load(path)
        self.net.load_state_dict(checkpoint['state_dict'])
        self.net.eval()

    def reset(self):
        z0 = self.rng.randint(BORDER_SIZE,
                              self.vol_size[0] - ORIGIN_SIZE[0] - BORDER_SIZE)
        y0 = self.rng.randint(
            50 + BORDER_SIZE,
            self.vol_size[1] - ORIGIN_SIZE[1] - BORDER_SIZE - 50)
        x0 = self.rng.randint(
            180 + BORDER_SIZE,
            self.vol_size[2] - ORIGIN_SIZE[2] - BORDER_SIZE - 180)
        # print ((z0, y0, x0))
        # z0, y0, x0 = (114, 312, 290)
        #(114, 312, 290)
        self.state = State(Node(start=[z0, y0, x0]))
        self.mask = np.zeros(ORIGIN_SIZE, dtype=np.uint8)
        self.history = {}
        self.viewed = {}
        self.history[self.state.node.id] = {'refined': False, 'zoomed': 0}
        self.stack = []
        return self.observation()

    def get_cell_id(self, gt_lbl):
        ids, count = np.unique(gt_lbl, return_counts=True)
        count[0] = -1
        biggest_cell_id = ids[np.argmax(count)]
        return biggest_cell_id

    def calculate_score(self, gt_lbl, mask, cell_id):
        gt_mask = (gt_lbl == cell_id).astype(np.int32)
        mask = (mask > 128).astype(np.int32)
        dice_score = np.sum(
            mask[gt_mask == 1]) * 2.0 / (np.sum(mask) + np.sum(gt_mask))
        return dice_score

    def refine(self):
        # print ('-----------------------Refine')
        st = self.state.node.start
        sz = self.state.node.size
        padding = UPDATE_MASK_PADDING
        net = self.net
        seg_net_size = self.seg_net_size
        ost = self.state.node.origin_start
        mask_st = [st[0] - ost[0], st[1] - ost[1], st[2] - ost[2]]

        with torch.no_grad():
            raw_patch = self.raw[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                                 st[2]:st[2] + sz[2]]
            raw_patch = resize(raw_patch,
                               seg_net_size,
                               order=0,
                               mode='reflect',
                               preserve_range=True)
            raw_patch = np.expand_dims(np.expand_dims(raw_patch, 0),
                                       0)  # (1, 1, H, W)
            new_mask = net(
                torch.tensor(raw_patch,
                             device=self.device,
                             dtype=torch.float32))
            new_mask = np.squeeze(new_mask.cpu().numpy()) * 255

        new_mask = resize(new_mask,
                          sz,
                          order=0,
                          mode='reflect',
                          preserve_range=True).astype(np.uint8)
        cur_mask = self.mask[mask_st[0]:mask_st[0] + sz[0],
                             mask_st[1]:mask_st[1] + sz[1],
                             mask_st[2]:mask_st[2] + sz[2]]

        # print ('cur_mask_shape', cur_mask.shape, 'new_mask_shape', new_mask.shape)
        # print ('st:', st, 'sz:', sz)
        # print ('mask_st', mask_st, 'origin_start', ost)
        self.mask[mask_st[0]:mask_st[0] + sz[0], mask_st[1]:mask_st[1] + sz[1],
                  mask_st[2]:mask_st[2] + sz[2]] = np.maximum(
                      cur_mask, new_mask)

        # print (np.max (self.mask))

        gt_lbl = self.lbl[st[0] - padding:st[0] + sz[0] + padding,
                          st[1] - padding:st[1] + sz[1] + padding,
                          st[2] - padding:st[2] + sz[2] + padding]

        new_mask = np.pad(new_mask,
                          padding,
                          mode='constant',
                          constant_values=0)

        cell_id = self.get_cell_id(self.lbl[st[0]:st[0] + sz[0],
                                            st[1]:st[1] + sz[1],
                                            st[2]:st[2] + sz[2]])
        if cell_id in self.viewed:
            return 0
        self.viewed[cell_id] = True
        score = self.calculate_score(gt_lbl, new_mask, cell_id)
        if score < 0.6:
            score = 0
        return score

    def handle_zoomin(self, action):
        next_node_id = self.state.node.id * TREE_BASE + action
        # Revisit a visited node
        if next_node_id in self.history:
            return self.handle_zoomout()

        self.history[next_node_id] = {'refined': False, 'zoomed': 0}
        self.history[self.state.node.id]['zoomed'] |= 2**action

        # print ('----------------------------Zoomin')
        # print ('old:', self.state.node.start, self.state.node.size)

        if self.state.node.level == MAX_LV - 1:
            # instantly refine inner patch
            self.history[next_node_id]['refined'] = True
            # Zoomin, refine then zoomout with no stack operation
            self.state.node.step(action)
            reward = self.refine()
            self.handle_zoomout(pop_last=False)
            info = {'up_level': False, 'ale.lives': 1, 'down_level': False}
            done = False
            return self.observation(), reward, done, info

        reward = 0
        done = False
        info = {'up_level': False, 'ale.lives': 1, 'down_level': True}

        self.stack += [copy.deepcopy(self.state)]
        self.state.node.step(action)
        # print ('current:', self.state.node.start, self.state.node.size)

        return self.observation(), reward, done, info

    def handle_refine(self):
        # If trying to re-refine
        if self.history[self.state.node.id]['refined']:
            return self.handle_zoomout()

        reward = self.refine()
        self.history[self.state.node.id]['refined'] = True
        info = {'up_level': False, 'ale.lives': 1, 'down_level': False}
        done = False

        return self.observation(), reward, done, info

    def handle_zoomout(self, pop_last=True):
        # print ('----------------------------Zoomout')
        # print ('old:', self.state.node.start, self.state.node.size)
        done = True
        observation = self.observation()
        reward = 0
        info = {
            'up_level': True and (self.state.node.level != 0),
            'ale.lives': 1 if (self.state.node.level != 0) else 0,
            'down_level': False
        }
        if len(self.stack) > 0 and pop_last:
            self.stack.pop()
        self.state.node.step(TREE_BASE)
        # print ('current', self.state.node.start, self.state.node.size)
        return observation, reward, done, info

    def step(self, action):
        done = False
        state = copy.deepcopy(self.state)
        self.set_state(state)
        # print ("DEBUG", TREE_BASE)
        prev_state = copy.deepcopy(state)

        info = {'ale.lives': 1}

        if (0 <= action < TREE_BASE):
            return self.handle_zoomin(action)

        if (action == TREE_BASE):
            return self.handle_refine()

        if (action == TREE_BASE + 1):
            return self.handle_zoomout()

    def cell_count(self):
        st = self.state.node.start
        sz = self.state.node.size
        gt_lbl = self.lbl[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                          st[2]:st[2] + sz[2]]
        return len(np.unique(gt_lbl))

    def debug(self):
        print('------------------------DEBUG')
        print(self.state.node.id)
        print(self.history[self.state.node.id]['refined'])
        print(self.history[self.state.node.id]['zoomed'])
        print('-----------------------------')

    def sample_action(self):
        return self.action_space.sample_action()

    def get_zoomed_mask(self, zoomed):
        # print ("------------------Get Zoomed Mask")
        ret = np.zeros(RL_NET_SIZE[:3], dtype=np.uint8)

        for i in range(TREE_BASE):
            if ((2**i) & zoomed) != 0:
                # print ('i', i)
                z0 = int(DZ[i] * 0.5 * RL_NET_SIZE[0]) + RL_NET_SIZE[0] // 12
                y0 = int(DY[i] * 0.5 * RL_NET_SIZE[1]) + RL_NET_SIZE[1] // 12
                x0 = int(DX[i] * 0.5 * RL_NET_SIZE[2]) + RL_NET_SIZE[2] // 12

                size = [
                    RL_NET_SIZE[0] // 3, RL_NET_SIZE[1] // 3,
                    RL_NET_SIZE[2] // 3
                ]
                # print ('size', size)
                # print ((z0, y0, x0))
                ret[z0:z0 + size[0], y0:y0 + size[1], x0:x0 + size[2]] = 255
        return ret

    def set_state(self, state):
        self.state = copy.deepcopy(state)

    def observation(self):
        '''
            Observation of size RL_NET_SIZE
            # Raw, mask, zoomed_mask, refined_mask
        '''
        st = self.state.node.start
        sz = self.state.node.size
        ost = self.state.node.origin_start
        mask_st = [st[0] - ost[0], st[1] - ost[1], st[2] - ost[2]]

        raw_patch = self.raw[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1],
                             st[2]:st[2] + sz[2]]

        mask_patch = self.mask[mask_st[0]:mask_st[0] + sz[0],
                               mask_st[1]:mask_st[1] + sz[1],
                               mask_st[2]:mask_st[2] + sz[2]]
        if (self.history[self.state.node.id]['refined']):
            refined_mask = np.ones(RL_NET_SIZE[:3], dtype=np.uint8) * 255
        else:
            refined_mask = np.zeros(RL_NET_SIZE[:3], dtype=np.uint8)

        raw_patch = resize(raw_patch,
                           RL_NET_SIZE[:3],
                           order=0,
                           mode='reflect',
                           preserve_range=True).astype(np.uint8)
        mask_patch = resize(mask_patch,
                            RL_NET_SIZE[:3],
                            order=0,
                            mode='reflect',
                            preserve_range=True).astype(np.uint8)

        zoomed_mask = self.get_zoomed_mask(
            self.history[self.state.node.id]['zoomed'])

        raw_patch = np.expand_dims(raw_patch, -1)
        mask_patch = np.expand_dims(mask_patch, -1)
        refined_mask = np.expand_dims(refined_mask, -1)
        zoomed_mask = np.expand_dims(zoomed_mask, -1)

        ret = np.concatenate(
            [raw_patch, mask_patch, refined_mask, zoomed_mask], -1)
        return ret

    def concat_last_dim_2_x(self, arr):
        # Arr of HxWxC
        # Ret of HxW*C
        assert (len(arr.shape) == 3)
        ret = []
        for i in range(arr.shape[-1]):
            ret.append(arr[..., i])
        ret = np.concatenate(ret, -1)
        return ret

    def render(self):
        ret = self.observation()

        st = self.state.node.start
        sz = self.state.node.size

        size = ret.shape

        z = size[0] // 2
        y = size[1] // 2
        x = size[2] // 2
        yx = self.concat_last_dim_2_x(ret[z, :, :, :])
        zx = self.concat_last_dim_2_x(ret[:, y, :, :])
        zy = self.concat_last_dim_2_x(ret[:, :, x, :])
        # Concatenate to Y dim
        ret = np.concatenate([yx, zx, zy], 0)
        ret = np.repeat(np.expand_dims(ret, -1), 3, axis=-1)

        return ret.astype(np.uint8)
Exemplo n.º 4
0
 def setup_nets(self, path):
     self.seg_net_size = SEG_NET_SIZE
     self.net = Unet(SEG_CHANNELS, FEATURES, 1).to(self.device)
     checkpoint = torch.load(path)
     self.net.load_state_dict(checkpoint['state_dict'])
     self.net.eval()
Exemplo n.º 5
0
class Environment:
    def __init__ (self, raw_list, lbl_list):
        self.action_space = ActionSpace (NUM_ACTIONS)
        self.observation_space = ObservationSpace ()
        self.obs_shape = RL_NET_SIZE
        self.mask_shape = RL_NET_SIZE [:3]
        self.raw_list = raw_list
        self.lbl_list = lbl_list
        self.viewed = {}
        self.tree_base = TREE_BASE

        self.base_start = (80, 80, 80)
        self.base_size = ORIGIN_SIZE

        self.device = torch.device("cuda:1")
        self.setup_net ()

    def setup_net (self):
        self.net = Unet (1, [8, 16, 32, 64], 1).to (self.device)
        checkpoint_path = 'seg_net/checkpoints/checkpoint_364500.pth.tar'
        checkpoint = torch.load  (checkpoint_path)
        self.net.load_state_dict (checkpoint['state_dict'])

    def calculate_score (self, gt_lbl, mask, cell_id):
        gt_mask = (gt_lbl == cell_id).astype (np.int32)
        mask = (mask > 128).astype (np.int32)

        #plt.imshow (gt_mask, cmap='gray')
        #plt.show ()
        #plt.imshow (mask, cmap='gray')
        #plt.show ()
        # print (gt_mask.shape, mask.shape)

        dice_score = np.sum(mask[gt_mask==1])*2.0 / (np.sum(mask) + np.sum(gt_mask))
        return dice_score

    def get_cell_id (self, gt_lbl):
        ids, count = np.unique (gt_lbl, return_counts=True)
        count [0] = -1
        biggest_cell_id = ids [np.argmax (count)]
        return biggest_cell_id

    def predict (self):
        # print ("predict")
        st = self.state.node.start
        sz = self.state.node.size
        padding = 30
        with torch.no_grad():
            raw_patch = self.raw_list [st[0]: st[0]+sz[0], 
                                    st[1]: st[1] + sz[1], 
                                    st[2]: st[2] + sz[2]]
            raw_patch = resize (raw_patch, SEG_NET_SIZE [:3], order=0, mode='reflect', preserve_range=True)
            raw_patch = np.expand_dims (np.expand_dims (raw_patch, 0), 0).astype (np.float32) / 255.0
            new_mask = self.net (torch.tensor (raw_patch).to (self.device))
            new_mask = np.squeeze (new_mask.cpu ().numpy()) * 255

        new_mask = resize (new_mask, sz, order=0, mode='reflect', preserve_range=True).astype (np.uint8)
        cur_mask = self.mask [st[0]: st[0]+sz[0], 
                        st[1]: st[1] + sz[1], 
                        st[2]: st[2] + sz[2]]
        self.mask [st[0]: st[0]+sz[0], 
                    st[1]: st[1] + sz[1], 
                    st[2]: st[2] + sz[2]] = np.maximum (cur_mask, new_mask)
        gt_lbl = self.lbl_list [st[0]-padding: st[0]+sz[0]+padding, 
                            st[1]-padding: st[1] + sz[1]+padding, 
                            st[2]-padding: st[2] + sz[2]+padding]

        new_mask = np.pad (new_mask, padding, mode='constant', constant_values=0)
        cell_id = self.get_cell_id (self.lbl_list [st[0]: st[0]+sz[0], 
                                        st[1]: st[1] + sz[1], 
                                        st[2]: st[2] + sz[2]])
        # Hanle revisit segmented cell
        if cell_id in self.viewed:
            return 0
        self.viewed [cell_id] = True
        score = self.calculate_score (gt_lbl, new_mask, cell_id)
        if score < 0.6:
            score = 0
        return score

    def update_taboo (self):
        self.taboo [self.state.node.id] = True
        
    def step (self, action):
        done = False
        reward = 0
        state = copy.deepcopy (self.state)
        self.set_state (state)
        prev_state = copy.deepcopy (state)

        action = Action (action)
        info = { 'ale.lives': 1}
        
        # Handle revisiting a visited node
        taboo_action = False
        if (0 <= action.val < self.tree_base):
            if (self.state.node.id * TREE_BASE + action.val) in self.taboo:
                taboo_action = True
                action.val = self.tree_base + 1

        if action.val == self.tree_base: