Beispiel #1
0
def test_checkpoint():
    def train():
        for i in range(2):
            save_checkpoint(epoch=i)

    def validate_zero(expected):
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data["epoch"] == expected

    init_session(training_func=train, world_rank=0, local_rank=0)
    session = get_session()
    session.start()
    validate_zero(0)
    validate_zero(1)
    session.finish()
    shutdown_session()

    def validate_nonzero():
        next = session.get_next()
        assert next is not None
        assert next.type == TrainingResultType.CHECKPOINT
        assert next.data == {}

    init_session(training_func=train, world_rank=1, local_rank=1)
    session = get_session()
    session.start()
    validate_nonzero()
    validate_nonzero()
    session.finish()
    shutdown_session()

    with pytest.raises(ValueError):
        save_checkpoint(epoch=2)
Beispiel #2
0
def test_locking():
    """Tests that report pauses training until fetch_next or finish."""
    def train_1():
        import _thread
        _thread.interrupt_main()

    init_session(training_func=train_1, world_rank=0)
    session = get_session()
    with pytest.raises(KeyboardInterrupt):
        session.start()
    shutdown_session()

    def train_2():
        for i in range(2):
            report(loss=i)
        train_1()

    init_session(training_func=train_2, world_rank=0)
    session = get_session()
    session.start()
    time.sleep(3)

    with pytest.raises(KeyboardInterrupt):
        session.finish()
    shutdown_session()
Beispiel #3
0
def session():
    def f():
        return 1

    init_session(training_func=f, world_rank=0)
    yield get_session()
    shutdown_session()
Beispiel #4
0
        def pause_reporting():
            # Get the session for this worker.
            try:
                session = get_session()
            except ValueError:
                # Session is not initialized yet.
                raise SGDBackendError("`finish_training` has been called "
                                      "before `start_training`. Please call "
                                      "`start_training` before "
                                      "`finish_training`.")

            return session.pause_reporting()
Beispiel #5
0
def test_report_fail():
    def train():
        for i in range(2):
            report(i)
        return 1

    init_session(training_func=train, world_rank=0)
    session = get_session()
    session.start()
    assert session.get_next() is None
    with pytest.raises(TypeError):
        session.finish()
    shutdown_session()
Beispiel #6
0
def test_report():
    def train():
        for i in range(2):
            report(loss=i)

    init_session(training_func=train, world_rank=0)
    session = get_session()
    session.start()
    assert session.get_next()["loss"] == 0
    assert session.get_next()["loss"] == 1
    shutdown_session()

    with pytest.raises(ValueError):
        report(loss=2)
Beispiel #7
0
def test_load_checkpoint_after_save():
    def train():
        for i in range(2):
            save_checkpoint(epoch=i)
            checkpoint = load_checkpoint()
            assert checkpoint["epoch"] == i

    init_session(training_func=train, world_rank=0, local_rank=0)
    session = get_session()
    session.start()
    for i in range(2):
        session.get_next()
    session.finish()
    shutdown_session()
Beispiel #8
0
        def end_training():
            # Get the session for this worker.
            try:
                session = get_session()
            except ValueError:
                # Session is not initialized yet.
                raise SGDBackendError("`finish_training` has been called "
                                      "before `start_training`. Please call "
                                      "`start_training` before "
                                      "`finish_training`.")

            try:
                # session.finish raises any Exceptions from training.
                output = session.finish()
            finally:
                # Shutdown session even if session.finish() raises an
                # Exception.
                shutdown_session()

            return output
Beispiel #9
0
        def get_next():
            # Get the session for this worker.
            try:
                session = get_session()
            except ValueError:
                # Session is not initialized yet.
                raise SGDBackendError("`fetch_next_result` has been called "
                                      "before `start_training`. Please call "
                                      "`start_training` before "
                                      "`fetch_next_result`.")

            try:
                result = session.get_next()
            except RuntimeError:
                # Training thread has not been started yet.
                raise SGDBackendError("`fetch_next_result` has been called "
                                      "before `start_training`. Please call "
                                      "`start_training` before "
                                      "`fetch_next_result`.")

            return result
Beispiel #10
0
 def train_async():
     session = get_session()
     session.start()
Beispiel #11
0
def test_get_fail(session):
    shutdown_session()
    with pytest.raises(ValueError):
        get_session()