def get_post_init_ops(self):
        """Broadcast initialized values of variables to other devices.

    Returns:
      At task 0 device 0, broadcast_send.
      At all other devices and tasks, broadcast_recv.
    """
        global_vars = tf.global_variables()
        group_size = self._num_workers * self._num_gpus
        post_init_ops = []
        # Gather variables into same-var-different-device groups.
        vars_by_suffix = dict()
        for v in global_vars:
            split_name = v.name.split('/')
            mo = re.match(r'v(\d+)$', split_name[0])
            if mo:
                device_id = int(mo.group(1))
                suffix = '/'.join(split_name[1:])
                if suffix in vars_by_suffix.keys():
                    vars_by_suffix[suffix].append(v)
                else:
                    vars_by_suffix[suffix] = [v]
        # Generate broadcast ops for each such group.
        for suffix in sorted(vars_by_suffix):
            vlist = vars_by_suffix[suffix]
            assert self._num_gpus == len(vlist)
            devices = [v.device for v in vlist]
            # NOTE: this key should generate the same value for all tasks
            group_key = allreduce.collective_group_key(devices)
            group_size = self._num_workers * len(devices)
            instance_key = self._get_instance_key(suffix)
            for v in vlist:
                split_name = v.name.split('/')
                mo = re.match(r'v(\d+)$', split_name[0])
                if mo:
                    device_id = int(mo.group(1))
                    if (self._task_id == 0 and device_id == 0):
                        with tf.device(v.device):
                            bcast_send = allreduce.broadcast_send(
                                v, v.shape, v.dtype, group_size, group_key,
                                instance_key)
                            post_init_ops.append(v.assign(bcast_send))
                    else:
                        with tf.device(v.device):
                            bcast_recv = allreduce.broadcast_recv(
                                v.shape, v.dtype, group_size, group_key,
                                instance_key)
                            post_init_ops.append(v.assign(bcast_recv))
        return post_init_ops
Beispiel #2
0
 def testGroupKey(self):
     d0 = [
         '/job:worker/replica:0/task:0/device:GPU:1',
         '/job:worker/replica:0/task:0/device:GPU:0',
         '/job:worker/replica:0/task:0/device:GPU:3',
     ]
     d1 = [
         '/job:worker/replica:0/task:1/device:GPU:1',
         '/job:worker/replica:0/task:1/device:GPU:0',
         '/job:worker/replica:0/task:1/device:GPU:3',
     ]
     d2 = [
         '/job:worker/replica:0/task:1/device:GPU:1',
         '/job:worker/replica:0/task:1/device:GPU:3',
         '/job:worker/replica:0/task:1/device:GPU:0',
     ]
     d3 = [
         '/job:worker/replica:0/task:1/device:GPU:1',
         '/job:worker/replica:0/task:1/device:GPU:3',
         '/job:worker/replica:0/task:1/device:GPU:2',
     ]
     d4 = [
         '/job:worker/task:0/device:GPU:1',
         '/job:worker/task:0/device:GPU:2',
         '/job:worker/task:0/device:GPU:3',
     ]
     d5 = [
         '/job:worker/task:0/device:CPU:1',
         '/job:worker/task:0/device:CPU:2'
     ]
     d6 = [
         '/job:worker/task:0/device:CPU:2',
         '/job:worker/task:0/device:CPU:1'
     ]
     g0 = allreduce.collective_group_key(d0)
     g1 = allreduce.collective_group_key(d1)
     g2 = allreduce.collective_group_key(d2)
     g3 = allreduce.collective_group_key(d3)
     g4 = allreduce.collective_group_key(d4)
     g5 = allreduce.collective_group_key(d5)
     g6 = allreduce.collective_group_key(d6)
     self.assertEqual(g0, g1)
     self.assertEqual(g0, g2)
     self.assertNotEqual(g0, g3)
     self.assertEqual(g3, g4)
     self.assertEqual(g5, g6)
     self.assertNotEqual(g4, g5)