Beispiel #1
0
def batchnorm_add_relu(rank, data, addend, io_layout, batchnorm_layout, bn_group, local_gpus, local_comm, **kwargs):
    # Transpose as needed to batchnorm_layout
    transposed_data_as_needed = transform_layout(data, io_layout, batchnorm_layout)
    transposed_addend_as_needed = transform_layout(addend, io_layout, batchnorm_layout)
    bn_axis = 3 if batchnorm_layout == 'NHWC' else 1

    xbuf_ptr = (ctypes.c_void_p * local_gpus)()

    if bn_group>1:
        sync_depth = bn_group_to_sync_depth(bn_group)
        if local_comm is not None:
            handler = np.zeros(handler_bytes(),dtype=np.byte)
            check_call(_LIB.MXInitXBufSingle(rank, sync_depth, xbuf_ptr, handler.ctypes.data_as(ctypes.c_void_p)))
            handlers = np.asarray([np.zeros(handler_bytes(),dtype=np.byte)]*local_gpus)
            local_comm.Allgather([handler, handler_bytes(), MPI.BYTE], [handlers, handler_bytes(), MPI.BYTE])
            (_LIB.MXOpenIpcHandles(rank, local_gpus, sync_depth, xbuf_ptr, handlers.ctypes.data_as(ctypes.c_void_p)))
        else:
            check_call(_LIB.MXInitXBuf(local_gpus, sync_depth, xbuf_ptr))
   
    anti_gc.append(xbuf_ptr)
    batchnormed = mx.sym.BatchNormAddRelu(data=transposed_data_as_needed,
                                      addend=transposed_addend_as_needed,
                                      axis=bn_axis, bn_group=bn_group, xbuf_ptr=ctypes.addressof(xbuf_ptr), **kwargs)
    # Transpose back to i/o layout as needed
    return transform_layout(batchnormed, batchnorm_layout, io_layout)
Beispiel #2
0
def _init_gbn_buffers(bn_group):
    assert bn_group >= 1, 'bn_group can\'t be smaller than 1'
    if bn_group == 1:
        return _Null

    sync_depth = int(math.log2(bn_group))  # required sync steps
    if USE_MPI4PY:
        global_comm = MPI.COMM_WORLD
        local_comm = global_comm.Split_type(MPI.COMM_TYPE_SHARED)
        local_gpus = local_comm.Get_size()
        xbuf_ptr = (ctypes.c_void_p * local_gpus)()
        rank = hvd.local_rank()
        handler = np.zeros(handler_bytes(), dtype=np.byte)
        check_call(
            _LIB.MXInitXBufSingle(rank, sync_depth, xbuf_ptr,
                                  handler.ctypes.data_as(ctypes.c_void_p)))
        handlers = np.asarray([np.zeros(handler_bytes(), dtype=np.byte)] *
                              local_gpus)
        local_comm.Allgather([handler, handler_bytes(), MPI.BYTE],
                             [handlers, handler_bytes(), MPI.BYTE])
        check_call(
            _LIB.MXOpenIpcHandles(rank, local_gpus, sync_depth, xbuf_ptr,
                                  handlers.ctypes.data_as(ctypes.c_void_p)))
    else:
        local_gpus = hvd.local_size()
        xbuf_ptr = (ctypes.c_void_p * local_gpus)()
        check_call(_LIB.MXInitXBuf(local_gpus, sync_depth, xbuf_ptr))

    anti_gc.append(xbuf_ptr)
    return ctypes.addressof(xbuf_ptr)