Exemplo n.º 1
0
    def WorkerDeviceInModelSplit(self, device_index):
        """Returns the device to use for 'device_index' for the current model split.

    Args:
      device_index: An int, the device index within 'model_split'.

    Returns:
      A string. The device to place ops onto.
    """
        devices = self.available_devices.reshape([-1]).tolist()
        if not devices:
            return ''
        else:
            model_split = py_utils.GetModelSplit()
            assert model_split < self.num_splits_per_client, (
                '%d %d' % (model_split, self.num_splits_per_client))
            devices_per_split = self.num_devices_per_split
            return devices[devices_per_split * model_split +
                           device_index % devices_per_split]
Exemplo n.º 2
0
 def testModelSplit(self):
   with py_utils.ModelSplit(2):
     assert py_utils.GetModelSplit() == 2
     with py_utils.ModelSplit(3):
       assert py_utils.GetModelSplit() == 3
   assert py_utils.GetModelSplit() == 0