예제 #1
0
 def __init__(self):
     default_device = '/' + os.path.basename(tf.zeros(()).device)
     default_device = ComputeDevice(self, default_device,
                                    default_device.split(":")[-2], -1, -1,
                                    "", default_device)
     for device in self.list_devices():
         if device.name == default_device.name:
             default_device = device
     Backend.__init__(self, "TensorFlow", default_device)
예제 #2
0
 def list_devices(self, device_type: str or None = None) -> List[ComputeDevice]:
     types = ['cpu', 'gpu', 'tpu'] if device_type is None else [device_type.lower()]
     devices = []
     for device_type in types:
         try:
             for jax_dev in jax.devices(device_type):
                 devices.append(ComputeDevice(self, jax_dev.device_kind, jax_dev.platform.upper(), -1, -1, f"id={jax_dev.id}", jax_dev))
         except RuntimeError as err:
             pass  # this is just Jax not finding anything. jaxlib.xla_client._get_local_backends() could help but isn't currently available on GitHub actions
     return devices
예제 #3
0
 def list_devices(self,
                  device_type: str or None = None) -> List[ComputeDevice]:
     devices = []
     for jax_dev in jax.devices():
         jax_dev_type = jax_dev.platform.upper()
         if device_type is None or device_type == jax_dev_type:
             description = f"id={jax_dev.id}"
             devices.append(
                 ComputeDevice(self, jax_dev.device_kind, jax_dev_type, -1,
                               -1, description, jax_dev))
     return devices
예제 #4
0
 def list_devices(self,
                  device_type: str or None = None) -> List[ComputeDevice]:
     tf_devices = device_lib.list_local_devices()
     devices = []
     for device in tf_devices:
         if device_type in (None, device.device_type):
             devices.append(
                 ComputeDevice(self,
                               device.name,
                               device.device_type,
                               device.memory_limit,
                               processor_count=-1,
                               description=str(device),
                               ref=tf.device(device.name)))
     return devices
예제 #5
0
 def list_devices(self, device_type: str or None = None) -> List[ComputeDevice]:
     devices = []
     if device_type in (None, 'CPU'):
         devices.append(self.cpu)
     if device_type in (None, 'GPU'):
         for index in range(torch.cuda.device_count()):
             properties = torch.cuda.get_device_properties(index)
             devices.append(ComputeDevice(self,
                                          properties.name,
                                          'GPU',
                                          properties.total_memory,
                                          properties.multi_processor_count,
                                          f"compute capability {properties.major}.{properties.minor}",
                                          ref=f'cuda:{index}'))
     return devices
예제 #6
0
 def __init__(self):
     self.cpu = ComputeDevice(self, "CPU", 'CPU', -1, -1, "", ref='cpu')
     Backend.__init__(self, 'PyTorch', default_device=self.cpu)
예제 #7
0
 def __init__(self):
     cpu = NUMPY.cpu
     self.cpu = ComputeDevice(self, "CPU", 'CPU', cpu.memory, cpu.processor_count, cpu.description, ref='cpu')
     gpus = self.list_devices('GPU')
     Backend.__init__(self, 'PyTorch', default_device=gpus[0] if gpus else cpu)