def main(rank): if cmd_args.gpu >= 0: set_device(rank) else: set_device(-1) setup_treelib(cmd_args) model = RecurTreeGen(cmd_args).to(cmd_args.device) if rank == 0 and cmd_args.model_dump is not None and os.path.isfile( cmd_args.model_dump): print('loading from', cmd_args.model_dump) model.load_state_dict(torch.load(cmd_args.model_dump)) optimizer = optim.Adam(model.parameters(), lr=cmd_args.learning_rate, weight_decay=1e-4) for param in model.parameters(): dist.broadcast(param.data, 0) graphs = [] with open(os.path.join(cmd_args.data_dir, 'train-graphs.pkl'), 'rb') as f: while True: try: g = cp.load(f) TreeLib.InsertGraph(g) except: break graphs.append(g) for epoch in range(cmd_args.num_epochs): pbar = range(cmd_args.epoch_save) if rank == 0: pbar = tqdm(pbar) g = graphs[0] graph_ids = [0] blksize = cmd_args.blksize if blksize < 0 or blksize > len(g): blksize = len(g) for e_it in pbar: optimizer.zero_grad() num_stages = len(g) // (blksize * cmd_args.num_proc) + ( len(g) % (blksize * cmd_args.num_proc) > 0) states_prev = [None, None] list_caches = [] prev_rank_last = None for stage in range(num_stages): local_st, local_num, rank_last = get_stage_stats( stage, blksize, rank, g) if local_num <= 0: break with torch.no_grad(): fn_hc_bot, h_buf_list, c_buf_list = model.forward_row_trees( graph_ids, list_node_starts=[local_st], num_nodes=local_num) if stage and rank == 0: states_prev = recv_states(num_expect(local_st), prev_rank_last) if rank: num_recv = num_expect(local_st) states_prev = recv_states(num_recv, rank - 1) _, next_states = model.row_tree.forward_train( *(fn_hc_bot(0)), h_buf_list[0], c_buf_list[0], *states_prev) list_caches.append(states_prev) if rank != rank_last: send_states(next_states, rank + 1) elif stage + 1 < num_stages: send_states(next_states, 0) prev_rank_last = rank_last tot_ll = torch.zeros(1).to(cmd_args.device) for stage in range(num_stages - 1, -1, -1): local_st, local_num, rank_last = get_stage_stats( stage, blksize, rank, g) if local_num <= 0: continue prev_states = list_caches[stage] if prev_states[0] is not None: for x in prev_states: x.requires_grad = True ll, cur_states = model.forward_train( graph_ids, list_node_starts=[local_st], num_nodes=local_num, prev_rowsum_states=prev_states) tot_ll = tot_ll + ll loss = -ll / len(g) if stage + 1 == num_stages and rank == rank_last: loss.backward() else: top_grad = recv_states( cur_states[0].shape[0], rank + 1 if rank != rank_last else 0) torch.autograd.backward([loss, *cur_states], [None, *top_grad]) if prev_states[0] is not None: grads = [x.grad.detach() for x in prev_states] dst = rank - 1 if rank else cmd_args.num_proc - 1 send_states(grads, dst) dist.all_reduce(tot_ll.data, op=dist.ReduceOp.SUM) for param in model.parameters(): if param.grad is None: param.grad = param.data.new(param.data.shape).zero_() dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) if cmd_args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cmd_args.grad_clip) optimizer.step() if rank == 0: loss = -tot_ll.item() / len(g) pbar.set_description('epoch %.2f, loss: %.4f' % (epoch + (e_it + 1) / cmd_args.epoch_save, loss)) torch.save( model.state_dict(), os.path.join(cmd_args.save_dir, 'epoch-%d.ckpt' % epoch)) print('done rank', rank)
import numpy as np import random import networkx as nx from bigg.common.configs import cmd_args, set_device from bigg.model.tree_clib.tree_lib import setup_treelib, TreeLib from bigg.model.tree_model import RecurTreeGen if __name__ == '__main__': random.seed(cmd_args.seed) torch.manual_seed(cmd_args.seed) np.random.seed(cmd_args.seed) set_device(cmd_args.gpu) setup_treelib(cmd_args) train_graphs = [nx.barabasi_albert_graph(10, 2)] TreeLib.InsertGraph(train_graphs[0]) max_num_nodes = max([len(gg.nodes) for gg in train_graphs]) cmd_args.max_num_nodes = max_num_nodes model = RecurTreeGen(cmd_args).to(cmd_args.device) optimizer = optim.Adam(model.parameters(), lr=cmd_args.learning_rate, weight_decay=1e-4) for i in range(2): optimizer.zero_grad() ll, _ = model.forward_train([0]) loss = -ll / max_num_nodes print('iter', i, 'loss', loss.item()) loss.backward() optimizer.step()
lr=cmd_args.learning_rate, weight_decay=1e-4) indices = list(range(len(train_graphs))) for epoch in range(cmd_args.num_epochs): pbar = tqdm(range(cmd_args.epoch_save)) optimizer.zero_grad() for idx in pbar: random.shuffle(indices) batch_indices = indices[:cmd_args.batch_size] num_nodes = sum([len(train_graphs[i]) for i in batch_indices]) if cmd_args.blksize < 0 or num_nodes <= cmd_args.blksize: ll, _ = model.forward_train(batch_indices, list_col_ranges=[ (0, stats[i][1]) for i in batch_indices ]) loss = -ll / num_nodes loss.backward() loss = loss.item() else: ll = 0.0 for i in batch_indices: n = len(train_graphs[i]) cur_ll, _ = sqrtn_forward_backward( model, graph_ids=[i], list_node_starts=[0], num_nodes=stats[i][0], blksize=cmd_args.blksize,
lr=cmd_args.learning_rate, weight_decay=1e-4) indices = list(range(len(train_graphs))) if cmd_args.epoch_load is None: cmd_args.epoch_load = 0 for epoch in range(cmd_args.epoch_load, cmd_args.num_epochs): pbar = tqdm(range(cmd_args.epoch_save)) optimizer.zero_grad() for idx in pbar: random.shuffle(indices) batch_indices = indices[:cmd_args.batch_size] num_nodes = sum([len(train_graphs[i]) for i in batch_indices]) if cmd_args.blksize < 0 or num_nodes <= cmd_args.blksize: ll, _ = model.forward_train(batch_indices) loss = -ll / num_nodes loss.backward() loss = loss.item() else: ll = 0.0 for i in batch_indices: n = len(train_graphs[i]) cur_ll, _ = sqrtn_forward_backward( model, graph_ids=[i], list_node_starts=[0], num_nodes=n, blksize=cmd_args.blksize, loss_scale=1.0 / n) ll += cur_ll