Esempio n. 1
0
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
    """Test checks that trial restarts if checkpoint is lost w/ node fail."""
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    runner = TrialRunner(BasicVariantGenerator())
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 3
        },
        "checkpoint_freq": 2,
        "max_failures": 2
    }

    # Test recovery of trial that has been checkpointed
    t1 = Trial("__fake", **kwargs)
    runner.add_trial(t1)
    runner.step()  # start
    runner.step()  # 1 result
    runner.step()  # 2 result and checkpoint
    assert t1.has_checkpoint()
    cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()
    shutil.rmtree(os.path.dirname(t1._checkpoint.value))

    runner.step()  # Recovery step
    for i in range(3):
        runner.step()

    assert t1.status == Trial.TERMINATED
Esempio n. 2
0
    def restore(cls,
                metadata_checkpoint_dir,
                search_alg=None,
                scheduler=None,
                trial_executor=None):
        """Restores all checkpointed trials from previous run.

        Requires user to manually re-register their objects. Also stops
        all ongoing trials.

        Args:
            metadata_checkpoint_dir (str): Path to metadata checkpoints.
            search_alg (SearchAlgorithm): Search Algorithm. Defaults to
                BasicVariantGenerator.
            scheduler (TrialScheduler): Scheduler for executing
                the experiment.
            trial_executor (TrialExecutor): Manage the execution of trials.

        Returns:
            runner (TrialRunner): A TrialRunner to resume experiments from.
        """

        newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
        with open(newest_ckpt_path, "r") as f:
            runner_state = json.load(f)

        logger.warning("".join([
            "Attempting to resume experiment from {}. ".format(
                metadata_checkpoint_dir), "This feature is experimental, "
            "and may not work with all search algorithms. ",
            "This will ignore any new changes to the specification."
        ]))

        from ray.tune.suggest import BasicVariantGenerator
        runner = TrialRunner(
            search_alg or BasicVariantGenerator(),
            scheduler=scheduler,
            trial_executor=trial_executor)

        runner.__setstate__(runner_state["runner_data"])

        trials = []
        for trial_cp in runner_state["checkpoints"]:
            new_trial = Trial(trial_cp["trainable_name"])
            new_trial.__setstate__(trial_cp)
            trials += [new_trial]
        for trial in sorted(
                trials, key=lambda t: t.last_update_time, reverse=True):
            runner.add_trial(trial)
        return runner
Esempio n. 3
0
    def _generate_trials(self, experiment_spec, output_path=""):
        """Generates trials with configurations from `_suggest`.

        Creates a trial_id that is passed into `_suggest`.

        Yields:
            Trial objects constructed according to `spec`
        """
        if "run" not in experiment_spec:
            raise TuneError("Must specify `run` in {}".format(experiment_spec))
        for _ in range(experiment_spec.get("num_samples", 1)):
            trial_id = Trial.generate_id()
            while True:
                suggested_config = self._suggest(trial_id)
                if suggested_config is None:
                    yield None
                else:
                    break
            spec = copy.deepcopy(experiment_spec)
            spec["config"] = merge_dicts(spec["config"], suggested_config)
            flattened_config = resolve_nested_dict(spec["config"])
            self._counter += 1
            tag = "{0}_{1}".format(
                str(self._counter), format_vars(flattened_config))
            yield create_trial_from_spec(
                spec,
                output_path,
                self._parser,
                experiment_tag=tag,
                trial_id=trial_id)
Esempio n. 4
0
def test_trial_migration(start_connected_emptyhead_cluster):
    """Removing a node while cluster has space should migrate trial.

    The trial state should also be consistent with the checkpoint.
    """
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    runner = TrialRunner(BasicVariantGenerator())
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 3
        },
        "checkpoint_freq": 2,
        "max_failures": 2
    }

    # Test recovery of trial that hasn't been checkpointed
    t = Trial("__fake", **kwargs)
    runner.add_trial(t)
    runner.step()  # start
    runner.step()  # 1 result
    assert t.last_result
    node2 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()
    runner.step()  # Recovery step

    # TODO(rliaw): This assertion is not critical but will not pass
    #   because checkpoint handling is messy and should be refactored
    #   rather than hotfixed.
    # assert t.last_result is None, "Trial result not restored correctly."
    for i in range(3):
        runner.step()

    assert t.status == Trial.TERMINATED

    # Test recovery of trial that has been checkpointed
    t2 = Trial("__fake", **kwargs)
    runner.add_trial(t2)
    runner.step()  # start
    runner.step()  # 1 result
    runner.step()  # 2 result and checkpoint
    assert t2.has_checkpoint()
    node3 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node2)
    cluster.wait_for_nodes()
    runner.step()  # Recovery step
    assert t2.last_result["training_iteration"] == 2
    for i in range(1):
        runner.step()

    assert t2.status == Trial.TERMINATED

    # Test recovery of trial that won't be checkpointed
    t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
    runner.add_trial(t3)
    runner.step()  # start
    runner.step()  # 1 result
    cluster.add_node(num_cpus=1)
    cluster.remove_node(node3)
    cluster.wait_for_nodes()
    runner.step()  # Error handling step
    assert t3.status == Trial.ERROR

    with pytest.raises(TuneError):
        runner.step()
