def forward_row_trees(self, graph_ids, list_node_starts=None, num_nodes=-1, list_col_ranges=None): TreeLib.PrepareMiniBatch(graph_ids, list_node_starts, num_nodes, list_col_ranges) # embed trees all_ids = TreeLib.PrepareTreeEmbed() if not self.bits_compress: h_bot = torch.cat([self.empty_h0, self.leaf_h0], dim=0) c_bot = torch.cat([self.empty_c0, self.leaf_c0], dim=0) fn_hc_bot = lambda d: (h_bot, c_bot) else: binary_embeds, base_feat = TreeLib.PrepareBinary() fn_hc_bot = lambda d: (binary_embeds[d], binary_embeds[d]) if d < len(binary_embeds) else base_feat max_level = len(all_ids) - 1 h_buf_list = [None] * (len(all_ids) + 1) c_buf_list = [None] * (len(all_ids) + 1) for d in range(len(all_ids) - 1, -1, -1): fn_ids = lambda i: all_ids[d][i] if d == max_level: h_buf = c_buf = None else: h_buf = h_buf_list[d + 1] c_buf = c_buf_list[d + 1] h_bot, c_bot = fn_hc_bot(d + 1) new_h, new_c = batch_tree_lstm2(h_bot, c_bot, h_buf, c_buf, fn_ids, self.lr2p_cell) h_buf_list[d] = new_h c_buf_list[d] = new_c return fn_hc_bot, h_buf_list, c_buf_list
def forward_train(self, h_bot, c_bot, h_buf0, c_buf0, prev_rowsum_h, prrev_rowsum_c): # embed row tree tree_agg_ids = TreeLib.PrepareRowEmbed() row_embeds = [(self.init_h0, self.init_c0)] if h_bot is not None: row_embeds.append((h_bot, c_bot)) if prev_rowsum_h is not None: row_embeds.append((prev_rowsum_h, prrev_rowsum_c)) if h_buf0 is not None: row_embeds.append((h_buf0, c_buf0)) th_bot = h_bot tc_bot = c_bot for i, all_ids in enumerate(tree_agg_ids): fn_ids = lambda x: all_ids[x] if i: th_bot = tc_bot = None new_states = batch_tree_lstm3(th_bot, tc_bot, row_embeds[-1][0], row_embeds[-1][1], prev_rowsum_h, prrev_rowsum_c, fn_ids, self.merge_cell) row_embeds.append(new_states) h_list, c_list = zip(*row_embeds) joint_h = torch.cat(h_list, dim=0) joint_c = torch.cat(c_list, dim=0) # get history representation init_select, all_ids, last_tos, next_ids, pos_info = TreeLib.PrepareRowSummary( ) cur_state = (joint_h[init_select], joint_c[init_select]) ret_state = (joint_h[next_ids], joint_c[next_ids]) hist_rnn_states = [] hist_froms = [] hist_tos = [] for i, (done_from, done_to, proceed_from, proceed_input) in enumerate(all_ids): hist_froms.append(done_from) hist_tos.append(done_to) hist_rnn_states.append(cur_state) next_input = joint_h[proceed_input], joint_c[proceed_input] sub_state = cur_state[0][proceed_from], cur_state[1][proceed_from] cur_state = self.summary_cell(sub_state, next_input) hist_rnn_states.append(cur_state) hist_froms.append(None) hist_tos.append(last_tos) hist_h_list, hist_c_list = zip(*hist_rnn_states) pos_embed = self.pos_enc(pos_info) row_h = multi_index_select(hist_froms, hist_tos, * hist_h_list) + pos_embed row_c = multi_index_select(hist_froms, hist_tos, * hist_c_list) + pos_embed return (row_h, row_c), ret_state
def load_graphs(graph_pkl): graphs = [] with open(graph_pkl, 'rb') as f: while True: try: g = cp.load(f) except: break graphs.append(g) for g in graphs: TreeLib.InsertGraph(g) return graphs
def forward_train(self, graph_ids, list_node_starts=None, num_nodes=-1, prev_rowsum_states=[None, None], list_col_ranges=None): fn_hc_bot, h_buf_list, c_buf_list = self.forward_row_trees(graph_ids, list_node_starts, num_nodes, list_col_ranges) row_states, next_states = self.row_tree.forward_train(*(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0], *prev_rowsum_states) # make prediction logit_has_edge = self.pred_has_ch(row_states[0]) has_ch, _ = TreeLib.GetChLabel(0, dtype=np.bool) ll = self.binary_ll(logit_has_edge, has_ch) # has_ch_idx cur_states = (row_states[0][has_ch], row_states[1][has_ch]) lv = 0 while True: is_nonleaf = TreeLib.QueryNonLeaf(lv) if is_nonleaf is None or np.sum(is_nonleaf) == 0: break cur_states = (cur_states[0][is_nonleaf], cur_states[1][is_nonleaf]) left_logits = self.pred_has_left(cur_states[0], lv) has_left, num_left = TreeLib.GetChLabel(-1, lv) left_update = self.topdown_left_embed[has_left] + self.tree_pos_enc(num_left) left_ll, float_has_left = self.binary_ll(left_logits, has_left, need_label=True, reduction='sum') ll = ll + left_ll cur_states = self.cell_topdown(left_update, cur_states, lv) left_ids = TreeLib.GetLeftRootStates(lv) h_bot, c_bot = fn_hc_bot(lv + 1) if lv + 1 < len(h_buf_list): h_next_buf, c_next_buf = h_buf_list[lv + 1], c_buf_list[lv + 1] else: h_next_buf = c_next_buf = None left_subtree_states = tree_state_select(h_bot, c_bot, h_next_buf, c_next_buf, lambda: left_ids) has_right, num_right = TreeLib.GetChLabel(1, lv) right_pos = self.tree_pos_enc(num_right) left_subtree_states = [x + right_pos for x in left_subtree_states] topdown_state = self.l2r_cell(cur_states, left_subtree_states, lv) right_logits = self.pred_has_right(topdown_state[0], lv) right_update = self.topdown_right_embed[has_right] topdown_state = self.cell_topright(right_update, topdown_state, lv) right_ll = self.binary_ll(right_logits, has_right, reduction='none') * float_has_left ll = ll + torch.sum(right_ll) lr_ids = TreeLib.GetLeftRightSelect(lv, np.sum(has_left), np.sum(has_right)) new_states = [] for i in range(2): new_s = multi_index_select([lr_ids[0], lr_ids[2]], [lr_ids[1], lr_ids[3]], cur_states[i], topdown_state[i]) new_states.append(new_s) cur_states = tuple(new_states) lv += 1 return ll, next_states
def load_graphs(graph_pkl, stats_pkl): graphs = [] with open(graph_pkl, 'rb') as f: while True: try: g = cp.load(f) except: break graphs.append(g) with open(stats_pkl, 'rb') as f: stats = cp.load(f) assert len(graphs) == len(stats) for g, stat in zip(graphs, stats): n, m = stat TreeLib.InsertGraph(g, bipart_stats=(n, m)) return graphs, stats
def main(rank): if cmd_args.gpu >= 0: set_device(rank) else: set_device(-1) setup_treelib(cmd_args) model = RecurTreeGen(cmd_args).to(cmd_args.device) if rank == 0 and cmd_args.model_dump is not None and os.path.isfile( cmd_args.model_dump): print('loading from', cmd_args.model_dump) model.load_state_dict(torch.load(cmd_args.model_dump)) optimizer = optim.Adam(model.parameters(), lr=cmd_args.learning_rate, weight_decay=1e-4) for param in model.parameters(): dist.broadcast(param.data, 0) graphs = [] with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f: while True: try: g = cp.load(f) TreeLib.InsertGraph(g) except: break graphs.append(g) for epoch in range(cmd_args.num_epochs): pbar = range(cmd_args.epoch_save) if rank == 0: pbar = tqdm(pbar) g = graphs[0] graph_ids = [0] blksize = cmd_args.blksize if blksize < 0 or blksize > len(g): blksize = len(g) for e_it in pbar: optimizer.zero_grad() num_stages = len(g) // (blksize * cmd_args.num_proc) + ( len(g) % (blksize * cmd_args.num_proc) > 0) states_prev = [None, None] list_caches = [] prev_rank_last = None for stage in range(num_stages): local_st, local_num, rank_last = get_stage_stats( stage, blksize, rank, g) if local_num <= 0: break with torch.no_grad(): fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees( graph_ids, list_node_starts=[local_st], num_nodes=local_num) if stage and rank == 0: states_prev = recv_states(num_expect(local_st), prev_rank_last) if rank: num_recv = num_expect(local_st) states_prev = recv_states(num_recv, rank - 1) _, next_states = model.row_tree.forward_train( *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0], *states_prev) list_caches.append(states_prev) if rank != rank_last: send_states(next_states, rank + 1) elif stage + 1 < num_stages: send_states(next_states, 0) prev_rank_last = rank_last tot_ll = torch.zeros(1).to(cmd_args.device) for stage in range(num_stages - 1, -1, -1): local_st, local_num, rank_last = get_stage_stats( stage, blksize, rank, g) if local_num <= 0: continue prev_states = list_caches[stage] if prev_states[0] is not None: for x in prev_states: x.requires_grad = True ll, cur_states = model.forward_train( graph_ids, list_node_starts=[local_st], num_nodes=local_num, prev_rowsum_states=prev_states) tot_ll = tot_ll + ll loss = -ll / len(g) if stage + 1 == num_stages and rank == rank_last: loss.backward() else: top_grad = recv_states( cur_states[0].shape[0], rank + 1 if rank != rank_last else 0) torch.autograd.backward([loss, *cur_states], [None, *top_grad]) if prev_states[0] is not None: grads = [x.grad.detach() for x in prev_states] dst = rank - 1 if rank else cmd_args.num_proc - 1 send_states(grads, dst) dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM) for param in model.parameters(): if param.grad is None: param.grad = param.data.new(param.data.shape).zero_() dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) if cmd_args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cmd_args.grad_clip) optimizer.step() if rank == 0: loss = -tot_ll.item() / len(g) pbar.set_description('epoch %.2f, loss: %.4f' % (epoch + (e_it + 1) / cmd_args.epoch_save, loss)) torch.save( model.state_dict(), os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch)) print('done rank', rank)
import numpy as np import random import networkx as nx from bigg.common.configs import cmd_args, set_device from bigg.model.tree_clib.tree_lib import setup_treelib, TreeLib from bigg.model.tree_model import RecurTreeGen if __name__ == '__main__': random.seed(cmd_args.seed) torch.manual_seed(cmd_args.seed) np.random.seed(cmd_args.seed) set_device(cmd_args.gpu) setup_treelib(cmd_args) train_graphs = [nx.barabasi_albert_graph(10, 2)] TreeLib.InsertGraph(train_graphs[0]) max_num_nodes = max([len(gg.nodes) for gg in train_graphs]) cmd_args.max_num_nodes = max_num_nodes model = RecurTreeGen(cmd_args).to(cmd_args.device) optimizer = optim.Adam(model.parameters(), lr=cmd_args.learning_rate, weight_decay=1e-4) for i in range(2): optimizer.zero_grad() ll, _ = model.forward_train([0]) loss = -ll / max_num_nodes print('iter', i, 'loss', loss.item()) loss.backward() optimizer.step()