def run(boardsize, width, depth, desc, nodes=64, c_puct=1 / 16, lr=1e-3, n_envs=32 * 1024): buffer_len = 64 worlds = learning.mix(hex.Hex.initial(n_envs, boardsize)) network = networks.FCModel(worlds.obs_space, worlds.action_space, width=width, depth=depth).to(worlds.device) agent = mcts.MCTSAgent(network, n_nodes=nodes, c_puct=c_puct) opt = torch.optim.Adam(network.parameters(), lr=lr) scaler = torch.cuda.amp.GradScaler() run = runs.new_run(description=desc, params=dict(boardsize=worlds.boardsize, width=width, depth=depth, nodes=nodes, c_puct=c_puct, lr=lr, n_envs=n_envs)) archive.archive(run) storer = storage.TimeStorer(run, agent) noise = noisescales.NoiseScales(agent, buffer_len) buffer = [] with logs.to_run(run), stats.to_run(run), \ arena.live.run(run): #TODO: Upgrade this to handle batches that are some multiple of the env count idxs = (torch.randint(buffer_len, (n_envs, ), device='cuda'), torch.arange(n_envs, device='cuda')) while True: # Collect experience while len(buffer) < buffer_len: with torch.no_grad(): decisions = agent(worlds, value=True) new_worlds, transition = worlds.step(decisions.actions) buffer.append( arrdict.arrdict( worlds=worlds, decisions=decisions.half(), transitions=learning.half(transition)).detach()) worlds = new_worlds log.info(f'({len(buffer)}/{buffer_len}) actor stepped') # Optimize chunk, buffer = as_chunk(buffer, n_envs) optimize(network, scaler, opt, chunk[idxs]) log.info('learner stepped') stats.gpu(worlds.device, 15) noise.step(chunk) finish = storer.step(agent, len(idxs[0])) if finish: break log.info('Finished')
def reset(self): reset = self.core.env_full(True) self._reset(reset) obs, reward = self._observe(reset) return arrdict.arrdict(obs=obs, reset=reset, reward=reward)
def forward(self, worlds): neck = self.body(worlds.obs) return arrdict.arrdict( logits=self.policy(neck, worlds.valid), v=self.value(neck, worlds.valid, worlds.seats))
def run(boardsize, width, depth, timelimit, desc): buffer_len = 64 n_envs = 32 * 1024 #TODO: Restore league and sched when you go back to large boards worlds = mix(hex.Hex.initial(n_envs, boardsize)) network = networks.FCModel(worlds.obs_space, worlds.action_space, width=width, depth=depth).to(worlds.device) agent = mcts.MCTSAgent(network) opt = torch.optim.Adam(network.parameters(), lr=1e-3) scaler = torch.cuda.amp.GradScaler() parent = warm_start(agent, opt, scaler, '') run = runs.new_run(description=desc, params=dict(boardsize=worlds.boardsize, width=width, depth=depth, parent=parent)) archive.archive(run) buffer = [] with logs.to_run(run), stats.to_run(run), \ arena.mohex.run(run): #TODO: Upgrade this to handle batches that are some multiple of the env count idxs = (torch.randint(buffer_len, (n_envs, ), device='cuda'), torch.arange(n_envs, device='cuda')) for _ in time_limited_loop(timelimit): # Collect experience while len(buffer) < buffer_len: with torch.no_grad(): decisions = agent(worlds, value=True) new_worlds, transition = worlds.step(decisions.actions) buffer.append( arrdict.arrdict(worlds=worlds, decisions=decisions.half(), transitions=half(transition)).detach()) worlds = new_worlds log.info(f'({len(buffer)}/{buffer_len}) actor stepped') # Optimize chunk, buffer = as_chunk(buffer, n_envs) optimize(network, scaler, opt, chunk[idxs]) log.info('learner stepped') sd = storage.state_dicts(agent=agent, opt=opt, scaler=scaler) storage.throttled_latest(run, sd, 60) storage.throttled_snapshot(run, sd, 900) storage.throttled_raw(run, 'model', lambda: pickle.dumps(network), 900) stats.gpu(worlds.device, 15) log.info('Finished; saving final state dict') sd = storage.state_dicts(agent=agent, opt=opt, scaler=scaler) storage.save_latest(run, sd)
def render(core): render = unpack(cuda.render(core.scenery, core.agents)) render = arrdict.arrdict({k: v.unsqueeze(2) for k, v in render.items()}) render['screen'] = render.screen.permute(0, 1, 4, 2, 3) return render
def step(self, actions): terminal = (self.seats == 1) trans = arrdict.arrdict(terminal=terminal, rewards=torch.stack( [terminal.float(), -terminal.float()], -1)) return type(self)(seats=1 - self.seats), trans
def step(self, decision): reset = self._reset() self._movement(collapse(decision, self.core.n_agents)) obs, reward = self._observe() return arrdict.arrdict(obs=expand(obs), reward=reward, reset=reset)
def reset(self): self.spawner(self.core.agent_full(True)) return arrdict.arrdict(obs=self.rgb())
def step(self, decision): self.movement(decision) return arrdict.arrdict(obs=self.rgb())
def __call__(self, world): id = torch.full((world.n_envs, ), self.id, device=world.device, dtype=torch.long) return arrdict.arrdict(actions=id)
def __init__(self, world, n_nodes=64, c_puct=1 / 16, noise_eps=.25, alpha_scale=10): """ c_puct high: concentrates on prior c_puct low: concentrates on value """ self.device = world.device self.n_envs = world.n_envs self.n_nodes = n_nodes self.n_seats = world.n_seats assert n_nodes > 0, 'MCTS requires at least one node' self.envs = torch.arange(world.n_envs, device=self.device) self.n_actions = np.prod(world.action_space) self.tree = arrdict.arrdict(children=self.envs.new_full( (world.n_envs, self.n_nodes, self.n_actions), -1, dtype=torch.short), parents=self.envs.new_full( (world.n_envs, self.n_nodes), -1, dtype=torch.short), relation=self.envs.new_full( (world.n_envs, self.n_nodes), -1, dtype=torch.short)) self.worlds = arrdict.stack([world for _ in range(self.n_nodes)], 1) self.transitions = arrdict.arrdict(rewards=torch.full( (world.n_envs, self.n_nodes, self.n_seats), 0., device=self.device, dtype=torch.half), terminal=torch.full( (world.n_envs, self.n_nodes), False, device=self.device, dtype=torch.bool)) self.decisions = arrdict.arrdict( logits=torch.full((world.n_envs, self.n_nodes, self.n_actions), np.nan, device=self.device, dtype=torch.half), v=torch.full((world.n_envs, self.n_nodes, self.n_seats), np.nan, device=self.device, dtype=torch.half)) self.stats = arrdict.arrdict( n=torch.full((world.n_envs, self.n_nodes), 0, device=self.device, dtype=torch.short), w=torch.full((world.n_envs, self.n_nodes, self.n_seats), 0., device=self.device, dtype=torch.half)) self.sim = torch.tensor(0, device=self.device, dtype=torch.long) self.worlds[:, 0] = world # https://github.com/LeelaChessZero/lc0/issues/694 # Larger c_puct -> greater regularization self.c_puct = torch.full((world.n_envs, ), c_puct, device=self.device, dtype=torch.half) self.noise_eps = noise_eps self.alpha_scale = alpha_scale
def sensible_way(gs, Bsmall): S = Bsmall * (gs - gs.mean(0, keepdims=True)).pow(2).mean() G2 = gs.mean(0).pow(2).mean() return arrdict.arrdict(S=S, G2=G2, B=(S / G2)).item()
def state(self, e): """Returns the state of this module on sub-env ``e``. The state is a :ref:`arrdict <dotdicts>` of the agents' lifespans and max lifespans as (n_agent,)-tensors.""" return arrdict.arrdict(lifespan=self._lifespans[e], max_lifespans=self._max_lifespans[e]).clone()
def unpack(d): """Unpacks :mod:`~megastep.cuda` datastructures into :ref:`arrdicts <dotdicts>` with the same attributes.""" if isinstance(d, torch.Tensor): return d return arrdict.arrdict({k: unpack(getattr(d, k)) for k in dir(d) if not k.startswith('_')})
def display(self, e=None, **kwargs): ax = self.plot_worlds(arrdict.numpyify(arrdict.arrdict(self)), e=e, **kwargs) plt.close(ax.figure) return ax
def forward(self, world): logits = self.policy(world.obs) actions = self.output.sample(logits) return arrdict.arrdict(logits=logits, actions=actions)
def step(self, actions): trans = arrdict.arrdict(terminal=torch.ones_like(self.envs.bool()), rewards=torch.ones_like(self.envs.float())) return self, trans
def state(self, d): return arrdict.arrdict(length=self._lengths[d], max_length=self._max_lengths[d]).clone()
def __call__(self, world, value=False): return arrdict.arrdict(logits=world.logits, v=world.v)
def unpack(d): if isinstance(d, torch.Tensor): return d return arrdict.arrdict( {k: unpack(getattr(d, k)) for k in dir(d) if not k.startswith('_')})
def reset(self): reset = self._reset(self.core.agent_full(True)) obs, reward = self._observe() return arrdict.arrdict(obs=expand(obs), reward=reward, reset=reset)
def solve(n, w, **kwargs): if isinstance(n, pd.DataFrame): return arrdict.arrdict({k: common.pandify(v, n.index) for k, v in solve(n.values, w.values, **common.numpyify(kwargs)).items()}) return _solve(n, w, **kwargs)