예제 #1
0
 def _move_ghost(self, g, ghost_range):
     if Grid.manhattan_distance(self.state.agent_pos,
                                self.state.ghosts[g].pos) < ghost_range:
         if self.state.power_step > 0:
             self._move_defensive(g)
         else:
             self._move_aggressive(g)
     else:
         self._move_random(g)
예제 #2
0
파일: rock.py 프로젝트: muthissar/gym_pomdp
 def _select_target(rock_state, x_size):
     best_dist = x_size * 2
     best_rock = -1  # Coord(-1, -1)
     for idx, rock in enumerate(rock_state.rocks):
         if rock.status != 0 and rock.count >= 0:
             d = Grid.manhattan_distance(rock_state.agent_pos, rock.pos)
             if d < best_dist:
                 best_dist = d
                 best_rock = idx  # rock.pos
     return best_rock
예제 #3
0
파일: rock.py 프로젝트: muthissar/gym_pomdp
    def __init__(self, board_size=7, num_rocks=8, use_heuristic=False):

        assert board_size in list(
            config.keys()) and num_rocks in config[board_size]['size']

        self.num_rocks = num_rocks
        self._use_heuristic = use_heuristic

        self._rock_pos = [
            Coord(*rock) for rock in config[board_size]['rock_pos']
        ]
        self._agent_pos = Coord(*config[board_size]['init_pos'])
        self.grid = Grid(board_size, board_size)

        for idx, rock in enumerate(self._rock_pos):
            self.grid.board[rock] = idx

        self.action_space = Discrete(len(Action) + self.num_rocks)
        self.observation_space = Discrete(len(Obs))
        self._discount = .95
        self._reward_range = 20
        self._penalization = -100
        self._query = 0
예제 #4
0
    def _move_aggressive(self, g, chase_prob=.75):
        if not np.random.binomial(1, p=chase_prob):
            return self._move_random(g)

        best_dist = self.grid.x_size + self.grid.y_size
        best_pos = self.state.ghosts[g].pos
        best_dir = -1
        for d in range(self.action_space.n):
            dist = Grid.directional_distance(self.state.agent_pos,
                                             self.state.ghosts[g].pos, d)
            new_pos = self._next_pos(self.state.ghosts[g].pos, d)
            if dist <= best_dist and new_pos.is_valid() and can_move(
                    self.state.ghosts[g], d):
                best_pos = new_pos
                best_dist = dist
                best_dir = d

        self.state.ghosts[g].update(best_pos, best_dir)
예제 #5
0
    def _move_defensive(self, g, defensive_prob=.5):
        if np.random.binomial(
                1, defensive_prob) and self.state.ghosts[g].direction >= 0:
            self.state.ghosts[g].direction = -1

        best_dist = self.grid.x_size + self.grid.y_size
        best_pos = self.state.ghosts[g].pos
        best_dir = -1
        for d in range(self.action_space.n):
            dist = Grid.directional_distance(self.state.agent_pos,
                                             self.state.ghosts[g].pos, d)
            new_pos = self._next_pos(self.state.ghosts[g].pos, d)
            if dist >= best_dist and new_pos.is_valid() and can_move(
                    self.state.ghosts[g], d):
                best_pos = new_pos
                best_dist = dist
                best_dir = d

        self.state.ghosts[g].update(best_pos, best_dir)
예제 #6
0
def can_move(ghost, d):

    return Grid.opposite(d) != ghost.direction
예제 #7
0
 def _hear_ghost(poc_state, hear_range=2):
     for ghost in poc_state.ghosts:
         if Grid.manhattan_distance(ghost.pos,
                                    poc_state.agent_pos) <= hear_range:
             return True
     return False
예제 #8
0
파일: rock.py 프로젝트: muthissar/gym_pomdp
 def _efficiency(agent_pos, rock_pos, hed=20):
     # TODO check me
     d = Grid.euclidean_distance(agent_pos, rock_pos)
     eff = (1 + pow(2, -d / hed)) * .5
     return eff
