def forward_row_trees(self, graph_ids, list_node_starts=None, num_nodes=-1, list_col_ranges=None):
        TreeLib.PrepareMiniBatch(graph_ids, list_node_starts, num_nodes, list_col_ranges)
        # embed trees
        all_ids = TreeLib.PrepareTreeEmbed()

        if not self.bits_compress:
            h_bot = torch.cat([self.empty_h0, self.leaf_h0], dim=0)
            c_bot = torch.cat([self.empty_c0, self.leaf_c0], dim=0)
            fn_hc_bot = lambda d: (h_bot, c_bot)
        else:
            binary_embeds, base_feat = TreeLib.PrepareBinary()
            fn_hc_bot = lambda d: (binary_embeds[d], binary_embeds[d]) if d < len(binary_embeds) else base_feat
        max_level = len(all_ids) - 1
        h_buf_list = [None] * (len(all_ids) + 1)
        c_buf_list = [None] * (len(all_ids) + 1)

        for d in range(len(all_ids) - 1, -1, -1):
            fn_ids = lambda i: all_ids[d][i]
            if d == max_level:
                h_buf = c_buf = None
            else:
                h_buf = h_buf_list[d + 1]
                c_buf = c_buf_list[d + 1]
            h_bot, c_bot = fn_hc_bot(d + 1)
            new_h, new_c = batch_tree_lstm2(h_bot, c_bot, h_buf, c_buf, fn_ids, self.lr2p_cell)
            h_buf_list[d] = new_h
            c_buf_list[d] = new_c
        return fn_hc_bot, h_buf_list, c_buf_list
    def forward_train(self, h_bot, c_bot, h_buf0, c_buf0, prev_rowsum_h,
                      prrev_rowsum_c):
        # embed row tree
        tree_agg_ids = TreeLib.PrepareRowEmbed()
        row_embeds = [(self.init_h0, self.init_c0)]
        if h_bot is not None:
            row_embeds.append((h_bot, c_bot))
        if prev_rowsum_h is not None:
            row_embeds.append((prev_rowsum_h, prrev_rowsum_c))
        if h_buf0 is not None:
            row_embeds.append((h_buf0, c_buf0))

        th_bot = h_bot
        tc_bot = c_bot
        for i, all_ids in enumerate(tree_agg_ids):
            fn_ids = lambda x: all_ids[x]
            if i:
                th_bot = tc_bot = None

            new_states = batch_tree_lstm3(th_bot, tc_bot, row_embeds[-1][0],
                                          row_embeds[-1][1], prev_rowsum_h,
                                          prrev_rowsum_c, fn_ids,
                                          self.merge_cell)
            row_embeds.append(new_states)
        h_list, c_list = zip(*row_embeds)
        joint_h = torch.cat(h_list, dim=0)
        joint_c = torch.cat(c_list, dim=0)

        # get history representation
        init_select, all_ids, last_tos, next_ids, pos_info = TreeLib.PrepareRowSummary(
        )
        cur_state = (joint_h[init_select], joint_c[init_select])
        ret_state = (joint_h[next_ids], joint_c[next_ids])
        hist_rnn_states = []
        hist_froms = []
        hist_tos = []
        for i, (done_from, done_to, proceed_from,
                proceed_input) in enumerate(all_ids):
            hist_froms.append(done_from)
            hist_tos.append(done_to)
            hist_rnn_states.append(cur_state)

            next_input = joint_h[proceed_input], joint_c[proceed_input]
            sub_state = cur_state[0][proceed_from], cur_state[1][proceed_from]
            cur_state = self.summary_cell(sub_state, next_input)
        hist_rnn_states.append(cur_state)
        hist_froms.append(None)
        hist_tos.append(last_tos)
        hist_h_list, hist_c_list = zip(*hist_rnn_states)
        pos_embed = self.pos_enc(pos_info)
        row_h = multi_index_select(hist_froms, hist_tos, *
                                   hist_h_list) + pos_embed
        row_c = multi_index_select(hist_froms, hist_tos, *
                                   hist_c_list) + pos_embed
        return (row_h, row_c), ret_state