Esempio n. 5
0
def test_trial_migration(start_connected_emptyhead_cluster):
    """Removing a node while cluster has space should migrate trial.

    The trial state should also be consistent with the checkpoint.
    """
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    runner = TrialRunner(BasicVariantGenerator())
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 3
        },
        "checkpoint_freq": 2,
        "max_failures": 2
    }

    # Test recovery of trial that hasn't been checkpointed
    t = Trial("__fake", **kwargs)
    runner.add_trial(t)
    runner.step()  # start
    runner.step()  # 1 result
    assert t.last_result
    node2 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()
    runner.step()  # Recovery step

    # TODO(rliaw): This assertion is not critical but will not pass
    #   because checkpoint handling is messy and should be refactored
    #   rather than hotfixed.
    # assert t.last_result is None, "Trial result not restored correctly."
    for i in range(3):
        runner.step()

    assert t.status == Trial.TERMINATED

    # Test recovery of trial that has been checkpointed
    t2 = Trial("__fake", **kwargs)
    runner.add_trial(t2)
    runner.step()  # start
    runner.step()  # 1 result
    runner.step()  # 2 result and checkpoint
    assert t2.has_checkpoint()
    node3 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node2)
    cluster.wait_for_nodes()
    runner.step()  # Recovery step
    assert t2.last_result["training_iteration"] == 2
    for i in range(1):
        runner.step()

    assert t2.status == Trial.TERMINATED

    # Test recovery of trial that won't be checkpointed
    t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}})
    runner.add_trial(t3)
    runner.step()  # start
    runner.step()  # 1 result
    cluster.add_node(num_cpus=1)
    cluster.remove_node(node3)
    cluster.wait_for_nodes()
    runner.step()  # Error handling step
    assert t3.status == Trial.ERROR

    with pytest.raises(TuneError):
        runner.step()
Esempio n. 6
0
 def testStartFailure(self):
     _global_registry.register(TRAINABLE_CLASS, "asdf", None)
     trial = Trial("asdf", resources=Resources(1, 0))
     self.trial_executor.start_trial(trial)
     self.assertEqual(Trial.ERROR, trial.status)
Esempio n. 7
0
    def _stop_trial(self,
                    trial: Trial,
                    error=False,
                    error_msg=None,
                    destroy_pg_if_cannot_replace=True):
        """Stops this trial.

        Stops this trial, releasing all allocating resources. If stopping the
        trial fails, the run will be marked as terminated in error, but no
        exception will be thrown.

        If the placement group will be used right away
        (destroy_pg_if_cannot_replace=False), we do not remove its placement
        group (or a surrogate placement group).

        Args:
            error (bool): Whether to mark this trial as terminated in error.
            error_msg (str): Optional error message.

        """
        self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED)
        self._trial_just_finished = True
        trial.set_location(Location())

        try:
            trial.write_error_log(error_msg)
            if hasattr(trial, "runner") and trial.runner:
                if (not error and self._reuse_actors
                        and (len(self._cached_actor_pg) <
                             (self._cached_actor_pg.maxlen or float("inf")))):
                    logger.debug("Reusing actor for %s", trial.runner)
                    # Move PG into cache (disassociate from trial)
                    pg = self._pg_manager.cache_trial_pg(trial)
                    if pg or not trial.uses_placement_groups:
                        # True if a placement group was replaced
                        self._cached_actor_pg.append((trial.runner, pg))
                        should_destroy_actor = False
                    else:
                        # False if no placement group was replaced. This should
                        # only be the case if there are no more trials with
                        # this placement group factory to run
                        logger.debug(
                            "Could not cache of trial {trial} actor for "
                            "reuse, as there are no pending trials "
                            "requiring its resources.")
                        should_destroy_actor = True
                else:
                    should_destroy_actor = True

                if should_destroy_actor:
                    logger.debug("Trial %s: Destroying actor.", trial)

                    # Try to return the placement group for other trials to use
                    self._pg_manager.return_pg(trial,
                                               destroy_pg_if_cannot_replace)

                    with self._change_working_directory(trial):
                        self._trial_cleanup.add(trial, actor=trial.runner)

                if trial in self._staged_trials:
                    self._staged_trials.remove(trial)

        except Exception:
            logger.exception("Trial %s: Error stopping runner.", trial)
            self.set_status(trial, Trial.ERROR)
        finally:
            trial.set_runner(None)