예제 #9
0
파일: rock.py 프로젝트: muthissar/gym_pomdp
class RockEnv(Env):
    metadata = {"render.modes": ["human", "ansi"]}

    def __init__(self, board_size=7, num_rocks=8, use_heuristic=False):

        assert board_size in list(
            config.keys()) and num_rocks in config[board_size]['size']

        self.num_rocks = num_rocks
        self._use_heuristic = use_heuristic

        self._rock_pos = [
            Coord(*rock) for rock in config[board_size]['rock_pos']
        ]
        self._agent_pos = Coord(*config[board_size]['init_pos'])
        self.grid = Grid(board_size, board_size)

        for idx, rock in enumerate(self._rock_pos):
            self.grid.board[rock] = idx

        self.action_space = Discrete(len(Action) + self.num_rocks)
        self.observation_space = Discrete(len(Obs))
        self._discount = .95
        self._reward_range = 20
        self._penalization = -100
        self._query = 0

    def seed(self, seed=None):
        np.random.seed(seed)

    def step(self, action):

        assert self.action_space.contains(action)
        assert self.done is False

        self.last_action = action
        self._query += 1

        reward = 0
        ob = Obs.NULL.value

        if action < Action.SAMPLE.value:
            if action == Action.EAST.value:
                if self.state.agent_pos.x + 1 < self.grid.x_size:
                    self.state.agent_pos += Moves.EAST.value
                else:
                    reward = 10
                    self.done = True
                    return ob, reward, self.done, {
                        "state": self._encode_state(self.state)
                    }
            elif action == Action.NORTH.value:
                if self.state.agent_pos.y + 1 < self.grid.y_size:
                    self.state.agent_pos += Moves.NORTH.value
                else:
                    reward = self._penalization
            elif action == Action.SOUTH.value:
                if self.state.agent_pos.y - 1 >= 0:
                    self.state.agent_pos += Moves.SOUTH.value
                else:
                    reward = self._penalization
            elif action == Action.WEST.value:
                if self.state.agent_pos.x - 1 >= 0:
                    self.state.agent_pos += Moves.WEST.value
                else:
                    reward = self._penalization
            else:
                raise NotImplementedError()

        if action == Action.SAMPLE.value:
            rock = self.grid[self.state.agent_pos]
            if rock >= 0 and not self.state.rocks[
                    rock].status == 0:  # collected
                if self.state.rocks[rock].status == 1:
                    reward = 10
                else:
                    reward = -10
                self.state.rocks[rock].status = 0
            else:
                reward = self._penalization

        if action > Action.SAMPLE.value:
            rock = action - Action.SAMPLE.value - 1
            assert rock < self.num_rocks

            ob = self._sample_ob(self.state.agent_pos, self.state.rocks[rock])

            self.state.rocks[rock].measured += 1

            eff = self._efficiency(self.state.agent_pos,
                                   self.state.rocks[rock].pos)

            if ob == Obs.GOOD.value:
                self.state.rocks[rock].count += 1
                self.state.rocks[rock].lkv *= eff
                self.state.rocks[rock].lkw *= (1 - eff)
            else:
                self.state.rocks[rock].count -= 1
                self.state.rocks[rock].lkw *= eff
                self.state.rocks[rock].lkv *= (1 - eff)

            denom = (.5 * self.state.rocks[rock].lkv) + (
                .5 * self.state.rocks[rock].lkw)
            self.state.rocks[rock].prob_valuable = (
                .5 * self.state.rocks[rock].lkv) / denom

        self.done = self._penalization == reward
        return ob, reward, self.done, {"state": self._encode_state(self.state)}

    def _decode_state(self, state, as_array=False):

        agent_pos = Coord(*state['agent_pos'])
        rock_state = RockState(agent_pos)
        for r in state['rocks']:
            rock = Rock(pos=0)
            rock.__dict__.update(r)
            rock_state.rocks.append(rock)

        if as_array:
            rocks = []
            for rock in rock_state.rocks:
                rocks.append(rock.status)

            return np.concatenate([[self.grid.get_index(agent_pos)], rocks])

        return rock_state

    def _encode_state(self, state):
        # use dictionary for state encodign

        return _encode_dict(state)
        # rocks can take 3 values: -1, 1, 0 if collected

    def render(self, mode='human', close=False):
        if close:
            return
        if mode == "human":
            if not hasattr(self, "gui"):
                start_pos = self.grid.get_index(self.state.agent_pos)
                obj_pos = [(self.grid.get_index(rock.pos), rock.status)
                           for rock in self.state.rocks]
                self.gui = RockGui((self.grid.x_size, self.grid.y_size),
                                   start_pos=start_pos,
                                   obj=obj_pos)

            if self.last_action > Action.SAMPLE.value:
                rock = self.last_action - Action.SAMPLE.value - 1
                print("Rock S: {} P:{}".format(self.state.rocks[rock].status,
                                               self.state.rocks[rock].pos))
            # msg = "Action : " + action_to_str(self.last_action) + " Step: " + str(self.t) + " Rw: " + str(self.total_rw)
            agent_pos = self.grid.get_index(self.state.agent_pos)
            self.gui.render(agent_pos)

    def reset(self):
        self.done = False
        self._query = 0
        self.last_action = Action.SAMPLE.value
        self.state = self._get_init_state(should_encode=False)
        return Obs.NULL.value

    def _set_state(self, state):
        self.done = False
        self.state = self._decode_state(state)

    def close(self):
        self.render(close=True)

    def _compute_prob(self, action, next_state, ob):

        next_state = self._decode_state(next_state)

        if action <= Action.SAMPLE.value:
            return int(ob == Obs.NULL.value)

        eff = self._efficiency(
            next_state.agent_pos,
            next_state.rocks[action - Action.SAMPLE.value - 1].pos)

        if ob == Obs.GOOD.value and next_state.rocks[action -
                                                     Action.SAMPLE.value -
                                                     1].status == 1:
            return eff
        elif ob == Obs.BAD.value and next_state.rocks[action -
                                                      Action.SAMPLE.value -
                                                      1].status == -1:
            return eff
        else:
            return 1 - eff

    def _get_init_state(self, should_encode=True):

        rock_state = RockState(self._agent_pos)
        for idx in range(self.num_rocks):
            rock_state.rocks.append(Rock(self._rock_pos[idx]))
        return self._encode_state(rock_state) if should_encode else rock_state

    def _generate_legal(self):
        legal = [Action.EAST.value]  # can always go east
        if self.state.agent_pos.y + 1 < self.grid.y_size:
            legal.append(Action.NORTH.value)

        if self.state.agent_pos.y - 1 >= 0:
            legal.append(Action.SOUTH.value)
        if self.state.agent_pos.x - 1 >= 0:
            legal.append(Action.WEST.value)

        rock = self.grid[self.state.agent_pos]
        if rock >= 0 and self.state.rocks[rock].status != 0:
            legal.append(Action.SAMPLE.value)

        for rock in self.state.rocks:
            assert self.grid[rock.pos] != -1
            if rock.status != 0:
                legal.append(self.grid[rock.pos] + 1 + Action.SAMPLE.value)
        return legal

    def _generate_preferred(self, history):
        if not self._use_heuristic:
            return self._generate_legal()

        actions = []

        # sample rocks with high likelihood of being good
        rock = self.grid[self.state.agent_pos]
        if rock >= 0 and self.state.rocks[rock].status != 0 and history.size:
            total = 0
            # history
            for t in range(history.size):
                if history[t].action == rock + 1 + Action.SAMPLE.value:
                    if history[t].ob == Obs.GOOD.value:
                        total += 1
                    elif history[t].ob == Obs.BAD.value:
                        total -= 1
            if total > 0:
                actions.append(Action.SAMPLE.value)
                return actions

        # process the rocks

        all_bad = True
        direction = {
            "north": False,
            "south": False,
            "west": False,
            "east": False
        }
        for idx in range(self.num_rocks):
            rock = self.state.rocks[idx]
            if rock.status != 0:
                total = 0
                for t in range(history.size):
                    if history[t].action == idx + 1 + Action.SAMPLE.value:
                        if history[t].ob == Obs.GOOD.value:
                            total += 1
                        elif history[t].ob == Obs.BAD.value:
                            total -= 1
                if total >= 0:
                    all_bad = False

                    if rock.pos.y > self.state.agent_pos.y:
                        direction['north'] = True
                    elif rock.pos.y < self.state.agent_pos.y:
                        direction['south'] = True
                    elif rock.pos.x < self.state.agent_pos.x:
                        direction['west'] = True
                    elif rock.pos.x > self.state.agent_pos.x:
                        direction['east'] = True

        if all_bad:
            actions.append(Action.EAST.value)
            return actions

        # generate a random legal move
        # do not measure a collected rock
        # do no measure a rock too often
        # do not measure clearly bad rocks
        # don't move in a direction that puts you closer to bad rocks
        # never sample a rock

        if self.state.agent_pos.y + 1 < self.grid.y_size and direction['north']:
            actions.append(Action.NORTH.value)

        if direction['east']:
            actions.append(Action.EAST.value)

        if self.state.agent_pos.y - 1 >= 0 and direction['south']:
            actions.append(Action.SOUTH.value)

        if self.state.agent_pos.x - 1 >= 0 and direction['west']:
            actions.append(Action.WEST.value)

        for idx, rock in enumerate(self.state.rocks):
            if not rock.status == 0 and rock.measured < 5 and abs(
                    rock.count) < 2 and 0 < rock.prob_valuable < 1:
                actions.append(idx + 1 + Action.SAMPLE.value)

        if len(actions) == 0:
            return self._generate_legal()

        return actions

    def __dict2np__(self, state):
        idx = self.grid.get_index(Coord(*state['agent_pos']))
        rocks = []
        for rock in state['rocks']:
            rocks.append(rock['status'])
        return np.concatenate([[idx], rocks])

    @staticmethod
    def _efficiency(agent_pos, rock_pos, hed=20):
        # TODO check me
        d = Grid.euclidean_distance(agent_pos, rock_pos)
        eff = (1 + pow(2, -d / hed)) * .5
        return eff

    @staticmethod
    def _select_target(rock_state, x_size):
        best_dist = x_size * 2
        best_rock = -1  # Coord(-1, -1)
        for idx, rock in enumerate(rock_state.rocks):
            if rock.status != 0 and rock.count >= 0:
                d = Grid.manhattan_distance(rock_state.agent_pos, rock.pos)
                if d < best_dist:
                    best_dist = d
                    best_rock = idx  # rock.pos
        return best_rock

    @staticmethod
    def _sample_ob(agent_pos, rock, hed=20):
        eff = RockEnv._efficiency(agent_pos, rock.pos, hed=hed)
        if np.random.binomial(1, eff):
            return Obs.GOOD.value if rock.status == 1 else Obs.BAD.value
        else:
            return Obs.BAD.value if rock.status == 1 else Obs.GOOD.value
