def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim): # Use same seed. torch.manual_seed(0) local_linear = torch.nn.Linear(*linear_size).cuda(self.rank) sharded_linear = torch.nn.Linear(*linear_size) # Copy the weights and bias from local linear sharded_linear.weight = local_linear.weight sharded_linear.bias = local_linear.bias # Shard the parameter. shard_parameter(sharded_linear, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size).cuda(self.rank) sharded_output = sharded_linear(inp) # Run local computation local_output = local_linear(inp) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.linear version. local_output = torch.nn.functional.linear( inp, local_linear.weight, local_linear.bias ) sharded_output = torch.nn.functional.linear( inp, sharded_linear.weight, sharded_linear.bias ) self.assertEqual(local_output, sharded_output)
def _run_sharded_embedding(self, spec, input_size, num_embeddings, embedding_dim): # Use same seed. torch.manual_seed(0) local_embedding = torch.nn.Embedding(num_embeddings, embedding_dim).cuda(self.rank) sharded_embedding = torch.nn.Embedding(num_embeddings, embedding_dim) # Copy the weights from local embedding sharded_embedding.weight = local_embedding.weight # Shard the parameter. shard_parameter(sharded_embedding, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.randint(num_embeddings, tuple(input_size)).cuda(self.rank) sharded_output = sharded_embedding(inp) # Run local computation local_output = local_embedding(inp) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.embedding version. local_output = torch.nn.functional.embedding(inp, local_embedding.weight) sharded_output = torch.nn.functional.embedding( inp, sharded_embedding.weight) self.assertEqual(local_output, sharded_output)
def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim): # Use same seed. torch.manual_seed(0) local_linear = torch.nn.Linear(*linear_size).cuda(self.rank) sharded_linear = torch.nn.Linear(*linear_size) # Copy the weights and bias from local linear sharded_linear.weight = torch.nn.Parameter( local_linear.weight.detach().clone()) sharded_linear.bias = torch.nn.Parameter( local_linear.bias.detach().clone()) # Shard the parameter. shard_parameter(sharded_linear, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size).cuda(self.rank) sharded_output = sharded_linear(inp) # Run local computation local_output = local_linear(inp) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.linear version. local_output = torch.nn.functional.linear(inp, local_linear.weight, local_linear.bias) sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight, sharded_linear.bias) self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() local_grad = local_linear.weight.grad # Verify that both weight and bias in the sharded linear has non-None grad. sharded_weight = sharded_linear.weight.local_shards()[0].tensor self.assertNotEqual(sharded_linear.bias.grad, None) self.assertNotEqual(sharded_weight.grad, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_grad) (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size) local_weight_narrowed = local_linear.weight.narrow( sharded_dim, start_pos, chunk_size) # Test backward gradient calculation. self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad) self.assertEqual(sharded_weight.grad, local_grad_narrowed)
def shard_parameter(self): rowwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) colwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec) sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec)
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex( TypeError, 'input and bias need to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) with self.assertRaisesRegex( ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def _run_sharded_embedding( self, spec, input_size, num_embeddings, embedding_dim, sharded_dim=None, max_norm=None, norm_type=2.0, padding_idx=None, ): # Use same seed. torch.manual_seed(0) local_embedding = torch.nn.Embedding( num_embeddings, embedding_dim, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ).cuda(self.rank) sharded_embedding = torch.nn.Embedding( num_embeddings, embedding_dim, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) # Copy the weights from local embedding sharded_embedding.weight = torch.nn.Parameter( local_embedding.weight.detach().clone()) # Shard the parameter. shard_parameter(sharded_embedding, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) sharded_output = sharded_embedding(inp) # If max_norm is set, we need to ensure that the renorm has been applied across # inputs from all ranks. if max_norm is not None: gathered_inputs = [ torch.zeros_like(inp) for _ in range(TEST_GPU_NUM) ] dist.all_gather(gathered_inputs, inp) unique_inp = torch.unique(torch.cat(gathered_inputs)) local_embedding(unique_inp) # Run local computation local_output = local_embedding(inp) # Compare local weight and shared one to ensure the renorm # as expected. if max_norm is not None: sharded_weight = sharded_embedding.weight.local_shards()[0].tensor (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_embedding.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_weight_narrowed = local_embedding.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(local_weight_narrowed, sharded_weight) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.embedding version. local_output = torch.nn.functional.embedding( inp, local_embedding.weight, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) sharded_output = torch.nn.functional.embedding( inp, sharded_embedding.weight, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) self.assertEqual(local_output, sharded_output)
def _run_sharded_embedding_bag(self, spec, input_size, num_embeddings, embedding_dim, mode, offset_size=None): # Use same seed. torch.manual_seed(0) local_embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, mode=mode).cuda(self.rank) sharded_embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, mode=mode) # Copy the weights from local embedding bag. sharded_embedding_bag.weight = local_embedding_bag.weight # Shard the parameter. shard_parameter(sharded_embedding_bag, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) per_sample_weights = None if mode == "sum": per_sample_weights = torch.rand(*input_size).cuda(self.rank) offsets = None if len(input_size) == 1: # We need to generate certain length offset for each rank. # The current implementation and dist API does not support # the case when the offset has different lengths. # input_size[0] >> offset_size, so the while loop will not # for too long. while offsets is None or (offsets.size(0) != offset_size): offsets = torch.randint(input_size[0], (offset_size, )) offsets[0] = 0 offsets = (torch.unique( offsets, sorted=True).contiguous().cuda(self.rank)) sharded_output = sharded_embedding_bag( inp, offsets=offsets, per_sample_weights=per_sample_weights) # Run local computation local_output = local_embedding_bag( inp, offsets=offsets, per_sample_weights=per_sample_weights) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.embedding_bag version. local_output = torch.nn.functional.embedding_bag( inp, local_embedding_bag.weight, offsets=offsets, mode=mode, per_sample_weights=per_sample_weights, ) sharded_output = torch.nn.functional.embedding_bag( inp, sharded_embedding_bag.weight, offsets=offsets, mode=mode, per_sample_weights=per_sample_weights, ) self.assertEqual(local_output, sharded_output)