示例#1
0
文件: network.py 项目: zhcomeon/dgl
def _recv_kv_msg(receiver):
    """Receive kvstore message.

    Parameters
    ----------
    receiver : ctypes.c_void_p
        C Receiver handle

    Return
    ------
    KVStoreMsg
        kvstore message
    """
    msg_ptr = CAPI_ReceiverRecvKVMsg(receiver)
    msg_type = KVMsgType(_CAPI_ReceiverGetKVMsgType(msg_ptr))
    rank = _CAPI_ReceiverGetKVMsgRank(msg_ptr)
    if msg_type == KVMsgType.PULL:
        name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
        tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr))
        msg = KVStoreMsg(
            type=msg_type,
            rank=rank,
            name=name,
            id=tensor_id,
            data=None,
            c_ptr=msg_ptr)
        return msg
    elif msg_type == KVMsgType.IP_ID:
        name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
        msg = KVStoreMsg(
            type=msg_type,
            rank=rank,
            name=name,
            id=None,
            data=None,
            c_ptr=msg_ptr)
        return msg
    elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
        msg = KVStoreMsg(
            type=msg_type,
            rank=rank,
            name=None,
            id=None,
            data=None,
            c_ptr=msg_ptr)
        return msg
    else:
        name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
        tensor_id = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgID(msg_ptr))
        data = F.zerocopy_from_dgl_ndarray(_CAPI_ReceiverGetKVMsgData(msg_ptr))
        msg = KVStoreMsg(
            type=msg_type,
            rank=rank,
            name=name,
            id=tensor_id,
            data=data,
            c_ptr=msg_ptr)
        return msg

    raise RuntimeError('Unknown message type: %d' % msg_type.value)
示例#2
0
def _fast_pull(name, id_tensor,
               machine_count, group_count, machine_id, client_id,
               partition_book, g2l, local_data,
               sender, receiver):
    """ Pull message

    Parameters
    ----------
    name : str
        data name string
    id_tensor : tensor
        tensor of ID
    machine_count : int
        count of total machine
    group_count : int
        count of server group
    machine_id : int
        current machine id
    client_id : int
        current client ID
    partition_book : tensor
        tensor of partition book
    g2l : tensor
        tensor of global2local
    local_data : tensor
        tensor of local shared data
    sender : ctypes.c_void_p
        C Sender handle
    receiver : ctypes.c_void_p
        C Receiver handle

    Return
    ------
    tensor
        target tensor
    """
    if g2l is not None:
        res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
                                    F.zerocopy_to_dgl_ndarray(id_tensor),
                                    F.zerocopy_to_dgl_ndarray(partition_book),
                                    F.zerocopy_to_dgl_ndarray(local_data),
                                    sender, receiver, 'has_g2l',
                                    F.zerocopy_to_dgl_ndarray(g2l))
    else:
        res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
                                    F.zerocopy_to_dgl_ndarray(id_tensor),
                                    F.zerocopy_to_dgl_ndarray(partition_book),
                                    F.zerocopy_to_dgl_ndarray(local_data),
                                    sender, receiver, 'no_g2l')

    return F.zerocopy_from_dgl_ndarray(res_tensor)