def save_checkpoint(**kwargs) -> None: """Checkpoints all keyword arguments to Train as restorable state. .. code-block:: python import time from ray import train def train_func(): for iter in range(100): time.sleep(1) train.save_checkpoint(epoch=iter) trainer = Trainer(backend="torch") trainer.start() trainer.run(train_func) trainer.shutdown() Args: **kwargs: Any key value pair to be checkpointed by Train. """ session = get_session() if session is None: _warn_session_misuse(save_checkpoint.__name__) return session.checkpoint(**kwargs)
def session(): def f(): return 1 init_session(training_func=f, world_rank=0, local_rank=0, world_size=1) yield get_session() shutdown_session()
def report(**kwargs) -> None: """Reports all keyword arguments to Train as intermediate results. .. code-block:: python import time from ray import train def train_func(): for iter in range(100): time.sleep(1) train.report(hello="world") trainer = Trainer(backend="torch") trainer.start() trainer.run(train_func) trainer.shutdown() Args: **kwargs: Any key value pair to be reported by Train. If callbacks are provided, they are executed on these intermediate results. """ session = get_session() if session is None: _warn_session_misuse(report.__name__) return session._report_legacy(**kwargs)
def test_encode_data(): def train_func(): save_checkpoint(epoch=0) report(epoch=1) def encode_checkpoint(checkpoint): checkpoint.update({"encoded": True}) return checkpoint def validate_encoded(result_type: TrainingResultType): next = session.get_next() assert next.type is result_type assert next.data["encoded"] is True init_session( training_func=train_func, world_rank=0, local_rank=0, world_size=1, encode_data_fn=encode_checkpoint, ) session = get_session() session.start() # Validate checkpoint is encoded. validate_encoded(TrainingResultType.CHECKPOINT) # Validate report is encoded. validate_encoded(TrainingResultType.REPORT) session.finish() shutdown_session()
def load_checkpoint() -> Optional[Dict]: """Loads checkpoint data onto the worker. .. code-block:: python from ray import train def train_func(): checkpoint = train.load_checkpoint() for iter in range(checkpoint["epoch"], 5): print(iter) trainer = Trainer(backend="torch") trainer.start() trainer.run(train_func, checkpoint={"epoch": 3}) # 3 # 4 trainer.shutdown() Args: **kwargs: Any key value pair to be checkpointed by Train. Returns: The most recently saved checkpoint if ``train.save_checkpoint()`` has been called. Otherwise, the checkpoint that the session was originally initialized with. ``None`` if neither exist. """ session = get_session() if session is None: _warn_session_misuse(load_checkpoint.__name__) return return session.loaded_checkpoint
def get_dataset_shard( dataset_name: Optional[str] = None, ) -> Optional[Union["Dataset", "DatasetPipeline"]]: """Returns the Ray Dataset or DatasetPipeline shard for this worker. You should call ``to_torch()`` or ``to_tf()`` on this shard to convert it to the appropriate framework-specific Dataset. .. code-block:: python import ray from ray import train def train_func(): model = Net() for iter in range(100): data_shard = train.get_dataset_shard().to_torch() model.train(data_shard) return model dataset = ray.data.read_csv("train.csv") dataset.filter(...).repeat().random_shuffle() trainer = Trainer(backend="torch") trainer.start() # Trainer will automatically handle sharding. train_model = trainer.run(train_func, dataset=dataset) trainer.shutdown() Args: dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then specifies which dataset shard to return. Returns: The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker. If no dataset is passed into Trainer, then return None. """ session = get_session() if session is None: _warn_session_misuse(get_dataset_shard.__name__) return shard = session.dataset_shard if shard is None: warnings.warn("No dataset passed in. Returning None. Make sure to " "pass in a Ray Dataset to Trainer.run to use this " "function.") elif isinstance(shard, dict): if not dataset_name: raise RuntimeError( "Multiple datasets were passed into ``Trainer``, " "but no ``dataset_name`` is passed into " "``get_dataset_shard``. Please specify which " "dataset shard to retrieve.") return shard.get(dataset_name) return shard
def _get_session(method_name: str): # Get the session for this worker. session = get_session() if not session: # Session is not initialized yet. raise TrainBackendError(f"`{method_name}` has been called " "before `start_training`. Please call " "`start_training` before " f"`{method_name}`.") return session
def test_checkpoint(): def train_func(): for i in range(2): save_checkpoint(epoch=i) def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() validate_zero(0) validate_zero(1) session.finish() shutdown_session() def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT assert next.data == {} init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1) session = get_session() session.start() validate_nonzero() validate_nonzero() session.finish() shutdown_session()
def test_report(): def train_func(): for i in range(2): report(loss=i) init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() assert session.get_next().data["loss"] == 0 assert session.get_next().data["loss"] == 1 shutdown_session()
def test_locking(): """Tests that report pauses training until fetch_next or finish.""" def train_1(): import _thread _thread.interrupt_main() init_session(training_func=train_1, world_rank=0, local_rank=0, world_size=1) session = get_session() with pytest.raises(KeyboardInterrupt): session.start() shutdown_session() def train_2(): for i in range(2): report(loss=i) train_1() init_session(training_func=train_2, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() time.sleep(3) session.pause_reporting() # Releases session.continue_lock to resume the training thread. session.get_next() with pytest.raises(KeyboardInterrupt): session.finish() shutdown_session()
def test_report_fail(): def train_func(): for i in range(2): report(i) return 1 init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() assert session.get_next() is None with pytest.raises(TypeError): session.finish() shutdown_session()
def test_load_checkpoint_after_save(): def train_func(): for i in range(2): save_checkpoint(epoch=i) checkpoint = load_checkpoint() assert checkpoint["epoch"] == i init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() for i in range(2): session.get_next() session.finish() shutdown_session()
def world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. .. code-block:: python import time from ray import train def train_func(): assert train.world_size() == 4 trainer = Trainer(backend="torch", num_workers=4) trainer.start() trainer.run(train_func) trainer.shutdown() """ session = get_session() if session is None: return 1 return session.world_size
def local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). .. code-block:: python import time from ray import train def train_func(): if torch.cuda.is_available(): torch.cuda.set_device(train.local_rank()) ... trainer = Trainer(backend="torch", use_gpu=True) trainer.start() trainer.run(train_func) trainer.shutdown() """ session = get_session() if session is None: return 0 return session.local_rank
def world_rank() -> int: """Get the world rank of this worker. .. code-block:: python import time from ray import train def train_func(): for iter in range(100): time.sleep(1) if train.world_rank() == 0: print("Worker 0") trainer = Trainer(backend="torch") trainer.start() trainer.run(train_func) trainer.shutdown() """ session = get_session() if session is None: return 0 return session.world_rank
def train_async(): session = get_session() session.start()
def test_shutdown(session): shutdown_session() assert not get_session()