Example #1
0
    def testDurableTrainable(self):
        class TestTrain(DurableTrainable):
            def setup(self, config):
                self.state = {"hi": 1, "iter": 0}

            def step(self):
                self.state["iter"] += 1
                return {"timesteps_this_iter": 1, "done": True}

            def save_checkpoint(self, path):
                return self.state

            def load_checkpoint(self, state):
                self.state = state

        sync_client = mock_storage_client()
        mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
        with patch(mock_get_client) as mock_get_cloud_sync_client:
            mock_get_cloud_sync_client.return_value = sync_client
            test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
            checkpoint_path = test_trainable.save()
            test_trainable.train()
            test_trainable.state["hi"] = 2
            test_trainable.restore(checkpoint_path)
            self.assertEqual(test_trainable.state["hi"], 1)

        self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
Example #2
0
    def _testDurableTrainable(self, trainable, function=False, cleanup=True):
        sync_client = mock_storage_client()
        mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
        with patch(mock_get_client) as mock_get_cloud_sync_client:
            mock_get_cloud_sync_client.return_value = sync_client
            test_trainable = trainable(remote_checkpoint_dir=MOCK_REMOTE_DIR)
            result = test_trainable.train()
            self.assertEqual(result["metric"], 1)
            checkpoint_path = test_trainable.save()
            result = test_trainable.train()
            self.assertEqual(result["metric"], 2)
            result = test_trainable.train()
            self.assertEqual(result["metric"], 3)
            result = test_trainable.train()
            self.assertEqual(result["metric"], 4)

            if not function:
                test_trainable.state["hi"] = 2
                test_trainable.restore(checkpoint_path)
                self.assertEqual(test_trainable.state["hi"], 1)
            else:
                # Cannot re-use function trainable, create new
                tune.session.shutdown()
                test_trainable = trainable(
                    remote_checkpoint_dir=MOCK_REMOTE_DIR)
                test_trainable.restore(checkpoint_path)

            result = test_trainable.train()
            self.assertEqual(result["metric"], 2)

        if cleanup:
            self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
Example #3
0
def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
    """Tests that run_experiment restoring works on cluster shutdown."""
    cluster = start_connected_cluster
    dirpath = str(tmpdir)

    use_default_sync = trainable_id == "__fake"
    from ray.tune.result import DEFAULT_RESULTS_DIR
    local_dir = DEFAULT_RESULTS_DIR
    upload_dir = None if use_default_sync else MOCK_REMOTE_DIR

    base_dict = dict(
        run=trainable_id,
        stop=dict(training_iteration=3),
        local_dir=local_dir,
        upload_dir=upload_dir,
        sync_to_driver=use_default_sync,
    )

    exp1_args = base_dict
    exp2_args = dict(base_dict.items(), local_dir=dirpath, checkpoint_freq=1)
    exp3_args = dict(base_dict.items(), config=dict(mock_error=True))
    exp4_args = dict(base_dict.items(),
                     config=dict(mock_error=True),
                     checkpoint_freq=1)

    all_experiments = {
        "exp1": exp1_args,
        "exp2": exp2_args,
        "exp3": exp3_args,
        "exp4": exp4_args
    }

    mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
    with patch(mock_get_client) as mock_get_cloud_syncer:
        mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client())
        mock_get_cloud_syncer.return_value = mock_syncer

        tune.run_experiments(all_experiments, raise_on_failed_trial=False)

        ray.shutdown()
        cluster.shutdown()
        cluster = _start_new_cluster()

        trials = tune.run_experiments(all_experiments,
                                      resume=True,
                                      raise_on_failed_trial=False)

    assert len(trials) == 4
    assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
    ray.shutdown()
    cluster.shutdown()
Example #4
0
 def mock_get_syncer_fn(local_dir, remote_dir, sync_function):
     client = mock_storage_client()
     return MockNodeSyncer(local_dir, remote_dir, client)
Example #5
0
 def _create_trial_syncer(self, trial: "Trial"):
     client = mock_storage_client()
     return MockNodeSyncer(trial.logdir, trial.logdir, client)