def test_shard_module_sub_process_group(self): megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank) colwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) rowwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) sharding_plan = ShardingPlan( plan={ "fc1.weight": colwise_sharding_spec, "fc2.weight": rowwise_sharding_spec }) pg = dist.new_group([2, 3]) if self.rank >= 2: shard_module(megatron_lm, sharding_plan, process_group=pg)
def test_sharding_plan_errors(self): rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] sharding_plan_wrong_plan = ShardingPlan( plan={ "fc1.weight": torch.randn(3, 4), }, output_plan={ "": rowwise_sharding_spec }, ) megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported to shard" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_plan) sharding_plan_wrong_output_plan = ShardingPlan( plan={ "fc1.weight": rowwise_sharding_spec, }, output_plan={ "": torch.randn(3, 4) }, ) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported as output_plan" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_output_plan) sharding_plan_wrong_module_path = ShardingPlan( plan={ "fc3.weight": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_module_path) sharding_plan_wrong_param_path = ShardingPlan( plan={ "fc1.biass": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_param_path)
def test_custom_sharding_planner(self): megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank).cuda(self.rank) planner = ChunkAllShardingPlanner(device_count=TEST_GPU_NUM) sharding_plan = planner.build_plan(megatron_lm) shard_module(megatron_lm, sharding_plan) # check to make sure the module already been sharded self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc1.bias, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.bias, ShardedTensor))
def test_sharding_plan_simple_megatron(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): # test each sharding spec pair and see if we can apply sharding reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharding_plan = ShardingPlan( plan={ "fc1.weight": spec[0], "fc2.weight": spec[1] }, output_plan={ "": reshard_spec }, return_local_tensor=[""]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) megatron_lm = copy.deepcopy(local_megatron_lm) # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan) # check to make sure the module already been sharded self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor)) self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0]) self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1]) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = megatron_lm(input) local_output = local_megatron_lm(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEqual(bias_grad_fc1, local_bias_grad_fc1) self.assertEqual(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEqual(bias_fc1, local_bias_fc1) self.assertEqual(bias_fc2, local_bias_fc2) self.assertEqual(bias_fc1.grad, local_bias_fc1.grad) self.assertEqual(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEqual(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEqual(bias_fc2, local_bias_fc2)
def __init__(self, rank=None): super().__init__() self.megatron = SimpleMegatronLM([[17, 12], [12, 29]], rank=rank) self.relu = nn.ReLU()
def _run_megatron_linear(self, spec, input_size, linear_size, dtype): def _weight_override(module_dst, module_src): module_dst.fc1.weight = clone_module_parameter(module_src.fc1, "weight") module_dst.fc1.bias = clone_module_parameter(module_src.fc1, "bias") module_dst.fc2.weight = clone_module_parameter(module_src.fc2, "weight") module_dst.fc2.bias = clone_module_parameter(module_src.fc2, "bias") def _shard_parameter(module, spec): shard_parameter(module.fc1, "weight", spec[0]) shard_parameter(module.fc2, "weight", spec[1]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM(linear_size, rank=self.rank, dtype=dtype) sharded_megatron_lm = SimpleMegatronLM(linear_size, dtype=dtype) _weight_override(sharded_megatron_lm, local_megatron_lm) # Shard the parameter. First col-wise sharding and then row-wise _shard_parameter(sharded_megatron_lm, spec) # Setup resharding of output. reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharded_megatron_lm = _collect_local_shard( _reshard_output(sharded_megatron_lm, reshard_spec) ) torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size, requires_grad=True, device=self.rank, dtype=dtype) # Run local computation local_output = local_megatron_lm(inp) # Compute loss and run backward pass. local_output.sum().backward() # Save and reset input grads. local_input_grad = inp.grad self.assertIsNotNone(inp.grad) inp.grad = None # Run sharded computation sharded_output = sharded_megatron_lm(inp) # Verify local and sharded results self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-6) sharded_output.sum().backward() sharded_input_grad = inp.grad self.assertIsNotNone(inp.grad) # Verify sharded and local grads. self.assertEqual(local_input_grad, sharded_input_grad, atol=1e-3, rtol=1e-6) ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = sharded_megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = sharded_megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEdistNorm(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEdistNorm(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEdistNorm(bias_grad_fc1, local_bias_grad_fc1) self.assertEdistNorm(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = sharded_megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEdistNorm(bias_fc1, local_bias_fc1) self.assertEdistNorm(bias_fc2, local_bias_fc2) self.assertEdistNorm(bias_fc1.grad, local_bias_fc1.grad) self.assertEdistNorm(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(sharded_megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEdistNorm(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEdistNorm(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEdistNorm(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEdistNorm(bias_fc2, local_bias_fc2)