def forward(ctx, g, score, eids): # remember to save the graph to backward cache before making it # a local variable if not is_all(eids): g = g.edge_subgraph(eids.long()) n_nodes = g.number_of_nodes() n_edges = g.number_of_edges() gidx = g._graph.get_immutable_gidx(utils.to_dgl_context(score.device)) ctx.backward_cache = n_nodes, n_edges, gidx # g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) smax = F.copy_reduce("max", gidx, TargetCode.EDGE, score, n_nodes) # g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) out = F.binary_reduce("none", "sub", gidx, TargetCode.EDGE, TargetCode.DST, score, smax, n_edges) # g.edata['out'] = th.exp(g.edata['out']) out = th.exp(out) # g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) out_sum = F.copy_reduce("sum", gidx, TargetCode.EDGE, out, n_nodes) # g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) out = F.binary_reduce("none", "div", gidx, TargetCode.EDGE, TargetCode.DST, out, out_sum, n_edges) ctx.save_for_backward(out) return out
def forward(ctx, g, score, eids): """Forward function. Pseudo-code: .. code:: python score = dgl.EData(g, score) score_max = score.dst_max() # of type dgl.NData score = score - score_max # edge_sub_dst, ret dgl.EData score_sum = score.dst_sum() # of type dgl.NData out = score / score_sum # edge_div_dst, ret dgl.EData return out.data """ # remember to save the graph to backward cache before making it # a local variable if not is_all(eids): g = g.edge_subgraph(eids.long()) n_nodes = g.number_of_dst_nodes() n_edges = g.number_of_edges() # TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch # in PR #1139. We should investigate further on what was actually happening # when implementing EdgeSoftmax with message passing API instead of # operators. score_context = utils.to_dgl_context(score.device) if isinstance(g, DGLGraph): gidx = g._graph.get_immutable_gidx(score_context) elif isinstance(g, DGLHeteroGraph): assert g._graph.number_of_etypes() == 1, \ "EdgeSoftmax only support one edge type" gidx = g._graph.get_unitgraph(0, score_context) ctx.backward_cache = n_nodes, n_edges, gidx #g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) smax = F.copy_reduce('max', gidx, TargetCode.EDGE, score, n_nodes) #g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) out = F.binary_reduce('none', 'sub', gidx, TargetCode.EDGE, TargetCode.DST, score, smax, n_edges) #g.edata['out'] = th.exp(g.edata['out']) out = th.exp(out) #g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) out_sum = F.copy_reduce('sum', gidx, TargetCode.EDGE, out, n_nodes) #g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) out = F.binary_reduce('none', 'div', gidx, TargetCode.EDGE, TargetCode.DST, out, out_sum, n_edges) ctx.save_for_backward(out) return out
def backward(ctx, grad_out): """Backward function. Pseudo-code: .. code:: python g, out = ctx.backward_cache grad_out = dgl.EData(g, grad_out) out = dgl.EData(g, out) sds = out * grad_out # type dgl.EData sds_sum = sds.dst_sum() # type dgl.NData grad_score = sds - sds * sds_sum # multiple expressions return grad_score.data """ n_nodes, n_edges, gidx = ctx.backward_cache out, = ctx.saved_tensors # g.edata['grad_s'] = out * grad_out grad_s = out * grad_out # g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum')) accum = F.copy_reduce("sum", gidx, TargetCode.EDGE, grad_s, n_nodes) # g.apply_edges(fn.e_mul_v('out', 'accum', 'out')) out = F.binary_reduce("none", "mul", gidx, TargetCode.EDGE, TargetCode.DST, out, accum, n_edges) # grad_score = g.edata['grad_s'] - g.edata['out'] grad_score = grad_s - out return None, grad_score, None
def gat_layer_dgl(feat, weight, attn_l, attn_r, in_feat_len, out_feat_len): feat2 = torch.mm(feat, weight) att_l = torch.mm(feat2, attn_l) att_r = torch.mm(feat2, attn_r) g.srcdata.update({'ft': feat2, 'el': att_l}) g.dstdata.update({'er': att_r}) g.apply_edges(fn.u_add_v('el', 'er', 'e')) e = torch.exp(F.leaky_relu(g.edata.pop('e'), 0.1)) cont = utils.to_dgl_context(e.device) gidx = g._graph.get_immutable_gidx(cont) e_sum = backend.copy_reduce("sum", gidx, TargetCode.EDGE, e, num_v) att = backend.binary_reduce('none', 'div', gidx, TargetCode.EDGE, TargetCode.DST, e, e_sum, n_edges) g.edata['a'] = att g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) output = g.dstdata['ft'] torch.cuda.synchronize() return output