Esempio n. 3
0
def load_graphs(graph_pkl):
    graphs = []
    with open(graph_pkl, 'rb') as f:
        while True:
            try:
                g = cp.load(f)
            except:
                break
            graphs.append(g)
    for g in graphs:
        TreeLib.InsertGraph(g)
    return graphs
    def forward_train(self, graph_ids, list_node_starts=None, num_nodes=-1, prev_rowsum_states=[None, None], list_col_ranges=None):
        fn_hc_bot, h_buf_list, c_buf_list = self.forward_row_trees(graph_ids, list_node_starts, num_nodes, list_col_ranges)
        row_states, next_states = self.row_tree.forward_train(*(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0], *prev_rowsum_states)

        # make prediction
        logit_has_edge = self.pred_has_ch(row_states[0])
        has_ch, _ = TreeLib.GetChLabel(0, dtype=np.bool)
        ll = self.binary_ll(logit_has_edge, has_ch)
        # has_ch_idx
        cur_states = (row_states[0][has_ch], row_states[1][has_ch])

        lv = 0
        while True:
            is_nonleaf = TreeLib.QueryNonLeaf(lv)
            if is_nonleaf is None or np.sum(is_nonleaf) == 0:
                break
            cur_states = (cur_states[0][is_nonleaf], cur_states[1][is_nonleaf])
            left_logits = self.pred_has_left(cur_states[0], lv)
            has_left, num_left = TreeLib.GetChLabel(-1, lv)
            left_update = self.topdown_left_embed[has_left] + self.tree_pos_enc(num_left)
            left_ll, float_has_left = self.binary_ll(left_logits, has_left, need_label=True, reduction='sum')
            ll = ll + left_ll

            cur_states = self.cell_topdown(left_update, cur_states, lv)

            left_ids = TreeLib.GetLeftRootStates(lv)
            h_bot, c_bot = fn_hc_bot(lv + 1)
            if lv + 1 < len(h_buf_list):
                h_next_buf, c_next_buf = h_buf_list[lv + 1], c_buf_list[lv + 1]
            else:
                h_next_buf = c_next_buf = None
            left_subtree_states = tree_state_select(h_bot, c_bot,
                                                    h_next_buf, c_next_buf,
                                                    lambda: left_ids)

            has_right, num_right = TreeLib.GetChLabel(1, lv)
            right_pos = self.tree_pos_enc(num_right)
            left_subtree_states = [x + right_pos for x in left_subtree_states]
            topdown_state = self.l2r_cell(cur_states, left_subtree_states, lv)

            right_logits = self.pred_has_right(topdown_state[0], lv)
            right_update = self.topdown_right_embed[has_right]
            topdown_state = self.cell_topright(right_update, topdown_state, lv)
            right_ll = self.binary_ll(right_logits, has_right, reduction='none') * float_has_left
            ll = ll + torch.sum(right_ll)
            lr_ids = TreeLib.GetLeftRightSelect(lv, np.sum(has_left), np.sum(has_right))
            new_states = []
            for i in range(2):
                new_s = multi_index_select([lr_ids[0], lr_ids[2]], [lr_ids[1], lr_ids[3]],
                                            cur_states[i], topdown_state[i])
                new_states.append(new_s)
            cur_states = tuple(new_states)
            lv += 1

        return ll, next_states
Esempio n. 5
0
def load_graphs(graph_pkl, stats_pkl):
    graphs = []
    with open(graph_pkl, 'rb') as f:
        while True:
            try:
                g = cp.load(f)
            except:
                break
            graphs.append(g)
    with open(stats_pkl, 'rb') as f:
        stats = cp.load(f)
    assert len(graphs) == len(stats)

    for g, stat in zip(graphs, stats):
        n, m = stat
        TreeLib.InsertGraph(g, bipart_stats=(n, m))
    return graphs, stats
