예제 #1
0
    def forward(ctx, g, score, eids):

        # remember to save the graph to backward cache before making it
        # a local variable
        if not is_all(eids):
            g = g.edge_subgraph(eids.long())

        n_nodes = g.number_of_nodes()
        n_edges = g.number_of_edges()
        gidx = g._graph.get_immutable_gidx(utils.to_dgl_context(score.device))
        ctx.backward_cache = n_nodes, n_edges, gidx

        # g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
        smax = F.copy_reduce("max", gidx, TargetCode.EDGE, score, n_nodes)
        # g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
        out = F.binary_reduce("none", "sub", gidx, TargetCode.EDGE,
                              TargetCode.DST, score, smax, n_edges)

        # g.edata['out'] = th.exp(g.edata['out'])
        out = th.exp(out)
        # g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
        out_sum = F.copy_reduce("sum", gidx, TargetCode.EDGE, out, n_nodes)
        # g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
        out = F.binary_reduce("none", "div", gidx, TargetCode.EDGE,
                              TargetCode.DST, out, out_sum, n_edges)

        ctx.save_for_backward(out)
        return out
예제 #2
0
    def forward(ctx, g, score, eids):
        """Forward function.
        Pseudo-code:
        .. code:: python
            score = dgl.EData(g, score)
            score_max = score.dst_max()  # of type dgl.NData
            score = score - score_max  # edge_sub_dst, ret dgl.EData
            score_sum = score.dst_sum()  # of type dgl.NData
            out = score / score_sum    # edge_div_dst, ret dgl.EData
            return out.data
        """
        # remember to save the graph to backward cache before making it
        # a local variable
        if not is_all(eids):
            g = g.edge_subgraph(eids.long())

        n_nodes = g.number_of_dst_nodes()
        n_edges = g.number_of_edges()

        # TODO(BarclayII): this is a temporary fix of memory leakage in PyTorch
        # in PR #1139.  We should investigate further on what was actually happening
        # when implementing EdgeSoftmax with message passing API instead of
        # operators.
        score_context = utils.to_dgl_context(score.device)
        if isinstance(g, DGLGraph):
            gidx = g._graph.get_immutable_gidx(score_context)
        elif isinstance(g, DGLHeteroGraph):
            assert g._graph.number_of_etypes() == 1, \
                "EdgeSoftmax only support one edge type"
            gidx = g._graph.get_unitgraph(0, score_context)

        ctx.backward_cache = n_nodes, n_edges, gidx

        #g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
        smax = F.copy_reduce('max', gidx, TargetCode.EDGE, score, n_nodes)
        #g.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
        out = F.binary_reduce('none', 'sub', gidx, TargetCode.EDGE,
                              TargetCode.DST, score, smax, n_edges)
        #g.edata['out'] = th.exp(g.edata['out'])
        out = th.exp(out)
        #g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
        out_sum = F.copy_reduce('sum', gidx, TargetCode.EDGE, out, n_nodes)
        #g.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
        out = F.binary_reduce('none', 'div', gidx, TargetCode.EDGE,
                              TargetCode.DST, out, out_sum, n_edges)

        ctx.save_for_backward(out)
        return out
예제 #3
0
    def backward(ctx, grad_out):
        """Backward function.
        Pseudo-code:
        .. code:: python
            g, out = ctx.backward_cache
            grad_out = dgl.EData(g, grad_out)
            out = dgl.EData(g, out)
            sds = out * grad_out  # type dgl.EData
            sds_sum = sds.dst_sum()  # type dgl.NData
            grad_score = sds - sds * sds_sum  # multiple expressions
            return grad_score.data
        """
        n_nodes, n_edges, gidx = ctx.backward_cache
        out, = ctx.saved_tensors

        # g.edata['grad_s'] = out * grad_out
        grad_s = out * grad_out
        # g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
        accum = F.copy_reduce("sum", gidx, TargetCode.EDGE, grad_s, n_nodes)
        # g.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
        out = F.binary_reduce("none", "mul", gidx, TargetCode.EDGE,
                              TargetCode.DST, out, accum, n_edges)
        # grad_score = g.edata['grad_s'] - g.edata['out']
        grad_score = grad_s - out

        return None, grad_score, None
예제 #4
0
    def gat_layer_dgl(feat, weight, attn_l, attn_r, in_feat_len, out_feat_len):
        feat2 = torch.mm(feat, weight)
        att_l = torch.mm(feat2, attn_l)
        att_r = torch.mm(feat2, attn_r)
        g.srcdata.update({'ft': feat2, 'el': att_l})
        g.dstdata.update({'er': att_r})
        g.apply_edges(fn.u_add_v('el', 'er', 'e'))
        e = torch.exp(F.leaky_relu(g.edata.pop('e'), 0.1))

        cont = utils.to_dgl_context(e.device)
        gidx = g._graph.get_immutable_gidx(cont)
        e_sum = backend.copy_reduce("sum", gidx, TargetCode.EDGE, e, num_v)
        att = backend.binary_reduce('none', 'div', gidx, TargetCode.EDGE,
                                    TargetCode.DST, e, e_sum, n_edges)
        g.edata['a'] = att
        g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        output = g.dstdata['ft']
        torch.cuda.synchronize()
        return output