def __init__(self, width: int, height: int, num_snakes: int, stacked_frames: int): self.width = width self.height = height self.num_fruits = num_snakes self.num_snakes = num_snakes self.stacked_frames = stacked_frames self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, stacked_frames=self.stacked_frames, ) self.action_space = spaces.Discrete(3) self.obs_space = spaces.Box( low=0, high=255, shape=(self.width, self.height, self.stacked_frames), dtype=np.uint8, ) if num_snakes == 1: self.observation_space = self.obs_space else: self.observation_space = spaces.Dict( dict( zip( map(str, range(num_snakes)), [self.obs_space for _ in range(num_snakes)], )))
def reset(self): self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, ) return self.state.observe()
def reset(self): self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, stacked_frames=self.stacked_frames, ) if self.num_snakes == 1: return self.state.observe() else: return dict( zip( map(str, range(self.num_snakes)), [self.state.observe(i) for i in range(self.num_snakes)], ))
def reset(self): self.unwrapped.state = State( width=self.unwrapped.width, height=self.unwrapped.height, num_snakes=self.unwrapped.num_snakes, num_fruits=self.unwrapped.num_fruits, stacked_frames=self.num_stacked_frames, ) return self.unwrapped.state.observe()
def reset(self): self.unwrapped.state = State( width=self.unwrapped.width, height=self.unwrapped.height, num_snakes=self.unwrapped.num_snakes, num_fruits=self.unwrapped.num_fruits, stacked_frames=self.num_stacked_frames, window_width=self.unwrapped.window_width, window_height=self.unwrapped.window_height, ) obs = self.unwrapped.state.observe() obs = np.moveaxis(obs, 0, -1) return obs
def __init__( self, width: int, height: int, num_fruits: int = 1, sparse_rewards: bool = False, window_width: int = 18, window_height: int = 18, ): self.width = width self.height = height self.window_width = window_width self.window_height = window_height self.sparse_rewards = sparse_rewards self.num_fruits = num_fruits self.num_snakes = 1 # if "DISPLAY" in os.environ: # self.game_renderer = GameRenderer(width, height, self.num_snakes) self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, window_width=self.window_width, window_height=self.window_height, ) self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box( low=0, high=255, shape=(1, self.window_width, self.window_height), dtype=np.uint8, )
class BattlesnakeEnv(gym.Env): """ Base class for different Battlesnake gym environments. """ metadata = {"render.modes": ["human"]} def __init__( self, width: int, height: int, num_fruits: int = 1, sparse_rewards: bool = False, window_width: int = 18, window_height: int = 18, ): self.width = width self.height = height self.window_width = window_width self.window_height = window_height self.sparse_rewards = sparse_rewards self.num_fruits = num_fruits self.num_snakes = 1 # if "DISPLAY" in os.environ: # self.game_renderer = GameRenderer(width, height, self.num_snakes) self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, window_width=self.window_width, window_height=self.window_height, ) self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box( low=0, high=255, shape=(1, self.window_width, self.window_height), dtype=np.uint8, ) def reset(self): self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, window_width=self.window_width, window_height=self.window_height, ) return self.state.observe() def step(self, action: Union[List[int], int]): fruit_eaten, collided, starved, won = self.state.move_snakes(action) reward, terminal = self._evaluate_reward(fruit_eaten, collided, starved, won) terminal = terminal or won if terminal: next_state = None else: next_state = self.state.observe() return next_state, reward, terminal, {} def render(self, mode="human"): # if "DISPLAY" in os.environ: # self.game_renderer.display(self.state) print(self.state.observe()) def _evaluate_reward( self, fruit_eaten: bool, collided: bool, starved: bool, won: bool ): terminal = False reward = Reward.nothing.value if self.sparse_rewards: if collided or starved: reward = Reward.lost.value terminal = True elif won: reward = Reward.won.value terminal = True else: if collided: terminal = True reward = Reward.collision.value else: if won: reward = Reward.won.value elif fruit_eaten: reward = Reward.fruit.value elif starved: terminal = True reward = Reward.starve.value return reward, terminal
class BattlesnakeEnv(MultiAgentEnv): """ Base class for different Battlesnake gym environments. """ def __init__(self, width: int, height: int, num_snakes: int, stacked_frames: int): self.width = width self.height = height self.num_fruits = num_snakes self.num_snakes = num_snakes self.stacked_frames = stacked_frames self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, stacked_frames=self.stacked_frames, ) self.action_space = spaces.Discrete(3) self.obs_space = spaces.Box( low=0, high=255, shape=(self.width, self.height, self.stacked_frames), dtype=np.uint8, ) if num_snakes == 1: self.observation_space = self.obs_space else: self.observation_space = spaces.Dict( dict( zip( map(str, range(num_snakes)), [self.obs_space for _ in range(num_snakes)], ))) def reset(self): self.state = State( width=self.width, height=self.height, num_snakes=self.num_snakes, num_fruits=self.num_fruits, stacked_frames=self.stacked_frames, ) if self.num_snakes == 1: return self.state.observe() else: return dict( zip( map(str, range(self.num_snakes)), [self.state.observe(i) for i in range(self.num_snakes)], )) def step(self, action: Union[Dict[str, int], int]): data = self.state.move_snakes(action) rewards, terminals = self._evaluate_reward(data) if self.num_snakes == 1: next_state = self.state.observe() reward = rewards[0] terminal = terminals[0] else: next_state = dict( zip(action.keys(), [self.state.observe(int(i)) for i in action.keys()])) reward = dict(zip(action.keys(), rewards)) terminal = dict(zip(action.keys(), terminals)) terminal["__all__"] = ( self.num_snakes - len([s for s in self.state.snakes if s.is_dead()]) <= 1) return next_state, reward, terminal, {} def render(self): return self.state.observe() def _evaluate_reward(self, data): rewards = [] terminals = [] for fruit_eaten, collided, starved, won, ate_enemy, action_corrected in zip( *data): terminal = False reward = Reward.nothing.value if collided or starved: terminal = True reward = Reward.lost.value else: if won: reward = Reward.won.value terminal = True elif fruit_eaten: reward = Reward.fruit.value elif ate_enemy: reward = Reward.ate_enemy.value elif action_corrected: reward = Reward.action_corrected.value rewards.append(reward) terminals.append(terminal) return rewards, terminals