Esempio n. 6
0
def main(rank):
    if cmd_args.gpu >= 0:
        set_device(rank)
    else:
        set_device(-1)
    setup_treelib(cmd_args)
    model = RecurTreeGen(cmd_args).to(cmd_args.device)

    if rank == 0 and cmd_args.model_dump is not None and os.path.isfile(
            cmd_args.model_dump):
        print('loading from', cmd_args.model_dump)
        model.load_state_dict(torch.load(cmd_args.model_dump))
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    graphs = []
    with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f:
        while True:
            try:
                g = cp.load(f)
                TreeLib.InsertGraph(g)
            except:
                break
            graphs.append(g)

    for epoch in range(cmd_args.num_epochs):
        pbar = range(cmd_args.epoch_save)
        if rank == 0:
            pbar = tqdm(pbar)

        g = graphs[0]
        graph_ids = [0]
        blksize = cmd_args.blksize
        if blksize < 0 or blksize > len(g):
            blksize = len(g)

        for e_it in pbar:
            optimizer.zero_grad()
            num_stages = len(g) // (blksize * cmd_args.num_proc) + (
                len(g) % (blksize * cmd_args.num_proc) > 0)
            states_prev = [None, None]
            list_caches = []
            prev_rank_last = None
            for stage in range(num_stages):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    break
                with torch.no_grad():
                    fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees(
                        graph_ids,
                        list_node_starts=[local_st],
                        num_nodes=local_num)
                    if stage and rank == 0:
                        states_prev = recv_states(num_expect(local_st),
                                                  prev_rank_last)
                    if rank:
                        num_recv = num_expect(local_st)
                        states_prev = recv_states(num_recv, rank - 1)
                    _, next_states = model.row_tree.forward_train(
                        *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0],
                        *states_prev)
                    list_caches.append(states_prev)
                    if rank != rank_last:
                        send_states(next_states, rank + 1)
                    elif stage + 1 < num_stages:
                        send_states(next_states, 0)
                prev_rank_last = rank_last

            tot_ll = torch.zeros(1).to(cmd_args.device)
            for stage in range(num_stages - 1, -1, -1):
                local_st, local_num, rank_last = get_stage_stats(
                    stage, blksize, rank, g)
                if local_num <= 0:
                    continue
                prev_states = list_caches[stage]
                if prev_states[0] is not None:
                    for x in prev_states:
                        x.requires_grad = True
                ll, cur_states = model.forward_train(
                    graph_ids,
                    list_node_starts=[local_st],
                    num_nodes=local_num,
                    prev_rowsum_states=prev_states)
                tot_ll = tot_ll + ll
                loss = -ll / len(g)
                if stage + 1 == num_stages and rank == rank_last:
                    loss.backward()
                else:
                    top_grad = recv_states(
                        cur_states[0].shape[0],
                        rank + 1 if rank != rank_last else 0)
                    torch.autograd.backward([loss, *cur_states],
                                            [None, *top_grad])
                if prev_states[0] is not None:
                    grads = [x.grad.detach() for x in prev_states]
                    dst = rank - 1 if rank else cmd_args.num_proc - 1
                    send_states(grads, dst)
            dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM)

            for param in model.parameters():
                if param.grad is None:
                    param.grad = param.data.new(param.data.shape).zero_()
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            if cmd_args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_norm=cmd_args.grad_clip)
            optimizer.step()

            if rank == 0:
                loss = -tot_ll.item() / len(g)
                pbar.set_description('epoch %.2f, loss: %.4f' %
                                     (epoch +
                                      (e_it + 1) / cmd_args.epoch_save, loss))
                torch.save(
                    model.state_dict(),
                    os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch))

    print('done rank', rank)
Esempio n. 7
0
import numpy as np
import random
import networkx as nx
from bigg.common.configs import cmd_args, set_device
from bigg.model.tree_clib.tree_lib import setup_treelib, TreeLib
from bigg.model.tree_model import RecurTreeGen

if __name__ == '__main__':
    random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    set_device(cmd_args.gpu)
    setup_treelib(cmd_args)

    train_graphs = [nx.barabasi_albert_graph(10, 2)]
    TreeLib.InsertGraph(train_graphs[0])
    max_num_nodes = max([len(gg.nodes) for gg in train_graphs])
    cmd_args.max_num_nodes = max_num_nodes

    model = RecurTreeGen(cmd_args).to(cmd_args.device)
    optimizer = optim.Adam(model.parameters(),
                           lr=cmd_args.learning_rate,
                           weight_decay=1e-4)
    for i in range(2):
        optimizer.zero_grad()
        ll, _ = model.forward_train([0])
        loss = -ll / max_num_nodes
        print('iter', i, 'loss', loss.item())
        loss.backward()
        optimizer.step()