Esempio n. 1
0
def test(run, snapshot=-1, **kwargs):
    boardsize = runs.info(run)['boardsize']
    worlds = hex.Hex.initial(n_envs=1024, boardsize=boardsize)

    network = storage.load_raw(run, 'model')
    sd = storage.load_snapshot(run, n=snapshot)['agent']
    network.load_state_dict(storage.expand(sd)['network'])
    A = mcts.MCTSAgent(network.cuda(), **kwargs)

    network = storage.load_raw(run, 'model')
    sd = storage.load_snapshot(run, n=snapshot)['agent']
    network.load_state_dict(storage.expand(sd)['network'])
    B = mcts.DummyAgent(network.cuda())

    fst = analysis.rollout(worlds, [A, B], n_reps=1, eval=False)
    snd = analysis.rollout(worlds, [B, A], n_reps=1, eval=False)

    wins = count_wins(fst.transitions) + count_wins(snd.transitions).flipud()

    rate = wins[0] / wins.sum()
    elo = torch.log(rate) - torch.log(1 - rate)

    kl = (kl_div(fst.decisions['0']) + kl_div(snd.decisions['0'])) / 2
    ent = (rel_entropy(fst.decisions['0']) +
           rel_entropy(snd.decisions['0'])) / 2
    return {'elo': elo.item(), 'kl': kl.item(), 'ent': ent.item()}
Esempio n. 2
0
def snapshot_kl_divs(run):
    import pandas as pd
    from pavlov import runs, storage
    from boardlaw import hex
    from boardlaw.main import mix
    import torch
    from tqdm.auto import tqdm

    m = storage.load_raw(run, 'model')
    worlds = mix(hex.Hex.initial(n_envs=16*1024, boardsize=runs.info(run)['params']['boardsize']))

    logits = {}
    for idx in tqdm(storage.snapshots(run)):
        sd = storage.load_snapshot(run, idx)['agent']
        m.load_state_dict(storage.expand(sd)['network'])
        logits[idx] = m(worlds).logits.detach()
        
    kldivs = {}
    for i in logits:
        for j in logits:
            li = logits[i]
            lj = logits[j]
            terms = -li.exp().mul(lj - li)
            mask = torch.isfinite(terms)
            kldiv = terms.where(mask, torch.zeros_like(terms)).sum(-1)
            kldivs[i, j] = kldiv.mean().item()
    df = pd.Series(kldivs).unstack()

    return df
Esempio n. 3
0
def adam_way(run, i, Bsmall):
    sd = storage.load_snapshot(run, i)
    beta1, beta2 = sd['opt']['param_groups'][0]['betas']
    step = sd['opt']['state'][0]['step']

    m_bias = 1 - beta1**step
    v_bias = 1 - beta2**step

    opt = sd['opt']['state']
    m = 1 / m_bias * torch.cat([s['exp_avg'].flatten() for s in opt.values()])
    v = 1 / v_bias * torch.cat(
        [s['exp_avg_sq'].flatten() for s in opt.values()])

    # Follows from chasing the var through the defn of m
    inflator = (1 - beta1**2) / (1 - beta1)**2

    S = Bsmall * (v.mean() - m.pow(2).mean())
    G2 = inflator * m.pow(2).mean()

    return arrdict.arrdict(S=S,
                           G2=G2,
                           B=(S / G2),
                           v=v.mean(),
                           m=m.mean(),
                           m2=m.pow(2).mean(),
                           step=torch.as_tensor(step)).item()
Esempio n. 4
0
def run():
    n_envs = 8
    world = worldfunc(n_envs)
    agent = agentfunc()
    agent.network = agent.network
    agent.kwargs['n_nodes'] = 512

    sd = storage.load_snapshot('*perky-boxes*', 64)
    agent.load_state_dict(sd['agent'])

    mhx = mohex.MoHexAgent()
    trace = rollout(world, [agent, mhx])
Esempio n. 5
0
def agent(run, idx=None, device='cpu'):
    try:
        network = storage.load_raw(run, 'model', device)
        agent = MCTSAgent(network)

        if idx is None:
            sd = storage.load_latest(run)
        else:
            sd = storage.load_snapshot(run, idx)
        agent.load_state_dict(sd['agent'])

        return agent
    except IOError:
        return None
Esempio n. 6
0
def snapshot_data(new_runs):
    snapshots = {}
    for _, r in tqdm(list(new_runs.iterrows()), desc='snapshots'):
        for i, s in storage.snapshots(r.run).items():
            stored = storage.load_snapshot(r.run, i)
            if 'n_samples' in stored:
                snapshots[r.run, i] = {
                    'samples': stored['n_samples'],
                    'flops': stored['n_flops']
                }
    snapshots = (pd.DataFrame.from_dict(
        snapshots,
        orient='index').rename_axis(index=('run', 'idx')).reset_index())
    # snapshots['id'] = snapshots.index.to_series()
    return snapshots
Esempio n. 7
0
def load(run):
    snapshots = pd.DataFrame.from_dict(storage.snapshots(run), orient='index')
    info, losses = {}, {}
    for i, row in snapshots.iterrows():
        losses[row.boardsize, row.depth,
               row.width] = storage.load_snapshot(run, i)['losses']
        info[row.boardsize, row.depth, row.width] = {
            'macs': row.n_macs,
            'params': row.n_params
        }
    losses = pd.DataFrame(losses)
    losses.index.name = 'step'
    losses.columns.names = ('boardsize', 'depth', 'width')

    info = pd.DataFrame(info)
    info.columns.names = ('boardsize', 'depth', 'width')
    return losses, info
Esempio n. 8
0
def gradients(run, i, n_envs=16 * 1024, buffer_len=64, device='cuda'):

    #TODO: Restore league and sched when you go back to large boards
    worlds = mix(Hex.initial(n_envs, device=device))
    network = storage.load_raw(run, 'model')
    agent = MCTSAgent(network)

    opt = torch.optim.Adam(network.parameters(), lr=0., amsgrad=True)
    scaler = torch.cuda.amp.GradScaler()

    sd = storage.load_snapshot(run, i)
    agent.load_state_dict(sd['agent'])
    opt.load_state_dict(sd['opt'])
    scaler.load_state_dict(sd['scaler'])

    buffer = []

    idxs = (torch.randint(buffer_len, (n_envs, ),
                          device=device), torch.arange(n_envs, device=device))
    while True:

        # Collect experience
        while len(buffer) < buffer_len:
            with torch.no_grad():
                decisions = agent(worlds, value=True)
            new_worlds, transition = worlds.step(decisions.actions)

            buffer.append(
                arrdict.arrdict(worlds=worlds,
                                decisions=decisions.half(),
                                transitions=half(transition)).detach())

            worlds = new_worlds

            log.info(f'({len(buffer)}/{buffer_len}) actor stepped')

        # Optimize
        chunk, buffer = as_chunk(buffer, n_envs)
        optimize(network, scaler, opt, chunk[idxs])
        log.info('learner stepped')

        yield torch.cat([
            p.grad.flatten() for p in network.parameters()
            if p.grad is not None
        ])
Esempio n. 9
0
def agent(run, idx=None, device='cpu'):
    try:
        network = storage.load_raw(run, 'model', device)
    except IOError:
        log.warn(f'No model file for "{run}"')
        return None

    agent = MCTSAgent(network)

    try:
        if idx is None:
            sd = storage.load_latest(run)
        else:
            sd = storage.load_snapshot(run, idx)
    except IOError:
        log.warn(f'No state dict file for "{run}"')
        return None

    agent.load_state_dict(sd['agent'])

    return agent