def test_concat_arrays_to_chainerx(self): device = 'native:0' self.check_concat_arrays( self.int_arrays, device, backend.ChainerxDevice(chainerx.get_device(device)), numpy.int64) self.check_concat_arrays( self.float_arrays, device, backend.ChainerxDevice(chainerx.get_device(device)), numpy.float64)
def _chainerx_apply_fallback_preprocess(self, in_data, inputs): chainerx_in_data = in_data in_data = [] device = None for data, x in six.moves.zip(chainerx_in_data, inputs): if data is None: fallback_data = None else: # Use the cached fallback arrays as inputs if they exist. x_is_variable = isinstance(x, variable.Variable) if x_is_variable and x._chainerx_fallback_array is not None: fallback_data = x._chainerx_fallback_array if device is None: device = x.device else: fallback_data = backend.from_chainerx(data) if device is None: device = backend.ChainerxDevice(data.device) # Update the fallback cache if possible. if x_is_variable: x._chainerx_fallback_array = fallback_data in_data.append(fallback_data) in_data = tuple(in_data) return chainerx_in_data, in_data, device
def test_init(self, backend_config): name = backend_config.chainerx_device chx_device = chainerx.get_device(name) device = backend.ChainerxDevice(chx_device) self.check_device(device, backend_config) assert device.device is chx_device
def test_from_fallback_device(self, backend_config): # Preparation: it depends on ChainerxDevice.fallback_device tmp_device = backend.ChainerxDevice( chainerx.get_device(backend_config.chainerx_device)) fallback_device = tmp_device.fallback_device # Test device = backend.ChainerxDevice.from_fallback_device(fallback_device) self.check_device(device, backend_config) assert device.fallback_device == fallback_device
def device(self): if self._device is None: if self.use_cuda: device = backend.GpuDevice.from_device_id(self.cuda_device) elif self.use_chainerx: device = backend.ChainerxDevice( chainerx.get_device(self.chainerx_device)) elif self.use_ideep != 'never': device = backend.Intel64Device() else: device = backend.CpuDevice() self._device = device return self._device
def test_concat_tuples_to_chainerx(self, backend_config): device = chainerx.get_device('native:0') arrays = self.get_tuple_arrays_to_concat(backend_config) self.check_concat_tuples( arrays, device, backend.ChainerxDevice(device))