Exemple #1
0
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
   if device_ids is None:
     device_ids = xm.get_xla_supported_devices()
   self._device_ids = list(device_ids)
   self._batchdim = batchdim
   self._drop_last = drop_last
   self._native_run = False
   if len(self._device_ids) > 1:
     replication_devices = xm.xla_replication_devices(self._device_ids)
     self._replication = xm.Replication(self._device_ids, replication_devices)
   else:
     self._replication = None
   self._models = []
   self._contexts = []
   module = network if isinstance(network, torch.nn.Module) else network()
   for device in device_ids:
     device_module = deepcopy(module).to(device=torch.device(device))
     self._models.append(device_module)
     self._contexts.append(Context(torch.device(device)))
   if not self._models:
     # No XLA device, push a vanilla network in.
     device = self._get_model_device(module)
     self._models.append(module)
     self._device_ids.append(device)
     self._contexts.append(Context(torch.device(device)))
     self._native_run = True
Exemple #2
0
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
     if device_ids is None:
         device_ids = xm.get_xla_supported_devices()
     self._batchdim = batchdim
     self._drop_last = drop_last
     self._device_ids = list(device_ids)
     self._replication = (xm.Replication(self._device_ids)
                          if self._device_ids else None)
     self._models = []
     for device in device_ids:
         module = network().to(device=torch.device(device))
         self._models.append(module)
     if not self._models:
         # No XLA device, push a vanilla network in.
         self._models.append(network())
Exemple #3
0
 def __init__(self, network, device_ids=None, batchdim=0, drop_last=False):
   if device_ids is None:
     device_ids = xm.get_xla_supported_devices()
   self._device_ids = list(device_ids)
   self._batchdim = batchdim
   self._drop_last = drop_last
   replication_devices = (
       xm.xla_replication_devices(self._device_ids)
       if self._device_ids else None)
   self._replication = (
       xm.Replication(self._device_ids, replication_devices)
       if replication_devices else None)
   self._models = []
   module = network if isinstance(network, torch.nn.Module) else network()
   for device in device_ids:
     device_module = deepcopy(module).to(device=torch.device(device))
     self._models.append(device_module)
   if not self._models:
     # No XLA device, push a vanilla network in.
     self._models.append(network())