def run_all_to_all_3devices(): collectives = [] with ops.device(dev0): group_handle0 = _collective_ops.initialize_communicator( group_key=group_key, rank=1, group_size=group_size, communication_hint=communication) collectives.append( _collective_ops.all_to_all_v3(group_handle0, constant_op.constant([1.0, 2.0, 3.0]))) with ops.device(dev1): group_handle1 = _collective_ops.initialize_communicator( group_key=group_key, rank=0, group_size=group_size, communication_hint=communication) collectives.append( _collective_ops.all_to_all_v3(group_handle1, constant_op.constant([4.0, 5.0, 6.0]))) with ops.device(dev2): group_handle2 = _collective_ops.initialize_communicator( group_key=group_key, rank=2, group_size=group_size, communication_hint=communication) collectives.append( _collective_ops.all_to_all_v3(group_handle2, constant_op.constant([7.0, 8.0, 9.0]))) return collectives
def run_all_to_all_2devices(): collectives = [] with ops.device(dev0): group_handle0 = _collective_ops.initialize_communicator( group_key=group_key, rank=0, group_size=group_size, communication_hint=communication) collectives.append( _collective_ops.all_to_all_v3(group_handle0, [1.0, 3.0])) with ops.device(dev1): group_handle1 = _collective_ops.initialize_communicator( group_key=group_key, rank=1, group_size=group_size, communication_hint=communication) collectives.append( _collective_ops.all_to_all_v3(group_handle1, [2.0, 4.0])) return collectives