ll, _ = model.forward_train(batch_indices, list_col_ranges=[ (0, stats[i][1]) for i in batch_indices ]) loss = -ll / num_nodes loss.backward() loss = loss.item() else: ll = 0.0 for i in batch_indices: n = len(train_graphs[i]) cur_ll, _ = sqrtn_forward_backward( model, graph_ids=[i], list_node_starts=[0], num_nodes=stats[i][0], blksize=cmd_args.blksize, loss_scale=1.0 / n, list_col_ranges=[(0, stats[i][1])]) ll += cur_ll loss = -ll / num_nodes if False: true_edges = get_ordered_edges(train_graphs[0], offset=stats[0][0]) #ll2, _, _ = model(stats[0][0], edge_list=true_edges, col_range=(0, stats[0][1])) print(-ll / num_nodes) #print(ll2) sys.exit() if (idx + 1) % cmd_args.accum_grad == 0: if cmd_args.grad_clip > 0:
batch_indices = indices[:cmd_args.batch_size] num_nodes = sum([len(train_graphs[i]) for i in batch_indices]) if cmd_args.blksize < 0 or num_nodes <= cmd_args.blksize: ll, _ = model.forward_train(batch_indices) loss = -ll / num_nodes loss.backward() loss = loss.item() else: ll = 0.0 for i in batch_indices: n = len(train_graphs[i]) cur_ll, _ = sqrtn_forward_backward( model, graph_ids=[i], list_node_starts=[0], num_nodes=n, blksize=cmd_args.blksize, loss_scale=1.0 / n) ll += cur_ll loss = -ll / num_nodes if (idx + 1) % cmd_args.accum_grad == 0: if cmd_args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cmd_args.grad_clip) optimizer.step() optimizer.zero_grad() pbar.set_description('epoch %.2f, loss: %.4f' % (epoch + (idx + 1) / cmd_args.epoch_save, loss))