Example #1
0
class All(arrdict.namedarrtuple(fields=('history', 'count'))):
    """Players need to submit 1s each turn; if they do it every turn they get +1, else 0"""
    @classmethod
    def initial(cls, n_envs=1, n_seats=1, length=4, device='cuda'):
        return cls(history=torch.full((n_envs, length, n_seats),
                                      -1,
                                      dtype=torch.long,
                                      device=device),
                   count=torch.full((n_envs, ),
                                    0,
                                    dtype=torch.long,
                                    device=device))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.count, torch.Tensor):
            return

        self.n_envs, self.length, self.n_seats = self.history.shape[-3:]
        self.device = self.count.device

        self.max_count = self.n_seats * self.length

        self.obs_space = heads.Tensor((1, ))
        self.action_space = heads.Masked(2)

        self.valid = torch.ones(self.count.shape + (2, ),
                                dtype=torch.bool,
                                device=self.device)
        self.seats = self.count % self.n_seats

        self.obs = self.count[..., None].float() / self.max_count

        self.envs = torch.arange(self.n_envs, device=self.device)

        # Planted values for validation use
        self.logits = uniform_logits(self.valid)

        correct_so_far = (self.history == 1).sum(-2) == self.count[..., None]
        correct_to_go = 2**((self.history == 1).sum(-2) - self.length).float()

        v = correct_so_far.float() * correct_to_go
        self.v = v

    def step(self, actions):
        history = self.history.clone()
        idx = self.count // self.n_seats
        history[self.envs, idx, self.seats] = actions
        count = self.count + 1

        terminal = (count == self.max_count)
        reward = ((count == self.max_count)[:, None] &
                  (history == 1).all(-2)).float()
        transition = arrdict.arrdict(terminal=terminal, rewards=reward)

        count[terminal] = 0
        history[terminal] = -1

        world = type(self)(history=history, count=count)
        return world, transition
Example #2
0
class Win(arrdict.namedarrtuple(fields=('envs', ))):
    """One-step one-seat win (+1)"""
    @classmethod
    def initial(cls, n_envs=1, device='cuda'):
        return cls(envs=torch.arange(n_envs, device=device))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.envs, torch.Tensor):
            return

        self.device = self.envs.device
        self.n_envs = len(self.envs)
        self.n_seats = 1

        self.obs_space = (0, )
        self.action_space = (1, )

        self.valid = torch.ones_like(self.envs[..., None].bool())
        self.seats = torch.zeros_like(self.envs)

        self.logits = uniform_logits(self.valid)
        self.v = torch.ones_like(self.valid.float().unsqueeze(-1))

    def step(self, actions):
        trans = arrdict.arrdict(terminal=torch.ones_like(self.envs.bool()),
                                rewards=torch.ones_like(self.envs.float()))
        return self, trans
Example #3
0
class WinnerLoser(arrdict.namedarrtuple(fields=('seats', ))):
    """First seat wins each turn and gets +1; second loses and gets -1"""
    @classmethod
    def initial(cls, n_envs=1, device='cuda'):
        return cls(seats=torch.zeros(n_envs, device=device, dtype=torch.int))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.seats, torch.Tensor):
            return

        self.device = self.seats.device
        self.n_envs = len(self.seats)
        self.n_seats = 2

        self.obs_space = (0, )
        self.action_space = (1, )

        self.valid = torch.ones((self.n_envs, 1),
                                dtype=torch.bool,
                                device=self.device)
        self.seats = self.seats

        self.logits = uniform_logits(self.valid)
        self.v = torch.stack(
            [torch.ones_like(self.seats), -torch.ones_like(self.seats)],
            -1).float()

    def step(self, actions):
        terminal = (self.seats == 1)
        trans = arrdict.arrdict(terminal=terminal,
                                rewards=torch.stack(
                                    [terminal.float(), -terminal.float()], -1))
        return type(self)(seats=1 - self.seats), trans
