def evaluate(pair, n_envs=64 * 1024, device='cuda'): agents = {} for name, sd in pair.items(): agent = agentfunc(device) agent.network.load_state_dict(sd) agents[name] = agent worlds = mix(worldfunc(n_envs, device=device)) return arena.evaluate(worlds, agents)
def run(): n_envs = 8 world = worldfunc(n_envs) agent = agentfunc() agent.network = agent.network agent.kwargs['n_nodes'] = 512 sd = storage.load_snapshot('*perky-boxes*', 64) agent.load_state_dict(sd['agent']) mhx = mohex.MoHexAgent() trace = rollout(world, [agent, mhx])
def compare(fst_run=-1, snd_run=-1, n_envs=256, device='cuda:1'): import pandas as pd from .main import worldfunc, agentfunc world = worldfunc(n_envs, device=device) fst = agentfunc(device=device) fst.load_state_dict(storing.select(storing.load_latest(fst_run), 'agent')) snd = agentfunc(device=device) snd.load_state_dict(storing.select(storing.load_latest(snd_run), 'agent')) bw = rollout(world, [fst, snd], n_reps=1) bw_wins = (bw.transitions.rewards[bw.transitions.terminal.cumsum(0) <= 1] == 1).sum(0) wb = rollout(world, [snd, fst], n_reps=1) wb_wins = (wb.transitions.rewards[wb.transitions.terminal.cumsum(0) <= 1] == 1).sum(0) # Rows: black, white; cols: old, new wins = torch.stack([bw_wins, wb_wins.flipud()]).detach().cpu().numpy() return pd.DataFrame(wins/n_envs, ['black', 'white'], ['fst', 'snd'])
def mohex_benchmark(run): from boardlaw import mohex from boardlaw.main import worldfunc, agentfunc from boardlaw.arena import evaluate n_envs = 8 worlds = worldfunc(n_envs) agent = agentfunc() agent.load_state_dict(storage.load_latest('*gross-steams')['agent']) # agent.kwargs['n_nodes'] = 512 # agent.kwargs['noise_eps'] = 0. mhx = mohex.MoHexAgent() return evaluate(worlds, {'boardlaw': agent, 'mohex': mhx})
def generate_state_dicts(run): n_envs = 24 * 1024 buffer_len = 64 device = 'cuda' #TODO: Restore league and sched when you go back to large boards worlds = mix(worldfunc(n_envs, device=device)) agent = agentfunc(device) network = agent.network opt = torch.optim.Adam(network.parameters(), lr=1e-2, amsgrad=True) scaler = torch.cuda.amp.GradScaler() sd = storage.load_latest(run) agent.load_state_dict(sd['agent']) opt.load_state_dict(sd['opt']) state_dicts = [clone(network.state_dict())] buffer = [] #TODO: Upgrade this to handle batches that are some multiple of the env count idxs = (torch.randint(buffer_len, (n_envs, ), device=device), torch.arange(n_envs, device=device)) for _ in range(8): # Collect experience while len(buffer) < buffer_len: with torch.no_grad(): decisions = agent(worlds, value=True) new_worlds, transition = worlds.step(decisions.actions) buffer.append( arrdict.arrdict(worlds=worlds, decisions=decisions.half(), transitions=half(transition)).detach()) worlds = new_worlds log.info(f'({len(buffer)}/{buffer_len}) actor stepped') # Optimize chunk, buffer = as_chunk(buffer, n_envs) optimize(network, scaler, opt, chunk[idxs]) log.info('learner stepped') state_dicts.append(clone(network.state_dict())) return state_dicts