Beispiel #1
0
    def test_basic_math_ops(self):
        ops = [
            "torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*",
            "/"
        ]

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_lhs = sharded_tensor.rand(spec, (12, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))
        current_rank = dist.get_rank()
        global_lhs = torch.empty((12, 3)) if current_rank == 0 else None
        global_rhs = torch.empty((12, 3)) if current_rank == 0 else None
        sharded_lhs.gather(dst=0, out=global_lhs)
        sharded_rhs.gather(dst=0, out=global_rhs)

        for op in ops:
            binary_op = gen_binary_op_func(op)
            binary_op_ = gen_binary_op_func(op, inplace=True)
            # test basic math ops between ShardedTensors
            sharded_output = binary_op(sharded_lhs, sharded_rhs)
            output = torch.empty((12, 3)) if current_rank == 0 else None
            sharded_output.gather(dst=0, out=output)

            if current_rank == 0:
                global_output = binary_op(global_lhs, global_rhs)

                self.assertEqual(output, global_output)

            # test basic math ops between ShardedTensor and scalar
            scalars = [3, 1.8]
            for scalar in scalars:
                sharded_output_lhs = binary_op(sharded_lhs, scalar)

                sharded_output_lhs_ = binary_op_(sharded_lhs, scalar)
                self.assertTrue(
                    torch.allclose(sharded_output_lhs, sharded_output_lhs_))
                output_lhs = torch.empty(
                    (12, 3)) if current_rank == 0 else None
                sharded_output_lhs.gather(dst=0, out=output_lhs)

                sharded_output_rhs = binary_op(scalar, sharded_lhs)
                output_rhs = torch.empty(
                    (12, 3)) if current_rank == 0 else None
                sharded_output_rhs.gather(dst=0, out=output_rhs)

                if current_rank == 0:
                    global_output_lhs = binary_op(global_lhs, scalar)
                    global_output_rhs = binary_op(scalar, global_lhs)

                    self.assertEqual(output_lhs, global_output_lhs)
                    self.assertEqual(output_rhs, global_output_rhs)
Beispiel #2
0
 def test_sharded_chunk_error(self):
     chunk_spec = generate_chunk_sharding_specs_for_test(-1)
     with self.assertRaisesRegex(NotImplementedError,
                                 "Chunk by sharding dim is not supported."):
         st = sharded_tensor.rand(chunk_spec[0], [17, 24])
         torch.chunk(st, 5, dim=-1)
     enumerable_spec = generate_enumerable_sharding_specs_for_test()
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec is supported for chunk."):
         st = sharded_tensor.rand(enumerable_spec[0], [10, 10])
         torch.chunk(st, 5, dim=-1)
Beispiel #3
0
 def test_sharded_tensor_type_as(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
         )
         st_2 = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.float
         )
         st_3 = st.type_as(st_2)
         self.assertEqual(torch.float, st_3.dtype)
         self.assertEqual(torch.float, st_3.local_tensor().dtype)
         st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda())
         self.assertEqual(torch.bool, st_3.dtype)
         self.assertEqual(torch.bool, st_3.local_tensor().dtype)
 def test_sharded_tensor_reshard_errors(self):
     specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
     spec, reshard_spec = specs[0], specs[1]
     enumerable_sharding_spec = EnumerableShardingSpec([
         ShardMetadata(
             shard_offsets=[0, 0],
             shard_sizes=[5, 5],
             placement="rank:0/cuda:0",
         ),
         ShardMetadata(
             shard_offsets=[5, 0],
             shard_sizes=[5, 5],
             placement="rank:1/cuda:1",
         ),
     ])
     st = sharded_tensor.rand(spec, 24, 12)
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only ChunkShardingSpec supported for reshard."):
         st.reshard(enumerable_sharding_spec)
     st._local_shards = [st.local_shards()[0], st.local_shards()[0]]
     with self.assertRaisesRegex(
             NotImplementedError,
             "Only single local shard supported for reshard."):
         st.reshard(reshard_spec)
Beispiel #5
0
    def get_random_tensors(self,
                           spec1,
                           spec2,
                           *sizes,
                           pg1=None,
                           pg2=None,
                           seed_offset=0):
        pg1 = _get_default_group() if pg1 is None else pg1
        pg2 = _get_default_group() if pg2 is None else pg2
        torch.manual_seed(TestShardedTensorBinaryOps.seed)
        st1 = sharded_tensor.rand(spec1, sizes, process_group=pg1)
        torch.manual_seed(TestShardedTensorBinaryOps.seed + seed_offset)
        st2 = sharded_tensor.rand(spec2, sizes, process_group=pg2)

        TestShardedTensorBinaryOps.seed += 1
        return st1, st2
 def __init__(
     self,
     spec: ShardingSpec,
 ) -> None:
     super(MyShardedModel3, self).__init__()
     self.sharded_tensor: ShardedTensor = sharded_tensor.rand(
         spec, 10, 20, init_rrefs=False)
