def test_sharded_embedding_colwise(self): for spec in generate_chunk_sharding_specs_for_test(1): self._run_sharded_embedding(spec, [5, 4], 17, 12) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13) self._run_sharded_embedding(spec, [8, 6, 5, 4, 7], 23, 16) self._run_sharded_embedding(spec, [4], 15, 14) self._run_sharded_embedding(spec, [34], 15, 14, padding_idx=10) self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13, padding_idx=12) self._run_sharded_embedding(spec, [4, 5, 6], 23, 13, max_norm=2.5, sharded_dim=1) self._run_sharded_embedding(spec, [12, 7, 16], 23, 13, max_norm=2.5, sharded_dim=1) self._run_sharded_embedding(spec, [8, 16, 20], 12, 12, max_norm=1.25, norm_type=1.0, sharded_dim=1) self._run_sharded_embedding(spec, [30], 15, 14, max_norm=2.0, sharded_dim=1)
def test_sharded_embedding_colwise(self): for spec in generate_chunk_sharding_specs_for_test(1): self._run_sharded_embedding(spec, [5, 4], 17, 12) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13) self._run_sharded_embedding(spec, [8, 6, 5, 4, 7], 23, 16) self._run_sharded_embedding(spec, [4], 15, 14)
def test_sharded_linear_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(1): # Test even split. self._run_sharded_linear(spec, [8, 16], [16, 11], 1) # Test uneven split. self._run_sharded_linear(spec, [5, 19], [19, 11], 1) self._run_sharded_linear(spec, [10, 21], [21, 11], 1)
def test_sharded_embedding_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): # Test even split. self._run_sharded_embedding(spec, [5, 12], 16, 22) self._run_sharded_embedding(spec, [5, 4], 32, 12) self._run_sharded_embedding(spec, [6, 7, 6], 64, 11) # Test uneven split. self._run_sharded_embedding(spec, [8, 6, 5, 4], 19, 11) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [4], 21, 11)
def test_sharded_linear_colwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._run_sharded_linear(spec, [2, 17], [17, 12], 0) self._run_sharded_linear(spec, [8, 21], [21, 11], 0) self._run_sharded_linear(spec, [7, 23], [23, 13], 0) self._run_sharded_linear(spec, [4, 15], [15, 14], 0) # Test multiple input dims self._run_sharded_linear(spec, [10, 2, 17], [17, 12], 0) self._run_sharded_linear(spec, [13, 8, 21], [21, 11], 0) self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0) self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0)
def test_sharded_linear_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(1): # Test even split. self._run_sharded_linear(spec, [8, 16], [16, 11], 1) # Test uneven split. self._run_sharded_linear(spec, [5, 19], [19, 11], 1) self._run_sharded_linear(spec, [10, 21], [21, 11], 1) # Test multiple input dims self._run_sharded_linear(spec, [13, 8, 16], [16, 11], 1) self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1) self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1) # Test single input dim self._run_sharded_linear(spec, [16], [16, 11], 1) self._run_sharded_linear(spec, [19], [19, 11], 1) self._run_sharded_linear(spec, [21], [21, 11], 1)
def test_sharded_embedding_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): # Test even split. self._run_sharded_embedding(spec, [5, 12], 16, 22) self._run_sharded_embedding(spec, [5, 4], 32, 12) self._run_sharded_embedding(spec, [6, 7, 6], 64, 11) self._run_sharded_embedding(spec, [5, 12], 16, 22, max_norm=2.5, sharded_dim=0) self._run_sharded_embedding(spec, [6, 7, 6], 64, 11, padding_idx=30) self._run_sharded_embedding(spec, [6, 5, 3], 26, 11, max_norm=2.0, sharded_dim=0) # Test uneven split. self._run_sharded_embedding(spec, [8, 6, 5, 4], 19, 11) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [4], 21, 11) self._run_sharded_embedding(spec, [8, 6, 5, 4], 21, 11, padding_idx=10) self._run_sharded_embedding(spec, [12, 16, 8], 27, 11, max_norm=2.0, sharded_dim=0) self._run_sharded_embedding(spec, [4], 14, 11, max_norm=2.5, sharded_dim=0)
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex( TypeError, 'input and bias need to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) with self.assertRaisesRegex( ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def test_sharded_linear_colwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._run_sharded_linear(spec, [2, 17], [17, 12], 0) self._run_sharded_linear(spec, [8, 21], [21, 11], 0) self._run_sharded_linear(spec, [7, 23], [23, 13], 0) self._run_sharded_linear(spec, [4, 15], [15, 14], 0)
def test_sharded_embedding_bag_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._test_sharded_embedding_bag_with_test_cases(spec)