def get_device_from_array(*arrays): """Gets the device from arrays. The device on which the given array reside is returned. .. note:: Unlike :func:`get_array_module`, this method does not recognize :class:`~chainer.Variable` objects. If you need to get device from the :class:`~chainer.Variable` instance ``v``, you need to use ``get_device_from_array(v.array)``. Args: arrays (array or list of arrays): Arrays to determine the device. If multiple arrays are given, the device correspoinding to the first array which is not NumPy array is returned. Returns: chainer.Device: Device instance. """ for array in arrays: device = GpuDevice.from_array(array) if device is not None: return device if isinstance(array, chainerx.ndarray): return ChainerxDevice(array.device) device = Intel64Device.from_array(array) if device is not None: return device return CpuDevice()
def get_device_from_array(*arrays): """Gets the device from arrays. The device on which the given array reside is returned. .. note:: Unlike :func:`get_array_module`, this method does not recognize :class:`~chainer.Variable` objects. If you need to get device from the :class:`~chainer.Variable` instance ``v``, you need to use ``get_device_from_array(v.array)``. Args: arrays (array or list of arrays): Arrays to determine the device. If multiple arrays are given, the device correspoinding to the first array which is not NumPy array is returned. Returns: chainer.backend.Device: Device instance. """ for array in arrays: device = GpuDevice.from_array(array) if device is not None: return device if isinstance(array, chainerx.ndarray): return ChainerxDevice(array.device) device = Intel64Device.from_array(array) if device is not None: return device return CpuDevice()
def _get_device_compat(device_spec): # Backward-compatibility version of get_device. # It supports CUDA device index as an integer (numpy if negative) # Returns chainer.Device. if isinstance(device_spec, cuda._integer_types): if device_spec < 0: return CpuDevice() else: return GpuDevice.from_device_id(device_spec) return get_device(device_spec)