def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim, dtype): # Use same seed. torch.manual_seed(0) local_linear = torch.nn.Linear(*linear_size, dtype=dtype).cuda(self.rank) sharded_linear = torch.nn.Linear(*linear_size, dtype=dtype) # Copy the weights and bias from local linear sharded_linear.weight = clone_module_parameter(local_linear, "weight") sharded_linear.bias = clone_module_parameter(local_linear, "bias") # Shard the parameter. shard_parameter(sharded_linear, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size, dtype=dtype).cuda(self.rank) reshard_spec = copy.deepcopy(spec) reshard_spec.dim = 0 reshard_spec.placements.sort(key=lambda placement: placement.rank()) sharded_linear = _collect_local_shard( _reshard_output(sharded_linear, reshard_spec)) sharded_output = sharded_linear(inp) # Run local computation local_output = local_linear(inp) # Verify self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) # Validate for torch.nn.functional.linear version. local_output = torch.nn.functional.linear(inp, local_linear.weight, local_linear.bias) sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight, sharded_linear.bias) sharded_output = sharded_output.reshard(reshard_spec).local_tensor() # When local tensor only has one dimension, we increase one more dimension # for reshard. We need to squeeze the # of dimensions manually. if inp.dim() == 1: sharded_output = sharded_output.squeeze(reshard_spec.dim) self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-3) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() local_grad = local_linear.weight.grad # Verify that both weight and bias in the sharded linear has non-None grad. sharded_weight = sharded_linear.weight.local_tensor() self.assertNotEqual(sharded_linear.bias.grad, None) self.assertNotEqual(sharded_weight.grad, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_grad) (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size) local_bias_grad = local_linear.bias.grad dist.all_reduce(local_bias_grad) # Test backward gradient calculation. self.assertEqual(sharded_linear.bias.grad, local_bias_grad, atol=1e-3, rtol=1e-3) self.assertEqual(sharded_weight.grad, local_grad_narrowed, atol=1e-3, rtol=1e-3) # Test optimizer. previous = local_linear.weight.clone().detach() optim = torch.optim.SGD(local_linear.parameters(), lr=0.1) optim.step() self.assertNotEqual(previous, local_linear.weight) previous_sharded_weight = sharded_weight.clone() previous_sharded_bias = sharded_linear.bias.clone() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(sharded_linear)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() sharded_weight = sharded_linear.weight.local_tensor() local_weight_narrowed = local_linear.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(sharded_weight.size(), local_weight_narrowed.size()) self.assertNotEqual(previous_sharded_weight, sharded_weight) self.assertEqual(sharded_weight, local_weight_narrowed, atol=1e-3, rtol=1e-3) self.assertNotEqual(previous_sharded_bias, sharded_linear.bias) self.assertEqual(sharded_linear.bias, local_linear.bias, atol=1e-3, rtol=1e-3)
def test_sharding_plan_simple_megatron(self): colwise_sharding_spec = generate_chunk_sharding_specs_for_test(0) rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1) for spec in zip(colwise_sharding_spec, rowwise_sharding_spec): # test each sharding spec pair and see if we can apply sharding reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharding_plan = ShardingPlan( plan={ "fc1.weight": spec[0], "fc2.weight": spec[1] }, output_plan={ "": reshard_spec }, return_local_tensor=[""]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) megatron_lm = copy.deepcopy(local_megatron_lm) # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan) # check to make sure the module already been sharded self.assertTrue(isinstance(megatron_lm.fc1.weight, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.weight, ShardedTensor)) self.assertEqual(megatron_lm.fc1.weight.sharding_spec(), spec[0]) self.assertEqual(megatron_lm.fc2.weight.sharding_spec(), spec[1]) # make sure we can run sharded computation input = torch.rand(22, 17).cuda(self.rank) sharded_output = megatron_lm(input) local_output = local_megatron_lm(input) # verify and make sure local and sharded output matches self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEqual(bias_grad_fc1, local_bias_grad_fc1) self.assertEqual(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEqual(bias_fc1, local_bias_fc1) self.assertEqual(bias_fc2, local_bias_fc2) self.assertEqual(bias_fc1.grad, local_bias_fc1.grad) self.assertEqual(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEqual(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEqual(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEqual(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEqual(bias_fc2, local_bias_fc2)
def _run_megatron_linear(self, spec, input_size, linear_size, dtype): def _weight_override(module_dst, module_src): module_dst.fc1.weight = clone_module_parameter(module_src.fc1, "weight") module_dst.fc1.bias = clone_module_parameter(module_src.fc1, "bias") module_dst.fc2.weight = clone_module_parameter(module_src.fc2, "weight") module_dst.fc2.bias = clone_module_parameter(module_src.fc2, "bias") def _shard_parameter(module, spec): shard_parameter(module.fc1, "weight", spec[0]) shard_parameter(module.fc2, "weight", spec[1]) # Use same seed. torch.manual_seed(0) local_megatron_lm = SimpleMegatronLM(linear_size, rank=self.rank, dtype=dtype) sharded_megatron_lm = SimpleMegatronLM(linear_size, dtype=dtype) _weight_override(sharded_megatron_lm, local_megatron_lm) # Shard the parameter. First col-wise sharding and then row-wise _shard_parameter(sharded_megatron_lm, spec) # Setup resharding of output. reshard_spec = copy.deepcopy(spec[1]) reshard_spec.placements.sort(key=lambda placement: placement.rank()) reshard_spec.dim = 0 sharded_megatron_lm = _collect_local_shard( _reshard_output(sharded_megatron_lm, reshard_spec) ) torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size, requires_grad=True, device=self.rank, dtype=dtype) # Run local computation local_output = local_megatron_lm(inp) # Compute loss and run backward pass. local_output.sum().backward() # Save and reset input grads. local_input_grad = inp.grad self.assertIsNotNone(inp.grad) inp.grad = None # Run sharded computation sharded_output = sharded_megatron_lm(inp) # Verify local and sharded results self.assertEqual(local_output, sharded_output, atol=1e-3, rtol=1e-6) sharded_output.sum().backward() sharded_input_grad = inp.grad self.assertIsNotNone(inp.grad) # Verify sharded and local grads. self.assertEqual(local_input_grad, sharded_input_grad, atol=1e-3, rtol=1e-6) ( local_weight_grad_fc1, local_weight_grad_fc2, ) = local_megatron_lm.get_weight_grads() local_bias_grad_fc1, local_bias_grad_fc2 = local_megatron_lm.get_bias_grads() # Verify that weights in both layers and biases in the sharded linear has non-None grad. ( sharded_weight_fc1, sharded_weight_fc2, ) = sharded_megatron_lm.get_weights() bias_grad_fc1, bias_grad_fc2 = sharded_megatron_lm.get_bias_grads() self.assertNotEqual(sharded_weight_fc1.grad, None) self.assertNotEqual(sharded_weight_fc2.grad, None) self.assertNotEqual(bias_grad_fc1, None) self.assertNotEqual(bias_grad_fc2, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_weight_grad_fc1) dist.all_reduce(local_weight_grad_fc2) dist.all_reduce(local_bias_grad_fc1) dist.all_reduce(local_bias_grad_fc2) local_weight_fc1, local_weight_fc2 = local_megatron_lm.get_weights() ( start_pos_fc1, chunk_size_fc1, ) = generate_local_weight_sharding_params_for_test( local_weight_fc1, 0, TEST_GPU_NUM, spec[0], self.rank ) local_grad_narrowed_fc1 = local_weight_grad_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) ( start_pos_fc2, chunk_size_fc2, ) = generate_local_weight_sharding_params_for_test( local_weight_fc2, 1, TEST_GPU_NUM, spec[1], self.rank ) local_grad_narrowed_fc2 = local_weight_grad_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test backward gradient calculation. self.assertEdistNorm(sharded_weight_fc1.grad, local_grad_narrowed_fc1) self.assertEdistNorm(sharded_weight_fc2.grad, local_grad_narrowed_fc2) self.assertEdistNorm(bias_grad_fc1, local_bias_grad_fc1) self.assertEdistNorm(bias_grad_fc2, local_bias_grad_fc2) # Test optimizer. bias_fc1, bias_fc2 = sharded_megatron_lm.get_biases() local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertEdistNorm(bias_fc1, local_bias_fc1) self.assertEdistNorm(bias_fc2, local_bias_fc2) self.assertEdistNorm(bias_fc1.grad, local_bias_fc1.grad) self.assertEdistNorm(bias_fc2.grad, local_bias_fc2.grad) previous_sharded_weight_fc1 = sharded_weight_fc1.clone() previous_sharded_weight_fc2 = sharded_weight_fc2.clone() previous_bias_fc1 = bias_fc1.clone() previous_bias_fc2 = bias_fc2.clone() optim = torch.optim.SGD(local_megatron_lm.parameters(), lr=0.1) optim.step() sharded_optim = ShardedOptimizer( dict(named_params_with_sharded_tensor(sharded_megatron_lm)), torch.optim.SGD, lr=0.1, ) sharded_optim.step() local_weight_fc1_narrowed = local_weight_fc1.narrow( 0, start_pos_fc1, chunk_size_fc1 ) local_weight_fc2_narrowed = local_weight_fc2.narrow( 1, start_pos_fc2, chunk_size_fc2 ) # Test weight value after optimizer. self.assertEqual(sharded_weight_fc1.size(), local_weight_fc1_narrowed.size()) self.assertEqual(sharded_weight_fc2.size(), local_weight_fc2_narrowed.size()) self.assertNotEqual(previous_sharded_weight_fc1, sharded_weight_fc1) self.assertNotEqual(previous_sharded_weight_fc2, sharded_weight_fc2) self.assertEdistNorm(sharded_weight_fc1, local_weight_fc1_narrowed) self.assertEdistNorm(sharded_weight_fc2, local_weight_fc2_narrowed) # Test bias value after optimizer. local_bias_fc1, local_bias_fc2 = local_megatron_lm.get_biases() self.assertNotEqual(previous_bias_fc1, bias_fc1) self.assertEdistNorm(bias_fc1, local_bias_fc1) self.assertNotEqual(previous_bias_fc2, bias_fc2) self.assertEdistNorm(bias_fc2, local_bias_fc2)
def _run_sharded_embedding( self, spec, input_size, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, padding_idx=None, ): # Use same seed. torch.manual_seed(0) local_embedding = torch.nn.Embedding( num_embeddings, embedding_dim, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ).cuda(self.rank) sharded_embedding = torch.nn.Embedding( num_embeddings, embedding_dim, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) # Copy the weights from local embedding sharded_embedding.weight = clone_module_parameter( local_embedding, "weight") # Shard the parameter. shard_parameter(sharded_embedding, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) sharded_output = sharded_embedding(inp) # If max_norm is set, we need to ensure that the renorm has been applied across # inputs from all ranks. if max_norm is not None: gathered_inputs = [ torch.zeros_like(inp) for _ in range(TEST_GPU_NUM) ] dist.all_gather(gathered_inputs, inp) unique_inp = torch.unique(torch.cat(gathered_inputs)) local_embedding(unique_inp) # Run local computation local_output = local_embedding(inp) # Compare local weight and shared one to ensure the renorm # as expected. if max_norm is not None: sharded_dim = spec.dim sharded_weight = sharded_embedding.weight.local_shards()[0].tensor (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_embedding.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_weight_narrowed = local_embedding.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(local_weight_narrowed, sharded_weight) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.embedding version. local_output = torch.nn.functional.embedding( inp, local_embedding.weight, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) sharded_output = torch.nn.functional.embedding( inp, sharded_embedding.weight, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) self.assertEqual(local_output, sharded_output)
def _run_sharded_embedding_bag( self, spec, input_size, num_embeddings, embedding_dim, mode, sharded_dim=None, include_last_offset=False, offset_size=None, max_norm=None, norm_type=2.0, padding_idx=None, ): # Use same seed. torch.manual_seed(0) local_embedding_bag = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, mode=mode, max_norm=max_norm, norm_type=norm_type, include_last_offset=include_last_offset, padding_idx=padding_idx, ).cuda(self.rank) sharded_embedding_bag = torch.nn.EmbeddingBag( num_embeddings, embedding_dim, mode=mode, max_norm=max_norm, norm_type=norm_type, include_last_offset=include_last_offset, padding_idx=padding_idx, ) # Copy the weights from local embedding bag. sharded_embedding_bag.weight = torch.nn.Parameter( local_embedding_bag.weight.detach().clone()) # Shard the parameter. shard_parameter(sharded_embedding_bag, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) per_sample_weights = None if mode == "sum": per_sample_weights = torch.rand(*input_size).cuda(self.rank) offsets = None if len(input_size) == 1: # We need to generate certain length offset for each rank. # The current implementation and dist API does not support # the case when the offset has different lengths. # input_size[0] >> offset_size, so the while loop will not # for too long. while offsets is None or (offsets.size(0) != offset_size): offsets = torch.randint(input_size[0], (offset_size, )) offsets[0] = 0 if include_last_offset: offsets[-1] = input_size[0] offsets = (torch.unique( offsets, sorted=True).contiguous().cuda(self.rank)) # If max_norm is set, we need to ensure that the renorm has been applied across # inputs from all ranks. if max_norm is not None: gathered_inputs = [ torch.zeros_like(inp) for _ in range(TEST_GPU_NUM) ] dist.all_gather(gathered_inputs, inp) unique_inp = torch.unique(torch.cat(gathered_inputs)) offsets_dummy = torch.tensor([len(unique_inp) // 2 ]).cuda(self.rank) local_embedding_bag(unique_inp, offsets=offsets_dummy) sharded_output = sharded_embedding_bag( inp, offsets=offsets, per_sample_weights=per_sample_weights, ) # Run local computation local_output = local_embedding_bag( inp, offsets=offsets, per_sample_weights=per_sample_weights, ) # Compare local weight and shared one to ensure the renorm # as expected. if max_norm is not None: sharded_weight = sharded_embedding_bag.weight.local_shards( )[0].tensor (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_embedding_bag.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_weight_narrowed = local_embedding_bag.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(local_weight_narrowed, sharded_weight) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.embedding_bag version. local_output = torch.nn.functional.embedding_bag( inp, local_embedding_bag.weight, offsets=offsets, mode=mode, per_sample_weights=per_sample_weights, include_last_offset=include_last_offset, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) sharded_output = torch.nn.functional.embedding_bag( inp, sharded_embedding_bag.weight, offsets=offsets, mode=mode, per_sample_weights=per_sample_weights, include_last_offset=include_last_offset, max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, ) self.assertEqual(local_output, sharded_output)
def _run_sharded_linear(self, spec, input_size, linear_size, sharded_dim): # Use same seed. torch.manual_seed(0) local_linear = torch.nn.Linear(*linear_size).cuda(self.rank) sharded_linear = torch.nn.Linear(*linear_size) # Copy the weights and bias from local linear sharded_linear.weight = torch.nn.Parameter( local_linear.weight.detach().clone()) sharded_linear.bias = torch.nn.Parameter( local_linear.bias.detach().clone()) # Shard the parameter. shard_parameter(sharded_linear, "weight", spec) # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank inp = torch.rand(*input_size).cuda(self.rank) sharded_output = sharded_linear(inp) # Run local computation local_output = local_linear(inp) # Verify self.assertEqual(local_output, sharded_output) # Validate for torch.nn.functional.linear version. local_output = torch.nn.functional.linear(inp, local_linear.weight, local_linear.bias) sharded_output = torch.nn.functional.linear(inp, sharded_linear.weight, sharded_linear.bias) self.assertEqual(local_output, sharded_output) # Compute loss and run backward pass. local_output.sum().backward() sharded_output.sum().backward() local_grad = local_linear.weight.grad # Verify that both weight and bias in the sharded linear has non-None grad. sharded_weight = sharded_linear.weight.local_shards()[0].tensor self.assertNotEqual(sharded_linear.bias.grad, None) self.assertNotEqual(sharded_weight.grad, None) # Shard the local linear's weight grad so that we can compare. dist.all_reduce(local_grad) (start_pos, chunk_size) = generate_local_weight_sharding_params_for_test( local_linear.weight, sharded_dim, TEST_GPU_NUM, spec, self.rank) local_grad_narrowed = local_grad.narrow(sharded_dim, start_pos, chunk_size) # Test backward gradient calculation. self.assertEqual(sharded_linear.bias.grad, local_linear.bias.grad) self.assertEqual(sharded_weight.grad, local_grad_narrowed) # Test optimizer. previous = local_linear.weight.clone().detach() optim = torch.optim.SGD(local_linear.parameters(), lr=0.1) optim.step() self.assertNotEqual(previous, local_linear.weight) previous_sharded_weight = sharded_weight.clone() previous_sharded_bias = sharded_linear.bias.clone() sharded_optim = ShardedOptimizer(dict( named_params_with_sharded_tensor(sharded_linear)), torch.optim.SGD, lr=0.1) sharded_optim.step() sharded_weight = sharded_linear.weight.local_shards()[0].tensor local_weight_narrowed = local_linear.weight.narrow( sharded_dim, start_pos, chunk_size) self.assertEqual(sharded_weight.size(), local_weight_narrowed.size()) self.assertNotEqual(previous_sharded_weight, sharded_weight) self.assertEqual(sharded_weight, local_weight_narrowed) self.assertNotEqual(previous_sharded_bias, sharded_linear.bias) self.assertEqual(sharded_linear.bias, local_linear.bias)