Exemplo n.º 1
0
def session():
    def f():
        return 1

    init_session(training_func=f, world_rank=0, local_rank=0, world_size=1)
    yield get_session()
    shutdown_session()
Exemplo n.º 2
0
def test_encode_data():
    def train_func():
        save_checkpoint(epoch=0)
        report(epoch=1)

    def encode_checkpoint(checkpoint):
        checkpoint.update({"encoded": True})
        return checkpoint

    def validate_encoded(result_type: TrainingResultType):
        next = session.get_next()
        assert next.type is result_type
        assert next.data["encoded"] is True

    init_session(
        training_func=train_func,
        world_rank=0,
        local_rank=0,
        world_size=1,
        encode_data_fn=encode_checkpoint,
    )

    session = get_session()
    session.start()
    # Validate checkpoint is encoded.
    validate_encoded(TrainingResultType.CHECKPOINT)
    # Validate report is encoded.
    validate_encoded(TrainingResultType.REPORT)
    session.finish()
    shutdown_session()
Exemplo n.º 3
0
        def end_training():
            session = _get_session("finish_training")
            try:
                # session.finish raises any Exceptions from training.
                output = session.finish()
            finally:
                # Shutdown session even if session.finish() raises an
                # Exception.
                shutdown_session()

            return output
Exemplo n.º 4
0
def test_get_dataset_shard():
    dataset = ray.data.from_items([1, 2, 3])
    init_session(
        training_func=lambda: 1,
        world_rank=0,
        local_rank=0,
        world_size=1,
        dataset_shard=dataset,
    )
    assert get_dataset_shard() == dataset
    shutdown_session()
Exemplo n.º 5
0
def test_report():
    def train_func():
        for i in range(2):
            report(loss=i)

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    assert session.get_next().data["loss"] == 0
    assert session.get_next().data["loss"] == 1
    shutdown_session()
Exemplo n.º 6
0
def test_report_fail():
    def train_func():
        for i in range(2):
            report(i)
        return 1

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    assert session.get_next() is None
    with pytest.raises(TypeError):
        session.finish()
    shutdown_session()
Exemplo n.º 7
0
def test_load_checkpoint_after_save():
    def train_func():
        for i in range(2):
            save_checkpoint(epoch=i)
            checkpoint = load_checkpoint()
            assert checkpoint["epoch"] == i

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    for i in range(2):
        session.get_next()
    session.finish()
    shutdown_session()
Exemplo n.º 8
0
def test_checkpoint():
    def train_func():
        for i in range(2):
            save_checkpoint(epoch=i)

    def validate_zero(expected):
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data["epoch"] == expected

    init_session(training_func=train_func,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    validate_zero(0)
    validate_zero(1)
    session.finish()
    shutdown_session()

    def validate_nonzero():
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data == {}

    init_session(training_func=train_func,
                 world_rank=1,
                 local_rank=1,
                 world_size=1)
    session = get_session()
    session.start()
    validate_nonzero()
    validate_nonzero()
    session.finish()
    shutdown_session()
Exemplo n.º 9
0
def test_locking():
    """Tests that report pauses training until fetch_next or finish."""
    def train_1():
        import _thread

        _thread.interrupt_main()

    init_session(training_func=train_1,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    with pytest.raises(KeyboardInterrupt):
        session.start()
    shutdown_session()

    def train_2():
        for i in range(2):
            report(loss=i)
        train_1()

    init_session(training_func=train_2,
                 world_rank=0,
                 local_rank=0,
                 world_size=1)
    session = get_session()
    session.start()
    time.sleep(3)

    session.pause_reporting()
    # Releases session.continue_lock to resume the training thread.
    session.get_next()

    with pytest.raises(KeyboardInterrupt):
        session.finish()
    shutdown_session()
Exemplo n.º 10
0
def test_world_size(session):
    assert world_size() == 1
    shutdown_session()
    # Make sure default to 1.
    assert world_size() == 1
Exemplo n.º 11
0
def test_local_rank(session):
    assert local_rank() == 0
    shutdown_session()
    # Make sure default to 0.
    assert local_rank() == 0
Exemplo n.º 12
0
def test_world_rank(session):
    assert world_rank() == 0
    shutdown_session()
    # Make sure default to 0.
    assert world_rank() == 0
Exemplo n.º 13
0
def test_shutdown(session):
    shutdown_session()
    assert not get_session()