def testDurableTrainable(self): class TestTrain(DurableTrainable): def setup(self, config): self.state = {"hi": 1, "iter": 0} def step(self): self.state["iter"] += 1 return {"timesteps_this_iter": 1, "done": True} def save_checkpoint(self, path): return self.state def load_checkpoint(self, state): self.state = state sync_client = mock_storage_client() mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client" with patch(mock_get_client) as mock_get_cloud_sync_client: mock_get_cloud_sync_client.return_value = sync_client test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR) checkpoint_path = test_trainable.save() test_trainable.train() test_trainable.state["hi"] = 2 test_trainable.restore(checkpoint_path) self.assertEqual(test_trainable.state["hi"], 1) self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def _testDurableTrainable(self, trainable, function=False, cleanup=True): sync_client = mock_storage_client() mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client" with patch(mock_get_client) as mock_get_cloud_sync_client: mock_get_cloud_sync_client.return_value = sync_client test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR) result = test_trainable.train() self.assertEqual(result["metric"], 1) checkpoint_path = test_trainable.save() result = test_trainable.train() self.assertEqual(result["metric"], 2) result = test_trainable.train() self.assertEqual(result["metric"], 3) result = test_trainable.train() self.assertEqual(result["metric"], 4) if not function: test_trainable.state["hi"] = 2 test_trainable.restore(checkpoint_path) self.assertEqual(test_trainable.state["hi"], 1) else: # Cannot re-use function trainable, create new tune.session.shutdown() test_trainable = trainable( remote_checkpoint_dir=MOCK_REMOTE_DIR) test_trainable.restore(checkpoint_path) result = test_trainable.train() self.assertEqual(result["metric"], 2) if cleanup: self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id): """Tests that run_experiment restoring works on cluster shutdown.""" cluster = start_connected_cluster dirpath = str(tmpdir) use_default_sync = trainable_id == "__fake" from ray.tune.result import DEFAULT_RESULTS_DIR local_dir = DEFAULT_RESULTS_DIR upload_dir = None if use_default_sync else MOCK_REMOTE_DIR base_dict = dict( run=trainable_id, stop=dict(training_iteration=3), local_dir=local_dir, upload_dir=upload_dir, sync_to_driver=use_default_sync, ) exp1_args = base_dict exp2_args = dict(base_dict.items(), local_dir=dirpath, checkpoint_freq=1) exp3_args = dict(base_dict.items(), config=dict(mock_error=True)) exp4_args = dict(base_dict.items(), config=dict(mock_error=True), checkpoint_freq=1) all_experiments = { "exp1": exp1_args, "exp2": exp2_args, "exp3": exp3_args, "exp4": exp4_args } mock_get_client = "ray.tune.trial_runner.get_cloud_syncer" with patch(mock_get_client) as mock_get_cloud_syncer: mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client()) mock_get_cloud_syncer.return_value = mock_syncer tune.run_experiments(all_experiments, raise_on_failed_trial=False) ray.shutdown() cluster.shutdown() cluster = _start_new_cluster() trials = tune.run_experiments(all_experiments, resume=True, raise_on_failed_trial=False) assert len(trials) == 4 assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials) ray.shutdown() cluster.shutdown()
def mock_get_syncer_fn(local_dir, remote_dir, sync_function): client = mock_storage_client() return MockNodeSyncer(local_dir, remote_dir, client)
def _create_trial_syncer(self, trial: "Trial"): client = mock_storage_client() return MockNodeSyncer(trial.logdir, trial.logdir, client)