def wrapper(*args): args = [vaex.array_types.to_numpy(ar) for ar in args] def getbuf(name, value=None, dtype=np.dtype("float32"), N=None): buf = getattr(storage, name, None) if value is not None: N = len(value) dtype = value.dtype if dtype.name == "float64": warnings.warn("Casting input argument from float64 to float32 since Metal does not support float64") dtype = np.dtype("float32") nbytes = N * dtype.itemsize if buf is not None and buf.length() != nbytes: # doesn't match size, create a new one buf = None # create a buffer if buf is None: buf = self.device.newBufferWithLength_options_(nbytes, 0) setattr(storage, name, buf) # copy data to buffer if value is not None: mv = buf.contents().as_buffer(buf.length()) buf_as_numpy = np.frombuffer(mv, dtype=dtype) buf_as_numpy[:] = value.astype(dtype, copy=False) return buf input_buffers = [getbuf(name, chunk) for name, chunk in zip(self.arguments, args)] output_buffer = getbuf('vaex_output', N=len(args[0]), dtype=dtype_out) buffers = input_buffers + [output_buffer] command_buffer = command_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() encoder.setComputePipelineState_(state) for i, buf in enumerate(buffers): encoder.setBuffer_offset_atIndex_(buf, 0, i) nitems = len(args[0]) tpgrid = Metal.MTLSize(width=nitems, height=1, depth=1) # state.threadExecutionWidth() == 32 on M1 max # state.maxTotalThreadsPerThreadgroup() == 1024 on M1 max tptgroup = Metal.MTLSize(width=state.threadExecutionWidth(), height=state.maxTotalThreadsPerThreadgroup()//state.threadExecutionWidth(), depth=1) # this is simpler, and gives the same performance # tptgroup = Metal.MTLSize(width=1, height=1, depth=1) encoder.dispatchThreads_threadsPerThreadgroup_(tpgrid, tptgroup) encoder.endEncoding() command_buffer.commit() command_buffer.waitUntilCompleted() output_buffer_py = output_buffer.contents().as_buffer(output_buffer.length()) # do we needs .copy() ? result = np.frombuffer(output_buffer_py, dtype=dtype_out) return result
def test_structs(self): v = Metal.MTLOrigin() self.assertEqual(v.x, 0) self.assertEqual(v.y, 0) self.assertEqual(v.z, 0) v = Metal.MTLSize() self.assertEqual(v.width, 0) self.assertEqual(v.height, 0) self.assertEqual(v.depth, 0) v = Metal.MTLRegion() self.assertIsInstance(v.origin, Metal.MTLOrigin) self.assertIsInstance(v.size, Metal.MTLSize) v = Metal.MTLSamplePosition() self.assertEqual(v.x, 0.0) self.assertEqual(v.y, 0.0)