def get_post_init_ops(self): """Broadcast initialized values of variables to other devices. Returns: At task 0 device 0, broadcast_send. At all other devices and tasks, broadcast_recv. """ global_vars = tf.global_variables() group_size = self._num_workers * self._num_gpus post_init_ops = [] # Gather variables into same-var-different-device groups. vars_by_suffix = dict() for v in global_vars: split_name = v.name.split('/') mo = re.match(r'v(\d+)$', split_name[0]) if mo: device_id = int(mo.group(1)) suffix = '/'.join(split_name[1:]) if suffix in vars_by_suffix.keys(): vars_by_suffix[suffix].append(v) else: vars_by_suffix[suffix] = [v] # Generate broadcast ops for each such group. for suffix in sorted(vars_by_suffix): vlist = vars_by_suffix[suffix] assert self._num_gpus == len(vlist) devices = [v.device for v in vlist] # NOTE: this key should generate the same value for all tasks group_key = allreduce.collective_group_key(devices) group_size = self._num_workers * len(devices) instance_key = self._get_instance_key(suffix) for v in vlist: split_name = v.name.split('/') mo = re.match(r'v(\d+)$', split_name[0]) if mo: device_id = int(mo.group(1)) if (self._task_id == 0 and device_id == 0): with tf.device(v.device): bcast_send = allreduce.broadcast_send( v, v.shape, v.dtype, group_size, group_key, instance_key) post_init_ops.append(v.assign(bcast_send)) else: with tf.device(v.device): bcast_recv = allreduce.broadcast_recv( v.shape, v.dtype, group_size, group_key, instance_key) post_init_ops.append(v.assign(bcast_recv)) return post_init_ops
def testGroupKey(self): d0 = [ '/job:worker/replica:0/task:0/device:GPU:1', '/job:worker/replica:0/task:0/device:GPU:0', '/job:worker/replica:0/task:0/device:GPU:3', ] d1 = [ '/job:worker/replica:0/task:1/device:GPU:1', '/job:worker/replica:0/task:1/device:GPU:0', '/job:worker/replica:0/task:1/device:GPU:3', ] d2 = [ '/job:worker/replica:0/task:1/device:GPU:1', '/job:worker/replica:0/task:1/device:GPU:3', '/job:worker/replica:0/task:1/device:GPU:0', ] d3 = [ '/job:worker/replica:0/task:1/device:GPU:1', '/job:worker/replica:0/task:1/device:GPU:3', '/job:worker/replica:0/task:1/device:GPU:2', ] d4 = [ '/job:worker/task:0/device:GPU:1', '/job:worker/task:0/device:GPU:2', '/job:worker/task:0/device:GPU:3', ] d5 = [ '/job:worker/task:0/device:CPU:1', '/job:worker/task:0/device:CPU:2' ] d6 = [ '/job:worker/task:0/device:CPU:2', '/job:worker/task:0/device:CPU:1' ] g0 = allreduce.collective_group_key(d0) g1 = allreduce.collective_group_key(d1) g2 = allreduce.collective_group_key(d2) g3 = allreduce.collective_group_key(d3) g4 = allreduce.collective_group_key(d4) g5 = allreduce.collective_group_key(d5) g6 = allreduce.collective_group_key(d6) self.assertEqual(g0, g1) self.assertEqual(g0, g2) self.assertNotEqual(g0, g3) self.assertEqual(g3, g4) self.assertEqual(g5, g6) self.assertNotEqual(g4, g5)