Esempio n. 8
0
def create_next(client):
    '''A stateless API for HPO
    '''
    state = client.get_state()
    setting = client.get_settings_dict()
    if state is None:
        # first time call
        try:
            from ray.tune import (uniform, quniform, choice, randint, qrandint, randn,
        qrandn, loguniform, qloguniform)
            from ray.tune.trial import Trial
        except:
            from ..tune.sample import (uniform, quniform, choice, randint, qrandint, randn,
        qrandn, loguniform, qloguniform)
            from ..tune.trial import Trial
        method = setting.get('method', 'BlendSearch')
        mode = client.get_optimization_mode()
        if mode == 'minimize':
            mode = 'min'
        elif mode == 'maximize':
            mode = 'max'
        metric = client.get_primary_metric()
        hp_space = client.get_hyperparameter_space_dict()
        space = {}
        for key, value in hp_space.items():
            t = value["type"]
            if t == 'continuous':
                space[key] = uniform(value["min_val"], value["max_val"])
            elif t == 'discrete':
                space[key] = choice(value["values"])
            elif t == 'integral':
                space[key] = randint(value["min_val"], value["max_val"])
            elif t == 'quantized_continuous':
                space[key] = quniform(value["min_val"], value["max_val"],
                 value["step"])
        init_config = setting.get('init_config', None)
        if init_config:
            points_to_evaluate = [init_config]
        else:
            points_to_evaluate = None
        cat_hp_cost = setting.get('cat_hp_cost', None)

        if method == 'BlendSearch':
            Algo = BlendSearch
        elif method == 'CFO':
            Algo = CFO
        algo = Algo(
            mode=mode, 
            metric=metric, 
            space=space,
            points_to_evaluate=points_to_evaluate,
            cat_hp_cost=cat_hp_cost,
            )
        time_budget_s = setting.get('time_budget_s', None)
        if time_budget_s:
            algo._deadline = time_budget_s + time.time()
        config2trialid = {}
    else:
        algo = state['algo']
        config2trialid = state['config2trialid']
    # update finished trials
    trials_completed = []
    for trial in client.get_trials():
        if trial.end_time is not None:
            signature = algo._ls.config_signature(trial.hp_sample)
            if not algo._result[signature]:
                trials_completed.append((trial.end_time, trial))
    trials_completed.sort()
    for t in trials_completed:
        end_time, trial = t
        trial_id = config2trialid[trial.hp_sample]
        result = {}
        result[algo.metric] = trial.metrics[algo.metric].values[-1]
        result[algo.cost_attr] = (end_time - trial.start_time).total_seconds()
        for key, value in trial.hp_sample.items():
            result['config/'+key] = value
        algo.on_trial_complete(trial_id, result=result)
    # propose new trial
    trial_id = Trial.generate_id()
    config = algo.suggest(trial_id)
    if config:
        config2trialid[config] = trial_id
        client.launch_trial(config)
    client.update_state({'algo': algo, 'config2trialid': config2trialid})
