def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last self._native_run = False if len(self._device_ids) > 1: replication_devices = xm.xla_replication_devices(self._device_ids) self._replication = xm.Replication(self._device_ids, replication_devices) else: self._replication = None self._models = [] self._contexts = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) self._contexts.append(Context(torch.device(device))) if not self._models: # No XLA device, push a vanilla network in. device = self._get_model_device(module) self._models.append(module) self._device_ids.append(device) self._contexts.append(Context(torch.device(device))) self._native_run = True
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._batchdim = batchdim self._drop_last = drop_last self._device_ids = list(device_ids) self._replication = (xm.Replication(self._device_ids) if self._device_ids else None) self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) self._models.append(module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last replication_devices = ( xm.xla_replication_devices(self._device_ids) if self._device_ids else None) self._replication = ( xm.Replication(self._device_ids, replication_devices) if replication_devices else None) self._models = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())