Beispiel #1
0
def save_checkpoint(**kwargs) -> None:
    """Checkpoints all keyword arguments to Train as restorable state.

    .. code-block:: python

        import time
        from ray import train

        def train_func():
            for iter in range(100):
                time.sleep(1)
                train.save_checkpoint(epoch=iter)

        trainer = Trainer(backend="torch")
        trainer.start()
        trainer.run(train_func)
        trainer.shutdown()

    Args:
        **kwargs: Any key value pair to be checkpointed by Train.
    """
    session = get_session()
    if session is None:
        _warn_session_misuse(save_checkpoint.__name__)
        return
    session.checkpoint(**kwargs)
Beispiel #2
0
def session():
    def f():
        return 1

    init_session(training_func=f, world_rank=0, local_rank=0, world_size=1)
    yield get_session()
    shutdown_session()
Beispiel #3
0
def report(**kwargs) -> None:
    """Reports all keyword arguments to Train as intermediate results.

    .. code-block:: python

        import time
        from ray import train

        def train_func():
            for iter in range(100):
                time.sleep(1)
                train.report(hello="world")

        trainer = Trainer(backend="torch")
        trainer.start()
        trainer.run(train_func)
        trainer.shutdown()

    Args:
        **kwargs: Any key value pair to be reported by Train.
            If callbacks are provided, they are executed on these
            intermediate results.
    """
    session = get_session()
    if session is None:
        _warn_session_misuse(report.__name__)
        return
    session._report_legacy(**kwargs)
Beispiel #4
0
def test_encode_data():
    def train_func():
        save_checkpoint(epoch=0)
        report(epoch=1)

    def encode_checkpoint(checkpoint):
        checkpoint.update({"encoded": True})
        return checkpoint

    def validate_encoded(result_type: TrainingResultType):
        next = session.get_next()
        assert next.type is result_type
        assert next.data["encoded"] is True

    init_session(
        training_func=train_func,
        world_rank=0,
        local_rank=0,
        world_size=1,
        encode_data_fn=encode_checkpoint,
    )

    session = get_session()
    session.start()
    # Validate checkpoint is encoded.
    validate_encoded(TrainingResultType.CHECKPOINT)
    # Validate report is encoded.
    validate_encoded(TrainingResultType.REPORT)
    session.finish()
    shutdown_session()
Beispiel #5
0
def load_checkpoint() -> Optional[Dict]:
    """Loads checkpoint data onto the worker.

    .. code-block:: python

        from ray import train

        def train_func():
            checkpoint = train.load_checkpoint()
            for iter in range(checkpoint["epoch"], 5):
                print(iter)

        trainer = Trainer(backend="torch")
        trainer.start()
        trainer.run(train_func, checkpoint={"epoch": 3})
        # 3
        # 4
        trainer.shutdown()

    Args:
        **kwargs: Any key value pair to be checkpointed by Train.
    Returns:
        The most recently saved checkpoint if ``train.save_checkpoint()``
        has been called. Otherwise, the checkpoint that the session was
        originally initialized with. ``None`` if neither exist.
    """
    session = get_session()
    if session is None:
        _warn_session_misuse(load_checkpoint.__name__)
        return
    return session.loaded_checkpoint
