def render(self, mode='human', close=False): if close: return if mode == 'human': if not hasattr(self, "gui"): obj_pos = [] for ship in self.state.ships: pos = ship.pos obj_pos.append(self.grid.get_index(pos)) for i in range(ship.length): pos += Compass.get_coord(ship.direction) obj_pos.append(self.grid.get_index(pos)) self.gui = ShipGui(board_size=self.grid.get_size, obj_pos=obj_pos) if self.t > 0: msg = "A: " + str(self.grid.get_coord(self.last_action)) + "T: " + str(self.t) + "Rw :" + str( self.tot_rw) self.gui.render(state=self.last_action, msg=msg)
class BattleShipEnv(Env): metadata = {"render.modes": ["human", "ansi"]} def __init__(self, board_size=(5, 5), ship_sizes=[5, 4, 3, 3, 2]): self.grid = BattleGrid(board_size) self.action_space = Discrete(self.grid.n_tiles) self.observation_space = Discrete(len(Obs)) self.num_obs = 2 self._reward_range = self.action_space.n / 4. self._discount = 1. self.ship_sizes = ship_sizes self.total_remaining = sum(self.ship_sizes) def seed(self, seed=None): np.random.seed(seed) def _compute_prob(self, action, next_state, ob): action_pos = self.grid.get_coord(action) cell = self.grid[action_pos] if ob == Obs.NULL.value and cell.visited: return 1 elif ob == Obs.HIT.value and cell.occupied: return 1 else: return int(ob == Obs.NULL.value) def step(self, action): assert self.done is False assert self.action_space.contains(action) assert self.total_remaining > 0 self.last_action = action self.t += 1 action_pos = self.grid.get_coord(action) # cell = self.grid.get_value(action_pos) cell = self.grid[action_pos] reward = 0 if cell.visited: reward -= 10 obs = Obs.NULL.value else: if cell.occupied: reward -= 1 obs = 1 self.state.total_remaining -= 1 for d in range(4, 8): if self.grid[action_pos + Compass.get_coord(d)]: self.grid[action_pos + Compass.get_coord(d)].diagonal = False else: reward -= 1 obs = Obs.NULL.value cell.visited = True if self.state.total_remaining == 0: reward += self.grid.n_tiles self.done = True self.tot_rw += reward return obs, reward, self.done, {"state": self.state} def _set_state(self, state): self.reset() self.state = state def close(self): return def reset(self): self.done = False self.tot_rw = 0 self.t = 0 self.last_action = -1 self.state = self._get_init_state() return Obs.NULL.value def render(self, mode='human', close=False): if close: return if mode == 'human': if not hasattr(self, "gui"): obj_pos = [] for ship in self.state.ships: pos = ship.pos obj_pos.append(self.grid.get_index(pos)) for i in range(ship.length): pos += Compass.get_coord(ship.direction) obj_pos.append(self.grid.get_index(pos)) self.gui = ShipGui(board_size=self.grid.get_size, obj_pos=obj_pos) if self.t > 0: msg = "A: " + str(self.grid.get_coord( self.last_action)) + "T: " + str(self.t) + "Rw :" + str( self.tot_rw) self.gui.render(state=self.last_action, msg=msg) def _generate_legal(self): # assert self.state.total_remaining > 0 actions = [] for action in range(self.action_space.n): action_pos = self.grid.get_coord(action) if not self.grid[action_pos].visited: actions.append(action) # assert len(actions) > 0 return actions def _get_init_state(self): bsstate = ShipState() self.grid.build_board() for length in self.ship_sizes: # num_ships = 1 # for idx in range(num_ships): while True: # add one ship of each kind ship = Ship(coord=self.grid.sample(), length=length) if not self.collision(ship, self.grid, bsstate): break self.mark_ship(ship, self.grid, bsstate) bsstate.ships.append(ship) return bsstate @staticmethod def mark_ship(ship, grid, state): pos = ship.pos # .copy() for i in range(ship.length + 1): cell = grid[pos] assert not cell.occupied cell.occupied = True if not cell.visited: state.total_remaining += 1 pos += Compass.get_coord(ship.direction) @staticmethod def collision(ship, grid, state): pos = ship.pos # .copy() for i in range(ship.length): if not grid.is_inside(pos + Compass.get_coord(ship.direction)): return True # cell = grid.get_value(pos) cell = grid[pos] if cell.occupied: return True for adj in range(8): coord = pos + Compass.get_coord(adj) if grid.is_inside(coord) and grid[coord].occupied: return True pos += Compass.get_coord(ship.direction) return False