예제 #1
0
def _send_kv_msg(sender, msg, recv_id):
    """Send kvstore message.

    Parameters
    ----------
    sender : ctypes.c_void_p
        C sender handle
    msg : KVStoreMsg
        kvstore message
    recv_id : int
        receiver's ID
    """
    if msg.type == KVMsgType.PULL:
        tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
        _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank,
                              msg.name, tensor_id)
    elif msg.type in (KVMsgType.INIT, KVMsgType.GET_SHAPE_BACK):
        tensor_shape = F.zerocopy_to_dgl_ndarray(msg.shape)
        _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank,
                              msg.name, tensor_shape)
    elif msg.type in (KVMsgType.IP_ID, KVMsgType.GET_SHAPE):
        _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank,
                              msg.name)
    elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
        _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank)
    else:
        tensor_id = F.zerocopy_to_dgl_ndarray(msg.id)
        data = F.zerocopy_to_dgl_ndarray(msg.data)
        _CAPI_SenderSendKVMsg(sender, int(recv_id), msg.type.value, msg.rank,
                              msg.name, tensor_id, data)
예제 #2
0
def _clear_kv_msg(msg):
    """Clear data of kvstore message

    Parameters
    ----------
    msg : KVStoreMsg
        kvstore message
    """
    if msg.data is not None:
        F.sync()
        data = F.zerocopy_to_dgl_ndarray(msg.data)
        _CAPI_DeleteNDArrayData(data)
예제 #3
0
def _fast_pull(name, id_tensor,
               machine_count, group_count, machine_id, client_id,
               partition_book, g2l, local_data,
               sender, receiver):
    """ Pull message

    Parameters
    ----------
    name : str
        data name string
    id_tensor : tensor
        tensor of ID
    machine_count : int
        count of total machine
    group_count : int
        count of server group
    machine_id : int
        current machine id
    client_id : int
        current client ID
    partition_book : tensor
        tensor of partition book
    g2l : tensor
        tensor of global2local
    local_data : tensor
        tensor of local shared data
    sender : ctypes.c_void_p
        C Sender handle
    receiver : ctypes.c_void_p
        C Receiver handle

    Return
    ------
    tensor
        target tensor
    """
    if g2l is not None:
        res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
                                    F.zerocopy_to_dgl_ndarray(id_tensor),
                                    F.zerocopy_to_dgl_ndarray(partition_book),
                                    F.zerocopy_to_dgl_ndarray(local_data),
                                    sender, receiver, 'has_g2l',
                                    F.zerocopy_to_dgl_ndarray(g2l))
    else:
        res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
                                    F.zerocopy_to_dgl_ndarray(id_tensor),
                                    F.zerocopy_to_dgl_ndarray(partition_book),
                                    F.zerocopy_to_dgl_ndarray(local_data),
                                    sender, receiver, 'no_g2l')

    return F.zerocopy_from_dgl_ndarray(res_tensor)
예제 #4
0
import torch
import dgl
import dgl.backend as F

g = dgl.rand_graph(10, 15).int().to(torch.device(0))
gidx = g._graph
u = torch.rand((10, 2, 8), device=torch.device(0))
v = torch.rand((10, 2, 8), device=torch.device(0))
e = dgl.ops.gsddmm(g, 'dot', u, v)
print(e)
e = torch.zeros((15, 2, 1), device=torch.device(0))
u = F.zerocopy_to_dgl_ndarray(u)
v = F.zerocopy_to_dgl_ndarray(v)
e = F.zerocopy_to_dgl_ndarray_for_write(e)
dgl.sparse._CAPI_FG_LoadModule("../build/featgraph/libfeatgraph_kernels.so")
dgl.sparse._CAPI_FG_SDDMMTreeReduction(gidx, u, v, e)
print(e)