def test_shard_module_sub_process_group(self): megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank) colwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) rowwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], ) sharding_plan = ShardingPlan( plan={ "fc1.weight": colwise_sharding_spec, "fc2.weight": rowwise_sharding_spec }) pg = dist.new_group([2, 3]) if self.rank >= 2: shard_module(megatron_lm, sharding_plan, process_group=pg)
def _generate_sharding_spec(world_size): """ We first need to create a sharding spec for our sharding work. For now, we only support sharding on one dimension. So we use ``ChunkShardingSpec`` to chunk the size of the given sharding dim to equally split length. The behavior is similar to `torch.chunk`. We also need to create the output sharding spec for the second nn because we need to aggregate(reduce) the partial result after the second nn layer. So we have a new sharding spec to represent that how we store the aggregation result in a new sharded tensor. """ placements = [f"rank:{idx}/cuda:{idx}" for idx in range(world_size)] # Shard the first nn module's weight by dim 0. # (nn.Linear transposes the weight internally so dim 0 actually means column) colwise_spec = ChunkShardingSpec( dim=0, placements=placements, ) # Shard the second nn module's weight by dim 1. rowwise_spec = ChunkShardingSpec( dim=1, placements=placements, ) # The result from the second nn.linear layer needs aggregation by dim 0. output_spec = ChunkShardingSpec( dim=0, placements=placements, ) return colwise_spec, rowwise_spec, output_spec
def test_get_chunk_sharding_params(self): ranks = [ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ] spec = ChunkShardingSpec( dim=0, placements=ranks, ) result = get_chunk_sharding_params(21, 4, spec, 1) self.assertEqual(6, result[0]) self.assertEqual(6, result[1]) result = get_chunk_sharding_params(21, 4, spec, 3) self.assertEqual(18, result[0]) self.assertEqual(3, result[1]) ranks[1], ranks[2] = ranks[2], ranks[1] ranks[0], ranks[3] = ranks[3], ranks[0] spec.placements = ranks result = get_chunk_sharding_params(21, 4, spec, 1) self.assertEqual(12, result[0]) self.assertEqual(6, result[1]) result = get_chunk_sharding_params(21, 4, spec, 3) self.assertEqual(0, result[0]) self.assertEqual(6, result[1])
def generate_chunk_sharding_specs_for_test(sharding_dim): return [ ChunkShardingSpec( dim=sharding_dim, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ), # Test different ordering. (Case 1) ChunkShardingSpec( dim=sharding_dim, placements=[ "rank:2/cuda:2", "rank:3/cuda:3", "rank:0/cuda:0", "rank:1/cuda:1", ], ), # Test different ordering. (Case 2) ChunkShardingSpec( dim=sharding_dim, placements=[ "rank:3/cuda:3", "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", ], ), ]
def test_load_rowwise_to_colwise(self) -> None: path = self.get_file_path() self.assertEqual(self.world_size, dist.get_world_size()) # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. src_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ) # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. dst_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0", "rank:1", ], ) if dist.get_rank() == 0: shutil.rmtree(path, ignore_errors=True) os.makedirs(path) model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank()) model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank()) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) # We can't use torch.allclose since each ST has a different sharding spec store_tensor = self.load_tensor(model_to_save.sharded_tensor) load_tensor = self.load_tensor(model_to_load.sharded_tensor) if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor))
def test_custom_sharder_errors(self): custom_sharder = CustomSharder( devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], split_sharding_idx=TEST_GPU_NUM // 2) sharding_plan = ShardingPlan(plan={ "": custom_sharder, }) sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank) with self.assertRaisesRegex( KeyError, "path must not be empty for custom sharder!"): # shard the module with the provided sharding plan shard_module(sharded_model, sharding_plan) # test conflicted sharding plan spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"]) sharding_plan = ShardingPlan( plan={ "embedding_bags.embedding_bag_0.weight": spec, "embedding_bags": custom_sharder, }) with self.assertRaisesRegex( RuntimeError, "should not conflict with the submodule tree"): # shard the module with the provided sharding plan shard_module(sharded_model, sharding_plan)
def test_named_params_with_sharded_tensor(self): rowwise_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank) sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model)) param_keys = list(sharded_model_params.keys()) self.assertEqual(len(param_keys), 2) self.assertTrue("param" in param_keys) self.assertTrue("sharded_param" in param_keys) sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank) sharded_linear.shard_parameter() sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear)) param_keys = list(sharded_linear_params.keys()) self.assertEqual(len(param_keys), 4) self.assertTrue("linear1.bias" in param_keys) self.assertTrue("linear2.bias" in param_keys) self.assertTrue("linear1.weight" in param_keys) self.assertTrue("linear2.weight" in param_keys) self.assertFalse("bias" in param_keys)
def get_spec(self): return ChunkShardingSpec( dim=0, placements=[ f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size()) ] )
def test_init_sharded_tensor_with_normal(self): """ Test torch.nn.init.normal_(ShardedTensor, mean, std) """ spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) h, w = 8, 2 expected_h = 2 expected_device = torch.device(f"cuda:{self.rank}") mean, std = 10, 5 seed = 1234 dtype = torch.double st = sharded_tensor.empty(spec, h, w, dtype=dtype) self.assertEqual(1, len(st.local_shards())) # Clone local tensor to ensure torch.nn.init starts from the same input local_tensor_clone = torch.clone(st.local_shards()[0].tensor) torch.manual_seed(seed) torch.nn.init.normal_(st, mean=mean, std=std) torch.manual_seed(seed) torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std) self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
def build_plan(self, module: nn.Module) -> ShardingPlan: named_params = module.named_parameters() plan = {} for name, param in named_params: plan[name] = ChunkShardingSpec(self.dim, placements=self.devices) return ShardingPlan(plan=plan)
def test_basic_math_ops(self): ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_lhs = sharded_tensor.rand(spec, (12, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) current_rank = dist.get_rank() global_lhs = torch.empty((12, 3)) if current_rank == 0 else None global_rhs = torch.empty((12, 3)) if current_rank == 0 else None sharded_lhs.gather(dst=0, out=global_lhs) sharded_rhs.gather(dst=0, out=global_rhs) for op in ops: binary_op = gen_binary_op_func(op) binary_op_ = gen_binary_op_func(op, inplace=True) # test basic math ops between ShardedTensors sharded_output = binary_op(sharded_lhs, sharded_rhs) output = torch.empty((12, 3)) if current_rank == 0 else None sharded_output.gather(dst=0, out=output) if current_rank == 0: global_output = binary_op(global_lhs, global_rhs) self.assertEqual(output, global_output) # test basic math ops between ShardedTensor and scalar scalars = [3, 1.8] for scalar in scalars: sharded_output_lhs = binary_op(sharded_lhs, scalar) sharded_output_lhs_ = binary_op_(sharded_lhs, scalar) self.assertTrue( torch.allclose(sharded_output_lhs, sharded_output_lhs_)) output_lhs = torch.empty( (12, 3)) if current_rank == 0 else None sharded_output_lhs.gather(dst=0, out=output_lhs) sharded_output_rhs = binary_op(scalar, sharded_lhs) output_rhs = torch.empty( (12, 3)) if current_rank == 0 else None sharded_output_rhs.gather(dst=0, out=output_rhs) if current_rank == 0: global_output_lhs = binary_op(global_lhs, scalar) global_output_rhs = binary_op(scalar, global_lhs) self.assertEqual(output_lhs, global_output_lhs) self.assertEqual(output_rhs, global_output_rhs)
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of col-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) When the local tensor only has one dimension, we increase one more dimension for reshard. We need to do squeeze manually to reduce the dimension later-on. For example, if we have: input: size[15] weight: size[15, 16] world_size: 4 In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4] tensor and generate a sharded tenor sharded by dim 1. For the rest situations, we just simply concatenate local tensors. No more actions are needed afterward. Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: A :class:`ShardedTensor` object which filled with local intermediate results. """ # allgather the inputs first. gathered_inputs = all_gather(input, group=pg) (start_pos, chunk_size) = get_chunk_sharding_params(bias.size(0), world_size, weight._sharding_spec, rank) local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size, weight, pg, bias) results = [] for i, inp in enumerate(gathered_inputs): results.append(inp.matmul(local_shard_t) + local_bias) # When the local result only has one dimension, we need to make sure # it does not shard by dim 0. So reshard can work properly. if results[0].dim() == 1: # type: ignore[attr-defined] result = torch.stack(results) # type: ignore[arg-type] else: result = torch.cat(results) # type: ignore[arg-type] st_size = list(result.size()) st_size[-1] = weight.size(0) new_sharding_spec = ChunkShardingSpec( dim=-1, placements=weight.sharding_spec().placements) return ShardedTensor._init_from_local_tensor( result, new_sharding_spec, *st_size, # type: ignore[arg-type] process_group=pg, )
def _reshard_spec_for_subgroup(self, rank): if rank in [0, 1]: return ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) else: return ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:2", "rank:1/cuda:3", ], )
def spec(self) -> ChunkShardingSpec: # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. return ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], )
def test_sharded_optim(self): rowwise_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) local_model = MyShardedModel().cuda(self.rank) sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank) # copy the parameteres from local model sharded_model.sharded_param.local_shards()[0].tensor = \ local_model.sharded_param.detach().clone().requires_grad_() local_optim = optim.SGD(local_model.parameters(), lr=0.1) sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model)) sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1) local_optim.zero_grad() sharded_optim.zero_grad() before_update = deepcopy(sharded_optim.named_params) inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_() # run forward local_output = local_model(inp) sharded_output = sharded_model(inp) # backward local_output.sum().backward() sharded_output.sum().backward() # optimizer update local_optim.step() sharded_optim.step() # make sure the parameters (including sharded param) # get updated by the optimizer, and the updated # local params are the same as the sharded params for key, val in before_update.items(): new_val = sharded_optim.named_params[key] if isinstance(val, sharded_tensor.ShardedTensor): self.assertNotEqual( val.local_shards()[0].tensor, new_val.local_shards()[0].tensor ) self.assertEqual( new_val.local_shards()[0].tensor, local_model.sharded_param ) else: self.assertNotEqual(val, new_val) self.assertEqual(new_val, local_model.param)
def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0): spec_list = [] for i in range(len(sharding_dims)): random.Random(seed + i).shuffle(PLACEMENTS) spec_list.append( ChunkShardingSpec( dim=sharding_dims[i], placements=copy.deepcopy(PLACEMENTS), )) return spec_list
def get_gpu_specs(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) alt_spec = ChunkShardingSpec( dim=0, placements=[ "rank:1/cuda:1", "rank:0/cuda:0", "rank:3/cuda:3", "rank:2/cuda:2", ], ) return spec, alt_spec
def _test_common_failures(self, cmp_op): spec, alt_spec = self.get_gpu_specs() st1, st2 = self.get_random_tensors(spec, spec, 10, 10) if self.rank == 0: torch.nn.init.uniform_(st1.local_shards()[0].tensor) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 5) self.assertFalse(cmp_op(st1, st2)) st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.zeros(spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True) self.assertFalse(cmp_op(st1, st2)) cpu_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cpu", "rank:1/cpu", "rank:2/cpu", "rank:3/cpu", ], ) st1 = sharded_tensor.ones(cpu_spec, 10, 10) st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True) self.assertFalse(cmp_op(st1, st2)) pg = dist.new_group([1, 0, 3, 2]) st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg) with self.assertRaisesRegex( RuntimeError, "All distributed tensors should use the same ProcessGroup"): cmp_op(st1, st2) pg = dist.new_group([0, 1, 2, 3]) st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg) with self.assertRaisesRegex( RuntimeError, "All distributed tensors should use the same ProcessGroup"): cmp_op(st1, st2)
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, ( 10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare( state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertEqual(bytes_reqs[0].storage_key, metadata.state_dict_metadata['bytes'].storage_key) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) self.assertEqual( replicated.storage_key, metadata.state_dict_metadata['replicated'].storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] self.assertTrue('sharded' in metadata.state_dict_metadata) shard_keys = [ sm.storage_key for sm in metadata.state_dict_metadata['sharded'].storage_metadata ] self.assertTrue(shard.storage_key in shard_keys)
def shard_parameter(self): rowwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) colwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec) sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec)
def test_clone(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) copied_st = st.clone() self.assertTrue(type(copied_st) is type(st)) self.assertEqual(copied_st.local_tensor(), st.local_tensor()) self.assertFalse(copied_st is st)
def test_tensor_metadata_with_missing_rank_spec(self) -> None: spec = ChunkShardingSpec( dim=0, placements=[ "rank:1/cuda:1", ], ) st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64) mapping = dict() (_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping) self.assertEqual(1, len(storage_md)) self.assertEqual(1, len(mapping))
def test_read_write_shard_tensor(self) -> None: paths = [tempfile.mkdtemp()] dist.broadcast_object_list(paths) path = paths[0] # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. spec = ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ) model_to_save = MyShardedModel1(spec, init_rrefs=False) # Test save model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) dist.barrier() # Create a new model model_to_load = MyShardedModel1(spec, init_rrefs=False) # This is not the correct hook for loading the state dict # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() dist.barrier() with self.assertRaises(AssertionError): assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) # Test load. fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) dist.barrier()
def _test_sharded_softmax(self, softmax_dim, sharding_dim): torch.manual_seed(0) local_tensor = torch.rand(10, 10, device=self.rank) local_softmax = torch.nn.functional.softmax(local_tensor, softmax_dim) spec = ChunkShardingSpec(dim=sharding_dim, placements=[ f'rank:{idx}/cuda:{idx}' for idx in range(self.world_size) ]) st = _shard_tensor(local_tensor, spec) sharded_softmax = torch.nn.functional.softmax(st, softmax_dim) self.assertEqual( local_softmax.chunk(self.world_size, dim=sharding_dim)[self.rank], sharded_softmax.local_tensor())
def test_math_ops_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_lhs = sharded_tensor.rand(spec, (20, 3)) sharded_rhs = sharded_tensor.rand(spec, (12, 3)) with self.assertRaisesRegex( RuntimeError, "Implicit broadcasting not supported" ): torch.add(sharded_lhs, sharded_rhs) spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ), ] ) st = sharded_tensor.rand(spec, 10, 10) with self.assertRaisesRegex(RuntimeError, "not supported"): torch.add(st, sharded_rhs)
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, (10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertTrue(MetadataIndex('bytes') in metadata.storage_data) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) storage_key = MetadataIndex('replicated', torch.Size([0])) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] local_shard = state_dict["sharded"].local_shards()[0] self.assertTrue('sharded' in metadata.state_dict_metadata) storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets)) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
def test_replicated_tensor_inter_op_sharded_tensor(self): torch.manual_seed(self.rank) local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4 local_tensor2 = torch.ones(12, 3, device=f"cuda:{self.rank}") * 4 spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = _shard_tensor(local_tensor1, spec, src_rank=0) replica_tensor = ReplicatedTensor(local_tensor2) ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] for op in ops: binary_op = gen_binary_op_func(op) res = binary_op(st, replica_tensor) self.assertIsInstance(res, sharded_tensor.ShardedTensor) self.assertNotIsInstance(res, ReplicatedTensor) output = torch.empty((12, 3)) if self.rank == 0 else None res.gather(dst=0, out=output) if self.rank == 0: local_output = binary_op(local_tensor1, local_tensor2) self.assertEqual(output, local_output) # reflective reflect_res = binary_op(replica_tensor, st) self.assertIsInstance(reflect_res, sharded_tensor.ShardedTensor) self.assertNotIsInstance(reflect_res, ReplicatedTensor) reflect_output = torch.empty((12, 3)) if self.rank == 0 else None reflect_res.gather(dst=0, out=reflect_output) if self.rank == 0: reflect_local_output = binary_op(local_tensor2, local_tensor1) self.assertEqual(reflect_output, reflect_local_output)
def test_detach(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = sharded_tensor.rand(spec, (12, 5), requires_grad=True) local_shards = st.local_shards() # before set requires_grad, all local shards should not require grads for local_shard in local_shards: self.assertTrue(local_shard.tensor.requires_grad) detached_st = st.detach() self.assertFalse(detached_st.requires_grad) for local_shard in detached_st.local_shards(): self.assertFalse(local_shard.tensor.requires_grad)
def test_replicated_tensor_implicit_broadcasting(self): # use same seed torch.manual_seed(self.rank) # test implicit broadcasting local_tensor1 = torch.rand(12, 3, device=f"cuda:{self.rank}") * 4 # we use size (3) to trigger the implicit broadcasting logic # and it will fail if implicit broadcasting not happen. local_tensor2 = torch.ones(3, device=f"cuda:{self.rank}") spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st = _shard_tensor(local_tensor1, spec, src_rank=0) replica_tensor = ReplicatedTensor(local_tensor2) ops = [ "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/" ] for op in ops: binary_op = gen_binary_op_func(op) # replicated tensor should automatically broadcasted res = binary_op(st, replica_tensor) self.assertIsInstance(res, sharded_tensor.ShardedTensor) output = torch.empty((12, 3)) if self.rank == 0 else None res.gather(dst=0, out=output) if self.rank == 0: local_output = binary_op(local_tensor1, local_tensor2) self.assertEqual(output, local_output)
def test_replicated_tensor_inter_op_sharded_tensor_errors(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 replica_tensor = ReplicatedTensor(local_tensor) torch.manual_seed(self.rank) spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) st1 = sharded_tensor.rand(spec, (20, 3, 3)) st2 = sharded_tensor.rand(spec, (30, 3, 3)) with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting'): st1 + st2 with self.assertRaisesRegex(RuntimeError, 'not supported for ShardedTensor'): st1 % replica_tensor