示例#1
0
def build_model(rank):
    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    if rank == 0 and cmd_args.model_dump is not None and os.path.isfile(
            cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))
    for param in model.parameters():
        dist.broadcast(param.data, 0)
    return model
def main(rank):
    if cmd_args.gpu >= 0:
        set_device(rank)
    else:
        set_device(-1)
    setup_treelib(cmd_args)
    model = RecurTreeGen(cmd_args).to(cmd_args.device)

    if rank == 0 and cmd_args.model_dump is not None and os.path.isfile(
            cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    graphs = []
    with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f:
        while True:
            try:
                g = cp.load(f)
                TreeLib.InsertGraph(g)
            except:
                break
            graphs.append(g)

    for epoch in range(cmd_args.num_epochs):
        pbar = range(cmd_args.epoch_save)
        if rank == 0:
            pbar = tqdm(pbar)

        g = graphs[0]
        graph_ids = [0]
        blksize = cmd_args.blksize
        if blksize < 0 or blksize > len(g):
            blksize = len(g)

        for e_it in pbar:
            optimizer.zero_grad()
            num_stages = len(g) // (blksize * cmd_args.num_proc) + (
                len(g) % (blksize * cmd_args.num_proc) > 0)
            states_prev = [None, None]
            list_caches = []
            prev_rank_last = None
            for stage in range(num_stages):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    break
                with torch.no_grad():
                    fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees(
                        graph_ids,
                        list_node_starts=[local_st],
                        num_nodes=local_num)
                    if stage and rank == 0:
                        states_prev = recv_states(num_expect(local_st),
                                                  prev_rank_last)
                    if rank:
                        num_recv = num_expect(local_st)
                        states_prev = recv_states(num_recv, rank - 1)
                    _, next_states = model.row_tree.forward_train(
                        *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0],
                        *states_prev)
                    list_caches.append(states_prev)
                    if rank != rank_last:
                        send_states(next_states, rank + 1)
                    elif stage + 1 < num_stages:
                        send_states(next_states, 0)
                prev_rank_last = rank_last

            tot_ll = torch.zeros(1).to(cmd_args.device)
            for stage in range(num_stages - 1, -1, -1):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    continue
                prev_states = list_caches[stage]
                if prev_states[0] is not None:
                    for x in prev_states:
                        x.requires_grad = True
                ll, cur_states = model.forward_train(
                    graph_ids,
                    list_node_starts=[local_st],
                    num_nodes=local_num,
                    prev_rowsum_states=prev_states)
                tot_ll = tot_ll + ll
                loss = -ll / len(g)
                if stage + 1 == num_stages and rank == rank_last:
                    loss.backward()
                else:
                    top_grad = recv_states(
                        cur_states[0].shape[0],
                        rank + 1 if rank != rank_last else 0)
                    torch.autograd.backward([loss, *cur_states],
                                            [None, *top_grad])
                if prev_states[0] is not None:
                    grads = [x.grad.detach() for x in prev_states]
                    dst = rank - 1 if rank else cmd_args.num_proc - 1
                    send_states(grads, dst)
            dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM)

            for param in model.parameters():
                if param.grad is None:
                    param.grad = param.data.new(param.data.shape).zero_()
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            if cmd_args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=cmd_args.grad_clip)
            optimizer.step()

            if rank == 0:
                loss = -tot_ll.item() / len(g)
                pbar.set_description('epoch %.2f, loss: %.4f' %
                                     (epoch +
                                      (e_it + 1) / cmd_args.epoch_save, loss))
                torch.save(
                    model.state_dict(),
                    os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch))

    print('done rank', rank)
示例#3
0
    setup_treelib(cmd_args)

    train_graphs, stats = load_graphs(
        os.path.join(cmd_args.data_dir, 'train-graphs.pkl'),
        os.path.join(cmd_args.data_dir, 'train-graph-stats.pkl'))
    max_left = 0
    max_right = 0
    for n, m in stats:
        max_left = max(max_left, n)
        max_right = max(max_right, m)
    max_num_nodes = max(max_left, max_right)
    print('max # nodes:', max_num_nodes, 'max_left:', max_left, 'max_right:',
          max_right)
    cmd_args.max_num_nodes = max_num_nodes

    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    if cmd_args.model_dump is not None and os.path.isfile(cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))

    if cmd_args.eval_folder is not None:
        g_idx = 0
        eval_dir = os.path.join(cmd_args.data_dir,
                                '../sat-%s' % cmd_args.eval_folder)
        print('loading eval from', eval_dir)
        test_graphs, stats = load_graphs(
            os.path.join(eval_dir, 'test-graphs.pkl'),
            os.path.join(eval_dir, 'test-graph-stats.pkl'))
        out_dir = os.path.join(
            cmd_args.save_dir, '%s-pred_formulas-e-%d-g-%.2f' %
            (cmd_args.eval_folder, cmd_args.epoch_load, cmd_args.greedy_frac))
示例#4
0
import numpy as np
import random
import networkx as nx
from bigg.common.configs import cmd_args, set_device
from bigg.model.tree_clib.tree_lib import setup_treelib, TreeLib
from bigg.model.tree_model import RecurTreeGen

if __name__ == '__main__':
    random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    set_device(cmd_args.gpu)
    setup_treelib(cmd_args)

    train_graphs = [nx.barabasi_albert_graph(10, 2)]
    TreeLib.InsertGraph(train_graphs[0])
    max_num_nodes = max([len(gg.nodes) for gg in train_graphs])
    cmd_args.max_num_nodes = max_num_nodes

    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for i in range(2):
        optimizer.zero_grad()
        ll, _ = model.forward_train([0])
        loss = -ll / max_num_nodes
        print('iter', i, 'loss', loss.item())
        loss.backward()
        optimizer.step()
示例#5
0
    return graphs


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    set_device(cmd_args.gpu)
    setup_treelib(cmd_args)

    train_graphs = load_graphs(
        os.path.join(cmd_args.data_dir, 'train-graphs.pkl'))
    max_num_nodes = max([len(gg.nodes) for gg in train_graphs])
    cmd_args.max_num_nodes = max_num_nodes

    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    if cmd_args.model_dump is not None and os.path.isfile(cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))

    if cmd_args.phase != 'train':
        # get num nodes dist
        num_node_dist = get_node_dist(train_graphs)
        gt_graphs = load_graphs(
            os.path.join(cmd_args.data_dir, '%s-graphs.pkl' % cmd_args.phase))
        print('# gt graphs', len(gt_graphs))
        gen_graphs = []
        with torch.no_grad():
            for _ in tqdm(range(cmd_args.num_test_gen)):
                num_nodes = np.argmax(np.random.multinomial(1, num_node_dist))
                _, pred_edges, _ = model(num_nodes, display=cmd_args.display)