Beispiel #7
0
    def test_math_ops_errors(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_lhs = sharded_tensor.rand(spec, (20, 3))
        sharded_rhs = sharded_tensor.rand(spec, (12, 3))

        with self.assertRaisesRegex(
            RuntimeError, "Implicit broadcasting not supported"
        ):
            torch.add(sharded_lhs, sharded_rhs)

        spec = EnumerableShardingSpec(
            [
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                ),
            ]
        )

        st = sharded_tensor.rand(spec, 10, 10)

        with self.assertRaisesRegex(RuntimeError, "not supported"):
            torch.add(st, sharded_rhs)
Beispiel #8
0
    def test_dummy_writer_works(self) -> None:
        state_dict = {
            'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
            'replicated': torch.rand(10, 10),
            'bytes': [1, 2, 3, 4]
        }

        save_state_dict(state_dict, FaultyStorageWriter({}))
Beispiel #9
0
 def test_sharded_tensor_contiguous(self):
     specs = _chunk_sharding_specs_list_for_test([0], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True)
         st = st.transpose(1, 0)
         st = st.contiguous()
         self.assertTrue(st.is_contiguous())
         self.assertTrue(st.local_tensor().is_contiguous())
Beispiel #10
0
 def __init__(self, spec=None, group=None):
     super(MyShardedModel, self).__init__()
     # Use same seed.
     torch.manual_seed(0)
     self.param = torch.nn.Parameter(torch.rand(5, 10))
     if spec is not None:
         self.sharded_param = sharded_tensor.rand(spec, 20, 10, requires_grad=True, process_group=group)
     else:
         self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))
Beispiel #11
0
 def test_sharded_tensor_transpose_error(self):
     enumerable_spec = generate_enumerable_sharding_specs_for_test()[0]
     st = sharded_tensor.rand(
         enumerable_spec, 10, 10, init_rrefs=True, dtype=torch.double
     )
     with self.assertRaisesRegex(
         NotImplementedError,
         "Only ChunkShardingSpec supported for 'transpose'",
     ):
         st.transpose(1, 0)
Beispiel #12
0
 def _run_sharded_elementwise_ops(self, spec, input_size, op):
     torch.manual_seed(self.rank)
     st = sharded_tensor.rand(spec, *input_size)
     new_st = op(st)
     local_shard = st.local_tensor()
     new_st_local_shard = new_st.local_tensor()
     self.assertEqual(
         op(local_shard),
         new_st_local_shard,
     )
Beispiel #13
0
 def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
     super(MyShardedModel2, self).__init__()
     if spec is not None:
         self.sharded_tensor2 = sharded_tensor.rand(spec,
                                                    10,
                                                    20,
                                                    process_group=group,
                                                    init_rrefs=init_rrefs)
     else:
         self.sharded_tensor2 = None
     self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
Beispiel #14
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (
                10,
                10,
            )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(
            state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertEqual(bytes_reqs[0].storage_key,
                             metadata.state_dict_metadata['bytes'].storage_key)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            self.assertEqual(
                replicated.storage_key,
                metadata.state_dict_metadata['replicated'].storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            shard_keys = [
                sm.storage_key for sm in
                metadata.state_dict_metadata['sharded'].storage_metadata
            ]
            self.assertTrue(shard.storage_key in shard_keys)
Beispiel #15
0
 def test_sharded_tensor_transpose_error(self):
     enumerable_spec = generate_enumerable_sharding_specs_for_test()[0]
     st = sharded_tensor.rand(enumerable_spec,
                              10,
                              10,
                              init_rrefs=False,
                              dtype=torch.double)
     with self.assertRaisesRegex(
             RuntimeError,
             "not supported",
     ):
         st.transpose(1, 0)
Beispiel #16
0
 def test_sharded_bmm_errors(self):
     specs = generate_chunk_sharding_specs_for_test(0)
     st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6))
     st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6))
     with self.assertRaisesRegex(
             NotImplementedError,
             'Both st and st2 need to have same placements for bmm',
     ):
         torch.bmm(st_lhs, st_rhs)
     for spec in specs:
         st_lhs = sharded_tensor.rand(spec, (20, 3))
         st_rhs = sharded_tensor.rand(spec, (20, 3))
         with self.assertRaisesRegex(
                 TypeError,
                 'both st and st2 need to be a 3D ShardedTensor',
         ):
             torch.bmm(st_lhs, st_rhs)
         rhs = torch.rand(15, 5, 6).cuda(self.rank)
         with self.assertRaisesRegex(
                 TypeError,
                 'st2 needs to be a ShardedTensor for torch.bmm',
         ):
             torch.bmm(st_lhs, rhs)
         spec.dim = 1
         st_lhs = sharded_tensor.rand(spec, (15, 5, 6))
         st_rhs = sharded_tensor.rand(spec, (15, 5, 6))
         with self.assertRaisesRegex(
                 NotImplementedError,
                 'Only support performing bmm on tensors sharded on dim 0 now',
         ):
             torch.bmm(st_lhs, st_rhs)
