Beispiel #1
0
 def _load():
     metadata, _, _ = _prepare(state_dict, write_replicated)
     load_state_dict(
         state_dict,
         storage_reader=FaultyStorageReader(metadata, kwargs),
         coordinator_rank=coordinator,
         no_dist=no_dist,
     )
    def test_load_rowwise_to_colwise(self) -> None:
        path = self.get_file_path()
        self.assertEqual(self.world_size, dist.get_world_size())

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        src_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        dst_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        if dist.get_rank() == 0:
            shutil.rmtree(path, ignore_errors=True)
            os.makedirs(path)

        model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
        model_to_save._register_state_dict_hook(state_dict_hook)
        state_dict_to_save = model_to_save.state_dict()

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
        model_to_load._register_state_dict_hook(state_dict_hook)
        state_dict_to_load_to = model_to_load.state_dict()

        fs_reader = FileSystemReader(path=path)

        load_state_dict(state_dict=state_dict_to_load_to,
                        storage_reader=fs_reader)

        # We can't use torch.allclose since each ST has a different sharding spec
        store_tensor = self.load_tensor(model_to_save.sharded_tensor)
        load_tensor = self.load_tensor(model_to_load.sharded_tensor)

        if dist.get_rank() == 0:
            self.assertTrue(torch.allclose(store_tensor, load_tensor))
    def test_read_write_shard_tensor(self) -> None:
        paths = [tempfile.mkdtemp()]
        dist.broadcast_object_list(paths)

        path = paths[0]

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        model_to_save = MyShardedModel1(spec, init_rrefs=False)

        # Test save
        model_to_save._register_state_dict_hook(state_dict_hook)
        state_dict_to_save = model_to_save.state_dict()

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        dist.barrier()

        # Create a new model
        model_to_load = MyShardedModel1(spec, init_rrefs=False)
        # This is not the correct hook for loading the state dict
        # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
        model_to_load._register_state_dict_hook(state_dict_hook)
        state_dict_to_load_to = model_to_load.state_dict()

        dist.barrier()

        with self.assertRaises(AssertionError):
            assert_state_dict_equal(self, state_dict_to_load_to,
                                    state_dict_to_save)

        # Test load.
        fs_reader = FileSystemReader(path=path)
        load_state_dict(state_dict=state_dict_to_load_to,
                        storage_reader=fs_reader)

        assert_state_dict_equal(self, state_dict_to_load_to,
                                state_dict_to_save)
        dist.barrier()
    def test_save_load_bytes(self) -> None:
        path = self.get_file_path()

        state_dict_to_save = {'bytes0': [1], 'bytes1': 'string'}

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        state_dict_to_load = {'bytes0': [2], 'bytes1': 'other'}

        fs_reader = FileSystemReader(path=path)
        load_state_dict(state_dict=state_dict_to_load,
                        storage_reader=fs_reader)

        self.assertEqual([1], state_dict_to_load['bytes0'])
        self.assertEqual('string', state_dict_to_load['bytes1'])
    def test_read_write_only_tensor(self) -> None:
        with tempfile.TemporaryDirectory() as path:
            state_dict_to_save = MyTestModule().state_dict()

            fs_writer = FileSystemWriter(path=path)
            save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True)

            state_dict_to_load_to = MyTestModule().state_dict()

            with self.assertRaises(AssertionError):
                assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)

            # Load from file without any resharding
            fs_reader = FileSystemReader(path=path)
            load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True)

            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
Beispiel #6
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)
    def test_switch_between_sharded_tensor_to_tensor(self) -> None:
        path = self.get_file_path()
        tensor_size = 32

        specs = [
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[8],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8],
                    shard_sizes=[tensor_size - 8],
                    placement="rank:0",
                ),
            ]),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[10],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[10],
                    shard_sizes=[tensor_size - 10],
                    placement="rank:1",
                ),
            ]),
        ]

        for save_spec in specs:
            for load_spec in specs:
                save_dict = {
                    'sharded':
                    sharded_tensor.rand(save_spec, tensor_size),
                    'replicated':
                    torch.rand(tensor_size, device=f"cpu:{self.rank}")
                }

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)

                # Freaky Friday the tensors
                load_dict = {
                    'sharded': torch.zeros(tensor_size,
                                           device=f"cpu:{self.rank}"),
                    'replicated': sharded_tensor.zeros(load_spec, tensor_size)
                }

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)

                save_dict_sharded = self.load_tensor(save_dict['sharded'])
                load_dict_replicated = self.load_tensor(
                    load_dict['replicated'])

                if dist.get_rank() == 0:
                    self.assertTrue(
                        torch.allclose(save_dict_sharded,
                                       load_dict['sharded']),
                        f"save-spec {save_spec} load-spec {load_spec}")
                    self.assertTrue(
                        torch.allclose(save_dict['replicated'],
                                       load_dict_replicated),
                        f"save-spec {save_spec} load-spec {load_spec}")
    def test_load_with_different_shard_plan(self) -> None:
        path = self.get_file_path()

        # We hardcode the assumption of how many shards are around
        self.assertEqual(self.world_size, dist.get_world_size())

        specs = [
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[2, 0],
                    shard_sizes=[1, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[3, 0],
                    shard_sizes=[3, 20],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[6, 0],
                    shard_sizes=[3, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[9, 0],
                    shard_sizes=[1, 20],
                    placement="rank:0",
                ),
            ]),
            # This requires the tensors to be [10, 20]
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[8, 20],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8, 0],
                    shard_sizes=[2, 20],
                    placement="rank:0",
                ),
            ]),
        ]

        for s0 in specs:
            for s1 in specs:
                if s0 == s1:
                    continue

                if dist.get_rank() == 0:
                    shutil.rmtree(path, ignore_errors=True)
                    os.makedirs(path)
                dist.barrier()

                model_to_save = MyShardedModel3(s0)
                model_to_save._register_state_dict_hook(state_dict_hook)
                state_dict_to_save = model_to_save.state_dict()

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=state_dict_to_save,
                                storage_writer=fs_writer)

                dist.barrier()

                model_to_load = MyShardedModel3(s1)
                model_to_load._register_state_dict_hook(state_dict_hook)
                state_dict_to_load_to = model_to_load.state_dict()
                dist.barrier()

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=state_dict_to_load_to,
                                storage_reader=fs_reader)

                dist.barrier()
                store_tensor = self.load_tensor(model_to_save.sharded_tensor)
                dist.barrier()
                load_tensor = self.load_tensor(model_to_load.sharded_tensor)

                if dist.get_rank() == 0:
                    self.assertTrue(torch.allclose(store_tensor, load_tensor),
                                    msg=f"{s0} vs {s1}")