def batchnorm_add_relu(rank, data, addend, io_layout, batchnorm_layout, bn_group, local_gpus, local_comm, **kwargs): # Transpose as needed to batchnorm_layout transposed_data_as_needed = transform_layout(data, io_layout, batchnorm_layout) transposed_addend_as_needed = transform_layout(addend, io_layout, batchnorm_layout) bn_axis = 3 if batchnorm_layout == 'NHWC' else 1 xbuf_ptr = (ctypes.c_void_p * local_gpus)() if bn_group>1: sync_depth = bn_group_to_sync_depth(bn_group) if local_comm is not None: handler = np.zeros(handler_bytes(),dtype=np.byte) check_call(_LIB.MXInitXBufSingle(rank, sync_depth, xbuf_ptr, handler.ctypes.data_as(ctypes.c_void_p))) handlers = np.asarray([np.zeros(handler_bytes(),dtype=np.byte)]*local_gpus) local_comm.Allgather([handler, handler_bytes(), MPI.BYTE], [handlers, handler_bytes(), MPI.BYTE]) (_LIB.MXOpenIpcHandles(rank, local_gpus, sync_depth, xbuf_ptr, handlers.ctypes.data_as(ctypes.c_void_p))) else: check_call(_LIB.MXInitXBuf(local_gpus, sync_depth, xbuf_ptr)) anti_gc.append(xbuf_ptr) batchnormed = mx.sym.BatchNormAddRelu(data=transposed_data_as_needed, addend=transposed_addend_as_needed, axis=bn_axis, bn_group=bn_group, xbuf_ptr=ctypes.addressof(xbuf_ptr), **kwargs) # Transpose back to i/o layout as needed return transform_layout(batchnormed, batchnorm_layout, io_layout)
def _init_gbn_buffers(bn_group): assert bn_group >= 1, 'bn_group can\'t be smaller than 1' if bn_group == 1: return _Null sync_depth = int(math.log2(bn_group)) # required sync steps if USE_MPI4PY: global_comm = MPI.COMM_WORLD local_comm = global_comm.Split_type(MPI.COMM_TYPE_SHARED) local_gpus = local_comm.Get_size() xbuf_ptr = (ctypes.c_void_p * local_gpus)() rank = hvd.local_rank() handler = np.zeros(handler_bytes(), dtype=np.byte) check_call( _LIB.MXInitXBufSingle(rank, sync_depth, xbuf_ptr, handler.ctypes.data_as(ctypes.c_void_p))) handlers = np.asarray([np.zeros(handler_bytes(), dtype=np.byte)] * local_gpus) local_comm.Allgather([handler, handler_bytes(), MPI.BYTE], [handlers, handler_bytes(), MPI.BYTE]) check_call( _LIB.MXOpenIpcHandles(rank, local_gpus, sync_depth, xbuf_ptr, handlers.ctypes.data_as(ctypes.c_void_p))) else: local_gpus = hvd.local_size() xbuf_ptr = (ctypes.c_void_p * local_gpus)() check_call(_LIB.MXInitXBuf(local_gpus, sync_depth, xbuf_ptr)) anti_gc.append(xbuf_ptr) return ctypes.addressof(xbuf_ptr)