Esempio n. 1
0
def run(boardsize,
        width,
        depth,
        desc,
        nodes=64,
        c_puct=1 / 16,
        lr=1e-3,
        n_envs=32 * 1024):
    buffer_len = 64

    worlds = learning.mix(hex.Hex.initial(n_envs, boardsize))
    network = networks.FCModel(worlds.obs_space,
                               worlds.action_space,
                               width=width,
                               depth=depth).to(worlds.device)
    agent = mcts.MCTSAgent(network, n_nodes=nodes, c_puct=c_puct)

    opt = torch.optim.Adam(network.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler()

    run = runs.new_run(description=desc,
                       params=dict(boardsize=worlds.boardsize,
                                   width=width,
                                   depth=depth,
                                   nodes=nodes,
                                   c_puct=c_puct,
                                   lr=lr,
                                   n_envs=n_envs))

    archive.archive(run)

    storer = storage.TimeStorer(run, agent)
    noise = noisescales.NoiseScales(agent, buffer_len)

    buffer = []
    with logs.to_run(run), stats.to_run(run), \
            arena.live.run(run):
        #TODO: Upgrade this to handle batches that are some multiple of the env count
        idxs = (torch.randint(buffer_len, (n_envs, ), device='cuda'),
                torch.arange(n_envs, device='cuda'))
        while True:

            # Collect experience
            while len(buffer) < buffer_len:
                with torch.no_grad():
                    decisions = agent(worlds, value=True)
                new_worlds, transition = worlds.step(decisions.actions)

                buffer.append(
                    arrdict.arrdict(
                        worlds=worlds,
                        decisions=decisions.half(),
                        transitions=learning.half(transition)).detach())

                worlds = new_worlds

                log.info(f'({len(buffer)}/{buffer_len}) actor stepped')

            # Optimize
            chunk, buffer = as_chunk(buffer, n_envs)
            optimize(network, scaler, opt, chunk[idxs])
            log.info('learner stepped')

            stats.gpu(worlds.device, 15)

            noise.step(chunk)
            finish = storer.step(agent, len(idxs[0]))
            if finish:
                break

        log.info('Finished')
Esempio n. 2
0
 def reset(self):
     reset = self.core.env_full(True)
     self._reset(reset)
     obs, reward = self._observe(reset)
     return arrdict.arrdict(obs=obs, reset=reset, reward=reward)
Esempio n. 3
0
 def forward(self, worlds):
     neck = self.body(worlds.obs)
     return arrdict.arrdict(
         logits=self.policy(neck, worlds.valid), 
         v=self.value(neck, worlds.valid, worlds.seats))
Esempio n. 4
0
def run(boardsize, width, depth, timelimit, desc):
    buffer_len = 64
    n_envs = 32 * 1024

    #TODO: Restore league and sched when you go back to large boards
    worlds = mix(hex.Hex.initial(n_envs, boardsize))
    network = networks.FCModel(worlds.obs_space,
                               worlds.action_space,
                               width=width,
                               depth=depth).to(worlds.device)
    agent = mcts.MCTSAgent(network)

    opt = torch.optim.Adam(network.parameters(), lr=1e-3)
    scaler = torch.cuda.amp.GradScaler()

    parent = warm_start(agent, opt, scaler, '')

    run = runs.new_run(description=desc,
                       params=dict(boardsize=worlds.boardsize,
                                   width=width,
                                   depth=depth,
                                   parent=parent))

    archive.archive(run)

    buffer = []
    with logs.to_run(run), stats.to_run(run), \
            arena.mohex.run(run):
        #TODO: Upgrade this to handle batches that are some multiple of the env count
        idxs = (torch.randint(buffer_len, (n_envs, ), device='cuda'),
                torch.arange(n_envs, device='cuda'))
        for _ in time_limited_loop(timelimit):

            # Collect experience
            while len(buffer) < buffer_len:
                with torch.no_grad():
                    decisions = agent(worlds, value=True)
                new_worlds, transition = worlds.step(decisions.actions)

                buffer.append(
                    arrdict.arrdict(worlds=worlds,
                                    decisions=decisions.half(),
                                    transitions=half(transition)).detach())

                worlds = new_worlds

                log.info(f'({len(buffer)}/{buffer_len}) actor stepped')

            # Optimize
            chunk, buffer = as_chunk(buffer, n_envs)
            optimize(network, scaler, opt, chunk[idxs])
            log.info('learner stepped')

            sd = storage.state_dicts(agent=agent, opt=opt, scaler=scaler)
            storage.throttled_latest(run, sd, 60)
            storage.throttled_snapshot(run, sd, 900)
            storage.throttled_raw(run, 'model', lambda: pickle.dumps(network),
                                  900)
            stats.gpu(worlds.device, 15)

        log.info('Finished; saving final state dict')
        sd = storage.state_dicts(agent=agent, opt=opt, scaler=scaler)
        storage.save_latest(run, sd)
Esempio n. 5
0
def render(core):
    render = unpack(cuda.render(core.scenery, core.agents))
    render = arrdict.arrdict({k: v.unsqueeze(2) for k, v in render.items()})
    render['screen'] = render.screen.permute(0, 1, 4, 2, 3)
    return render
Esempio n. 6
0
 def step(self, actions):
     terminal = (self.seats == 1)
     trans = arrdict.arrdict(terminal=terminal,
                             rewards=torch.stack(
                                 [terminal.float(), -terminal.float()], -1))
     return type(self)(seats=1 - self.seats), trans
Esempio n. 7
0
 def step(self, decision):
     reset = self._reset()
     self._movement(collapse(decision, self.core.n_agents))
     obs, reward = self._observe()
     return arrdict.arrdict(obs=expand(obs), reward=reward, reset=reset)
Esempio n. 8
0
 def reset(self):
     self.spawner(self.core.agent_full(True))
     return arrdict.arrdict(obs=self.rgb())
Esempio n. 9
0
 def step(self, decision):
     self.movement(decision)
     return arrdict.arrdict(obs=self.rgb())
Esempio n. 10
0
 def __call__(self, world):
     id = torch.full((world.n_envs, ),
                     self.id,
                     device=world.device,
                     dtype=torch.long)
     return arrdict.arrdict(actions=id)
Esempio n. 11
0
    def __init__(self,
                 world,
                 n_nodes=64,
                 c_puct=1 / 16,
                 noise_eps=.25,
                 alpha_scale=10):
        """
        c_puct high: concentrates on prior
        c_puct low: concentrates on value
        """
        self.device = world.device
        self.n_envs = world.n_envs
        self.n_nodes = n_nodes
        self.n_seats = world.n_seats
        assert n_nodes > 0, 'MCTS requires at least one node'

        self.envs = torch.arange(world.n_envs, device=self.device)

        self.n_actions = np.prod(world.action_space)
        self.tree = arrdict.arrdict(children=self.envs.new_full(
            (world.n_envs, self.n_nodes, self.n_actions),
            -1,
            dtype=torch.short),
                                    parents=self.envs.new_full(
                                        (world.n_envs, self.n_nodes),
                                        -1,
                                        dtype=torch.short),
                                    relation=self.envs.new_full(
                                        (world.n_envs, self.n_nodes),
                                        -1,
                                        dtype=torch.short))

        self.worlds = arrdict.stack([world for _ in range(self.n_nodes)], 1)

        self.transitions = arrdict.arrdict(rewards=torch.full(
            (world.n_envs, self.n_nodes, self.n_seats),
            0.,
            device=self.device,
            dtype=torch.half),
                                           terminal=torch.full(
                                               (world.n_envs, self.n_nodes),
                                               False,
                                               device=self.device,
                                               dtype=torch.bool))

        self.decisions = arrdict.arrdict(
            logits=torch.full((world.n_envs, self.n_nodes, self.n_actions),
                              np.nan,
                              device=self.device,
                              dtype=torch.half),
            v=torch.full((world.n_envs, self.n_nodes, self.n_seats),
                         np.nan,
                         device=self.device,
                         dtype=torch.half))

        self.stats = arrdict.arrdict(
            n=torch.full((world.n_envs, self.n_nodes),
                         0,
                         device=self.device,
                         dtype=torch.short),
            w=torch.full((world.n_envs, self.n_nodes, self.n_seats),
                         0.,
                         device=self.device,
                         dtype=torch.half))

        self.sim = torch.tensor(0, device=self.device, dtype=torch.long)
        self.worlds[:, 0] = world

        # https://github.com/LeelaChessZero/lc0/issues/694
        # Larger c_puct -> greater regularization
        self.c_puct = torch.full((world.n_envs, ),
                                 c_puct,
                                 device=self.device,
                                 dtype=torch.half)

        self.noise_eps = noise_eps
        self.alpha_scale = alpha_scale
Esempio n. 12
0
def sensible_way(gs, Bsmall):
    S = Bsmall * (gs - gs.mean(0, keepdims=True)).pow(2).mean()
    G2 = gs.mean(0).pow(2).mean()
    return arrdict.arrdict(S=S, G2=G2, B=(S / G2)).item()
Esempio n. 13
0
 def state(self, e):
     """Returns the state of this module on sub-env ``e``. The state is a :ref:`arrdict <dotdicts>` of 
     the agents' lifespans and max lifespans as (n_agent,)-tensors."""
     return arrdict.arrdict(lifespan=self._lifespans[e], max_lifespans=self._max_lifespans[e]).clone()
Esempio n. 14
0
def unpack(d):
    """Unpacks :mod:`~megastep.cuda` datastructures into :ref:`arrdicts <dotdicts>` with the same attributes."""
    if isinstance(d, torch.Tensor):
        return d
    return arrdict.arrdict({k: unpack(getattr(d, k)) for k in dir(d) if not k.startswith('_')})
Esempio n. 15
0
 def display(self, e=None, **kwargs):
     ax = self.plot_worlds(arrdict.numpyify(arrdict.arrdict(self)),
                           e=e,
                           **kwargs)
     plt.close(ax.figure)
     return ax
Esempio n. 16
0
 def forward(self, world):
     logits = self.policy(world.obs)
     actions = self.output.sample(logits)
     return arrdict.arrdict(logits=logits, actions=actions)
Esempio n. 17
0
 def step(self, actions):
     trans = arrdict.arrdict(terminal=torch.ones_like(self.envs.bool()),
                             rewards=torch.ones_like(self.envs.float()))
     return self, trans
Esempio n. 18
0
 def state(self, d):
     return arrdict.arrdict(length=self._lengths[d],
                            max_length=self._max_lengths[d]).clone()
Esempio n. 19
0
 def __call__(self, world, value=False):
     return arrdict.arrdict(logits=world.logits, v=world.v)
Esempio n. 20
0
def unpack(d):
    if isinstance(d, torch.Tensor):
        return d
    return arrdict.arrdict(
        {k: unpack(getattr(d, k))
         for k in dir(d) if not k.startswith('_')})
Esempio n. 21
0
 def reset(self):
     reset = self._reset(self.core.agent_full(True))
     obs, reward = self._observe()
     return arrdict.arrdict(obs=expand(obs), reward=reward, reset=reset)
Esempio n. 22
0
def solve(n, w, **kwargs):
    if isinstance(n, pd.DataFrame):
        return arrdict.arrdict({k: common.pandify(v, n.index) for k, v in solve(n.values, w.values, **common.numpyify(kwargs)).items()})
    return _solve(n, w, **kwargs)