コード例 #1
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) in ('TPU', 'GPU'):
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()

        dist.init_process_group('xla', world_size=world_size, rank=rank)

        ones = torch.ones((2, 3))
        xones = ones.to(device)
        dist.all_reduce(xones)
        expected = torch.ones((2, 3)) * world_size
        assert torch.all(xones.cpu() == expected), f'{xones} != {expected}'
    else:
        print('Default device {} is not a TPU or GPU device'.format(device),
              file=sys.stderr)
コード例 #2
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) != 'CPU':
        ones = torch.ones((2, 3))
        twos = ones + 1.0
        xones = ones.to(device)
        xtwos = twos.to(device)
        xm.all_reduce(xm.REDUCE_SUM, [xones, xtwos])

        if (not xones.cpu().allclose(ones * float(xm.xrt_world_size())) or
                not xtwos.cpu().allclose(twos * float(xm.xrt_world_size()))):
            print('CrossReplicaSum produced wrong reductions', file=sys.stderr)
            print(xones, file=sys.stderr)
            sys.exit(1)
    else:
        print('Default device {} does not support replication'.format(device),
              file=sys.stderr)
コード例 #3
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) in ('TPU', 'GPU'):
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()

        dist.init_process_group('xla', world_size=world_size, rank=rank)

        input_size = (32, 3)
        inputs = torch.ones(input_size).split(input_size[0] // world_size)
        output = torch.zeros_like(inputs[0])
        xinputs = [i.to(device) for i in inputs]
        xoutput = output.to(device)
        dist.reduce_scatter(xoutput, xinputs)
        expected = torch.ones_like(inputs[0]) * world_size
        assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}'
    else:
        print('Default device {} is not a TPU or GPU device'.format(device),
              file=sys.stderr)
コード例 #4
0
def _mp_fn(index):
  device = xm.xla_device()
  if xm.xla_device_hw(device) in ('TPU', 'GPU'):
    world_size = xm.xrt_world_size()
    rank = xm.get_ordinal()

    dist.init_process_group('xla', world_size=world_size, rank=rank)

    input = torch.ones((2, 3)) * rank
    outputs = [torch.zeros_like(input)] * world_size
    xinput = input.to(device)
    xoutputs = [o.to(device) for o in outputs]
    xoutput0 = xoutputs[0]
    dist.all_gather(xoutputs, xinput)
    for i, o in enumerate(xoutputs):
      expected = torch.ones((2, 3)) * i
      assert torch.all(o.cpu() == expected), f'{o} != {expected}'
    expected0 = torch.zeros_like(input)
    assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}'
  else:
    print(
        'Default device {} is not a TPU or GPU device'.format(device),
        file=sys.stderr)
コード例 #5
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) == 'TPU':
        world_size = xm.xrt_world_size()
        ordinal = xm.get_ordinal()
        value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device)
        pairs = []
        for i in range(1, world_size):
            pairs.append([i - 1, i])
        pairs.append([world_size - 1, 0])
        result_tensor = xm.collective_permute(value, pairs)

        result = result_tensor.cpu().tolist()
        expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size -
                                                             1] * 100

        if result != expected:
            print('Wrong result from core {}: {}'.format(ordinal, result),
                  file=sys.stderr)
            sys.exit(1)
    else:
        print('Default device {} is not a TPU device'.format(device),
              file=sys.stderr)
コード例 #6
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) in ('TPU', 'GPU'):
        world_size = xm.xrt_world_size()
        rank = xm.get_ordinal()

        dist.init_process_group('xla', world_size=world_size, rank=rank)

        num_all_reduces = 20
        xinputs_list = []
        for i in range(num_all_reduces):
            inputs = torch.ones((2, 3)) * i
            xinputs = inputs.to(device)
            xinputs_list.append(xinputs)
            dist.all_reduce(xinputs)
        for i in range(num_all_reduces):
            expected = torch.ones((2, 3)) * i * world_size
            xinputs = xinputs_list[i]
            assert torch.all(xinputs.cpu() ==
                             expected), f'trial {i}, {xinputs} != {expected}'
    else:
        print('Default device {} is not a TPU or GPU device'.format(device),
              file=sys.stderr)