Beispiel #17
0
 def test_sharded_tensor_softmax_error(self):
     specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
         )
         with self.assertRaisesRegex(
             NotImplementedError,
             "Only support performing softmax on non-sharding dim now.",
         ):
             torch.nn.functional.softmax(
                 st, dim=st.sharding_spec().dim, dtype=torch.float32
             )
    def test_replicated_tensor_inter_op_sharded_tensor_errors(self):
        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
        replica_tensor = ReplicatedTensor(local_tensor)

        torch.manual_seed(self.rank)
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        st1 = sharded_tensor.rand(spec, (20, 3, 3))
        st2 = sharded_tensor.rand(spec, (30, 3, 3))

        with self.assertRaisesRegex(RuntimeError, 'Implicit broadcasting'):
            st1 + st2

        with self.assertRaisesRegex(RuntimeError, 'not supported for ShardedTensor'):
            st1 % replica_tensor
 def _run_sharded_elementwise_ops(
     self, spec, input_size, op, reset_seed=None, **kwargs
 ):
     torch.manual_seed(self.rank)
     st = sharded_tensor.rand(spec, *input_size)
     reset_seed() if reset_seed else None
     new_st = op(st, **kwargs)
     local_shard = st.local_tensor()
     new_st_local_shard = new_st.local_tensor()
     reset_seed() if reset_seed else None
     self.assertEqual(
         op(local_shard, **kwargs),
         new_st_local_shard,
     )
Beispiel #20
0
    def test_inplace_copy(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        st = sharded_tensor.rand(spec, (12, 5))
        ones_st = sharded_tensor.ones(spec, (12, 5))
        self.assertFalse(torch.equal(ones_st, st))
        st.copy_(ones_st)
        self.assertTrue(torch.equal(st, ones_st))

        # no grad inplace_copy should work between two with different requires_grad
        st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
        self.assertTrue(st_with_grad.requires_grad)
        self.assertFalse(ones_st.requires_grad)
        with torch.no_grad():
            st_with_grad.copy_(ones_st)
            self.assertEqual(st_with_grad.local_tensor(),
                             ones_st.local_tensor())
Beispiel #21
0
 def test_clone(self):
     spec = ChunkShardingSpec(
         dim=0,
         placements=[
             "rank:0/cuda:0",
             "rank:1/cuda:1",
             "rank:2/cuda:2",
             "rank:3/cuda:3",
         ],
     )
     st = sharded_tensor.rand(spec, (12, 5))
     copied_st = st.clone()
     self.assertTrue(type(copied_st) is type(st))
     self.assertEqual(copied_st.local_tensor(), st.local_tensor())
     self.assertFalse(copied_st is st)
Beispiel #22
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (10, 10, )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertTrue(MetadataIndex('bytes') in metadata.storage_data)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('replicated', torch.Size([0]))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]
            local_shard = state_dict["sharded"].local_shards()[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)
Beispiel #23
0
    def test_load_error_handling(self) -> None:
        state_dict = {
            'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
            'replicated': torch.rand(10, 10),
            'bytes': [1, 2, 3, 4]
        }

        self._test_load(state_dict)
        self._test_load(state_dict, fail_read_metadata=[0])
        self._test_load(state_dict, fail_read_bytes=[1])
        self._test_load(state_dict, fail_read_bytes_async=[2])
        self._test_load(state_dict, fail_read_tensors=[3])
        self._test_load(state_dict, fail_read_tensors_async=[1])
        # We don't want to depend on the actual exception raised by pickle
        self._test_load(state_dict, fail_deser_bytes=[2], ignore_exception_type=True)

        self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
        self._test_load(state_dict, coordinator=2, fail_read_bytes=[0])
        self._test_load(state_dict, coordinator=3, fail_read_tensors_async=[2])
