def test_sharded_dropout(self): def _reset_random_seed(): torch.manual_seed(self.rank + 4) specs = generate_chunk_sharding_specs_for_test( 0 ) + generate_chunk_sharding_specs_for_test(1) for spec in specs: self._run_sharded_elementwise_ops( spec, [12, 17], torch.nn.functional.dropout, p=0.4, reset_seed=_reset_random_seed, ) self._run_sharded_elementwise_ops( spec, [18, 21], torch.nn.functional.dropout, p=0.5, reset_seed=_reset_random_seed, ) _reset_random_seed() dropout = torch.nn.Dropout(p=0.8) self._run_sharded_elementwise_ops( spec, [17, 23], dropout, reset_seed=_reset_random_seed )
def test_megatron_two_layer_prototype(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]]) self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]]) self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]]) self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]])
def test_sharded_relu(self): specs = generate_chunk_sharding_specs_for_test( 0 ) + generate_chunk_sharding_specs_for_test(1) for spec in specs: self._run_sharded_elementwise_ops(spec, [12, 17], torch.nn.functional.relu) self._run_sharded_elementwise_ops(spec, [18, 21], torch.nn.functional.relu) self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu) self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
def test_reshard_to_ddp_sharding_plan(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)[0] rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] # test each sharding spec pair and see if we can apply sharding output_spec = copy.deepcopy(rowwise_sharding_spec) output_spec.placements.sort(key=lambda placement: placement.rank()) output_spec.dim = 0 # new module with megatron as submodule class MyModule(nn.Module): def __init__(self, rank=None): super().__init__() self.megatron = SimpleMegatronLM([[17, 12], [12, 29]], rank=rank) self.relu = nn.ReLU() def forward(self, input): return self.relu(self.megatron(input)) sharding_plan = ShardingPlan(plan={ "megatron.fc1.weight": colwise_sharding_spec, "megatron.fc2.weight": rowwise_sharding_spec, }, output_plan={"megatron": output_spec}, return_local_tensor=["megatron"]) # Use same seed. torch.manual_seed(0) local_module = MyModule().cuda(self.rank) sharded_module = copy.deepcopy(local_module) # shard the module with the provided sharding plan shard_module(sharded_module, sharding_plan) # check to make sure the module already been sharded self.assertTrue( isinstance(sharded_module.megatron.fc1.weight, ShardedTensor)) self.assertTrue( isinstance(sharded_module.megatron.fc2.weight, ShardedTensor)) self.assertEqual(sharded_module.megatron.fc1.weight.sharding_spec(), colwise_sharding_spec) self.assertEqual(sharded_module.megatron.fc2.weight.sharding_spec(), rowwise_sharding_spec) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = sharded_module(input) local_output = local_module(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output)
def test_sharded_embedding_colwise(self): for spec in generate_chunk_sharding_specs_for_test(1): self._run_sharded_embedding(spec, [5, 4], 17, 12) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13) self._run_sharded_embedding(spec, [8, 6, 5, 4, 7], 23, 16) self._run_sharded_embedding(spec, [4], 15, 14) self._run_sharded_embedding(spec, [34], 15, 14, padding_idx=10) self._run_sharded_embedding(spec, [8, 6, 5, 4], 23, 13, padding_idx=12) self._run_sharded_embedding( spec, [4, 5, 6], 23, 13, max_norm=2.5, ) self._run_sharded_embedding( spec, [12, 7, 16], 23, 13, max_norm=2.5, ) self._run_sharded_embedding( spec, [8, 16, 20], 12, 12, max_norm=1.25, norm_type=1.0, ) self._run_sharded_embedding(spec, [30], 15, 14, max_norm=2.0)
def test_sharded_bmm_errors(self): specs = generate_chunk_sharding_specs_for_test(0) st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6)) st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, 'Both st and st2 need to have same placements for bmm', ): torch.bmm(st_lhs, st_rhs) for spec in specs: st_lhs = sharded_tensor.rand(spec, (20, 3)) st_rhs = sharded_tensor.rand(spec, (20, 3)) with self.assertRaisesRegex( TypeError, 'both st and st2 need to be a 3D ShardedTensor', ): torch.bmm(st_lhs, st_rhs) rhs = torch.rand(15, 5, 6).cuda(self.rank) with self.assertRaisesRegex( TypeError, 'st2 needs to be a ShardedTensor for torch.bmm', ): torch.bmm(st_lhs, rhs) spec.dim = 1 st_lhs = sharded_tensor.rand(spec, (15, 5, 6)) st_rhs = sharded_tensor.rand(spec, (15, 5, 6)) with self.assertRaisesRegex( NotImplementedError, 'Only support performing bmm on tensors sharded on dim 0 now', ): torch.bmm(st_lhs, st_rhs)
def test_sharded_bmm(self): for spec in generate_chunk_sharding_specs_for_test(0): lhs = torch.rand(15, 4, 5).cuda(self.rank) rhs = torch.rand(15, 5, 6).cuda(self.rank) tensor = lhs.bmm(rhs) st_lhs = _shard_tensor(lhs, spec) st_rhs = _shard_tensor(rhs, spec) st_expected = _shard_tensor(tensor, spec) self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected)) self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
def test_sharding_plan_errors(self): rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] sharding_plan_wrong_plan = ShardingPlan( plan={ "fc1.weight": torch.randn(3, 4), }, output_plan={ "": rowwise_sharding_spec }, ) megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported to shard" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_plan) sharding_plan_wrong_output_plan = ShardingPlan( plan={ "fc1.weight": rowwise_sharding_spec, }, output_plan={ "": torch.randn(3, 4) }, ) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported as output_plan" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_output_plan) sharding_plan_wrong_module_path = ShardingPlan( plan={ "fc3.weight": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_module_path) sharding_plan_wrong_param_path = ShardingPlan( plan={ "fc1.biass": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_param_path)
def test_megatron_two_layer_prototype(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): self._run_megatron_linear(spec, [22, 17], [[17, 12], [12, 29]], torch.float16) self._run_megatron_linear(spec, [28, 21], [[21, 11], [11, 29]], torch.float32) self._run_megatron_linear(spec, [37, 23], [[23, 13], [13, 24]], torch.float64) self._run_megatron_linear(spec, [24, 15], [[15, 14], [14, 20]], torch.float16) # Test multiple input dims self._run_megatron_linear(spec, [10, 22, 17], [[17, 12], [12, 29]], torch.float32) self._run_megatron_linear(spec, [13, 28, 21], [[21, 11], [11, 29]], torch.float16) self._run_megatron_linear(spec, [27, 37, 23], [[23, 13], [13, 24]], torch.float32) self._run_megatron_linear(spec, [100, 24, 15], [[15, 14], [14, 20]], torch.float64) # Test single input dim self._run_megatron_linear(spec, [17], [[17, 12], [12, 29]], torch.float16) self._run_megatron_linear(spec, [21], [[21, 11], [11, 29]], torch.float32) self._run_megatron_linear(spec, [23], [[23, 13], [13, 24]], torch.float64) self._run_megatron_linear(spec, [15], [[15, 14], [14, 20]], torch.float16)
def test_sharded_chunk_error(self): chunk_spec = generate_chunk_sharding_specs_for_test(-1) with self.assertRaisesRegex(NotImplementedError, "Chunk by sharding dim is not supported."): st = sharded_tensor.rand(chunk_spec[0], [17, 24]) torch.chunk(st, 5, dim=-1) enumerable_spec = generate_enumerable_sharding_specs_for_test() with self.assertRaisesRegex( NotImplementedError, "Only ChunkShardingSpec is supported for chunk."): st = sharded_tensor.rand(enumerable_spec[0], [10, 10]) torch.chunk(st, 5, dim=-1)
def test_sharded_chunk(self): sharding_dims = [0] specs = [] for dim in sharding_dims: specs.extend(generate_chunk_sharding_specs_for_test(dim)) for spec in specs: self._run_sharded_chunk_test([17, 14], spec, 3) self._run_sharded_chunk_test([17, 15, 20], spec, 5) self._run_sharded_chunk_test([17, 16], spec, 2) # Large matrix case. self._run_sharded_chunk_test([128, 512], spec, 8) self._run_sharded_chunk_test([1024, 2048], spec, 4)
def test_sharded_bmm(self): for spec in generate_chunk_sharding_specs_for_test(0): lhs = torch.rand(15, 4, 5).cuda(self.rank) rhs = torch.rand(15, 5, 6).cuda(self.rank) tensor = lhs.bmm(rhs) st_lhs = _shard_tensor(lhs, spec) st_rhs = _shard_tensor(rhs, spec) st_expected = _shard_tensor(tensor, spec) st_expected._metadata.shards_metadata.sort( key=lambda x: x.shard_offsets[0], ) self.assertTrue( torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected)) self.assertTrue(torch.allclose(st_lhs.bmm(st_rhs), st_expected))
def test_sharded_linear_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(1): # Test even split. self._run_sharded_linear(spec, [8, 16], [16, 11], 1) # Test uneven split. self._run_sharded_linear(spec, [5, 19], [19, 11], 1) self._run_sharded_linear(spec, [10, 21], [21, 11], 1) # Test multiple input dims self._run_sharded_linear(spec, [13, 8, 16], [16, 11], 1) self._run_sharded_linear(spec, [10, 5, 19], [19, 11], 1) self._run_sharded_linear(spec, [12, 15, 10, 21], [21, 11], 1) # Test single input dim self._run_sharded_linear(spec, [16], [16, 11], 1) self._run_sharded_linear(spec, [19], [19, 11], 1) self._run_sharded_linear(spec, [21], [21, 11], 1)
def test_sharded_linear_colwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._run_sharded_linear(spec, [2, 17], [17, 12], 0) self._run_sharded_linear(spec, [8, 21], [21, 11], 0) self._run_sharded_linear(spec, [7, 23], [23, 13], 0) self._run_sharded_linear(spec, [4, 15], [15, 14], 0) # Test multiple input dims self._run_sharded_linear(spec, [10, 2, 17], [17, 12], 0) self._run_sharded_linear(spec, [13, 8, 21], [21, 11], 0) self._run_sharded_linear(spec, [27, 7, 23], [23, 13], 0) self._run_sharded_linear(spec, [100, 12, 4, 15], [15, 14], 0) # Test single input dim self._run_sharded_linear(spec, [17], [17, 12], 0) self._run_sharded_linear(spec, [21], [21, 11], 0) self._run_sharded_linear(spec, [23], [23, 13], 0) self._run_sharded_linear(spec, [15], [15, 14], 0)
def test_sharded_embedding_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): # Test even split. self._run_sharded_embedding(spec, [5, 12], 16, 22) self._run_sharded_embedding(spec, [5, 4], 32, 12) self._run_sharded_embedding(spec, [6, 7, 6], 64, 11) self._run_sharded_embedding( spec, [5, 12], 16, 22, max_norm=2.5, ) self._run_sharded_embedding(spec, [6, 7, 6], 64, 11, padding_idx=30) self._run_sharded_embedding( spec, [6, 5, 3], 26, 11, max_norm=2.0, ) # Test uneven split. self._run_sharded_embedding(spec, [8, 6, 5, 4], 19, 11) self._run_sharded_embedding(spec, [6, 7, 6], 21, 11) self._run_sharded_embedding(spec, [4], 21, 11) self._run_sharded_embedding(spec, [8, 6, 5, 4], 21, 11, padding_idx=10) self._run_sharded_embedding( spec, [6, 5, 8], 28, 5, max_norm=2.0, ) self._run_sharded_embedding(spec, [4], 14, 11, max_norm=2.5)
def test_sharding_plan_simple_megatron(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): # test each sharding spec pair and see if we can apply sharding reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharding_plan = ShardingPlan( plan={ "fc1.weight": spec[0], "fc2.weight": spec[1] }, output_plan={ "": reshard_spec }, return_local_tensor=[""]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) megatron_lm = copy.deepcopy(local_megatron_lm) # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan) # check to make sure the module already been sharded self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor)) self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0]) self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1]) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = megatron_lm(input) local_output = local_megatron_lm(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEqual(bias_grad_fc1, local_bias_grad_fc1) self.assertEqual(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEqual(bias_fc1, local_bias_fc1) self.assertEqual(bias_fc2, local_bias_fc2) self.assertEqual(bias_fc1.grad, local_bias_fc1.grad) self.assertEqual(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEqual(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEqual(bias_fc2, local_bias_fc2)
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) # Sharded Tensor metadata has parenthesis imbalance issue when using re.compile error_msg = r"torch function 'linear', with args: (?s).* " r"and kwargs: None not supported for ShardedTensor!" with self.assertRaisesRegex(RuntimeError, error_msg): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def test_sharded_embedding_bag_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._test_sharded_embedding_bag_with_test_cases(spec, 0)