def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile error_msg = r"torch function 'linear', with args: (?s).* " r"and kwargs: None not supported for ShardedTensor!" with self.assertRaisesRegex(RuntimeError, error_msg): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim, dtype): # Use same seed. torch.manual_seed(0) local_linear = torch.nn.Linear(*linear_size, dtype=dtype).cuda(self.rank) sharded_linear = torch.nn.Linear(*linear_size, dtype=dtype) # Copy the weights and bias from local linear sharded_linear.weight = clone_module_parameter(local_linear, "weight") sharded_linear.bias = clone_module_parameter(local_linear, "bias") # Shard the parameter. shard_parameter(sharded_linear, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size, dtype=dtype).cuda(self.rank) reshard_spec = copy.deepcopy(spec) reshard_spec.dim = 0 reshard_spec.placements.sort(key=lambda placement: placement.rank()) sharded_linear = _collect_local_shard( _reshard_output(sharded_linear, reshard_spec)) sharded_output = sharded_linear(inp) # Run local computation local_output = local_linear(inp) # Verify self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) # Validate for torch.nn.functional.linear version. local_output = torch.nn.functional.linear(inp, local_linear.weight, local_linear.bias) sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight, sharded_linear.bias) sharded_output = sharded_output.reshard(reshard_spec).local_tensor() # When local tensor only has one dimension, we increase one more dimension # for reshard. We need to squeeze the # of dimensions manually. if inp.dim() == 1: sharded_output = sharded_output.squeeze(reshard_spec.dim) self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() local_grad = local_linear.weight.grad # Verify that both weight and bias in the sharded linear has non-None grad. sharded_weight = sharded_linear.weight.local_tensor() self.assertNotEqual(sharded_linear.bias.grad, None) self.assertNotEqual(sharded_weight.grad, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_grad) (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size) local_bias_grad = local_linear.bias.grad dist.all_reduce(local_bias_grad) # Test backward gradient calculation. self.assertEqual(sharded_linear.bias.grad, local_bias_grad, atol=1e-3, rtol=1e-3) self.assertEqual(sharded_weight.grad, local_grad_narrowed, atol=1e-3, rtol=1e-3) # Test optimizer. previous = local_linear.weight.clone().detach() optim = torch.optim.SGD(local_linear.parameters(), lr=0.1) optim.step() self.assertNotEqual(previous, local_linear.weight) previous_sharded_weight = sharded_weight.clone() previous_sharded_bias = sharded_linear.bias.clone() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(sharded_linear)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() sharded_weight = sharded_linear.weight.local_tensor() local_weight_narrowed = local_linear.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(sharded_weight.size(), local_weight_narrowed.size()) self.assertNotEqual(previous_sharded_weight, sharded_weight) self.assertEqual(sharded_weight, local_weight_narrowed, atol=1e-3, rtol=1e-3) self.assertNotEqual(previous_sharded_bias, sharded_linear.bias) self.assertEqual(sharded_linear.bias, local_linear.bias, atol=1e-3, rtol=1e-3)
def _shard_parameter(module, spec): shard_parameter(module.fc1, "weight", spec[0]) shard_parameter(module.fc2, "weight", spec[1])