def test_checkpoint(): def train_func(): for i in range(2): save_checkpoint(epoch=i) def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0) session = get_session() session.start() validate_zero(0) validate_zero(1) session.finish() shutdown_session() def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data == {} init_session(training_func=train_func, world_rank=1, local_rank=1) session = get_session() session.start() validate_nonzero() validate_nonzero() session.finish() shutdown_session() with pytest.raises(ValueError): save_checkpoint(epoch=2)
def test_locking(): """Tests that report pauses training until fetch_next or finish.""" def train_1(): import _thread _thread.interrupt_main() init_session(training_func=train_1, world_rank=0, local_rank=0, world_size=1) session = get_session() with pytest.raises(KeyboardInterrupt): session.start() shutdown_session() def train_2(): for i in range(2): report(loss=i) train_1() init_session(training_func=train_2, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() time.sleep(3) session.pause_reporting() # Releases session.continue_lock to resume the training thread. session.get_next() with pytest.raises(KeyboardInterrupt): session.finish() shutdown_session()
def test_encode_data(): def train_func(): save_checkpoint(epoch=0) report(epoch=1) def encode_checkpoint(checkpoint): checkpoint.update({"encoded": True}) return checkpoint def validate_encoded(result_type: TrainingResultType): next = session.get_next() assert next.type is result_type assert next.data["encoded"] is True init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1, encode_data_fn=encode_checkpoint) session = get_session() session.start() # Validate checkpoint is encoded. validate_encoded(TrainingResultType.CHECKPOINT) # Validate report is encoded. validate_encoded(TrainingResultType.REPORT) session.finish() shutdown_session()
def session(): def f(): return 1 init_session(training_func=f, world_rank=0, local_rank=0, world_size=1) yield get_session() shutdown_session()
def _get_session(method_name: str): # Get the session for this worker. session = get_session() if not session: # Session is not initialized yet. raise TrainBackendError(f"`{method_name}` has been called " "before `start_training`. Please call " "`start_training` before " f"`{method_name}`.") return session
def _get_session(method_name: str): try: # Get the session for this worker. return get_session() except ValueError: # Session is not initialized yet. raise TrainBackendError(f"`{method_name}` has been called " "before `start_training`. Please call " "`start_training` before " f"`{method_name}`.")
def pause_reporting(): # Get the session for this worker. try: session = get_session() except ValueError: # Session is not initialized yet. raise TrainBackendError("`finish_training` has been called " "before `start_training`. Please call " "`start_training` before " "`finish_training`.") return session.pause_reporting()
def test_report_fail(): def train_func(): for i in range(2): report(i) return 1 init_session(training_func=train_func, world_rank=0, local_rank=0) session = get_session() session.start() assert session.get_next() is None with pytest.raises(TypeError): session.finish() shutdown_session()
def test_report(): def train_func(): for i in range(2): report(loss=i) init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() assert session.get_next().data["loss"] == 0 assert session.get_next().data["loss"] == 1 shutdown_session()
def test_report(): def train_func(): for i in range(2): report(loss=i) init_session(training_func=train_func, world_rank=0, local_rank=0) session = get_session() session.start() assert session.get_next().data["loss"] == 0 assert session.get_next().data["loss"] == 1 shutdown_session() with pytest.raises(ValueError): report(loss=2)
def test_load_checkpoint_after_save(): def train_func(): for i in range(2): save_checkpoint(epoch=i) checkpoint = load_checkpoint() assert checkpoint["epoch"] == i init_session(training_func=train_func, world_rank=0, local_rank=0) session = get_session() session.start() for i in range(2): session.get_next() session.finish() shutdown_session()
def end_training(): # Get the session for this worker. try: session = get_session() except ValueError: # Session is not initialized yet. raise TrainBackendError("`finish_training` has been called " "before `start_training`. Please call " "`start_training` before " "`finish_training`.") try: # session.finish raises any Exceptions from training. output = session.finish() finally: # Shutdown session even if session.finish() raises an # Exception. shutdown_session() return output
def get_next(): # Get the session for this worker. try: session = get_session() except ValueError: # Session is not initialized yet. raise TrainBackendError("`fetch_next_result` has been called " "before `start_training`. Please call " "`start_training` before " "`fetch_next_result`.") try: result = session.get_next() except RuntimeError: # Training thread has not been started yet. raise TrainBackendError("`fetch_next_result` has been called " "before `start_training`. Please call " "`start_training` before " "`fetch_next_result`.") return result
def test_get_fail(session): shutdown_session() with pytest.raises(ValueError): get_session()
def train_async(): session = get_session() session.start()
def test_shutdown(session): shutdown_session() assert not get_session()