return ops.PythonFunction.function_wrapper_per_sample( func, self.num_outputs, torch_dlpack.from_dlpack, torch_dlpack.to_dlpack, *args) def __call__(self, *inputs, **kwargs): pipeline = Pipeline.current() if pipeline is None: Pipeline._raise_no_current_pipeline("TorchPythonFunction") if self.stream is None: self.stream = torch.cuda.Stream(device=pipeline.device_id) return super(TorchPythonFunction, self).__call__(*inputs, **kwargs) def __init__(self, function, num_outputs=1, device='cpu', batch_processing=False, **kwargs): self.stream = None super(TorchPythonFunction, self).__init__(impl_name="DLTensorPythonFunctionImpl", function=lambda *ins: self.torch_wrapper( batch_processing, function, device, *ins), num_outputs=num_outputs, device=device, batch_processing=batch_processing, **kwargs) ops._wrap_op(TorchPythonFunction, "fn", __name__)
in_shapes_np[4][0]) if num_ins >= 6: in5 = in5_lambda(address_as_void_pointer(in_arr[5]), in_shapes_np[5][0]) run_fn_lambda(run_fn, out0, out1, out2, out3, out4, out5, in0, in1, in2, in3, in4, in5) self._impl_name = "NumbaFuncImpl" self._schema = _b.GetSchema(self._impl_name) self._spec = _b.OpSpec(self._impl_name) self._device = device kwargs, self._call_args = ops._separate_kwargs(kwargs) for key, value in kwargs.items(): self._spec.AddArg(key, value) self.run_fn = run_cfunc.address self.setup_fn = setup_fn_address self.out_types = out_types self.in_types = in_types self.outs_ndim = outs_ndim self.ins_ndim = ins_ndim self.num_outputs = len(out_types) self.batch_processing = batch_processing self._preserve = True ops._wrap_op(NumbaFunction, "fn.experimental", "nvidia.dali.plugin.numba")