Example #4
0
class MockGame(arrdict.namedarrtuple(fields=('count', 'history'))):
    @classmethod
    def initial(cls, n_envs=1, length=4, device='cuda'):
        return cls(history=torch.full((n_envs, length),
                                      -1,
                                      dtype=torch.long,
                                      device=device),
                   count=torch.full((n_envs, ),
                                    0,
                                    dtype=torch.long,
                                    device=device))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.count, torch.Tensor):
            return

        self.n_envs = self.count.shape[0]
        self.device = self.count.device
        self.n_seats = 2

        self.valid = torch.ones(self.count.shape + (2, ),
                                dtype=torch.bool,
                                device=self.device)

    @property
    def seats(self):
        return self.count % self.n_seats

    def step(self, actions):
        history = self.history.clone()
        history.scatter_(1, self.count[:, None], actions[:, None])

        count = self.count + 1
        terminal = (count == self.history.shape[1])
        transition = arrdict.arrdict(terminal=terminal)

        count[terminal] = 0

        world = type(self)(count=count, history=history)

        return world, transition, list(history[terminal])
Example #5
0
class MockWorlds(arrdict.namedarrtuple(fields=('seats', 'cumulator', 'lengths'))):
    """One-step one-seat win (+1)"""

    @classmethod
    def initial(cls, n_envs=1, device='cpu'):
        return cls(
            seats=torch.full((n_envs,), 0, device=device),
            lengths=torch.full((n_envs,), 1, device=device),
            cumulator=torch.full((n_envs, 2), 0., device=device))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.seats, torch.Tensor):
            return 

        self.n_envs = len(self.seats)
        self.device = self.seats.device
        self.n_seats = 2

    def step(self, skills):
        terminal = self.lengths == 4
        new_lengths = self.lengths + 1
        new_lengths[terminal] = 1

        new_cumulator = self.cumulator.scatter_add(1, self.seats[:, None], skills[:, None].float())

        rates = torch.sigmoid((new_cumulator[:, 0] - new_cumulator[:, 1])/4)
        rewards = terminal * (2*(torch.rand((self.n_envs,),) <= rates) - 1)
        rewards = torch.stack([rewards, -rewards], -1)

        new_cumulator[terminal] = 0

        new_seats = 1 - self.seats
        new_seats[terminal] = 0

        trans = arrdict.arrdict(terminal=terminal, rewards=rewards)
        return type(self)(seats=new_seats, lengths=new_lengths, cumulator=new_cumulator), trans
Example #6
0
class SequentialMatrix(
        arrdict.namedarrtuple(fields=('payoffs', 'moves', 'seats'))):
    @classmethod
    def initial(cls, payoff, n_envs=1, device='cuda'):
        return cls(payoffs=torch.as_tensor(payoff).to(device)[None,
                                                              ...].repeat(
                                                                  n_envs, 1, 1,
                                                                  1),
                   seats=torch.zeros((n_envs, ),
                                     dtype=torch.int,
                                     device=device),
                   moves=torch.full((n_envs, 2),
                                    -1,
                                    dtype=torch.int,
                                    device=device))

    @classmethod
    def dilemma(cls, *args, **kwargs):
        return cls.initial([[[0., 0.], [1., 0.]], [[0., 1.], [.5, .5]]], *args,
                           **kwargs)

    @classmethod
    def antisymmetric(cls, *args, **kwargs):
        return cls.initial([[[1., 0.], [1., 1.]], [[0., 0.], [0., .1]]], *args,
                           **kwargs)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.payoffs, torch.Tensor):
            return

        self.n_envs = self.seats.size(-1)
        self.n_seats = 2
        self.device = self.seats.device

        self.obs_space = heads.Tensor((1, ))
        self.action_space = heads.Masked(2)

        self.obs = self.moves[..., [0]].float()
        self.valid = torch.stack(
            [torch.ones_like(self.seats, dtype=torch.bool)] * 2, -1)

        self.envs = torch.arange(self.n_envs,
                                 device=self.device,
                                 dtype=torch.long)

    def step(self, actions):
        seats = self.seats + 1
        terminal = (seats == 2)

        moves = self.moves.clone()
        moves[self.envs, self.seats.long()] = actions.int()
        self._stats(moves[terminal])

        rewards = torch.zeros_like(self.payoffs[:, 0, 0])
        rewards[terminal] = self.payoffs[self.envs[terminal], moves[terminal,
                                                                    0].long(),
                                         moves[terminal, 1].long()]

        seats[terminal] = 0
        moves[terminal] = -1

        world = type(self)(payoffs=self.payoffs, seats=seats, moves=moves)
        transitions = arrdict.arrdict(terminal=terminal, rewards=rewards)

        return world, transitions

    def _stats(self, moves):
        if not moves.nelement():
            return
        for i in range(2):
            for j in range(2):
                count = ((moves[..., 0] == i) & (moves[..., 1] == j)).sum()
                stats.mean(f'outcomes/{i}-{j}', count, moves.nelement() / 2)
