Esempio n. 1
0
    def _testPauseAndStart(self, result_buffer_length):
        """Tests that unpausing works for trials being processed."""
        os.environ["TUNE_RESULT_BUFFER_LENGTH"] = f"{result_buffer_length}"
        os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1"

        # Need a new trial executor so the ENV vars are parsed again
        self.trial_executor = RayTrialExecutor()

        base = max(result_buffer_length, 1)

        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)
        self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)

        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)

        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)
        self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)
Esempio n. 2
0
    def test_overriding_default_resource_request(self):
        config = DEFAULT_CONFIG.copy()
        config["model"]["fcnet_hiddens"] = [10]
        config["num_workers"] = 2
        # 3 Trials: Can only run 2 at a time (num_cpus=6; needed: 3).
        config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
        config["env"] = "CartPole-v0"
        config["framework"] = "tf"

        # Create an Algorithm with an overridden default_resource_request
        # method that returns a PlacementGroupFactory.

        class MyAlgo(PG):
            @classmethod
            def default_resource_request(cls, config):
                head_bundle = {"CPU": 1, "GPU": 0}
                child_bundle = {"CPU": 1}
                return PlacementGroupFactory(
                    [head_bundle, child_bundle, child_bundle],
                    strategy=config["placement_strategy"],
                )

        tune.register_trainable("my_trainable", MyAlgo)

        global trial_executor
        trial_executor = RayTrialExecutor(reuse_actors=False)

        tune.run(
            "my_trainable",
            config=config,
            stop={"training_iteration": 2},
            trial_executor=trial_executor,
            callbacks=[_TestCallback()],
            verbose=2,
        )
Esempio n. 3
0
    def testCheckpointAtEndNotBuffered(self):
        os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "7"
        os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "0.5"

        def num_checkpoints(trial):
            return sum(
                item.startswith("checkpoint_")
                for item in os.listdir(trial.logdir))

        ray.init(num_cpus=2)

        trial = Trial(
            "__fake",
            checkpoint_at_end=True,
            stopping_criterion={"training_iteration": 4},
        )
        observer = TrialResultObserver()
        runner = TrialRunner(
            local_checkpoint_dir=self.tmpdir,
            checkpoint_period=0,
            trial_executor=RayTrialExecutor(result_buffer_length=7),
            callbacks=[observer],
        )
        runner.add_trial(trial)

        while not observer.just_received_a_result():
            runner.step()
        self.assertEqual(trial.last_result[TRAINING_ITERATION], 1)
        self.assertEqual(num_checkpoints(trial), 0)

        while True:
            runner.step()
            if observer.just_received_a_result():
                break
        self.assertEqual(trial.last_result[TRAINING_ITERATION], 2)
        self.assertEqual(num_checkpoints(trial), 0)

        while True:
            runner.step()
            if observer.just_received_a_result():
                break
        self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
        self.assertEqual(num_checkpoints(trial), 0)

        while True:
            runner.step()
            if observer.just_received_a_result():
                break
        self.assertEqual(trial.last_result[TRAINING_ITERATION], 4)

        while not runner.is_finished():
            runner.step()
        self.assertEqual(num_checkpoints(trial), 1)
Esempio n. 4
0
    def testExperimentTagTruncation(self):
        ray.init(num_cpus=2)

        def train(config, reporter):
            reporter(timesteps_total=1)

        trial_executor = RayTrialExecutor()
        register_trainable("f1", train)

        experiments = {
            "foo": {
                "run": "f1",
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40),
                },
            }
        }

        for name, spec in experiments.items():
            trial_generator = BasicVariantGenerator()
            trial_generator.add_configurations({name: spec})
            while not trial_generator.is_finished():
                trial = trial_generator.next_trial()
                if not trial:
                    break
                trial_executor.start_trial(trial)
                self.assertLessEqual(len(os.path.basename(trial.logdir)), 200)
                trial_executor.stop_trial(trial)
Esempio n. 5
0
    def test_default_resource_request(self):
        config = DEFAULT_CONFIG.copy()
        config["model"]["fcnet_hiddens"] = [10]
        config["num_workers"] = 2
        config["num_cpus_per_worker"] = 2
        # 3 Trials: Can only run 1 at a time (num_cpus=6; needed: 5).
        config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
        config["env"] = "CartPole-v0"
        config["framework"] = "torch"
        config["placement_strategy"] = "SPREAD"

        global trial_executor
        trial_executor = RayTrialExecutor(reuse_actors=False)

        tune.run(
            "PG",
            config=config,
            stop={"training_iteration": 2},
            trial_executor=trial_executor,
            callbacks=[_TestCallback()],
            verbose=2,
        )
Esempio n. 6
0
 def setUp(self):
     self.trial_executor = RayTrialExecutor()
     ray.init(num_cpus=2, ignore_reinit_error=True)
     _register_all()  # Needed for flaky tests