Beispiel #24
0
    def test_save_error_handling(self) -> None:
        state_dict = {
            'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
            'replicated': torch.rand(10, 10),
            'bytes': [1, 2, 3, 4]
        }

        self._test_save(state_dict, fail_prepare=[0])
        self._test_save(state_dict, fail_finish=[0])

        self._test_save(state_dict, fail_prepare_storage=[0])
        self._test_save(state_dict, fail_write_tensors_on_ranks=[1])
        self._test_save(state_dict, fail_write_tensors_on_ranks_async=[2])
        self._test_save(state_dict, fail_write_bytes_on_ranks=[3])
        self._test_save(state_dict, fail_write_bytes_on_ranks_async=[1])

        self._test_save(state_dict, fail_write_tensors_on_ranks_async=[1, 3])

        self._test_save(state_dict, coordinator=1, fail_prepare=[1])
        self._test_save(state_dict, coordinator=1, fail_finish=[1])
Beispiel #25
0
 def test_sharded_tensor_view_error(self):
     for spec in _chunk_sharding_specs_list_for_test([2], seed=7):
         st = sharded_tensor.rand(
             spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
         )
         with self.assertRaisesRegex(
             NotImplementedError,
             "Shape having dim 2 is not supported "
             "for sharded tensor sharded on dim 2.",
         ):
             st.view(35 * 17, 26)
         with self.assertRaisesRegex(
             ValueError,
             r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.",
         ):
             st.view(5, 7, 35, 17, 26)
         with self.assertRaisesRegex(
             ValueError,
             "Only one dimension can be inferred for sharded view op.",
         ):
             st.view(5, 7, -1, -1)
Beispiel #26
0
    def test_detach(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        st = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
        local_shards = st.local_shards()
        # before set requires_grad, all local shards should not require grads
        for local_shard in local_shards:
            self.assertTrue(local_shard.tensor.requires_grad)

        detached_st = st.detach()
        self.assertFalse(detached_st.requires_grad)

        for local_shard in detached_st.local_shards():
            self.assertFalse(local_shard.tensor.requires_grad)
Beispiel #27
0
 def test_sharded_tensor_masked_fill_error(self):
     specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7)
     for spec in specs:
         st = sharded_tensor.rand(
             spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
         )
         mask = (
             torch.randint(0, 2, (2, 35, 17, 26))
             .type(torch.BoolTensor)
             .cuda(self.rank)
         )
         with self.assertRaisesRegex(
             ValueError,
             "mask dim must not greater than the dim of the sharded tensor.",
         ):
             st.masked_fill(mask, 25.0)
         mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank)
         with self.assertRaisesRegex(
             ValueError,
             "The size of mask 0 must match the size of sharded tensor 1 "
             "at non-singleton dimension 0",
         ):
             st.masked_fill(mask, 25.0)
    def test_switch_between_sharded_tensor_to_tensor(self) -> None:
        path = self.get_file_path()
        tensor_size = 32

        specs = [
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[8],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8],
                    shard_sizes=[tensor_size - 8],
                    placement="rank:0",
                ),
            ]),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[10],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[10],
                    shard_sizes=[tensor_size - 10],
                    placement="rank:1",
                ),
            ]),
        ]

        for save_spec in specs:
            for load_spec in specs:
                save_dict = {
                    'sharded':
                    sharded_tensor.rand(save_spec, tensor_size),
                    'replicated':
                    torch.rand(tensor_size, device=f"cpu:{self.rank}")
                }

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)

                # Freaky Friday the tensors
                load_dict = {
                    'sharded': torch.zeros(tensor_size,
                                           device=f"cpu:{self.rank}"),
                    'replicated': sharded_tensor.zeros(load_spec, tensor_size)
                }

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)

                save_dict_sharded = self.load_tensor(save_dict['sharded'])
                load_dict_replicated = self.load_tensor(
                    load_dict['replicated'])

                if dist.get_rank() == 0:
                    self.assertTrue(
                        torch.allclose(save_dict_sharded,
                                       load_dict['sharded']),
                        f"save-spec {save_spec} load-spec {load_spec}")
                    self.assertTrue(
                        torch.allclose(save_dict['replicated'],
                                       load_dict_replicated),
                        f"save-spec {save_spec} load-spec {load_spec}")