Example #7
0
class Hex(arrdict.namedarrtuple(fields=('board', 'seats'))):
    @classmethod
    def initial(cls, n_envs, boardsize=11, device='cuda'):
        # As per OpenSpiel and convention, black plays first.
        return cls(board=torch.full((n_envs, boardsize, boardsize),
                                    0,
                                    device=device,
                                    dtype=torch.uint8),
                   seats=torch.full((n_envs, ),
                                    0,
                                    device=device,
                                    dtype=torch.int))

    @profiling.nvtx
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not isinstance(self.board, torch.Tensor):
            # Need this conditional to deal with the case where we're calling a method like `self.clone()`, and the
            # intermediate arrdict generated is full of methods, which will break this here init function.
            return

        self.n_seats = 2
        self.n_envs = self.board.shape[0]
        self.boardsize = self.board.shape[1]
        self.device = self.board.device

        self.obs_space = heads.Tensor((self.boardsize, self.boardsize, 2))
        self.action_space = heads.Masked(self.boardsize * self.boardsize)

        self._obs = None
        self._valid = None

    @property
    def obs(self):
        if self._obs is None:
            self._obs = cuda.observe(self.board, self.seats)
        return self._obs

    @property
    def valid(self):
        if self._valid is None:
            shape = self.board.shape[:-2]
            self._valid = (self.obs == 0).all(-1).reshape(*shape, -1)
        return self._valid

    @profiling.nvtx
    def step(self, actions, reset=True):
        """Args:
            actions: (n_env, 2)-int tensor between (0, 0) and (boardsize, boardsize). Cells are indexed in row-major
            order from the top-left.
            
        Returns:

        """
        if self.board.ndim != 3:
            #TODO: Support stepping arbitrary batchings. Only needs a reshaping.
            raise ValueError(
                'You can only step a board with a single batch dimension')

        assert (0 <= actions).all(), 'You passed a negative action'
        if actions.ndim == 2:
            actions = actions[..., 0] * self.boardsize + actions[:, 1]

        assert actions.shape == (self.n_envs, )
        assert self.valid.gather(1, actions[:, None]).squeeze(-1).all()

        new_board = self.board.clone()
        rewards = cuda.step(new_board, self.seats.int(), actions.int())
        terminal = (rewards > 0).any(-1) if reset else torch.full(
            (self.n_envs, ), False, device=self.device)

        new_board[terminal] = 0

        new_seat = 1 - self.seats
        new_seat[terminal] = 0

        new_world = type(self)(board=new_board, seats=new_seat)

        transition = arrdict.arrdict(terminal=terminal, rewards=rewards)
        return new_world, transition

    @profiling.nvtx
    def __getitem__(self, x):
        # Just exists for profiling
        return super().__getitem__(x)

    @profiling.nvtx
    def __setitem__(self, x, y):
        # Just exists for profiling
        return super().__setitem__(x, y)

    @classmethod
    def plot_worlds(cls, worlds, e=None, ax=None, colors='obs', **kwargs):
        e = (0, ) * (worlds.board.ndim - 2) if e is None else e
        board = worlds.board[e]

        ax = plt.subplots()[1] if ax is None else ax

        colors = color_board(board, colors)
        plot_board(colors, ax, **kwargs)

        return ax.figure

    def display(self, e=None, **kwargs):
        ax = self.plot_worlds(arrdict.numpyify(arrdict.arrdict(self)),
                              e=e,
                              **kwargs)
        plt.close(ax.figure)
        return ax