예제 #10
0
    def __init__(self,
                 board_size=7,
                 num_rocks=8,
                 use_heuristic=False,
                 observation='o',
                 stay_inside=False):
        """

        :param board_size: int board is a square of board_size x board_size
        :param num_rocks: int number of rocks on board
        :param use_heuristic: bool usage unclear
        :param observation: str must be one of
                                'o': observed value only
                                'po': position of the agent + the above
                                'poa': the above + the action taken
        """

        assert board_size in list(config.keys()) and \
               num_rocks == len(config[board_size]["rock_pos"])

        self.num_rocks = num_rocks
        self._use_heuristic = use_heuristic

        self._rock_pos = \
            [Coord(*rock) for rock in config[board_size]['rock_pos']]
        self._agent_pos = Coord(*config[board_size]['init_pos'])
        self.grid = Grid(board_size, board_size)

        for idx, rock in enumerate(self._rock_pos):
            self.grid.board[rock] = idx

        self.action_space = Discrete(len(Action) + self.num_rocks)
        self._discount = .95
        self._reward_range = 20
        self._penalization = -100
        self._query = 0
        if stay_inside:
            self._out_of_bounds_penalty = 0
        else:
            self._out_of_bounds_penalty = self._penalization

        self.state = None
        self.last_action = None
        self.done = False

        self.gui = None

        assert observation in ['o', 'oa', 'po', 'poa']
        if observation == 'o':
            self._make_obs = lambda obs, a: obs
            self.observation_space = Discrete(len(Obs))
        elif observation == 'oa':
            self._make_obs = self._oa
            self.observation_space =\
                Box(low=0,
                    high=np.append(max(Obs), np.ones(self.action_space.n)),
                    dtype=np.int)

        elif observation == 'po':
            self._make_obs = self._po
            self.observation_space = \
                Box(low=0,
                    high=np.append(np.ones(self.grid.n_tiles), max(Obs)),
                    dtype=np.int)

        elif observation == 'poa':
            self._make_obs = self._poa
            self.observation_space = \
                Box(low=0,
                    high=np.concatenate((np.ones(self.grid.n_tiles),
                                         [max(Obs)],
                                        np.ones(self.action_space.n))),
                    dtype=np.int)
