def main(rank):
    if cmd_args.gpu >= 0:
        set_device(rank)
    else:
        set_device(-1)
    setup_treelib(cmd_args)
    model = RecurTreeGen(cmd_args).to(cmd_args.device)

    if rank == 0 and cmd_args.model_dump is not None and os.path.isfile(
            cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    graphs = []
    with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f:
        while True:
            try:
                g = cp.load(f)
                TreeLib.InsertGraph(g)
            except:
                break
            graphs.append(g)

    for epoch in range(cmd_args.num_epochs):
        pbar = range(cmd_args.epoch_save)
        if rank == 0:
            pbar = tqdm(pbar)

        g = graphs[0]
        graph_ids = [0]
        blksize = cmd_args.blksize
        if blksize < 0 or blksize > len(g):
            blksize = len(g)

        for e_it in pbar:
            optimizer.zero_grad()
            num_stages = len(g) // (blksize * cmd_args.num_proc) + (
                len(g) % (blksize * cmd_args.num_proc) > 0)
            states_prev = [None, None]
            list_caches = []
            prev_rank_last = None
            for stage in range(num_stages):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    break
                with torch.no_grad():
                    fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees(
                        graph_ids,
                        list_node_starts=[local_st],
                        num_nodes=local_num)
                    if stage and rank == 0:
                        states_prev = recv_states(num_expect(local_st),
                                                  prev_rank_last)
                    if rank:
                        num_recv = num_expect(local_st)
                        states_prev = recv_states(num_recv, rank - 1)
                    _, next_states = model.row_tree.forward_train(
                        *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0],
                        *states_prev)
                    list_caches.append(states_prev)
                    if rank != rank_last:
                        send_states(next_states, rank + 1)
                    elif stage + 1 < num_stages:
                        send_states(next_states, 0)
                prev_rank_last = rank_last

            tot_ll = torch.zeros(1).to(cmd_args.device)
            for stage in range(num_stages - 1, -1, -1):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    continue
                prev_states = list_caches[stage]
                if prev_states[0] is not None:
                    for x in prev_states:
                        x.requires_grad = True
                ll, cur_states = model.forward_train(
                    graph_ids,
                    list_node_starts=[local_st],
                    num_nodes=local_num,
                    prev_rowsum_states=prev_states)
                tot_ll = tot_ll + ll
                loss = -ll / len(g)
                if stage + 1 == num_stages and rank == rank_last:
                    loss.backward()
                else:
                    top_grad = recv_states(
                        cur_states[0].shape[0],
                        rank + 1 if rank != rank_last else 0)
                    torch.autograd.backward([loss, *cur_states],
                                            [None, *top_grad])
                if prev_states[0] is not None:
                    grads = [x.grad.detach() for x in prev_states]
                    dst = rank - 1 if rank else cmd_args.num_proc - 1
                    send_states(grads, dst)
            dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM)

            for param in model.parameters():
                if param.grad is None:
                    param.grad = param.data.new(param.data.shape).zero_()
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            if cmd_args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=cmd_args.grad_clip)
            optimizer.step()

            if rank == 0:
                loss = -tot_ll.item() / len(g)
                pbar.set_description('epoch %.2f, loss: %.4f' %
                                     (epoch +
                                      (e_it + 1) / cmd_args.epoch_save, loss))
                torch.save(
                    model.state_dict(),
                    os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch))

    print('done rank', rank)
import numpy as np
import random
import networkx as nx
from bigg.common.configs import cmd_args, set_device
from bigg.model.tree_clib.tree_lib import setup_treelib, TreeLib
from bigg.model.tree_model import RecurTreeGen

if __name__ == '__main__':
    random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    set_device(cmd_args.gpu)
    setup_treelib(cmd_args)

    train_graphs = [nx.barabasi_albert_graph(10, 2)]
    TreeLib.InsertGraph(train_graphs[0])
    max_num_nodes = max([len(gg.nodes) for gg in train_graphs])
    cmd_args.max_num_nodes = max_num_nodes

    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for i in range(2):
        optimizer.zero_grad()
        ll, _ = model.forward_train([0])
        loss = -ll / max_num_nodes
        print('iter', i, 'loss', loss.item())
        loss.backward()
        optimizer.step()
Beispiel #3
0
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    indices = list(range(len(train_graphs)))
    for epoch in range(cmd_args.num_epochs):
        pbar = tqdm(range(cmd_args.epoch_save))

        optimizer.zero_grad()
        for idx in pbar:
            random.shuffle(indices)
            batch_indices = indices[:cmd_args.batch_size]

            num_nodes = sum([len(train_graphs[i]) for i in batch_indices])
            if cmd_args.blksize < 0 or num_nodes <= cmd_args.blksize:
                ll, _ = model.forward_train(batch_indices,
                                            list_col_ranges=[
                                                (0, stats[i][1])
                                                for i in batch_indices
                                            ])
                loss = -ll / num_nodes
                loss.backward()
                loss = loss.item()
            else:
                ll = 0.0
                for i in batch_indices:
                    n = len(train_graphs[i])
                    cur_ll, _ = sqrtn_forward_backward(
                        model,
                        graph_ids=[i],
                        list_node_starts=[0],
                        num_nodes=stats[i][0],
                        blksize=cmd_args.blksize,
Beispiel #4
0
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    indices = list(range(len(train_graphs)))
    if cmd_args.epoch_load is None:
        cmd_args.epoch_load = 0
    for epoch in range(cmd_args.epoch_load, cmd_args.num_epochs):
        pbar = tqdm(range(cmd_args.epoch_save))

        optimizer.zero_grad()
        for idx in pbar:
            random.shuffle(indices)
            batch_indices = indices[:cmd_args.batch_size]

            num_nodes = sum([len(train_graphs[i]) for i in batch_indices])
            if cmd_args.blksize < 0 or num_nodes <= cmd_args.blksize:
                ll, _ = model.forward_train(batch_indices)
                loss = -ll / num_nodes
                loss.backward()
                loss = loss.item()
            else:
                ll = 0.0
                for i in batch_indices:
                    n = len(train_graphs[i])
                    cur_ll, _ = sqrtn_forward_backward(
                        model,
                        graph_ids=[i],
                        list_node_starts=[0],
                        num_nodes=n,
                        blksize=cmd_args.blksize,
                        loss_scale=1.0 / n)
                    ll += cur_ll