Exemple #1
0
def setup_dist(rank):
    if cmd_args.gpu >= 0:
        set_device(rank)
    else:
        set_device(-1)
    setup_treelib(cmd_args)
    random.seed(cmd_args.seed + rank)
    torch.manual_seed(cmd_args.seed + rank)
    np.random.seed(cmd_args.seed + rank)
def main(rank):
    if cmd_args.gpu >= 0:
        set_device(rank)
    else:
        set_device(-1)
    setup_treelib(cmd_args)
    model = RecurTreeGen(cmd_args).to(cmd_args.device)

    if rank == 0 and cmd_args.model_dump is not None and os.path.isfile(
            cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    graphs = []
    with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f:
        while True:
            try:
                g = cp.load(f)
                TreeLib.InsertGraph(g)
            except:
                break
            graphs.append(g)

    for epoch in range(cmd_args.num_epochs):
        pbar = range(cmd_args.epoch_save)
        if rank == 0:
            pbar = tqdm(pbar)

        g = graphs[0]
        graph_ids = [0]
        blksize = cmd_args.blksize
        if blksize < 0 or blksize > len(g):
            blksize = len(g)

        for e_it in pbar:
            optimizer.zero_grad()
            num_stages = len(g) // (blksize * cmd_args.num_proc) + (
                len(g) % (blksize * cmd_args.num_proc) > 0)
            states_prev = [None, None]
            list_caches = []
            prev_rank_last = None
            for stage in range(num_stages):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    break
                with torch.no_grad():
                    fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees(
                        graph_ids,
                        list_node_starts=[local_st],
                        num_nodes=local_num)
                    if stage and rank == 0:
                        states_prev = recv_states(num_expect(local_st),
                                                  prev_rank_last)
                    if rank:
                        num_recv = num_expect(local_st)
                        states_prev = recv_states(num_recv, rank - 1)
                    _, next_states = model.row_tree.forward_train(
                        *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0],
                        *states_prev)
                    list_caches.append(states_prev)
                    if rank != rank_last:
                        send_states(next_states, rank + 1)
                    elif stage + 1 < num_stages:
                        send_states(next_states, 0)
                prev_rank_last = rank_last

            tot_ll = torch.zeros(1).to(cmd_args.device)
            for stage in range(num_stages - 1, -1, -1):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    continue
                prev_states = list_caches[stage]
                if prev_states[0] is not None:
                    for x in prev_states:
                        x.requires_grad = True
                ll, cur_states = model.forward_train(
                    graph_ids,
                    list_node_starts=[local_st],
                    num_nodes=local_num,
                    prev_rowsum_states=prev_states)
                tot_ll = tot_ll + ll
                loss = -ll / len(g)
                if stage + 1 == num_stages and rank == rank_last:
                    loss.backward()
                else:
                    top_grad = recv_states(
                        cur_states[0].shape[0],
                        rank + 1 if rank != rank_last else 0)
                    torch.autograd.backward([loss, *cur_states],
                                            [None, *top_grad])
                if prev_states[0] is not None:
                    grads = [x.grad.detach() for x in prev_states]
                    dst = rank - 1 if rank else cmd_args.num_proc - 1
                    send_states(grads, dst)
            dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM)

            for param in model.parameters():
                if param.grad is None:
                    param.grad = param.data.new(param.data.shape).zero_()
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            if cmd_args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=cmd_args.grad_clip)
            optimizer.step()

            if rank == 0:
                loss = -tot_ll.item() / len(g)
                pbar.set_description('epoch %.2f, loss: %.4f' %
                                     (epoch +
                                      (e_it + 1) / cmd_args.epoch_save, loss))
                torch.save(
                    model.state_dict(),
                    os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch))

    print('done rank', rank)
Exemple #3
0
def get_ordered_edges(g, offset):
    true_edges = []
    for e in g.edges():
        if e[0] > e[1]:
            e = (e[1], e[0])
        true_edges.append((e[0], e[1] - offset))
    true_edges.sort(key=lambda x: x[0])
    return true_edges


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    set_device(cmd_args.gpu)
    setup_treelib(cmd_args)

    train_graphs, stats = load_graphs(
        os.path.join(cmd_args.data_dir, 'train-graphs.pkl'),
        os.path.join(cmd_args.data_dir, 'train-graph-stats.pkl'))
    max_left = 0
    max_right = 0
    for n, m in stats:
        max_left = max(max_left, n)
        max_right = max(max_right, m)
    max_num_nodes = max(max_left, max_right)
    print('max # nodes:', max_num_nodes, 'max_left:', max_left, 'max_right:',
          max_right)
    cmd_args.max_num_nodes = max_num_nodes