Example #1
0
            return ops.PythonFunction.function_wrapper_per_sample(
                func, self.num_outputs, torch_dlpack.from_dlpack,
                torch_dlpack.to_dlpack, *args)

    def __call__(self, *inputs, **kwargs):
        pipeline = Pipeline.current()
        if pipeline is None:
            Pipeline._raise_no_current_pipeline("TorchPythonFunction")
        if self.stream is None:
            self.stream = torch.cuda.Stream(device=pipeline.device_id)
        return super(TorchPythonFunction, self).__call__(*inputs, **kwargs)

    def __init__(self,
                 function,
                 num_outputs=1,
                 device='cpu',
                 batch_processing=False,
                 **kwargs):
        self.stream = None
        super(TorchPythonFunction,
              self).__init__(impl_name="DLTensorPythonFunctionImpl",
                             function=lambda *ins: self.torch_wrapper(
                                 batch_processing, function, device, *ins),
                             num_outputs=num_outputs,
                             device=device,
                             batch_processing=batch_processing,
                             **kwargs)


ops._wrap_op(TorchPythonFunction, "fn", __name__)
Example #2
0
                                     in_shapes_np[4][0])
                if num_ins >= 6:
                    in5 = in5_lambda(address_as_void_pointer(in_arr[5]),
                                     in_shapes_np[5][0])

                run_fn_lambda(run_fn, out0, out1, out2, out3, out4, out5, in0,
                              in1, in2, in3, in4, in5)

        self._impl_name = "NumbaFuncImpl"
        self._schema = _b.GetSchema(self._impl_name)
        self._spec = _b.OpSpec(self._impl_name)
        self._device = device

        kwargs, self._call_args = ops._separate_kwargs(kwargs)

        for key, value in kwargs.items():
            self._spec.AddArg(key, value)

        self.run_fn = run_cfunc.address
        self.setup_fn = setup_fn_address
        self.out_types = out_types
        self.in_types = in_types
        self.outs_ndim = outs_ndim
        self.ins_ndim = ins_ndim
        self.num_outputs = len(out_types)
        self.batch_processing = batch_processing
        self._preserve = True


ops._wrap_op(NumbaFunction, "fn.experimental", "nvidia.dali.plugin.numba")