def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() dist.init_process_group('xla', world_size=world_size, rank=rank) ones = torch.ones((2, 3)) xones = ones.to(device) dist.all_reduce(xones) expected = torch.ones((2, 3)) * world_size assert torch.all(xones.cpu() == expected), f'{xones} != {expected}' else: print('Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) != 'CPU': ones = torch.ones((2, 3)) twos = ones + 1.0 xones = ones.to(device) xtwos = twos.to(device) xm.all_reduce(xm.REDUCE_SUM, [xones, xtwos]) if (not xones.cpu().allclose(ones * float(xm.xrt_world_size())) or not xtwos.cpu().allclose(twos * float(xm.xrt_world_size()))): print('CrossReplicaSum produced wrong reductions', file=sys.stderr) print(xones, file=sys.stderr) sys.exit(1) else: print('Default device {} does not support replication'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() dist.init_process_group('xla', world_size=world_size, rank=rank) input_size = (32, 3) inputs = torch.ones(input_size).split(input_size[0] // world_size) output = torch.zeros_like(inputs[0]) xinputs = [i.to(device) for i in inputs] xoutput = output.to(device) dist.reduce_scatter(xoutput, xinputs) expected = torch.ones_like(inputs[0]) * world_size assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}' else: print('Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() dist.init_process_group('xla', world_size=world_size, rank=rank) input = torch.ones((2, 3)) * rank outputs = [torch.zeros_like(input)] * world_size xinput = input.to(device) xoutputs = [o.to(device) for o in outputs] xoutput0 = xoutputs[0] dist.all_gather(xoutputs, xinput) for i, o in enumerate(xoutputs): expected = torch.ones((2, 3)) * i assert torch.all(o.cpu() == expected), f'{o} != {expected}' expected0 = torch.zeros_like(input) assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}' else: print( 'Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) == 'TPU': world_size = xm.xrt_world_size() ordinal = xm.get_ordinal() value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device) pairs = [] for i in range(1, world_size): pairs.append([i - 1, i]) pairs.append([world_size - 1, 0]) result_tensor = xm.collective_permute(value, pairs) result = result_tensor.cpu().tolist() expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100 if result != expected: print('Wrong result from core {}: {}'.format(ordinal, result), file=sys.stderr) sys.exit(1) else: print('Default device {} is not a TPU device'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) in ('TPU', 'GPU'): world_size = xm.xrt_world_size() rank = xm.get_ordinal() dist.init_process_group('xla', world_size=world_size, rank=rank) num_all_reduces = 20 xinputs_list = [] for i in range(num_all_reduces): inputs = torch.ones((2, 3)) * i xinputs = inputs.to(device) xinputs_list.append(xinputs) dist.all_reduce(xinputs) for i in range(num_all_reduces): expected = torch.ones((2, 3)) * i * world_size xinputs = xinputs_list[i] assert torch.all(xinputs.cpu() == expected), f'trial {i}, {xinputs} != {expected}' else: print('Default device {} is not a TPU or GPU device'.format(device), file=sys.stderr)
def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() if xm.xla_device_hw(device) in ('TPU', 'GPU'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) # Testing with two replica groups if world_size % 2 == 0 and world_size > 1: mp_groups = [[n for n in range(world_size) if n % 2 == 0], [n for n in range(world_size) if n % 2 == 1]] group_size = len(mp_groups[0]) replica_id = int(index % 2 == 1) result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups) cpu_result = result.cpu() expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) else: print( f'Failed to create two replica groups with {world_size} replicas', file=sys.stderr) else: print(f'{device} is not a TPU or GPU device', file=sys.stderr)
def instantiate_test(cls, name, test): test_name = name + '_' + cls.device_type class_name = cls.__name__ real_device_type = xm.xla_device_hw(str(xm.xla_device())) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @wraps(test) def disallowed_test(self, test=test): raise unittest.SkipTest('skipped on XLA') return test(self, cls.device_type) if (match_name(test_name, disabled_torch_tests[class_name]) or match_name(name, disabled_torch_tests[class_name])): assert not hasattr( cls, test_name), 'Redefinition of test {0}'.format(test_name) setattr(cls, test_name, disallowed_test) else: # Test is allowed dtype_combinations = cls._get_dtypes(test) if dtype_combinations is None: # Tests without dtype variants are instantiated as usual super().instantiate_test(name, test) else: # Tests with dtype variants have unsupported dtypes skipped # Sets default precision for floating types to bfloat16 precision if not hasattr(test, 'precision_overrides'): test.precision_overrides = {} xla_dtypes = [] for dtype_combination in dtype_combinations: if type(dtype_combination) == torch.dtype: dtype_combination = (dtype_combination,) dtype_test_name = test_name skipped = False for dtype in dtype_combination: dtype_test_name += '_' + str(dtype).split('.')[1] for dtype in dtype_combination: if dtype in cls.unsupported_dtypes: reason = 'XLA does not support dtype {0}'.format(str(dtype)) @wraps(test) def skipped_test(self, *args, reason=reason, **kwargs): raise unittest.SkipTest(reason) assert not hasattr( cls, dtype_test_name), 'Redefinition of test {0}'.format( dtype_test_name) setattr(cls, dtype_test_name, skipped_test) skipped = True break if dtype in [torch.float, torch.double, torch.bfloat16]: floating_precision = XLATestBase._alt_lookup( TORCH_TEST_PRECIIONS, [dtype_test_name, test_name, test.__name__], DEFAULT_FLOATING_PRECISION) if dtype not in test.precision_overrides or test.precision_overrides[ dtype] < floating_precision: test.precision_overrides[dtype] = floating_precision if match_name(dtype_test_name, disabled_torch_tests[class_name]): skipped = True setattr(cls, dtype_test_name, disallowed_test) if not skipped: xla_dtypes.append( dtype_combination if len(dtype_combination) > 1 else dtype_combination[0]) if len(xla_dtypes) != 0: test.dtypes[cls.device_type] = xla_dtypes super().instantiate_test(name, test)
def test_xla_device_is_a_tpu(): """Check that the XLA device is a TPU""" device = xm.xla_device() device_type = xm.xla_device_hw(device) return device_type == "TPU"