Esempio n. 9
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    global _cached_pgf

    spec = spec.copy()
    resources = spec.pop("resources_per_trial", None)

    try:
        args, _ = parser.parse_known_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)

    if resources:
        if isinstance(resources, PlacementGroupFactory):
            trial_kwargs["placement_group_factory"] = resources
        elif callable(resources):
            if resources in _cached_pgf:
                trial_kwargs["placement_group_factory"] = _cached_pgf[
                    resources]
            else:
                pgf = PlacementGroupFactory(resources)
                _cached_pgf[resources] = pgf
                trial_kwargs["placement_group_factory"] = pgf
        else:
            try:
                trial_kwargs["resources"] = json_to_resources(resources)
            except (TuneError, ValueError) as exc:
                raise TuneError("Error parsing resources_per_trial",
                                resources) from exc

    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(spec["local_dir"], output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        remote_checkpoint_dir=spec.get("remote_checkpoint_dir"),
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_at_end=args.checkpoint_at_end,
        sync_on_checkpoint=args.sync_on_checkpoint,
        keep_checkpoints_num=args.keep_checkpoints_num,
        checkpoint_score_attr=args.checkpoint_score_attr,
        export_formats=spec.get("export_formats", []),
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        trial_name_creator=spec.get("trial_name_creator"),
        trial_dirname_creator=spec.get("trial_dirname_creator"),
        log_to_file=spec.get("log_to_file"),
        # str(None) doesn't create None
        max_failures=args.max_failures,
        **trial_kwargs)
Esempio n. 10
0
 def make_stub_if_needed(trial: Trial) -> Trial:
     if trial.stub:
         return trial
     trial_copy = Trial(trial.trainable_name, stub=True)
     trial_copy.__setstate__(trial.__getstate__())
     return trial_copy
Esempio n. 11
0
    def testCallbackSteps(self):
        trials = [
            Trial("__fake", trial_id="one"),
            Trial("__fake", trial_id="two")
        ]
        for t in trials:
            self.trial_runner.add_trial(t)

        self.executor.next_trial = trials[0]
        self.trial_runner.step()

        # Trial 1 has been started
        self.assertEqual(self.callback.state["trial_start"]["iteration"], 0)
        self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id,
                         "one")

        # All these events haven't happened, yet
        self.assertTrue(
            all(k not in self.callback.state for k in [
                "trial_restore", "trial_save", "trial_result",
                "trial_complete", "trial_fail"
            ]))

        self.executor.next_trial = trials[1]
        self.trial_runner.step()

        # Iteration not increased yet
        self.assertEqual(self.callback.state["step_begin"]["iteration"], 1)

        # Iteration increased
        self.assertEqual(self.callback.state["step_end"]["iteration"], 2)

        # Second trial has been just started
        self.assertEqual(self.callback.state["trial_start"]["iteration"], 1)
        self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id,
                         "two")

        cp = Checkpoint(Checkpoint.PERSISTENT, "__checkpoint",
                        {TRAINING_ITERATION: 0})

        # Let the first trial save a checkpoint
        self.executor.next_trial = trials[0]
        trials[0].saving_to = cp
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_save"]["iteration"], 2)
        self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id,
                         "one")

        # Let the second trial send a result
        result = {TRAINING_ITERATION: 1, "metric": 800, "done": False}
        self.executor.results[trials[1]] = result
        self.executor.next_trial = trials[1]
        self.assertEqual(trials[1].last_result, {})
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_result"]["iteration"], 3)
        self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id,
                         "two")
        self.assertEqual(
            self.callback.state["trial_result"]["result"]["metric"], 800)
        self.assertEqual(trials[1].last_result["metric"], 800)

        # Let the second trial restore from a checkpoint
        trials[1].restoring_from = cp
        self.executor.results[trials[1]] = trials[1].last_result
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4)
        self.assertEqual(
            self.callback.state["trial_restore"]["trial"].trial_id, "two")

        # Let the second trial finish
        trials[1].restoring_from = None
        self.executor.results[trials[1]] = {
            TRAINING_ITERATION: 2,
            "metric": 900,
            "done": True
        }
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5)
        self.assertEqual(
            self.callback.state["trial_complete"]["trial"].trial_id, "two")

        # Let the first trial error
        self.executor.failed_trial = trials[0]
        self.trial_runner.step()
        self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6)
        self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id,
                         "one")
