def testContainsIndexedSlices_PerDeviceMapOutput(self):
   t0 = math_ops._as_indexed_slices(
       constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
   t1 = math_ops._as_indexed_slices(
       constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
   per_device = value_lib.PerDevice({
       "/gpu:0": value_lib.MapOutput([t0]),
       "/cpu:0": value_lib.MapOutput([t1])})
   self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
예제 #2
0
 def map(self, map_over, fn, *args, **kwargs):
   # TODO(josh11b): In eager mode, use one thread per device.
   index = {}
   for i, m in enumerate(map_over):
     d = self._devices[i % len(self._devices)]
     with ops.device(d):
       l = index.get(d, [])
       l.append(fn(m,
                   *values.select_device_mirrored(d, args),
                   **values.select_device_mirrored(d, kwargs)))
       index[d] = l
   # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput
   # in addition to PerDevice data.
   return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
예제 #3
0
 def map(self, map_over, fn, *args, **kwargs):
     # TODO (josh11b): In eager mode, use one thread per device. id:1098
     # https://github.com/imdone/tensorflow/issues/1099
     index = {}
     i = 0
     for m in map_over:
         d = self._devices[i % len(self._devices)]
         with ops.device(d):
             l = index.get(d, [])
             l.append(
                 fn(m, *values.select_device_mirrored(d, args),
                    **values.select_device_mirrored(d, kwargs)))
             index[d] = l
     # TODO (josh11b): Need a values.regroup equivalent that handles MapOutput id:1079
     # https://github.com/imdone/tensorflow/issues/1080
     # in addition to PerDevice data.
     return values.PerDevice(
         {k: values.MapOutput(v)
          for k, v in index.items()})
예제 #4
0
 def map(self, map_over, fn, *args, **kwargs):
     with ops.device(self._device):
         return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])