class All(arrdict.namedarrtuple(fields=('history', 'count'))): """Players need to submit 1s each turn; if they do it every turn they get +1, else 0""" @classmethod def initial(cls, n_envs=1, n_seats=1, length=4, device='cuda'): return cls(history=torch.full((n_envs, length, n_seats), -1, dtype=torch.long, device=device), count=torch.full((n_envs, ), 0, dtype=torch.long, device=device)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.count, torch.Tensor): return self.n_envs, self.length, self.n_seats = self.history.shape[-3:] self.device = self.count.device self.max_count = self.n_seats * self.length self.obs_space = heads.Tensor((1, )) self.action_space = heads.Masked(2) self.valid = torch.ones(self.count.shape + (2, ), dtype=torch.bool, device=self.device) self.seats = self.count % self.n_seats self.obs = self.count[..., None].float() / self.max_count self.envs = torch.arange(self.n_envs, device=self.device) # Planted values for validation use self.logits = uniform_logits(self.valid) correct_so_far = (self.history == 1).sum(-2) == self.count[..., None] correct_to_go = 2**((self.history == 1).sum(-2) - self.length).float() v = correct_so_far.float() * correct_to_go self.v = v def step(self, actions): history = self.history.clone() idx = self.count // self.n_seats history[self.envs, idx, self.seats] = actions count = self.count + 1 terminal = (count == self.max_count) reward = ((count == self.max_count)[:, None] & (history == 1).all(-2)).float() transition = arrdict.arrdict(terminal=terminal, rewards=reward) count[terminal] = 0 history[terminal] = -1 world = type(self)(history=history, count=count) return world, transition
class Win(arrdict.namedarrtuple(fields=('envs', ))): """One-step one-seat win (+1)""" @classmethod def initial(cls, n_envs=1, device='cuda'): return cls(envs=torch.arange(n_envs, device=device)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.envs, torch.Tensor): return self.device = self.envs.device self.n_envs = len(self.envs) self.n_seats = 1 self.obs_space = (0, ) self.action_space = (1, ) self.valid = torch.ones_like(self.envs[..., None].bool()) self.seats = torch.zeros_like(self.envs) self.logits = uniform_logits(self.valid) self.v = torch.ones_like(self.valid.float().unsqueeze(-1)) def step(self, actions): trans = arrdict.arrdict(terminal=torch.ones_like(self.envs.bool()), rewards=torch.ones_like(self.envs.float())) return self, trans
class WinnerLoser(arrdict.namedarrtuple(fields=('seats', ))): """First seat wins each turn and gets +1; second loses and gets -1""" @classmethod def initial(cls, n_envs=1, device='cuda'): return cls(seats=torch.zeros(n_envs, device=device, def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.seats, torch.Tensor): return self.device = self.seats.device self.n_envs = len(self.seats) self.n_seats = 2 self.obs_space = (0, ) self.action_space = (1, ) self.valid = torch.ones((self.n_envs, 1), dtype=torch.bool, device=self.device) self.seats = self.seats self.logits = uniform_logits(self.valid) self.v = torch.stack( [torch.ones_like(self.seats), -torch.ones_like(self.seats)], -1).float() def step(self, actions): terminal = (self.seats == 1) trans = arrdict.arrdict(terminal=terminal, rewards=torch.stack( [terminal.float(), -terminal.float()], -1)) return type(self)(seats=1 - self.seats), trans
class MockGame(arrdict.namedarrtuple(fields=('count', 'history'))): @classmethod def initial(cls, n_envs=1, length=4, device='cuda'): return cls(history=torch.full((n_envs, length), -1, dtype=torch.long, device=device), count=torch.full((n_envs, ), 0, dtype=torch.long, device=device)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.count, torch.Tensor): return self.n_envs = self.count.shape[0] self.device = self.count.device self.n_seats = 2 self.valid = torch.ones(self.count.shape + (2, ), dtype=torch.bool, device=self.device) @property def seats(self): return self.count % self.n_seats def step(self, actions): history = self.history.clone() history.scatter_(1, self.count[:, None], actions[:, None]) count = self.count + 1 terminal = (count == self.history.shape[1]) transition = arrdict.arrdict(terminal=terminal) count[terminal] = 0 world = type(self)(count=count, history=history) return world, transition, list(history[terminal])
class MockWorlds(arrdict.namedarrtuple(fields=('seats', 'cumulator', 'lengths'))): """One-step one-seat win (+1)""" @classmethod def initial(cls, n_envs=1, device='cpu'): return cls( seats=torch.full((n_envs,), 0, device=device), lengths=torch.full((n_envs,), 1, device=device), cumulator=torch.full((n_envs, 2), 0., device=device)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.seats, torch.Tensor): return self.n_envs = len(self.seats) self.device = self.seats.device self.n_seats = 2 def step(self, skills): terminal = self.lengths == 4 new_lengths = self.lengths + 1 new_lengths[terminal] = 1 new_cumulator = self.cumulator.scatter_add(1, self.seats[:, None], skills[:, None].float()) rates = torch.sigmoid((new_cumulator[:, 0] - new_cumulator[:, 1])/4) rewards = terminal * (2*(torch.rand((self.n_envs,),) <= rates) - 1) rewards = torch.stack([rewards, -rewards], -1) new_cumulator[terminal] = 0 new_seats = 1 - self.seats new_seats[terminal] = 0 trans = arrdict.arrdict(terminal=terminal, rewards=rewards) return type(self)(seats=new_seats, lengths=new_lengths, cumulator=new_cumulator), trans
class SequentialMatrix( arrdict.namedarrtuple(fields=('payoffs', 'moves', 'seats'))): @classmethod def initial(cls, payoff, n_envs=1, device='cuda'): return cls(payoffs=torch.as_tensor(payoff).to(device)[None, ...].repeat( n_envs, 1, 1, 1), seats=torch.zeros((n_envs, ),, device=device), moves=torch.full((n_envs, 2), -1,, device=device)) @classmethod def dilemma(cls, *args, **kwargs): return cls.initial([[[0., 0.], [1., 0.]], [[0., 1.], [.5, .5]]], *args, **kwargs) @classmethod def antisymmetric(cls, *args, **kwargs): return cls.initial([[[1., 0.], [1., 1.]], [[0., 0.], [0., .1]]], *args, **kwargs) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.payoffs, torch.Tensor): return self.n_envs = self.seats.size(-1) self.n_seats = 2 self.device = self.seats.device self.obs_space = heads.Tensor((1, )) self.action_space = heads.Masked(2) self.obs = self.moves[..., [0]].float() self.valid = torch.stack( [torch.ones_like(self.seats, dtype=torch.bool)] * 2, -1) self.envs = torch.arange(self.n_envs, device=self.device, dtype=torch.long) def step(self, actions): seats = self.seats + 1 terminal = (seats == 2) moves = self.moves.clone() moves[self.envs, self.seats.long()] = self._stats(moves[terminal]) rewards = torch.zeros_like(self.payoffs[:, 0, 0]) rewards[terminal] = self.payoffs[self.envs[terminal], moves[terminal, 0].long(), moves[terminal, 1].long()] seats[terminal] = 0 moves[terminal] = -1 world = type(self)(payoffs=self.payoffs, seats=seats, moves=moves) transitions = arrdict.arrdict(terminal=terminal, rewards=rewards) return world, transitions def _stats(self, moves): if not moves.nelement(): return for i in range(2): for j in range(2): count = ((moves[..., 0] == i) & (moves[..., 1] == j)).sum() stats.mean(f'outcomes/{i}-{j}', count, moves.nelement() / 2)
class Hex(arrdict.namedarrtuple(fields=('board', 'seats'))): @classmethod def initial(cls, n_envs, boardsize=11, device='cuda'): # As per OpenSpiel and convention, black plays first. return cls(board=torch.full((n_envs, boardsize, boardsize), 0, device=device, dtype=torch.uint8), seats=torch.full((n_envs, ), 0, device=device, @profiling.nvtx def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(self.board, torch.Tensor): # Need this conditional to deal with the case where we're calling a method like `self.clone()`, and the # intermediate arrdict generated is full of methods, which will break this here init function. return self.n_seats = 2 self.n_envs = self.board.shape[0] self.boardsize = self.board.shape[1] self.device = self.board.device self.obs_space = heads.Tensor((self.boardsize, self.boardsize, 2)) self.action_space = heads.Masked(self.boardsize * self.boardsize) self._obs = None self._valid = None @property def obs(self): if self._obs is None: self._obs = cuda.observe(self.board, self.seats) return self._obs @property def valid(self): if self._valid is None: shape = self.board.shape[:-2] self._valid = (self.obs == 0).all(-1).reshape(*shape, -1) return self._valid @profiling.nvtx def step(self, actions, reset=True): """Args: actions: (n_env, 2)-int tensor between (0, 0) and (boardsize, boardsize). Cells are indexed in row-major order from the top-left. Returns: """ if self.board.ndim != 3: #TODO: Support stepping arbitrary batchings. Only needs a reshaping. raise ValueError( 'You can only step a board with a single batch dimension') assert (0 <= actions).all(), 'You passed a negative action' if actions.ndim == 2: actions = actions[..., 0] * self.boardsize + actions[:, 1] assert actions.shape == (self.n_envs, ) assert self.valid.gather(1, actions[:, None]).squeeze(-1).all() new_board = self.board.clone() rewards = cuda.step(new_board,, terminal = (rewards > 0).any(-1) if reset else torch.full( (self.n_envs, ), False, device=self.device) new_board[terminal] = 0 new_seat = 1 - self.seats new_seat[terminal] = 0 new_world = type(self)(board=new_board, seats=new_seat) transition = arrdict.arrdict(terminal=terminal, rewards=rewards) return new_world, transition @profiling.nvtx def __getitem__(self, x): # Just exists for profiling return super().__getitem__(x) @profiling.nvtx def __setitem__(self, x, y): # Just exists for profiling return super().__setitem__(x, y) @classmethod def plot_worlds(cls, worlds, e=None, ax=None, colors='obs', **kwargs): e = (0, ) * (worlds.board.ndim - 2) if e is None else e board = worlds.board[e] ax = plt.subplots()[1] if ax is None else ax colors = color_board(board, colors) plot_board(colors, ax, **kwargs) return ax.figure def display(self, e=None, **kwargs): ax = self.plot_worlds(arrdict.numpyify(arrdict.arrdict(self)), e=e, **kwargs) plt.close(ax.figure) return ax