コード例 #7
0
ファイル: test_mp_all_gather.py プロジェクト: pytorch/xla
def _mp_fn(index):
  device = xm.xla_device()
  world_size = xm.xrt_world_size()
  if xm.xla_device_hw(device) in ('TPU', 'GPU'):
    # Testing with a single replica group
    ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
    result = xm.all_gather(ordinal_tensor, dim=0)

    cpu_result = result.cpu()
    expected = torch.arange(0, world_size, dtype=torch.float)
    if not cpu_result.allclose(expected):
      print('xm.all_gather() produced wrong reductions', file=sys.stderr)
      print(f'[{index}] {cpu_result}', file=sys.stderr)
      sys.exit(1)

    # Testing with two replica groups
    if world_size % 2 == 0 and world_size > 1:
      mp_groups = [[n for n in range(world_size) if n % 2 == 0],
                   [n for n in range(world_size) if n % 2 == 1]]
      group_size = len(mp_groups[0])
      replica_id = int(index % 2 == 1)

      result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups)

      cpu_result = result.cpu()
      expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float)
      if not cpu_result.allclose(expected):
        print('xm.all_gather() produced wrong reductions', file=sys.stderr)
        print(f'[{index}] {cpu_result}', file=sys.stderr)
        sys.exit(1)
    else:
      print(
          f'Failed to create two replica groups with {world_size} replicas',
          file=sys.stderr)

  else:
    print(f'{device} is not a TPU or GPU device', file=sys.stderr)
コード例 #8
0
  def instantiate_test(cls, name, test):
    test_name = name + '_' + cls.device_type
    class_name = cls.__name__
    real_device_type = xm.xla_device_hw(str(xm.xla_device()))
    assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type
    disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type]

    @wraps(test)
    def disallowed_test(self, test=test):
      raise unittest.SkipTest('skipped on XLA')
      return test(self, cls.device_type)

    if (match_name(test_name, disabled_torch_tests[class_name]) or
        match_name(name, disabled_torch_tests[class_name])):
      assert not hasattr(
          cls, test_name), 'Redefinition of test {0}'.format(test_name)
      setattr(cls, test_name, disallowed_test)
    else:  # Test is allowed
      dtype_combinations = cls._get_dtypes(test)
      if dtype_combinations is None:  # Tests without dtype variants are instantiated as usual
        super().instantiate_test(name, test)
      else:  # Tests with dtype variants have unsupported dtypes skipped
        # Sets default precision for floating types to bfloat16 precision
        if not hasattr(test, 'precision_overrides'):
          test.precision_overrides = {}
        xla_dtypes = []
        for dtype_combination in dtype_combinations:
          if type(dtype_combination) == torch.dtype:
            dtype_combination = (dtype_combination,)
          dtype_test_name = test_name
          skipped = False
          for dtype in dtype_combination:
            dtype_test_name += '_' + str(dtype).split('.')[1]
          for dtype in dtype_combination:
            if dtype in cls.unsupported_dtypes:
              reason = 'XLA does not support dtype {0}'.format(str(dtype))

              @wraps(test)
              def skipped_test(self, *args, reason=reason, **kwargs):
                raise unittest.SkipTest(reason)

              assert not hasattr(
                  cls, dtype_test_name), 'Redefinition of test {0}'.format(
                      dtype_test_name)
              setattr(cls, dtype_test_name, skipped_test)
              skipped = True
              break
            if dtype in [torch.float, torch.double, torch.bfloat16]:
              floating_precision = XLATestBase._alt_lookup(
                  TORCH_TEST_PRECIIONS,
                  [dtype_test_name, test_name, test.__name__],
                  DEFAULT_FLOATING_PRECISION)
              if dtype not in test.precision_overrides or test.precision_overrides[
                  dtype] < floating_precision:
                test.precision_overrides[dtype] = floating_precision

          if match_name(dtype_test_name, disabled_torch_tests[class_name]):
            skipped = True
            setattr(cls, dtype_test_name, disallowed_test)
          if not skipped:
            xla_dtypes.append(
                dtype_combination
                if len(dtype_combination) > 1 else dtype_combination[0])
        if len(xla_dtypes) != 0:
          test.dtypes[cls.device_type] = xla_dtypes
          super().instantiate_test(name, test)
コード例 #9
0
def test_xla_device_is_a_tpu():
    """Check that the XLA device is a TPU"""
    device = xm.xla_device()
    device_type = xm.xla_device_hw(device)
    return device_type == "TPU"