def test_custom_sharder_errors(self): custom_sharder = CustomSharder( devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], split_sharding_idx=TEST_GPU_NUM // 2) sharding_plan = ShardingPlan(plan={ "": custom_sharder, }) sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank) with self.assertRaisesRegex( KeyError, "path must not be empty for custom sharder!"): # shard the module with the provided sharding plan shard_module(sharded_model, sharding_plan) # test conflicted sharding plan spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"]) sharding_plan = ShardingPlan( plan={ "embedding_bags.embedding_bag_0.weight": spec, "embedding_bags": custom_sharder, }) with self.assertRaisesRegex( RuntimeError, "should not conflict with the submodule tree"): # shard the module with the provided sharding plan shard_module(sharded_model, sharding_plan)
def test_sharding_plan_errors(self): rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] sharding_plan_wrong_plan = ShardingPlan( plan={ "fc1.weight": torch.randn(3, 4), }, output_plan={ "": rowwise_sharding_spec }, ) megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported to shard" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_plan) sharding_plan_wrong_output_plan = ShardingPlan( plan={ "fc1.weight": rowwise_sharding_spec, }, output_plan={ "": torch.randn(3, 4) }, ) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported as output_plan" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_output_plan) sharding_plan_wrong_module_path = ShardingPlan( plan={ "fc3.weight": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_module_path) sharding_plan_wrong_param_path = ShardingPlan( plan={ "fc1.biass": rowwise_sharding_spec, }, ) with self.assertRaisesRegex( AttributeError, "has no attribute" ): # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_param_path)
def build_plan(self, module: nn.Module) -> ShardingPlan: named_params = module.named_parameters() plan = {} for name, param in named_params: plan[name] = ChunkShardingSpec(self.dim, placements=self.devices) return ShardingPlan(plan=plan)
def _get_toy_module_sharding_plan(world_size): """ The idea behind Megatron-LM is that: 1. We shard the weight of the first nn by dim 0 (col-wise) 2. We shard the weight of the second nn by dim 1 (row-wise) 3. We aggregate the partial result of the second nn layer and store it as a sharded tensor by dim 0. 4. Return the final result on the local shard. We then need to create a sharding spec based on it and compose a sharding plan on the basis of the spec. """ colwise_spec, rowwise_spec, output_spec = _generate_sharding_spec( world_size) return ShardingPlan( # Specify the sharding plan for the component of each module. plan={ "net1.weight": colwise_spec, "net2.weight": rowwise_spec, }, # Specify the sharding plan for the output of one particular module. # e.g., the output of the second nn layer in the example of Megatron-LM. output_plan={ "net2": output_spec, }, # Specify to get the tensor stored on the local shard if the output # is a sharded tensor. return_local_tensor=["net2"], )
def test_shard_module_sub_process_group(self): megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank) colwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) rowwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) sharding_plan = ShardingPlan( plan={ "fc1.weight": colwise_sharding_spec, "fc2.weight": rowwise_sharding_spec }) pg = dist.new_group([2, 3]) if self.rank >= 2: shard_module(megatron_lm, sharding_plan, process_group=pg)
def test_reshard_to_ddp_sharding_plan(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0)[0] rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] # test each sharding spec pair and see if we can apply sharding output_spec = copy.deepcopy(rowwise_sharding_spec) output_spec.placements.sort(key=lambda placement: placement.rank()) output_spec.dim = 0 # new module with megatron as submodule class MyModule(nn.Module): def __init__(self, rank=None): super().__init__() self.megatron = SimpleMegatronLM([[17, 12], [12, 29]], rank=rank) self.relu = nn.ReLU() def forward(self, input): return self.relu(self.megatron(input)) sharding_plan = ShardingPlan(plan={ "megatron.fc1.weight": colwise_sharding_spec, "megatron.fc2.weight": rowwise_sharding_spec, }, output_plan={"megatron": output_spec}, return_local_tensor=["megatron"]) # Use same seed. torch.manual_seed(0) local_module = MyModule().cuda(self.rank) sharded_module = copy.deepcopy(local_module) # shard the module with the provided sharding plan shard_module(sharded_module, sharding_plan) # check to make sure the module already been sharded self.assertTrue( isinstance(sharded_module.megatron.fc1.weight, ShardedTensor)) self.assertTrue( isinstance(sharded_module.megatron.fc2.weight, ShardedTensor)) self.assertEqual(sharded_module.megatron.fc1.weight.sharding_spec(), colwise_sharding_spec) self.assertEqual(sharded_module.megatron.fc2.weight.sharding_spec(), rowwise_sharding_spec) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = sharded_module(input) local_output = local_module(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output)
def __init__(self, ebc, split_idx, specs): super().__init__() self.split_idx = split_idx row_spec, col_spec = specs # create embedding bags base on the spec self.embedding_bags: nn.ModuleDict = nn.ModuleDict() assert self.split_idx < ebc.num_bags for i in range(ebc.num_bags): bag_key = f"embedding_bag_{i}" if i < self.split_idx: shard_module( ebc, plan=ShardingPlan( plan={f"embedding_bags.{bag_key}.weight": row_spec})) else: shard_module( ebc, plan=ShardingPlan( plan={f"embedding_bags.{bag_key}.weight": col_spec})) self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
def test_custom_sharder(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.ebc = CustomEmbeddingBagCollection(10, 10, 8) def forward(self, inputs): return self.ebc(inputs) custom_sharder = CustomSharder( devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], split_sharding_idx=TEST_GPU_NUM // 2) sharding_plan = ShardingPlan(plan={ "ebc": custom_sharder, }) local_model = MyModule().cuda(self.rank) sharded_model = copy.deepcopy(local_model) # shard the module with the provided sharding plan shard_module(sharded_model, sharding_plan) # check to make sure the module already been sharded emb_bags = sharded_model.ebc.embedding_bags self.assertTrue( isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor)) self.assertTrue( isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor)) self.assertEqual(emb_bags["embedding_bag_0"].weight.sharding_spec(), custom_sharder.rowwise_spec) self.assertEqual(emb_bags["embedding_bag_9"].weight.sharding_spec(), custom_sharder.colwise_spec) # make sure we can run sharded computation and compare outputs # with the local model version input = torch.arange(8).reshape((2, 4)).cuda(self.rank) local_output = local_model(input) sharded_output = sharded_model(input) self.assertEqual(local_output, sharded_output)
def test_sharding_plan_simple_megatron(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): # test each sharding spec pair and see if we can apply sharding reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharding_plan = ShardingPlan( plan={ "fc1.weight": spec[0], "fc2.weight": spec[1] }, output_plan={ "": reshard_spec }, return_local_tensor=[""]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) megatron_lm = copy.deepcopy(local_megatron_lm) # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan) # check to make sure the module already been sharded self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor)) self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0]) self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1]) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = megatron_lm(input) local_output = local_megatron_lm(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEqual(bias_grad_fc1, local_bias_grad_fc1) self.assertEqual(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEqual(bias_fc1, local_bias_fc1) self.assertEqual(bias_fc2, local_bias_fc2) self.assertEqual(bias_fc1.grad, local_bias_fc1.grad) self.assertEqual(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEqual(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEqual(bias_fc2, local_bias_fc2)