Esempio n. 7
0
class RayTrialExecutorTest(unittest.TestCase):
    def setUp(self):
        self.trial_executor = RayTrialExecutor()
        ray.init(num_cpus=2, ignore_reinit_error=True)
        _register_all()  # Needed for flaky tests

    def tearDown(self):
        ray.shutdown()
        _register_all()  # re-register the evicted objects

    def _simulate_starting_trial(self, trial):
        future_result = self.trial_executor.get_next_executor_event(
            live_trials={trial}, next_trial_exists=True
        )
        assert future_result.type == _ExecutorEventType.PG_READY
        self.assertTrue(self.trial_executor.start_trial(trial))
        self.assertEqual(Trial.RUNNING, trial.status)

    def _simulate_getting_result(self, trial):
        while True:
            event = self.trial_executor.get_next_executor_event(
                live_trials={trial}, next_trial_exists=False
            )
            if event.type == _ExecutorEventType.TRAINING_RESULT:
                break
        training_result = event.result[_ExecutorEvent.KEY_FUTURE_RESULT]
        if isinstance(training_result, list):
            for r in training_result:
                trial.update_last_result(r)
        else:
            trial.update_last_result(training_result)

    def _simulate_saving(self, trial):
        checkpoint = self.trial_executor.save(trial, CheckpointStorage.PERSISTENT)
        self.assertEqual(checkpoint, trial.saving_to)
        self.assertEqual(trial.checkpoint.dir_or_data, None)
        event = self.trial_executor.get_next_executor_event(
            live_trials={trial}, next_trial_exists=False
        )
        assert event.type == _ExecutorEventType.SAVING_RESULT
        self.process_trial_save(trial, event.result[_ExecutorEvent.KEY_FUTURE_RESULT])
        self.assertEqual(checkpoint, trial.checkpoint)

    def testStartStop(self):
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)
        self.trial_executor.stop_trial(trial)

    def testAsyncSave(self):
        """Tests that saved checkpoint value not immediately set."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)

        self._simulate_saving(trial)

        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def testSaveRestore(self):
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)

        self._simulate_saving(trial)

        self.trial_executor.restore(trial)
        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def testPauseResume(self):
        """Tests that pausing works for trials in flight."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)

        self._simulate_starting_trial(trial)

        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def testSavePauseResumeErrorRestore(self):
        """Tests that pause checkpoint does not replace restore checkpoint."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)

        # Save
        self._simulate_saving(trial)

        # Train
        self.trial_executor.continue_training(trial)
        self._simulate_getting_result(trial)

        # Pause
        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)
        self.assertEqual(trial.checkpoint.storage_mode, CheckpointStorage.MEMORY)

        # Resume
        self._simulate_starting_trial(trial)

        # Error
        trial.set_status(Trial.ERROR)

        # Restore
        self.trial_executor.restore(trial)

        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def testStartFailure(self):
        _global_registry.register(TRAINABLE_CLASS, "asdf", None)
        trial = Trial("asdf", resources=Resources(1, 0))
        self.trial_executor.start_trial(trial)
        self.assertEqual(Trial.ERROR, trial.status)

    def testPauseResume2(self):
        """Tests that pausing works for trials being processed."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)

        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)

        self._simulate_starting_trial(trial)
        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def _testPauseAndStart(self, result_buffer_length):
        """Tests that unpausing works for trials being processed."""
        os.environ["TUNE_RESULT_BUFFER_LENGTH"] = f"{result_buffer_length}"
        os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1"

        # Need a new trial executor so the ENV vars are parsed again
        self.trial_executor = RayTrialExecutor()

        base = max(result_buffer_length, 1)

        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)
        self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)

        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)

        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)
        self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2)
        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)

    def testPauseAndStartNoBuffer(self):
        self._testPauseAndStart(0)

    def testPauseAndStartTrivialBuffer(self):
        self._testPauseAndStart(1)

    def testPauseAndStartActualBuffer(self):
        self._testPauseAndStart(8)

    def testNoResetTrial(self):
        """Tests that reset handles NotImplemented properly."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)
        exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
        self.assertEqual(exists, False)
        self.assertEqual(Trial.RUNNING, trial.status)

    def testResetTrial(self):
        """Tests that reset works as expected."""

        class B(Trainable):
            def step(self):
                return dict(timesteps_this_iter=1, done=True)

            def reset_config(self, config):
                self.config = config
                return True

        trials = self.generate_trials(
            {
                "run": B,
                "config": {"foo": 0},
            },
            "grid_search",
        )
        trial = trials[0]
        self._simulate_starting_trial(trial)
        exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock")
        self.assertEqual(exists, True)
        self.assertEqual(trial.config.get("hi"), 1)
        self.assertEqual(trial.experiment_tag, "modified_mock")
        self.assertEqual(Trial.RUNNING, trial.status)

    def testTrialCleanup(self):
        class B(Trainable):
            def step(self):
                print("Step start")
                time.sleep(4)
                print("Step done")
                return dict(my_metric=1, timesteps_this_iter=1, done=True)

            def reset_config(self, config):
                self.config = config
                return True

            def cleanup(self):
                print("Cleanup start")
                time.sleep(4)
                print("Cleanup done")

        # First check if the trials terminate gracefully by default
        trials = self.generate_trials(
            {
                "run": B,
                "config": {"foo": 0},
            },
            "grid_search",
        )
        trial = trials[0]
        self._simulate_starting_trial(trial)
        time.sleep(1)
        print("Stop trial")
        self.trial_executor.stop_trial(trial)
        print("Start trial cleanup")
        start = time.time()
        self.trial_executor.cleanup([trial])
        # 4 - 1 + 4.
        self.assertGreaterEqual(time.time() - start, 6)

        # Check forceful termination. It should run for much less than the
        # sleep periods in the Trainable
        trials = self.generate_trials(
            {
                "run": B,
                "config": {"foo": 0},
            },
            "grid_search",
        )
        trial = trials[0]
        os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
        self.trial_executor = RayTrialExecutor()
        os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
        self._simulate_starting_trial(trial)
        self.assertEqual(Trial.RUNNING, trial.status)
        # This should be enough time for `trial._default_result_or_future`
        # to return. Otherwise, PID won't show up in `trial.last_result`,
        # which is asserted down below.
        time.sleep(2)
        print("Stop trial")
        self.trial_executor.stop_trial(trial)
        print("Start trial cleanup")
        start = time.time()
        self.trial_executor.cleanup([trial])
        # less than 1 with some margin.
        self.assertLess(time.time() - start, 2.0)

        # also check if auto-filled metrics were returned
        self.assertIn(PID, trial.last_result)
        self.assertIn(TRIAL_ID, trial.last_result)
        self.assertNotIn("my_metric", trial.last_result)

    @staticmethod
    def generate_trials(spec, name):
        suggester = BasicVariantGenerator()
        suggester.add_configurations({name: spec})
        trials = []
        while not suggester.is_finished():
            trial = suggester.next_trial()
            if trial:
                trials.append(trial)
            else:
                break
        return trials

    def process_trial_save(self, trial, checkpoint_value):
        """Simulates trial runner save."""
        checkpoint = trial.saving_to
        checkpoint.dir_or_data = checkpoint_value
        trial.on_checkpoint(checkpoint)
Esempio n. 8
0
 def setUp(self):
     ray.init(local_mode=True)
     self.trial_executor = RayTrialExecutor()
Esempio n. 9
0
    def testHasResourcesForTrialWithCaching(self):
        pgm = _PlacementGroupManager()
        pgf1 = PlacementGroupFactory([{"CPU": self.head_cpus}])
        pgf2 = PlacementGroupFactory([{"CPU": self.head_cpus - 1}])

        executor = RayTrialExecutor(reuse_actors=True)
        executor._pg_manager = pgm
        executor.set_max_pending_trials(1)

        def train(config):
            yield 1
            yield 2
            yield 3
            yield 4

        register_trainable("resettable", train)

        trial1 = Trial("resettable", placement_group_factory=pgf1)
        trial2 = Trial("resettable", placement_group_factory=pgf1)
        trial3 = Trial("resettable", placement_group_factory=pgf2)

        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert executor.has_resources_for_trial(trial3)

        executor._stage_and_update_status([trial1, trial2, trial3])

        while not pgm.has_ready(trial1):
            time.sleep(1)
            executor._stage_and_update_status([trial1, trial2, trial3])

        # Fill staging
        executor._stage_and_update_status([trial1, trial2, trial3])

        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert not executor.has_resources_for_trial(trial3)

        executor._start_trial(trial1)
        executor._stage_and_update_status([trial1, trial2, trial3])
        executor.pause_trial(trial1)  # Caches the PG and removes a PG from staging

        assert len(pgm._staging_futures) == 0

        # This will re-schedule a placement group
        pgm.reconcile_placement_groups([trial1, trial2])

        assert len(pgm._staging_futures) == 1

        assert not pgm.can_stage()

        # We should still have resources for this trial as it has a cached PG
        assert executor.has_resources_for_trial(trial1)
        assert executor.has_resources_for_trial(trial2)
        assert not executor.has_resources_for_trial(trial3)
Esempio n. 10
0
    def testTrialCleanup(self):
        class B(Trainable):
            def step(self):
                print("Step start")
                time.sleep(4)
                print("Step done")
                return dict(my_metric=1, timesteps_this_iter=1, done=True)

            def reset_config(self, config):
                self.config = config
                return True

            def cleanup(self):
                print("Cleanup start")
                time.sleep(4)
                print("Cleanup done")

        # First check if the trials terminate gracefully by default
        trials = self.generate_trials(
            {
                "run": B,
                "config": {"foo": 0},
            },
            "grid_search",
        )
        trial = trials[0]
        self._simulate_starting_trial(trial)
        time.sleep(1)
        print("Stop trial")
        self.trial_executor.stop_trial(trial)
        print("Start trial cleanup")
        start = time.time()
        self.trial_executor.cleanup([trial])
        # 4 - 1 + 4.
        self.assertGreaterEqual(time.time() - start, 6)

        # Check forceful termination. It should run for much less than the
        # sleep periods in the Trainable
        trials = self.generate_trials(
            {
                "run": B,
                "config": {"foo": 0},
            },
            "grid_search",
        )
        trial = trials[0]
        os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1"
        self.trial_executor = RayTrialExecutor()
        os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0"
        self._simulate_starting_trial(trial)
        self.assertEqual(Trial.RUNNING, trial.status)
        # This should be enough time for `trial._default_result_or_future`
        # to return. Otherwise, PID won't show up in `trial.last_result`,
        # which is asserted down below.
        time.sleep(2)
        print("Stop trial")
        self.trial_executor.stop_trial(trial)
        print("Start trial cleanup")
        start = time.time()
        self.trial_executor.cleanup([trial])
        # less than 1 with some margin.
        self.assertLess(time.time() - start, 2.0)

        # also check if auto-filled metrics were returned
        self.assertIn(PID, trial.last_result)
        self.assertIn(TRIAL_ID, trial.last_result)
        self.assertNotIn("my_metric", trial.last_result)
Esempio n. 11
0
    def testPlacementGroupRequests(self, reuse_actors=False, scheduled=10):
        """In this test we try to start 10 trials but only have resources
        for 2. Placement groups should still be created and PENDING.

        Eventually they should be scheduled sequentially (i.e. in pairs
        of two)."""
        # Since we check per-step placement groups, set the reconcilation
        # interval to 0
        os.environ["TUNE_PLACEMENT_GROUP_RECON_INTERVAL"] = "0"

        def train(config):
            time.sleep(1)
            now = time.time()
            tune.report(end=now - config["start_time"])

        head_bundle = {"CPU": 4, "GPU": 0, "custom": 0}
        child_bundle = {"custom": 1}
        # Manually calculated number of parallel trials
        max_num_parallel = 2

        placement_group_factory = PlacementGroupFactory(
            [head_bundle, child_bundle, child_bundle])

        trial_executor = RayTrialExecutor(reuse_actors=reuse_actors)

        this = self

        class _TestCallback(Callback):
            def on_step_end(self, iteration, trials, **info):
                num_finished = len([
                    t for t in trials
                    if t.status == Trial.TERMINATED or t.status == Trial.ERROR
                ])

                num_staging = sum(
                    len(s)
                    for s in trial_executor._pg_manager._staging.values())
                num_ready = sum(
                    len(s) for s in trial_executor._pg_manager._ready.values())
                num_in_use = len(trial_executor._pg_manager._in_use_pgs)
                num_cached = len(trial_executor._pg_manager._cached_pgs)

                total_num_tracked = num_staging + num_ready + num_in_use + num_cached

                # All trials should be scheduled
                this.assertEqual(
                    scheduled,
                    min(scheduled, len(trials)),
                    msg=f"Num trials iter {iteration}",
                )

                # The following two tests were relaxed for reuse_actors=True
                # so that up to `max_num_parallel` more placement groups can
                # exist than we would expect. This is because caching
                # relies on reconciliation for cleanup to avoid overscheduling
                # of new placement groups.
                num_parallel_reuse = int(reuse_actors) * max_num_parallel

                # The number of PGs should decrease when trials finish
                this.assertGreaterEqual(
                    max(scheduled, len(trials)) - num_finished +
                    num_parallel_reuse,
                    total_num_tracked,
                    msg=f"Num tracked iter {iteration}",
                )

        start = time.time()
        out = tune.run(
            train,
            config={"start_time": start},
            resources_per_trial=placement_group_factory,
            num_samples=10,
            trial_executor=trial_executor,
            callbacks=[_TestCallback()],
            reuse_actors=reuse_actors,
            verbose=2,
        )

        trial_end_times = sorted(t.last_result["end"] for t in out.trials)
        print("Trial end times:", trial_end_times)
        max_diff = trial_end_times[-1] - trial_end_times[0]

        # Not all trials have been run in parallel
        self.assertGreater(max_diff, 3)

        # Some trials should have run in parallel
        # Todo: Re-enable when using buildkite
        # self.assertLess(max_diff, 10)

        self._assertCleanup(trial_executor)
Esempio n. 12
0
    def testPlacementGroupDistributedTraining(self, reuse_actors=False):
        """Run distributed training using placement groups.

        Each trial requests 4 CPUs and starts 4 remote training workers.
        """

        head_bundle = {"CPU": 1, "GPU": 0, "custom": 0}
        child_bundle = {"CPU": 1}

        placement_group_factory = PlacementGroupFactory(
            [head_bundle, child_bundle, child_bundle, child_bundle])

        @ray.remote
        class TrainingActor:
            def train(self, val):
                time.sleep(1)
                return val

        def train(config):
            base = config["base"]
            actors = [TrainingActor.remote() for _ in range(4)]
            futures = [
                actor.train.remote(base + 2 * i)
                for i, actor in enumerate(actors)
            ]
            results = ray.get(futures)

            end = time.time() - config["start_time"]
            tune.report(avg=np.mean(results), end=end)

        trial_executor = RayTrialExecutor(reuse_actors=reuse_actors)

        start = time.time()
        out = tune.run(
            train,
            config={
                "start_time": start,
                "base": tune.grid_search(list(range(0, 100, 10))),
            },
            resources_per_trial=placement_group_factory,
            num_samples=1,
            trial_executor=trial_executor,
            reuse_actors=reuse_actors,
            verbose=2,
        )

        avgs = sorted(t.last_result["avg"] for t in out.trials)
        self.assertSequenceEqual(avgs, list(range(3, 103, 10)))

        trial_end_times = sorted(t.last_result["end"] for t in out.trials)
        print("Trial end times:", trial_end_times)
        max_diff = trial_end_times[-1] - trial_end_times[0]

        # Not all trials have been run in parallel
        self.assertGreater(max_diff, 3)

        # Some trials should have run in parallel
        # Todo: Re-enable when using buildkite
        # self.assertLess(max_diff, 10)

        self._assertCleanup(trial_executor)
Esempio n. 13
0
def run(
    run_or_experiment: Union[str, Callable, Type],
    name: Optional[str] = None,
    metric: Optional[str] = None,
    mode: Optional[str] = None,
    stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping],
                                                    bool]]] = None,
    time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None,
    config: Optional[Dict[str, Any]] = None,
    resources_per_trial: Union[None, Mapping[str, Union[float, int, Mapping]],
                               PlacementGroupFactory] = None,
    num_samples: int = 1,
    local_dir: Optional[str] = None,
    search_alg: Optional[Union[Searcher, SearchAlgorithm, str]] = None,
    scheduler: Optional[Union[TrialScheduler, str]] = None,
    keep_checkpoints_num: Optional[int] = None,
    checkpoint_score_attr: Optional[str] = None,
    checkpoint_freq: int = 0,
    checkpoint_at_end: bool = False,
    verbose: Union[int, Verbosity] = Verbosity.V3_TRIAL_DETAILS,
    progress_reporter: Optional[ProgressReporter] = None,
    log_to_file: bool = False,
    trial_name_creator: Optional[Callable[[Trial], str]] = None,
    trial_dirname_creator: Optional[Callable[[Trial], str]] = None,
    sync_config: Optional[SyncConfig] = None,
    export_formats: Optional[Sequence] = None,
    max_failures: int = 0,
    fail_fast: bool = False,
    restore: Optional[str] = None,
    server_port: Optional[int] = None,
    resume: Union[bool, str] = False,
    reuse_actors: Optional[bool] = None,
    trial_executor: Optional[RayTrialExecutor] = None,
    raise_on_failed_trial: bool = True,
    callbacks: Optional[Sequence[Callback]] = None,
    max_concurrent_trials: Optional[int] = None,
    # == internal only ==
    _experiment_checkpoint_dir: Optional[str] = None,
    _remote: Optional[bool] = None,
) -> ExperimentAnalysis:
    """Executes training.

    When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run
    will gracefully shut down and checkpoint the latest experiment state.
    Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step.

    Many aspects of Tune, such as the frequency of global checkpointing,
    maximum pending placement group trials and the path of the result
    directory be configured through environment variables. Refer to
    :ref:`tune-env-vars` for a list of environment variables available.

    Examples:

    .. code-block:: python

        # Run 10 trials (each trial is one instance of a Trainable). Tune runs
        # in parallel and automatically determines concurrency.
        tune.run(trainable, num_samples=10)

        # Run 1 trial, stop when trial has reached 10 iterations
        tune.run(my_trainable, stop={"training_iteration": 10})

        # automatically retry failed trials up to 3 times
        tune.run(my_trainable, stop={"training_iteration": 10}, max_failures=3)

        # Run 1 trial, search over hyperparameters, stop after 10 iterations.
        space = {"lr": tune.uniform(0, 1), "momentum": tune.uniform(0, 1)}
        tune.run(my_trainable, config=space, stop={"training_iteration": 10})

        # Resumes training if a previous machine crashed
        tune.run(my_trainable, config=space,
                 local_dir=<path/to/dir>, resume=True)

        # Rerun ONLY failed trials after an experiment is finished.
        tune.run(my_trainable, config=space,
                 local_dir=<path/to/dir>, resume="ERRORED_ONLY")

    Args:
        run_or_experiment: If function|class|str, this is the algorithm or
            model to train. This may refer to the name of a built-on algorithm
            (e.g. RLlib's DQN or PPO), a user-defined trainable
            function or class, or the string identifier of a
            trainable function or class registered in the tune registry.
            If Experiment, then Tune will execute training based on
            Experiment.spec. If you want to pass in a Python lambda, you
            will need to first register the function:
            ``tune.register_trainable("lambda_id", lambda x: ...)``. You can
            then use ``tune.run("lambda_id")``.
        metric: Metric to optimize. This metric should be reported
            with `tune.report()`. If set, will be passed to the search
            algorithm and scheduler.
        mode: Must be one of [min, max]. Determines whether objective is
            minimizing or maximizing the metric attribute. If set, will be
            passed to the search algorithm and scheduler.
        name: Name of experiment.
        stop: Stopping criteria. If dict,
            the keys may be any field in the return result of 'train()',
            whichever is reached first. If function, it must take (trial_id,
            result) as arguments and return a boolean (True if trial should be
            stopped, False otherwise). This can also be a subclass of
            ``ray.tune.Stopper``, which allows users to implement
            custom experiment-wide stopping (i.e., stopping an entire Tune
            run based on some time constraint).
        time_budget_s: Global time budget in
            seconds after which all trials are stopped. Can also be a
            ``datetime.timedelta`` object.
        config: Algorithm-specific configuration for Tune variant
            generation (e.g. env, hyperparams). Defaults to empty dict.
            Custom search algorithms may ignore this.
        resources_per_trial: Machine resources
            to allocate per trial, e.g. ``{"cpu": 64, "gpu": 8}``.
            Note that GPUs will not be assigned unless you specify them here.
            Defaults to 1 CPU and 0 GPUs in
            ``Trainable.default_resource_request()``. This can also
            be a PlacementGroupFactory object wrapping arguments to create a
            per-trial placement group.
        num_samples: Number of times to sample from the
            hyperparameter space. Defaults to 1. If `grid_search` is
            provided as an argument, the grid will be repeated
            `num_samples` of times. If this is -1, (virtually) infinite
            samples are generated until a stopping condition is met.
        local_dir: Local dir to save training results to.
            Defaults to ``~/ray_results``.
        search_alg: Search algorithm for
            optimization. You can also use the name of the algorithm.
        scheduler: Scheduler for executing
            the experiment. Choose among FIFO (default), MedianStopping,
            AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
            ray.tune.schedulers for more options. You can also use the
            name of the scheduler.
        keep_checkpoints_num: Number of checkpoints to keep. A value of
            `None` keeps all checkpoints. Defaults to `None`. If set, need
            to provide `checkpoint_score_attr`.
        checkpoint_score_attr: Specifies by which attribute to rank the
            best checkpoint. Default is increasing order. If attribute starts
            with `min-` it will rank attribute in decreasing order, i.e.
            `min-validation_loss`.
        checkpoint_freq: How many training iterations between
            checkpoints. A value of 0 (default) disables checkpointing.
            This has no effect when using the Functional Training API.
        checkpoint_at_end: Whether to checkpoint at the end of the
            experiment regardless of the checkpoint_freq. Default is False.
            This has no effect when using the Functional Training API.
        verbose: 0, 1, 2, or 3. Verbosity mode.
            0 = silent, 1 = only status updates, 2 = status and brief trial
            results, 3 = status and detailed trial results. Defaults to 3.
        progress_reporter: Progress reporter for reporting
            intermediate experiment progress. Defaults to CLIReporter if
            running in command-line, or JupyterNotebookReporter if running in
            a Jupyter notebook.
        log_to_file: Log stdout and stderr to files in
            Tune's trial directories. If this is `False` (default), no files
            are written. If `true`, outputs are written to `trialdir/stdout`
            and `trialdir/stderr`, respectively. If this is a single string,
            this is interpreted as a file relative to the trialdir, to which
            both streams are written. If this is a Sequence (e.g. a Tuple),
            it has to have length 2 and the elements indicate the files to
            which stdout and stderr are written, respectively.
        trial_name_creator: Optional function
            for generating the trial string representation.
        trial_dirname_creator: Function
            for generating the trial dirname. This function should take
            in a Trial object and return a string representing the
            name of the directory. The return value cannot be a path.
        sync_config: Configuration object for syncing. See
            tune.SyncConfig.
        export_formats: List of formats that exported at the end of
            the experiment. Default is None.
        max_failures: Try to recover a trial at least this many times.
            Ray will recover from the latest checkpoint if present.
            Setting to -1 will lead to infinite recovery retries.
            Setting to 0 will disable retries. Defaults to 0.
        fail_fast: Whether to fail upon the first error.
            If fail_fast='raise' provided, Tune will automatically
            raise the exception received by the Trainable. fail_fast='raise'
            can easily leak resources and should be used with caution (it
            is best used with `ray.init(local_mode=True)`).
        restore: Path to checkpoint. Only makes sense to set if
            running 1 trial. Defaults to None.
        server_port: Port number for launching TuneServer.
        resume: One of "LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY", "AUTO",
            or bool. "LOCAL"/True restores the checkpoint from the
            local experiment directory, determined
            by ``name`` and ``local_dir``. "REMOTE" restores the checkpoint
            from ``upload_dir`` (as passed to ``sync_config``).
            "PROMPT" provides the CLI feedback.
            False forces a new experiment. "ERRORED_ONLY" resets and reruns
            errored trials upon resume - previous trial artifacts will
            be left untouched.
            "AUTO" will attempt to resume from a checkpoint and otherwise
            start a new experiment.
            If resume is set but checkpoint does not exist,
            ValueError will be thrown.
        reuse_actors: Whether to reuse actors between different trials
            when possible. This can drastically speed up experiments that start
            and stop actors often (e.g., PBT in time-multiplexing mode). This
            requires trials to have the same resource requirements.
            Defaults to ``True`` for function trainables and ``False`` for
            class and registered trainables.
        trial_executor: Manage the execution of trials.
        raise_on_failed_trial: Raise TuneError if there exists failed
            trial (of ERROR state) when the experiments complete.
        callbacks: List of callbacks that will be called at different
            times in the training loop. Must be instances of the
            ``ray.tune.callback.Callback`` class. If not passed,
            `LoggerCallback` and `SyncerCallback` callbacks are automatically
            added.
        max_concurrent_trials: Maximum number of trials to run
            concurrently. Must be non-negative. If None or 0, no limit will
            be applied. This is achieved by wrapping the ``search_alg`` in
            a :class:`ConcurrencyLimiter`, and thus setting this argument
            will raise an exception if the ``search_alg`` is already a
            :class:`ConcurrencyLimiter`. Defaults to None.
        _remote: Whether to run the Tune driver in a remote function.
            This is disabled automatically if a custom trial executor is
            passed in. This is enabled by default in Ray client mode.

    Returns:
        ExperimentAnalysis: Object for experiment analysis.

    Raises:
        TuneError: Any trials failed and `raise_on_failed_trial` is True.
    """
    # NO CODE IS TO BE ADDED ABOVE THIS COMMENT
    # remote_run_kwargs must be defined before any other
    # code is ran to ensure that at this point,
    # `locals()` is equal to args and kwargs
    remote_run_kwargs = locals().copy()
    remote_run_kwargs.pop("_remote")

    if _remote is None:
        _remote = ray.util.client.ray.is_connected()

    if _remote is True and trial_executor:
        raise ValueError("cannot use custom trial executor")

    if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
        _ray_auto_init()

    if _remote:
        remote_run = ray.remote(num_cpus=0)(run)

        # Make sure tune.run is called on the sever node.
        remote_run = force_on_current_node(remote_run)

        set_verbosity(verbose)
        progress_reporter = progress_reporter or detect_reporter()

        # JupyterNotebooks don't work with remote tune runs out of the box
        # (e.g. via Ray client) as they don't have access to the main
        # process stdout. So we introduce a queue here that accepts
        # strings, which will then be displayed on the driver side.
        if isinstance(progress_reporter, RemoteReporterMixin):
            string_queue = Queue(actor_options={
                "num_cpus": 0,
                **force_on_current_node(None)
            })
            progress_reporter.output_queue = string_queue

            def get_next_queue_item():
                try:
                    return string_queue.get(block=False)
                except Empty:
                    return None

        else:
            # If we don't need a queue, use this dummy get fn instead of
            # scheduling an unneeded actor
            def get_next_queue_item():
                return None

        def _handle_string_queue():
            string_item = get_next_queue_item()
            while string_item is not None:
                # This happens on the driver side
                progress_reporter.display(string_item)

                string_item = get_next_queue_item()

        # Override with detected progress reporter
        remote_run_kwargs["progress_reporter"] = progress_reporter
        remote_future = remote_run.remote(_remote=False, **remote_run_kwargs)

        # ray.wait(...)[1] returns futures that are not ready, yet
        while ray.wait([remote_future], timeout=0.2)[1]:
            # Check if we have items to execute
            _handle_string_queue()

        # Handle queue one last time
        _handle_string_queue()

        return ray.get(remote_future)

    del remote_run_kwargs

    all_start = time.time()

    if mode and mode not in ["min", "max"]:
        raise ValueError(
            "The `mode` parameter passed to `tune.run()` has to be one of "
            "['min', 'max']")

    set_verbosity(verbose)

    config = config or {}
    sync_config = sync_config or SyncConfig()
    _validate_upload_dir(sync_config)

    if num_samples == -1:
        num_samples = sys.maxsize

    result_buffer_length = None

    # Create scheduler here as we need access to some of its properties
    if isinstance(scheduler, str):
        # importing at top level causes a recursive dependency
        from ray.tune.schedulers import create_scheduler

        scheduler = create_scheduler(scheduler)
    scheduler = scheduler or FIFOScheduler()

    if not scheduler.supports_buffered_results:
        # Result buffering with e.g. a Hyperband scheduler is a bad idea, as
        # hyperband tries to stop trials when processing brackets. With result
        # buffering, we might trigger this multiple times when evaluating
        # a single trial, which leads to unexpected behavior.
        env_result_buffer_length = os.getenv("TUNE_RESULT_BUFFER_LENGTH", "")
        if env_result_buffer_length:
            warnings.warn(
                f"You are using a {type(scheduler)} scheduler, but "
                f"TUNE_RESULT_BUFFER_LENGTH is set "
                f"({env_result_buffer_length}). This can lead to undesired "
                f"and faulty behavior, so the buffer length was forcibly set "
                f"to 1 instead.")
        result_buffer_length = 1

    # If reuse_actors is unset, default to False for string and class trainables,
    # and default to True for everything else (i.e. function trainables)
    if reuse_actors is None:
        trainable = (run_or_experiment.run_identifier if isinstance(
            run_or_experiment, Experiment) else run_or_experiment)
        reuse_actors = (
            # Only default to True for function trainables that meet certain conditions
            is_function_trainable(trainable) and not (
                # Changing resources requires restarting actors
                scheduler and isinstance(scheduler, ResourceChangingScheduler))
            and not (
                # If GPUs are requested we could run into problems with device memory
                _check_gpus_in_resources(resources_per_trial))
            and not (
                # If the resource request is overridden, we don't know if GPUs
                # will be requested, yet, so default to False
                _check_default_resources_override(trainable)))

    if (isinstance(scheduler,
                   (PopulationBasedTraining, PopulationBasedTrainingReplay))
            and not reuse_actors):
        warnings.warn(
            "Consider boosting PBT performance by enabling `reuse_actors` as "
            "well as implementing `reset_config` for Trainable.")

    trial_executor = trial_executor or RayTrialExecutor(
        reuse_actors=reuse_actors, result_buffer_length=result_buffer_length)
    if isinstance(run_or_experiment, list):
        experiments = run_or_experiment
    else:
        experiments = [run_or_experiment]

    for i, exp in enumerate(experiments):
        if not isinstance(exp, Experiment):
            experiments[i] = Experiment(
                name=name,
                run=exp,
                stop=stop,
                time_budget_s=time_budget_s,
                config=config,
                resources_per_trial=resources_per_trial,
                num_samples=num_samples,
                local_dir=local_dir,
                _experiment_checkpoint_dir=_experiment_checkpoint_dir,
                sync_config=sync_config,
                trial_name_creator=trial_name_creator,
                trial_dirname_creator=trial_dirname_creator,
                log_to_file=log_to_file,
                checkpoint_freq=checkpoint_freq,
                checkpoint_at_end=checkpoint_at_end,
                keep_checkpoints_num=keep_checkpoints_num,
                checkpoint_score_attr=checkpoint_score_attr,
                export_formats=export_formats,
                max_failures=max_failures,
                restore=restore,
            )
    else:
        logger.debug("Ignoring some parameters passed into tune.run.")

    if fail_fast and max_failures != 0:
        raise ValueError("max_failures must be 0 if fail_fast=True.")

    if isinstance(search_alg, str):
        search_alg = create_searcher(search_alg)

    # if local_mode=True is set during ray.init().
    is_local_mode = ray._private.worker._mode(
    ) == ray._private.worker.LOCAL_MODE

    if is_local_mode:
        max_concurrent_trials = 1

    if not search_alg:
        search_alg = BasicVariantGenerator(
            max_concurrent=max_concurrent_trials or 0)
    elif max_concurrent_trials or is_local_mode:
        if isinstance(search_alg, ConcurrencyLimiter):
            if not is_local_mode:
                if search_alg.max_concurrent != max_concurrent_trials:
                    raise ValueError(
                        "You have specified `max_concurrent_trials="
                        f"{max_concurrent_trials}`, but the `search_alg` is "
                        "already a `ConcurrencyLimiter` with `max_concurrent="
                        f"{search_alg.max_concurrent}. FIX THIS by setting "
                        "`max_concurrent_trials=None`.")
                else:
                    logger.warning(
                        "You have specified `max_concurrent_trials="
                        f"{max_concurrent_trials}`, but the `search_alg` is "
                        "already a `ConcurrencyLimiter`. "
                        "`max_concurrent_trials` will be ignored.")
        else:
            if max_concurrent_trials < 1:
                raise ValueError(
                    "`max_concurrent_trials` must be greater or equal than 1, "
                    f"got {max_concurrent_trials}.")
            if isinstance(search_alg, Searcher):
                search_alg = ConcurrencyLimiter(
                    search_alg, max_concurrent=max_concurrent_trials)
            elif not is_local_mode:
                logger.warning(
                    "You have passed a `SearchGenerator` instance as the "
                    "`search_alg`, but `max_concurrent_trials` requires a "
                    "`Searcher` instance`. `max_concurrent_trials` "
                    "will be ignored.")

    if isinstance(search_alg, Searcher):
        search_alg = SearchGenerator(search_alg)

    if config and not searcher_set_search_props(
            search_alg.set_search_properties,
            metric,
            mode,
            config,
            **experiments[0].public_spec,
    ):
        if has_unresolved_values(config):
            raise ValueError(
                "You passed a `config` parameter to `tune.run()` with "
                "unresolved parameters, but the search algorithm was already "
                "instantiated with a search space. Make sure that `config` "
                "does not contain any more parameter definitions - include "
                "them in the search algorithm's search space if necessary.")

    if not scheduler_set_search_props(scheduler.set_search_properties, metric,
                                      mode, **experiments[0].public_spec):
        raise ValueError(
            "You passed a `metric` or `mode` argument to `tune.run()`, but "
            "the scheduler you are using was already instantiated with their "
            "own `metric` and `mode` parameters. Either remove the arguments "
            "from your scheduler or from your call to `tune.run()`")

    # Create syncer callbacks
    callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)

    runner = TrialRunner(
        search_alg=search_alg,
        scheduler=scheduler,
        local_checkpoint_dir=experiments[0].checkpoint_dir,
        remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
        sync_config=sync_config,
        stopper=experiments[0].stopper,
        resume=resume,
        server_port=server_port,
        fail_fast=fail_fast,
        trial_executor=trial_executor,
        callbacks=callbacks,
        metric=metric,
        # Driver should only sync trial checkpoints if
        # checkpoints are not synced to cloud
        driver_sync_trial_checkpoints=not bool(sync_config.upload_dir),
    )

    if not runner.resumed:
        for exp in experiments:
            search_alg.add_configurations([exp])
    else:
        logger.info("TrialRunner resumed, ignoring new add_experiment but "
                    "updating trial resources.")
        if resources_per_trial:
            runner.update_pending_trial_resources(resources_per_trial)

    # Calls setup on callbacks
    runner.setup_experiments(experiments=experiments,
                             total_num_samples=search_alg.total_samples)

    # User Warning for GPUs
    if trial_executor.has_gpus():
        if _check_gpus_in_resources(resources=resources_per_trial):
            # "gpu" is manually set.
            pass
        elif _check_default_resources_override(experiments[0].run_identifier):
            # "default_resources" is manually overridden.
            pass
        else:
            logger.warning("Tune detects GPUs, but no trials are using GPUs. "
                           "To enable trials to use GPUs, set "
                           "tune.run(resources_per_trial={'gpu': 1}...) "
                           "which allows Tune to expose 1 GPU to each trial. "
                           "You can also override "
                           "`Trainable.default_resource_request` if using the "
                           "Trainable API.")

    original_handler = signal.getsignal(signal.SIGINT)
    state = {"signal": None}

    def signal_interrupt_tune_run(sig: int, frame):
        logger.warning(
            "Stop signal received (e.g. via SIGINT/Ctrl+C), ending Ray Tune run. "
            "This will try to checkpoint the experiment state one last time. "
            "Press CTRL+C (or send SIGINT/SIGKILL/SIGTERM) "
            "to skip. ")
        state["signal"] = sig
        # Restore original signal handler to react to future SIGINT signals
        signal.signal(signal.SIGINT, original_handler)

    # We should only install the handler when it is safe to do so.
    # When tune.run() is called from worker thread, signal.signal will
    # fail.
    allow_signal_catching = True
    if threading.current_thread() != threading.main_thread():
        allow_signal_catching = False

    if allow_signal_catching:
        if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
            signal.signal(signal.SIGINT, signal_interrupt_tune_run)

        # Always register SIGUSR1 if available (not available e.g. on Windows)
        if hasattr(signal, "SIGUSR1"):
            signal.signal(signal.SIGUSR1, signal_interrupt_tune_run)

    progress_reporter = progress_reporter or detect_reporter()

    tune_start = time.time()

    progress_reporter.setup(
        start_time=tune_start,
        total_samples=search_alg.total_samples,
        metric=metric,
        mode=mode,
    )
    while not runner.is_finished() and not state["signal"]:
        runner.step()
        if has_verbosity(Verbosity.V1_EXPERIMENT):
            _report_progress(runner, progress_reporter)
    tune_taken = time.time() - tune_start

    try:
        runner.checkpoint(force=True)
    except Exception as e:
        logger.warning(f"Trial Runner checkpointing failed: {str(e)}")

    if has_verbosity(Verbosity.V1_EXPERIMENT):
        _report_progress(runner, progress_reporter, done=True)

    # Wait for syncing to finish
    for callback in callbacks:
        if isinstance(callback, SyncerCallback):
            try:
                callback.wait_for_all()
            except TuneError as e:
                logger.error(e)

    runner.cleanup()

    incomplete_trials = []
    for trial in runner.get_trials():
        if trial.status != Trial.TERMINATED:
            incomplete_trials += [trial]

    if incomplete_trials:
        if raise_on_failed_trial and not state["signal"]:
            raise TuneError("Trials did not complete", incomplete_trials)
        else:
            logger.error("Trials did not complete: %s", incomplete_trials)

    all_taken = time.time() - all_start
    if has_verbosity(Verbosity.V1_EXPERIMENT):
        logger.info(f"Total run time: {all_taken:.2f} seconds "
                    f"({tune_taken:.2f} seconds for the tuning loop).")

    if state["signal"]:
        logger.warning(
            "Experiment has been interrupted, but the most recent state was "
            "saved. You can continue running this experiment by passing "
            "`resume=True` to `tune.run()`")

    trials = runner.get_trials()
    return ExperimentAnalysis(
        runner.checkpoint_file,
        trials=trials,
        default_metric=metric,
        default_mode=mode,
        sync_config=sync_config,
    )