Esempio n. 1
0
def _threshold_and_support_graph(gidx: HeteroGraphIndex, scores: Tensor,
                                 end_n_ids: Tensor):
    """Find the threshold for each node and its edges"""
    in_degrees = _gspmm(gidx, "copy_rhs", "sum", None,
                        torch.ones_like(scores))[0]
    cum_in_degrees = torch.cat(
        [in_degrees.new_zeros(1),
         in_degrees.cumsum(dim=0)[:-1]], dim=0)

    # perform sort on edges for each node
    sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(
        scores, end_n_ids, in_degrees, cum_in_degrees)
    cumsum_scores = cumsum_scores - 1.
    support = rhos * sorted_scores > cumsum_scores
    support = support[
        dense_reverse_perm]  # from sorted order to unsorted order
    support = support[reverse_perm]  # from src-dst order to eid order

    support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0]
    support_size = support_size.long()
    idx = support_size + cum_in_degrees - 1

    # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
    mask = idx < 0
    idx[mask] = 0
    tau = cumsum_scores.gather(0, idx.long())
    tau /= support_size.to(scores.dtype)

    return tau, support_size
Esempio n. 2
0
def calcRepulsiveTdist(graph, embed, nV):
    D = _gsddmm(graph._graph,
                "sub",
                embed[:nV],
                embed[:nV],
                lhs_target='u',
                rhs_target='v')
    if torch.isnan(D).any():
        print("D got problem")
    d = _gsddmm(graph._graph, 'dot', D, D, lhs_target='e', rhs_target='e')
    if torch.isnan(d).any():
        print("d got problem")
    E = 2.0 / ((d) * (1.0 + d))
    if torch.isnan(E).any():
        print("First E got problem")
    E = _gsddmm(graph._graph, "mul", D, E, lhs_target='e', rhs_target='e')
    if torch.isnan(E).any():
        print("Second E got problem")
    E = allTdistScale(E)
    E = 0.02 * E
    if torch.isnan(E).any():
        print("Scaling E got problem")
    outputr = _gspmm(graph._graph.reverse(), "copy_rhs", "sum", E, E)[0]
    if torch.isnan(outputr).any():
        print("outputr in recals got problem")
    return outputr
Esempio n. 3
0
def calcRepulsive(graph, embed, sm_table, V):
    X = _gsddmm(graph._graph,
                'dot',
                embed[:V],
                embed[:V],
                lhs_target='u',
                rhs_target='v')
    Y = allFastSigmoid(X, sm_table, 0)
    output = _gspmm(graph._graph.reverse(), "mul", "sum", embed[:V], Y)[0]
    return output
Esempio n. 4
0
def dglusingtdistribution(batchgraphs, embed, iterations=1, lrate=1.0):
    it = 0
    totalktime = 0
    kerneltime = []
    purektime = 0
    while it < iterations:
        for [graph, s, e] in batchgraphs:
            #just to make sure: dim(sparse graph) == dim(embedding)
            pV = len(graph.nodes())
            start = time.time()
            #SDDMM
            D = _gsddmm(graph._graph,
                        "sub",
                        embed[:pV],
                        embed[:pV],
                        lhs_target='u',
                        rhs_target='v')
            d = _gsddmm(graph._graph,
                        'dot',
                        D,
                        D,
                        lhs_target='e',
                        rhs_target='e')
            end1 = time.time()
            #NonlinearTransformation
            #E = - 2.0 / (1.0 + d) # for t-distribution
            E = 1.0 + 1.0 / d  # for FR-model
            start1 = time.time()
            #SDDMM
            E = _gsddmm(graph._graph,
                        "mul",
                        D,
                        E,
                        lhs_target='e',
                        rhs_target='e')
            end2 = time.time()
            #scaling
            #E = allscale(E) # for t-distribution
            start2 = time.time()
            #SPMM
            outputa = _gspmm(graph._graph.reverse(), "copy_rhs", "sum", E,
                             E)[0]
            end = time.time()
            #embed[:pV] = embed[:pV] + lrate * outputa[:pV]
            #purektime += end1 - start + end2 - start1 + end - start2
            kerneltime.append(end1 - start + start1 - end1 + end2 - start1 +
                              end - start2)
            totalktime += end - start
        it += 1
    print("Total GDL Kernel Time:", totalktime, "s", ", Kernel T:", kerneltime,
          "Avg. Kernel T:",
          sum(kerneltime) / len(kerneltime))
    print("dglTime:", len(embed), ":DIM", len(embed[0]), ":avgtime:",
          sum(kerneltime) / len(kerneltime))
    return embed
