Ejemplo n.º 1
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) == 'TPU':
        slots_per_device = 4
        size = slots_per_device * xm.xrt_world_size()
        ordinal = xm.get_ordinal()
        value = torch.tensor([ordinal] * size,
                             dtype=torch.int32,
                             device=device)
        result_tensor = xm.all_to_all(value,
                                      split_dimension=0,
                                      concat_dimension=0,
                                      split_count=xm.xrt_world_size())

        result = result_tensor.cpu().tolist()
        for i in range(0, xm.xrt_world_size()):
            expected = [i] * slots_per_device
            if expected != result[i * slots_per_device:(i + 1) *
                                  slots_per_device]:
                print('Wrong result from core {}: {}'.format(i, result),
                      file=sys.stderr)
                sys.exit(1)
    else:
        print('Default device {} is not a TPU device'.format(device),
              file=sys.stderr)
Ejemplo n.º 2
0
 def all_to_all(self, collectiveArgs, retFlag=False):
     retObj = xm.all_to_all(collectiveArgs.ipTensor, 0, 0, collectiveArgs.world_size)
     collectiveArgs.opTensor = retObj
     if collectiveArgs.asyncOp:
         collectiveArgs.waitObj.append(retObj)
     if retFlag:
         return retObj
Ejemplo n.º 3
0
def broadcast_tensor(tensor, src=0):
    world_size = get_world_size()
    if world_size < 2:
        return tensor

    with torch.no_grad():
        if is_xla():
            tensor = xm.all_to_all(
                tensor.repeat([world_size, 1]),
                split_dimension=0,
                concat_dimension=0,
                split_count=world_size,
            )[0]
        else:
            dist.broadcast(tensor, src=0)

    return tensor
Ejemplo n.º 4
0
def all_to_all(tensor, group):
    """Perform an all-to-all operation on a 1D Tensor."""
    assert tensor.dim() == 1
    split_count = get_world_size(group=group)
    assert tensor.numel() % split_count == 0
    if use_xla():
        assert isinstance(group, tuple) and group[0] == "tpu"
        return xm.all_to_all(
            tensor,
            split_dimension=0,
            concat_dimension=0,
            split_count=split_count,
            groups=group[1],
        )
    else:
        output = torch.zeros_like(tensor)
        dist.all_to_all_single(output, tensor, group=group)
        return output
Ejemplo n.º 5
0
 def all_to_all(self, collectiveArgs, retFlag=False):
     retObj = xm.all_to_all(collectiveArgs.ipTensor, 0, 0,
                            collectiveArgs.world_size)
     collectiveArgs.opTensor = retObj
     if retFlag:
         return retObj