def test_fit_train(ray_start_4_cpus): def train_func(): train.report(loss=1) trainer = DataParallelTrainer(train_loop_per_worker=train_func, scaling_config=scale_config) assert trainer.fit().metrics["loss"] == 1
def test_datasets(ray_start_4_cpus): num_train_data = 10 num_val_data = 6 train_dataset = ray.data.range(num_train_data) val_dataset = ray.data.range(num_val_data) def get_dataset(): # Train dataset should be sharded. train_dataset = train.get_dataset_shard("train") assert train_dataset.count( ) == num_train_data / scale_config["num_workers"] # All other datasets should not be sharded. val_dataset = train.get_dataset_shard("val") assert val_dataset.count() == num_val_data trainer = DataParallelTrainer( train_loop_per_worker=get_dataset, scaling_config=scale_config, datasets={ "train": train_dataset, "val": val_dataset }, ) trainer.fit()
def test_run(ray_start_4_cpus): """Tests that Train can be run without any specific backends.""" num_workers = 2 key = "value" value = 1 config = TestConfig() def train_func(): checkpoint = session.get_checkpoint() session.report(metrics=checkpoint.to_dict(), checkpoint=checkpoint) return checkpoint.to_dict()[key] checkpoint = Checkpoint.from_dict({ # this would be set during checkpoint saving "_current_checkpoint_id": 1, key: value, }) trainer = DataParallelTrainer( train_func, backend_config=config, resume_from_checkpoint=checkpoint, scaling_config=ScalingConfig(num_workers=num_workers), ) results = trainer.fit() assert results.checkpoint.to_dict()[key] == checkpoint.to_dict()[key]
def test_checkpoint(ray_start_4_cpus): def train_func(): for i in range(3): train.save_checkpoint(model=i) trainer = DataParallelTrainer(train_loop_per_worker=train_func, scaling_config=scale_config) result = trainer.fit() assert result.checkpoint.to_dict()["model"] == 2
def test_scaling_config(ray_start_4_cpus): def train_func(): assert ray.available_resources()["CPU"] == 1 train.report(loss=1) assert ray.available_resources()["CPU"] == 4 trainer = DataParallelTrainer(train_loop_per_worker=train_func, scaling_config={"num_workers": 2}) trainer.fit()
def test_fit_train_config(ray_start_4_cpus): def train_func(config): session.report({"loss": config["x"]}) trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=scale_config, train_loop_config={"x": 100}, ) assert trainer.fit().metrics["loss"] == 100
def test_checkpoint(ray_start_4_cpus): def train_func(): for i in range(3): session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i})) trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=scale_config ) result = trainer.fit() assert result.checkpoint.to_dict()["model"] == 2
def test_scaling_config(ray_start_4_cpus): def train_func(): assert ray.available_resources()["CPU"] == 1 session.report({"loss": 1}) assert ray.available_resources()["CPU"] == 4 trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=ScalingConfig(num_workers=2) ) trainer.fit()
def test_retry(ray_start_4_cpus): def train_func(): ckpt = session.get_checkpoint() restored = bool(ckpt) # Does a previous checkpoint exist? itr = 0 if ckpt: ckpt = ckpt.to_dict() itr = ckpt["iter"] + 1 for i in range(itr, 4): if i == 2 and not restored: raise Exception("try to fail me") session.report( dict(test=i, training_iteration=i), checkpoint=Checkpoint.from_dict(dict(iter=i)), ) trainer = DataParallelTrainer( train_func, backend_config=TestConfig(), scaling_config=ScalingConfig(num_workers=1), ) tuner = Tuner( trainer, run_config=RunConfig(failure_config=FailureConfig(max_failures=3)) ) analysis = tuner.fit()._experiment_analysis checkpoint_path = analysis.trials[0].checkpoint.dir_or_data checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() assert checkpoint["iter"] == 3 trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 4
def test_reuse_checkpoint(ray_start_4_cpus): def train_func(config): itr = 0 ckpt = session.get_checkpoint() if ckpt is not None: ckpt = ckpt.to_dict() itr = ckpt["iter"] + 1 for i in range(itr, config["max_iter"]): session.report( dict(test=i, training_iteration=i), checkpoint=Checkpoint.from_dict(dict(iter=i)), ) trainer = DataParallelTrainer( train_func, backend_config=TestConfig(), scaling_config=ScalingConfig(num_workers=1), ) tuner = Tuner( trainer, param_space={"train_loop_config": {"max_iter": 5}}, ) [trial] = tuner.fit()._experiment_analysis.trials checkpoint_path = trial.checkpoint.dir_or_data checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() assert checkpoint["iter"] == 4 tuner = Tuner( trainer, param_space={"train_loop_config": {"max_iter": 10}}, ).restore(trial.local_dir) analysis = tuner.fit()._experiment_analysis trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5
def test_bad_return_in_train_loop(ray_start_4_cpus): """Test to check if returns from train loop are discarded.""" # Simulates what happens with eg. torch models class FailOnUnpickle: def __reduce__(self): raise RuntimeError("Failing") def train_loop(config): session.report({"loss": 1}) return FailOnUnpickle() trainer = DataParallelTrainer( train_loop_per_worker=train_loop, scaling_config=scale_config ) # No exception should happen here trainer.fit()
def test_preprocessor_in_checkpoint(ray_start_4_cpus): class DummyPreprocessor(Preprocessor): def __init__(self): super().__init__() self.is_same = True def train_func(): for i in range(3): session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i})) trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=scale_config, preprocessor=DummyPreprocessor(), ) result = trainer.fit() assert result.checkpoint.to_dict()["model"] == 2 assert result.checkpoint.to_dict()[PREPROCESSOR_KEY].is_same
def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir): def train_func(): checkpoint = session.get_checkpoint() if checkpoint: epoch = checkpoint.to_dict()["epoch"] else: epoch = 0 for i in range(epoch, epoch + 2): session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"epoch": i})) trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=scale_config ) result = trainer.fit() assert result.checkpoint.to_dict()["epoch"] == 1 # Move checkpoint to a different directory. checkpoint_dict = result.checkpoint.to_dict() checkpoint = Checkpoint.from_dict(checkpoint_dict) checkpoint_path = checkpoint.to_directory(tmpdir) resume_from = Checkpoint.from_directory(checkpoint_path) trainer = DataParallelTrainer( train_loop_per_worker=train_func, scaling_config=scale_config, resume_from_checkpoint=resume_from, ) result = trainer.fit() assert result.checkpoint.to_dict()["epoch"] == 2
def test_tune_error(ray_start_4_cpus): def train_func(config): raise RuntimeError("Error in training function!") trainer = DataParallelTrainer( train_func, backend_config=TestConfig(), scaling_config=ScalingConfig(num_workers=1), ) tuner = Tuner( trainer, ) result_grid = tuner.fit() with pytest.raises(RuntimeError): raise result_grid[0].error
def test_tune(ray_start_4_cpus): def train_func(config): session.report({"loss": config["x"]}) trainer = DataParallelTrainer( train_loop_per_worker=train_func, train_loop_config={"x": 100}, scaling_config=scale_config, ) tuner = Tuner( trainer, param_space={"train_loop_config": {"x": tune.choice([200, 300])}}, tune_config=TuneConfig(num_samples=2), ) result_grid = tuner.fit() assert result_grid[0].metrics["loss"] in [200, 300] # Make sure original Trainer is not affected. assert trainer._train_loop_config["x"] == 100
def test_tune_checkpoint(ray_start_4_cpus): def train_func(): for i in range(9): session.report(dict(test=i)) session.report( dict(test=i + 1), checkpoint=Checkpoint.from_dict(dict(hello="world")) ) trainer = DataParallelTrainer( train_func, backend_config=TestConfig(), scaling_config=ScalingConfig(num_workers=1), ) tuner = Tuner( trainer, param_space={"train_loop_config": {"max_iter": 5}}, ) [trial] = tuner.fit()._experiment_analysis.trials checkpoint_path = trial.checkpoint.dir_or_data assert os.path.exists(checkpoint_path) checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict() assert checkpoint["hello"] == "world"
def test_trainable_name_is_overriden_data_parallel_trainer(ray_start_4_cpus): trainer = DataParallelTrainer(lambda x: x, scaling_config=ScalingConfig(num_workers=1)) _is_trainable_name_overriden(trainer)
def test_invalid_train_loop(ray_start_4_cpus): def train_loop(config, extra_arg): pass with pytest.raises(ValueError): DataParallelTrainer(train_loop_per_worker=train_loop)