Esempio n. 5
0
    def backward(ctx, grad_out):
        gidx = ctx.backward_cache
        supp_size, out = ctx.saved_tensors
        grad_in = grad_out.clone()

        # grad for ReLU
        grad_in[out == 0] = 0

        # dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
        v_hat = _gspmm(gidx, "copy_rhs", "sum", None,
                       grad_in)[0] / supp_size.to(out.dtype)
        grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
        grad_in = torch.where(out != 0, grad_in_modify, grad_in)
        del gidx
        torch.cuda.empty_cache()

        return None, grad_in, None, None, None
Esempio n. 6
0
def dglusingsigmoid(batchgraphs, embed, iterations=1, lrate=1.0):
    it = 0
    totaltime = []
    kerneltime = []
    purekerneltime = 0
    #sm_table = init_SM_TABLE()
    #print(sm_table)
    while it < iterations:
        start = time.time()
        for [graph, s, e] in batchgraphs:
            ngraph = negativeSamples(graph, s, e)
            nV = len(ngraph.nodes())
            #just to make sure: dim(sparse graph) == dim(embedding)
            pV = len(graph.nodes())
            #SDDMM operation
            kstart = time.time()
            X = _gsddmm(graph._graph,
                        'dot',
                        embed[:pV],
                        embed[:pV],
                        lhs_target='u',
                        rhs_target='v')
            kend1 = time.time()
            #non-linear transformation
            Y = 1.0 * X  #- 1.0 / (1 + torch.exp(-X))
            #Y = allFastSigmoid(X, sm_table)
            #SPMM operation
            kstart1 = time.time()
            outputa = _gspmm(graph._graph.reverse(), "mul", "sum", embed[:pV],
                             Y)[0]
            kend = time.time()
            #outputr = calcRepulsive(ngraph, embed, sm_table, nV)
            #embed[s:e] = lrate * outputa[s:e]
            #embed[s:e] = embed[s:e] + lrate * outputr[s:e]
            purekerneltime += (kend1 - kstart) + (kend - kstart1)
            kerneltime.append((kend1 - kstart) + (kend - kstart1) +
                              (kstart1 - kend1))
        end = time.time()
        totaltime.append(end - start)
        it += 1
    print("GDL Total Time:", sum(totaltime), "s", kerneltime, ", Kernel Time:",
          sum(kerneltime), 'Avg. Kernel T:',
          sum(kerneltime) / len(kerneltime))
    print("dglTime:", len(embed), ":DIM", len(embed[0]), ":avgtime:",
          sum(kerneltime) / len(kerneltime))
    return embed
Esempio n. 7
0
    def forward(ctx, gidx: HeteroGraphIndex, scores: Tensor, eids: Tensor,
                end_n_ids: Tensor, norm_by: str):
        if not is_all(eids):
            gidx = gidx.edge_subgraph([eids], True).graph
        if norm_by == "src":
            gidx = gidx.reverse()

        # use feat - max(feat) for numerical stability.
        scores = scores.float()
        scores_max = _gspmm(gidx, "copy_rhs", "max", None, scores)[0]
        scores = _gsddmm(gidx, "sub", scores, scores_max, "e", "v")

        # find threshold for each node and perform ReLU(u-t(u)) operation.
        tau, supp_size = _threshold_and_support_graph(gidx, scores, end_n_ids)
        out = torch.clamp(_gsddmm(gidx, "sub", scores, tau, "e", "v"), min=0)
        ctx.backward_cache = gidx
        ctx.save_for_backward(supp_size, out)
        torch.cuda.empty_cache()
        return out
Esempio n. 8
0
def dglusingGCN(batchgraphs, embed, iterations=1, lrate=1.0):
    it = 0
    totaltime = []
    kerneltime = []
    while it < iterations:
        for [graph, s, e] in batchgraphs:
            ngraph = negativeSamples(graph, s, e)
            nV = len(ngraph.nodes())
            pV = len(graph.nodes())
            kstart = time.time()
            E = _gsddmm(graph._graph,
                        "copy_rhs",
                        embed[:pV],
                        embed[:pV],
                        lhs_target='u',
                        rhs_target='v')
            output = _gspmm(graph._graph.reverse(), "copy_rhs", "sum", E, E)[0]
            kend = time.time()
            kerneltime.append(kend - kstart)
        it += 1
    print("dglusingGCN")
    print("dglTime:", len(embed), ":DIM", len(embed[0]), ":avgtime:",
          sum(kerneltime) / len(kerneltime))
    return output