예제 #11
0
class RockEnv(Env):
    metadata = {"render.modes": ["human", "ansi"]}

    def __init__(self,
                 board_size=7,
                 num_rocks=8,
                 use_heuristic=False,
                 observation='o',
                 stay_inside=False):
        """

        :param board_size: int board is a square of board_size x board_size
        :param num_rocks: int number of rocks on board
        :param use_heuristic: bool usage unclear
        :param observation: str must be one of
                                'o': observed value only
                                'po': position of the agent + the above
                                'poa': the above + the action taken
        """

        assert board_size in list(config.keys()) and \
               num_rocks == len(config[board_size]["rock_pos"])

        self.num_rocks = num_rocks
        self._use_heuristic = use_heuristic

        self._rock_pos = \
            [Coord(*rock) for rock in config[board_size]['rock_pos']]
        self._agent_pos = Coord(*config[board_size]['init_pos'])
        self.grid = Grid(board_size, board_size)

        for idx, rock in enumerate(self._rock_pos):
            self.grid.board[rock] = idx

        self.action_space = Discrete(len(Action) + self.num_rocks)
        self._discount = .95
        self._reward_range = 20
        self._penalization = -100
        self._query = 0
        if stay_inside:
            self._out_of_bounds_penalty = 0
        else:
            self._out_of_bounds_penalty = self._penalization

        self.state = None
        self.last_action = None
        self.done = False

        self.gui = None

        assert observation in ['o', 'oa', 'po', 'poa']
        if observation == 'o':
            self._make_obs = lambda obs, a: obs
            self.observation_space = Discrete(len(Obs))
        elif observation == 'oa':
            self._make_obs = self._oa
            self.observation_space =\
                Box(low=0,
                    high=np.append(max(Obs), np.ones(self.action_space.n)),
                    dtype=np.int)

        elif observation == 'po':
            self._make_obs = self._po
            self.observation_space = \
                Box(low=0,
                    high=np.append(np.ones(self.grid.n_tiles), max(Obs)),
                    dtype=np.int)

        elif observation == 'poa':
            self._make_obs = self._poa
            self.observation_space = \
                Box(low=0,
                    high=np.concatenate((np.ones(self.grid.n_tiles),
                                         [max(Obs)],
                                        np.ones(self.action_space.n))),
                    dtype=np.int)

    def seed(self, seed=None):
        np.random.seed(seed)

    def step(self, action: int):
        err_msg = "%r (%s) invalid" % (action, type(action))
        assert self.action_space.contains(action), err_msg
        assert self.done is False

        self.last_action = action
        self._query += 1

        reward = 0
        ob = Obs.NULL

        if action < Action.SAMPLE:
            if action == Action.EAST:
                if self.state.agent_pos.x + 1 < self.grid.x_size:
                    self.state.agent_pos += Moves.EAST.value
                else:
                    reward = 10
                    self.done = True
                    ob = self._make_obs(ob, action)
                    return ob, reward, self.done, {
                        "state": self._encode_state(self.state)
                    }
            elif action == Action.NORTH:
                if self.state.agent_pos.y + 1 < self.grid.y_size:
                    self.state.agent_pos += Moves.NORTH.value
                else:
                    reward = self._out_of_bounds_penalty
            elif action == Action.SOUTH:
                if self.state.agent_pos.y - 1 >= 0:
                    self.state.agent_pos += Moves.SOUTH.value
                else:
                    reward = self._out_of_bounds_penalty
            elif action == Action.WEST:
                if self.state.agent_pos.x - 1 >= 0:
                    self.state.agent_pos += Moves.WEST.value
                else:
                    reward = self._out_of_bounds_penalty
            else:
                raise NotImplementedError()

        if action == Action.SAMPLE:
            rock = self.grid[self.state.agent_pos]
            if rock >= 0 and not self.state.rocks[
                    rock].status == 0:  # collected
                if self.state.rocks[rock].status == 1:
                    reward = 10
                else:
                    reward = -10
                self.state.rocks[rock].status = 0
            else:
                reward = self._penalization

        if action > Action.SAMPLE:
            rock = action - Action.SAMPLE - 1
            assert rock < self.num_rocks

            ob = self._sample_ob(self.state.agent_pos, self.state.rocks[rock])

            self.state.rocks[rock].measured += 1

            eff = self._efficiency(self.state.agent_pos,
                                   self.state.rocks[rock].pos)

            if ob == Obs.GOOD:
                self.state.rocks[rock].count += 1
                self.state.rocks[rock].lkv *= eff
                self.state.rocks[rock].lkw *= (1 - eff)
            else:
                self.state.rocks[rock].count -= 1
                self.state.rocks[rock].lkw *= eff
                self.state.rocks[rock].lkv *= (1 - eff)

                denominator = (.5 * self.state.rocks[rock].lkv) + (
                    .5 * self.state.rocks[rock].lkw) + 1e-10
                self.state.rocks[rock].prob_valuable = \
                    (.5 * self.state.rocks[rock].lkv) / denominator

        self.done = self._penalization == reward
        ob = self._make_obs(ob, action)
        return ob, reward, self.done, {"state": self._encode_state(self.state)}

    def _decode_state(self, state, as_array=False):

        agent_pos = Coord(*state['agent_pos'])
        rock_state = RockState(agent_pos)
        for r in state['rocks']:
            rock = Rock(pos=0)
            rock.__dict__.update(r)
            rock_state.rocks.append(rock)

        if as_array:
            rocks = []
            for rock in rock_state.rocks:
                rocks.append(rock.status)

            return np.concatenate([[self.grid.get_index(agent_pos)], rocks])

        return rock_state

    @staticmethod
    def _encode_state(state):
        # use dictionary for state encoding

        return _encode_dict(state)
        # rocks can take 3 values: -1, 1, 0 if collected

    def render(self, mode='human', close=False):
        if close:
            return
        if mode == "human":
            msg = None
            if self.gui is None:
                start_pos = self.grid.get_index(self.state.agent_pos)
                obj_pos = [(self.grid.get_index(rock.pos), rock.status)
                           for rock in self.state.rocks]
                self.gui = RockGui((self.grid.x_size, self.grid.y_size),
                                   start_pos=start_pos,
                                   obj=obj_pos)

            if self.last_action > Action.SAMPLE:
                rock = self.last_action - Action.SAMPLE - 1
                msg = "Rock S: {} P:{}".format(self.state.rocks[rock].status,
                                               self.state.rocks[rock].pos)
            agent_pos = self.grid.get_index(self.state.agent_pos)
            self.gui.render(agent_pos, msg)

    def reset(self):
        self.done = False
        self._query = 0
        self.last_action = Action.SAMPLE
        self.state = self._get_init_state(should_encode=False)
        return self._make_obs(Obs.NULL, self.last_action)

    def _set_state(self, state):
        self.done = False
        self.state = self._decode_state(state)

    def close(self):
        self.render(close=True)

    def _compute_prob(self, action, next_state, ob):

        next_state = self._decode_state(next_state)

        if action <= Action.SAMPLE:
            return int(ob == Obs.NULL)

        eff = self._efficiency(
            next_state.agent_pos,
            next_state.rocks[action - Action.SAMPLE - 1].pos)

        if ob == Obs.GOOD and next_state.rocks[action - Action.SAMPLE -
                                               1].status == 1:
            return eff
        elif ob == Obs.BAD and next_state.rocks[action - Action.SAMPLE -
                                                1].status == -1:
            return eff
        else:
            return 1 - eff

    def _get_init_state(self, should_encode=True):

        rock_state = RockState(self._agent_pos)
        for idx in range(self.num_rocks):
            rock_state.rocks.append(Rock(self._rock_pos[idx]))
        return self._encode_state(rock_state) if should_encode else rock_state

    def _generate_legal(self):
        legal = [Action.EAST]  # can always go east
        if self.state.agent_pos.y + 1 < self.grid.y_size:
            legal.append(Action.NORTH)

        if self.state.agent_pos.y - 1 >= 0:
            legal.append(Action.SOUTH)
        if self.state.agent_pos.x - 1 >= 0:
            legal.append(Action.WEST)

        rock = self.grid[self.state.agent_pos]
        if rock >= 0 and self.state.rocks[rock].status != 0:
            legal.append(Action.SAMPLE)

        for rock in self.state.rocks:
            assert self.grid[rock.pos] != -1
            if rock.status != 0:
                legal.append(self.grid[rock.pos] + 1 + Action.SAMPLE)
        return legal

    def _generate_preferred(self, history):
        if not self._use_heuristic:
            return self._generate_legal()

        actions = []

        # sample rocks with high likelihood of being good
        rock = self.grid[self.state.agent_pos]
        if rock >= 0 and self.state.rocks[rock].status != 0 and history.size:
            total = 0
            # history
            for t in range(history.size):
                if history[t].action == rock + 1 + Action.SAMPLE:
                    if history[t].ob == Obs.GOOD:
                        total += 1
                    elif history[t].ob == Obs.BAD:
                        total -= 1
            if total > 0:
                actions.append(Action.SAMPLE)
                return actions

        # process the rocks

        all_bad = True
        direction = {
            "north": False,
            "south": False,
            "west": False,
            "east": False
        }
        for idx in range(self.num_rocks):
            rock = self.state.rocks[idx]
            if rock.status != 0:
                total = 0
                for t in range(history.size):
                    if history[t].action == idx + 1 + Action.SAMPLE:
                        if history[t].ob == Obs.GOOD:
                            total += 1
                        elif history[t].ob == Obs.BAD:
                            total -= 1
                if total >= 0:
                    all_bad = False

                    if rock.pos.y > self.state.agent_pos.y:
                        direction['north'] = True
                    elif rock.pos.y < self.state.agent_pos.y:
                        direction['south'] = True
                    elif rock.pos.x < self.state.agent_pos.x:
                        direction['west'] = True
                    elif rock.pos.x > self.state.agent_pos.x:
                        direction['east'] = True

        if all_bad:
            actions.append(Action.EAST)
            return actions

        # generate a random legal move
        # do not measure a collected rock
        # do no measure a rock too often
        # do not measure clearly bad rocks
        # don't move in a direction that puts you closer to bad rocks
        # never sample a rock

        if self.state.agent_pos.y + 1 < self.grid.y_size and\
                direction['north']:
            actions.append(Action.NORTH)

        if direction['east']:
            actions.append(Action.EAST)

        if self.state.agent_pos.y - 1 >= 0 and direction['south']:
            actions.append(Action.SOUTH)

        if self.state.agent_pos.x - 1 >= 0 and direction['west']:
            actions.append(Action.WEST)

        for idx, rock in enumerate(self.state.rocks):
            if not rock.status == 0 and rock.measured < 5 and abs(
                    rock.count) < 2 and 0 < rock.prob_valuable < 1:
                actions.append(idx + 1 + Action.SAMPLE)

        if len(actions) == 0:
            return self._generate_legal()

        return actions

    def __dict2np__(self, state):
        idx = self.grid.get_index(Coord(*state['agent_pos']))
        rocks = []
        for rock in state['rocks']:
            rocks.append(rock['status'])
        return np.concatenate([[idx], rocks])

    @staticmethod
    def _efficiency(agent_pos, rock_pos, hed=20):
        # TODO check me
        d = Grid.euclidean_distance(agent_pos, rock_pos)
        eff = (1 + pow(2, -d / hed)) * .5
        return eff

    @staticmethod
    def _select_target(rock_state, x_size):
        best_dist = x_size * 2
        best_rock = -1  # Coord(-1, -1)
        for idx, rock in enumerate(rock_state.rocks):
            if rock.status != 0 and rock.count >= 0:
                d = Grid.manhattan_distance(rock_state.agent_pos, rock.pos)
                if d < best_dist:
                    best_dist = d
                    best_rock = idx  # rock.pos
        return best_rock

    @staticmethod
    def _sample_ob(agent_pos, rock, hed=20):
        eff = RockEnv._efficiency(agent_pos, rock.pos, hed=hed)
        if np.random.binomial(1, eff):
            return Obs.GOOD if rock.status == 1 else Obs.BAD
        else:
            return Obs.BAD if rock.status == 1 else Obs.GOOD

    def _po(self, o, _):
        obs = np.zeros(self.observation_space.shape[0])
        obs[self.grid.x_size * self.state.agent_pos.y +
            self.state.agent_pos.x] = 1.
        obs[self.grid.n_tiles] = o
        return obs

    def _poa(self, o, a):
        obs = self._po(o, a)
        obs[self.grid.n_tiles + a] = 1.
        return obs

    def _oa(self, o, a):
        obs = np.zeros(self.observation_space.shape[0])
        obs[0] = o
        obs[1 + a] = 1.
        return obs