Beispiel #6
0
def get_dataset_shard(
    dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
    """Returns the Ray Dataset or DatasetPipeline shard for this worker.

    You should call ``to_torch()`` or ``to_tf()`` on this shard to convert
    it to the appropriate framework-specific Dataset.

    .. code-block:: python

        import ray
        from ray import train

        def train_func():
            model = Net()
            for iter in range(100):
                data_shard = train.get_dataset_shard().to_torch()
                model.train(data_shard)
            return model

        dataset = ray.data.read_csv("train.csv")
        dataset.filter(...).repeat().random_shuffle()

        trainer = Trainer(backend="torch")
        trainer.start()

        # Trainer will automatically handle sharding.
        train_model = trainer.run(train_func, dataset=dataset)
        trainer.shutdown()

    Args:
        dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
            specifies which dataset shard to return.

    Returns:
        The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker.
        If no dataset is passed into Trainer, then return None.
    """
    session = get_session()
    if session is None:
        _warn_session_misuse(get_dataset_shard.__name__)
        return
    shard = session.dataset_shard
    if shard is None:
        warnings.warn("No dataset passed in. Returning None. Make sure to "
                      "pass in a Ray Dataset to Trainer.run to use this "
                      "function.")
    elif isinstance(shard, dict):
        if not dataset_name:
            raise RuntimeError(
                "Multiple datasets were passed into ``Trainer``, "
                "but no ``dataset_name`` is passed into "
                "``get_dataset_shard``. Please specify which "
                "dataset shard to retrieve.")
        return shard.get(dataset_name)
    return shard
Beispiel #7
0
def _get_session(method_name: str):
    # Get the session for this worker.
    session = get_session()
    if not session:
        # Session is not initialized yet.
        raise TrainBackendError(f"`{method_name}` has been called "
                                "before `start_training`. Please call "
                                "`start_training` before "
                                f"`{method_name}`.")
    return session
Beispiel #8
0
def test_checkpoint():
    def train_func():
        for i in range(2):
            save_checkpoint(epoch=i)

    def validate_zero(expected):
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data["epoch"] == expected

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    validate_zero(0)
    validate_zero(1)
    session.finish()
    shutdown_session()

    def validate_nonzero():
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data == {}

    init_session(training_func=train_func,
                 world_rank=1,
                 local_rank=1,
                 world_size=1)
    session = get_session()
    session.start()
    validate_nonzero()
    validate_nonzero()
    session.finish()
    shutdown_session()
Beispiel #9
0
def test_report():
    def train_func():
        for i in range(2):
            report(loss=i)

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    assert session.get_next().data["loss"] == 0
    assert session.get_next().data["loss"] == 1
    shutdown_session()
Beispiel #10
0
def test_locking():
    """Tests that report pauses training until fetch_next or finish."""
    def train_1():
        import _thread

        _thread.interrupt_main()

    init_session(training_func=train_1,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    with pytest.raises(KeyboardInterrupt):
        session.start()
    shutdown_session()

    def train_2():
        for i in range(2):
            report(loss=i)
        train_1()

    init_session(training_func=train_2,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    time.sleep(3)

    session.pause_reporting()
    # Releases session.continue_lock to resume the training thread.
    session.get_next()

    with pytest.raises(KeyboardInterrupt):
        session.finish()
    shutdown_session()
Beispiel #11
0
def test_report_fail():
    def train_func():
        for i in range(2):
            report(i)
        return 1

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    assert session.get_next() is None
    with pytest.raises(TypeError):
        session.finish()
    shutdown_session()
Beispiel #12
0
def test_load_checkpoint_after_save():
    def train_func():
        for i in range(2):
            save_checkpoint(epoch=i)
            checkpoint = load_checkpoint()
            assert checkpoint["epoch"] == i

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    for i in range(2):
        session.get_next()
    session.finish()
    shutdown_session()
Beispiel #13
0
def world_size() -> int:
    """Get the current world size (i.e. total number of workers) for this run.

    .. code-block:: python

        import time
        from ray import train

        def train_func():
            assert train.world_size() == 4

        trainer = Trainer(backend="torch", num_workers=4)
        trainer.start()
        trainer.run(train_func)
        trainer.shutdown()
    """
    session = get_session()
    if session is None:
        return 1
    return session.world_size
Beispiel #14
0
def local_rank() -> int:
    """Get the local rank of this worker (rank of the worker on its node).

    .. code-block:: python

        import time
        from ray import train

        def train_func():
            if torch.cuda.is_available():
                torch.cuda.set_device(train.local_rank())
            ...

        trainer = Trainer(backend="torch", use_gpu=True)
        trainer.start()
        trainer.run(train_func)
        trainer.shutdown()

    """
    session = get_session()
    if session is None:
        return 0
    return session.local_rank
Beispiel #15
0
def world_rank() -> int:
    """Get the world rank of this worker.

    .. code-block:: python

        import time
        from ray import train

        def train_func():
            for iter in range(100):
                time.sleep(1)
                if train.world_rank() == 0:
                    print("Worker 0")

        trainer = Trainer(backend="torch")
        trainer.start()
        trainer.run(train_func)
        trainer.shutdown()

    """
    session = get_session()
    if session is None:
        return 0
    return session.world_rank
Beispiel #16
0
 def train_async():
     session = get_session()
     session.start()
Beispiel #17
0
def test_shutdown(session):
    shutdown_session()
    assert not get_session()