예제 #1
0
    def yield_checkpoint_model(
            self, wkld: workload.Workload,
            respond: workload.ResponseFunc) -> workload.Stream:
        start_time = _current_timestamp()

        # Only the chief container should checkpoint.
        if self.rendezvous_info.get_rank() == 0:
            with self.storage_mgr.store_path() as (storage_id, path):
                yield wkld, [pathlib.Path(path)], lambda _: None

                metadata = storage.StorageMetadata(
                    storage_id, storage.StorageManager._list_directory(path))

            logging.info("Saved trial to checkpoint {}".format(
                metadata.storage_id))
            self.tensorboard_mgr.sync()

            metadata.labels = {
                "experiment_id": str(wkld.experiment_id),
                "trial_id": str(wkld.trial_id),
                "step_id": str(wkld.step_id),
            }

            message = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metadata,
            }  # type: workload.Response
        else:
            message = workload.Skipped()
        respond(message)
예제 #2
0
def to_delete(request: Any,
              manager: storage.StorageManager) -> List[Dict[str, Any]]:
    metadata = []
    for _ in range(request.param):
        with manager.store_path() as (storage_id, path):
            storage_util.create_checkpoint(path)
            metadata.append(
                storage.StorageMetadata(storage_id,
                                        manager._list_directory(path)))

    assert len(os.listdir(manager._base_path)) == request.param
    return [simplejson.loads(util.json_encode(m)) for m in metadata]
예제 #3
0
        def _respond(checkpoint_info: workload.Response) -> None:
            checkpoint_info = cast(Dict[str, Any], checkpoint_info)
            metadata = storage.StorageMetadata(
                storage_id,
                storage.StorageManager._list_directory(path),
                checkpoint_info.get("framework", ""),
                checkpoint_info.get("format", ""),
            )

            logging.info("Saved trial to checkpoint {}".format(metadata.storage_id))
            self.tensorboard_mgr.sync()

            nonlocal message
            message = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metadata,
            }
예제 #4
0
def test_s3_lifecycle(manager: storage.S3StorageManager) -> None:
    assert len(os.listdir(manager._base_path)) == 0

    checkpoints = []
    for _ in range(5):
        with manager.store_path() as (storage_id, path):
            # Ensure no checkpoint directories exist yet.
            assert len(os.listdir(manager._base_path)) == 0
            util.create_checkpoint(path)
            metadata = storage.StorageMetadata(storage_id, manager._list_directory(path))
            checkpoints.append(metadata)
            assert set(metadata.resources) == set(util.EXPECTED_FILES.keys())

    for metadata in checkpoints:
        # Load every checkpoint:
        with manager.restore_path(metadata) as path:
            util.validate_checkpoint(path)
        manager.delete(metadata)
        with pytest.raises(KeyError):
            with manager.restore_path(metadata) as path:
                pass
예제 #5
0
def test_checkpoint_lifecycle(manager: storage.SharedFSStorageManager) -> None:
    assert len(os.listdir(manager._base_path)) == 0

    checkpoints = []
    for index in range(5):
        with manager.store_path() as (storage_id, path):
            # Ensure no checkpoint directories exist yet.
            assert len(os.listdir(manager._base_path)) == index
            util.create_checkpoint(path)
            metadata = storage.StorageMetadata(storage_id, manager._list_directory(path))
            checkpoints.append(metadata)
            assert set(metadata.resources) == set(util.EXPECTED_FILES.keys())

    assert len(os.listdir(manager._base_path)) == 5

    for index in reversed(range(5)):
        metadata = checkpoints[index]
        assert metadata.storage_id in os.listdir(manager._base_path)
        with manager.restore_path(metadata) as path:
            util.validate_checkpoint(path)
        manager.delete(metadata)
        assert metadata.storage_id not in os.listdir(manager._base_path)
        assert len(os.listdir(manager._base_path)) == index