def WorkerDeviceInModelSplit(self, device_index): """Returns the device to use for 'device_index' for the current model split. Args: device_index: An int, the device index within 'model_split'. Returns: A string. The device to place ops onto. """ devices = self.available_devices.reshape([-1]).tolist() if not devices: return '' else: model_split = py_utils.GetModelSplit() assert model_split < self.num_splits_per_client, ( '%d %d' % (model_split, self.num_splits_per_client)) devices_per_split = self.num_devices_per_split return devices[devices_per_split * model_split + device_index % devices_per_split]
def testModelSplit(self): with py_utils.ModelSplit(2): assert py_utils.GetModelSplit() == 2 with py_utils.ModelSplit(3): assert py_utils.GetModelSplit() == 3 assert py_utils.GetModelSplit() == 0