Esempio n. 12
0
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
                                      trainable_id):
    """Test checks that trial restarts if checkpoint is lost w/ node fail."""
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    class _SyncerCallback(SyncerCallback):
        def _create_trial_syncer(self, trial: "Trial"):
            client = mock_storage_client()
            return MockNodeSyncer(trial.logdir, trial.logdir, client)

    syncer_callback = _SyncerCallback(None)
    runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 4
        },
        "checkpoint_freq": 2,
        "max_failures": 2,
        "remote_checkpoint_dir": MOCK_REMOTE_DIR,
    }

    # The following patches only affect __fake_remote.
    def hide_remote_path(path_function):
        def hidden_path_func(checkpoint_path):
            """Converts back to local path first."""
            if MOCK_REMOTE_DIR in checkpoint_path:
                checkpoint_path = checkpoint_path[len(MOCK_REMOTE_DIR):]
                checkpoint_path = os.path.join("/", checkpoint_path)
            return path_function(checkpoint_path)

        return hidden_path_func

    trainable_util = "ray.tune.ray_trial_executor.TrainableUtil"
    _find_ckpt = trainable_util + ".find_checkpoint_dir"
    find_func = TrainableUtil.find_checkpoint_dir
    _pickle_ckpt = trainable_util + ".pickle_checkpoint"
    pickle_func = TrainableUtil.pickle_checkpoint

    with patch(_find_ckpt) as mock_find, patch(_pickle_ckpt) as mock_pkl_ckpt:
        # __fake_remote trainables save to a separate "remote" directory.
        # TrainableUtil will not check this path unless we mock it.
        mock_find.side_effect = hide_remote_path(find_func)
        mock_pkl_ckpt.side_effect = hide_remote_path(pickle_func)

        # Test recovery of trial that has been checkpointed
        t1 = Trial(trainable_id, **kwargs)
        runner.add_trial(t1)

        # Start trial, process result (x2), process save
        for _ in range(4):
            runner.step()
        assert t1.has_checkpoint()

        cluster.add_node(num_cpus=1)
        cluster.remove_node(node)
        cluster.wait_for_nodes()
        shutil.rmtree(os.path.dirname(t1.checkpoint.value))
        runner.step()  # Collect result 3, kick off + fail result 4
        runner.step()  # Dispatch restore
        runner.step()  # Process restore + step 4
        for _ in range(3):
            if t1.status != Trial.TERMINATED:
                runner.step()
    assert t1.status == Trial.TERMINATED, runner.debug_string()
Esempio n. 13
0
def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
    """Removing a node while cluster has space should migrate trial.

    The trial state should also be consistent with the checkpoint.
    """
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    syncer_callback = _PerTrialSyncerCallback(
        lambda trial: trial.trainable_name == "__fake")
    runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 4
        },
        "checkpoint_freq": 2,
        "max_failures": 2,
        "remote_checkpoint_dir": MOCK_REMOTE_DIR,
    }

    # Test recovery of trial that hasn't been checkpointed
    t = Trial(trainable_id, **kwargs)
    runner.add_trial(t)
    runner.step()  # Start trial
    runner.step()  # Process result
    assert t.last_result
    node2 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()
    # TODO(ujvl): Node failure does not propagate until a step after it
    #  actually should. This is possibly a problem with `Cluster`.
    runner.step()
    runner.step()  # Recovery step

    # TODO(rliaw): This assertion is not critical but will not pass
    #   because checkpoint handling is messy and should be refactored
    #   rather than hotfixed.
    # assert t.last_result is None, "Trial result not restored correctly."

    # Process result (x2), process save, process result (x2), process save
    for _ in range(6):
        runner.step()

    assert t.status == Trial.TERMINATED, runner.debug_string()

    # Test recovery of trial that has been checkpointed
    t2 = Trial(trainable_id, **kwargs)
    runner.add_trial(t2)
    # Start trial, process result (x2), process save
    for _ in range(4):
        runner.step()
    assert t2.has_checkpoint()
    node3 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node2)
    cluster.wait_for_nodes()
    runner.step()  # Process result 3 + start and fail 4 result
    runner.step()  # Dispatch restore
    runner.step()  # Process restore
    runner.step()  # Process result 5
    if t2.status != Trial.TERMINATED:
        runner.step()  # Process result 6, dispatch save
        runner.step()  # Process save
    assert t2.status == Trial.TERMINATED, runner.debug_string()

    # Test recovery of trial that won't be checkpointed
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 3
        },
        "remote_checkpoint_dir": MOCK_REMOTE_DIR,
    }
    t3 = Trial(trainable_id, **kwargs)
    runner.add_trial(t3)
    runner.step()  # Start trial
    runner.step()  # Process result 1
    cluster.add_node(num_cpus=1)
    cluster.remove_node(node3)
    cluster.wait_for_nodes()
    runner.step()  # Error handling step
    if t3.status != Trial.ERROR:
        runner.step()
    assert t3.status == Trial.ERROR, runner.debug_string()

    with pytest.raises(TuneError):
        runner.step()