Example #1
0
class Trial(object):
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.
    """

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 evaluated_params=None,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr="",
                 export_formats=None,
                 restore_path=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_to_driver_fn=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """

        validate_trainable(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.local_dir = local_dir  # This remains unexpanded for syncing.

        #: Parameters that Tune varies across searches.
        self.evaluated_params = evaluated_params or {}
        self.experiment_tag = experiment_tag
        trainable_cls = self.get_trainable_cls()
        if trainable_cls and hasattr(trainable_cls,
                                     "default_resource_request"):
            default_resources = trainable_cls.default_resource_request(
                self.config)
            if default_resources:
                if resources:
                    raise ValueError(
                        "Resources for {} have been automatically set to {} "
                        "by its `default_resource_request()` method. Please "
                        "clear the `resources_per_trial` option.".format(
                            trainable_cls, default_resources))
                resources = default_resources
        self.resources = resources or Resources(cpu=1, gpu=0)
        self.stopping_criterion = stopping_criterion or {}
        self.loggers = loggers
        self.sync_to_driver_fn = sync_to_driver_fn
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end

        self.history = []
        self.keep_checkpoints_num = keep_checkpoints_num
        self._cmp_greater = not checkpoint_score_attr.startswith("min-")
        self.best_checkpoint_attr_value = -float("inf") \
            if self._cmp_greater else float("inf")
        # Strip off "min-" from checkpoint attribute
        self.checkpoint_score_attr = checkpoint_score_attr \
            if self._cmp_greater else checkpoint_score_attr[4:]

        self._checkpoint = Checkpoint(storage=Checkpoint.DISK,
                                      value=restore_path)
        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.error_file = None
        self.error_msg = None
        self.num_failures = 0
        self.custom_trial_name = None

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "_checkpoint",
            "loggers",
            "sync_to_driver_fn",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

    @classmethod
    def generate_id(cls):
        return str(uuid.uuid1().hex)[:8]

    @classmethod
    def create_logdir(cls, identifier, local_dir):
        local_dir = os.path.expanduser(local_dir)
        if not os.path.exists(local_dir):
            os.makedirs(local_dir)
        return tempfile.mkdtemp(prefix="{}_{}".format(
            identifier[:MAX_LEN_IDENTIFIER], date_str()),
                                dir=local_dir)

    def init_logger(self):
        """Init logger."""

        if not self.result_logger:
            if not self.logdir:
                self.logdir = Trial.create_logdir(str(self), self.local_dir)
            elif not os.path.exists(self.logdir):
                os.makedirs(self.logdir)

            self.result_logger = UnifiedLogger(
                self.config,
                self.logdir,
                trial=self,
                loggers=self.loggers,
                sync_function=self.sync_to_driver_fn)

    def update_resources(self, cpu, gpu, **kwargs):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")
        self.resources = Resources(cpu, gpu, **kwargs)

    def sync_logger_to_new_location(self, worker_ip):
        """Updates the logger location.

        Also pushes logdir to worker_ip, allowing for cross-node recovery.
        """
        if self.result_logger:
            self.result_logger.sync_results_to_new_location(worker_ip)

    def close_logger(self):
        """Close logger."""

        if self.result_logger:
            self.result_logger.close()
            self.result_logger = None

    def write_error_log(self, error_msg):
        if error_msg and self.logdir:
            self.num_failures += 1  # may be moved to outer scope?
            error_file = os.path.join(self.logdir,
                                      "error_{}.txt".format(date_str()))
            with open(error_file, "w") as f:
                f.write(error_msg)
            self.error_file = error_file
            self.error_msg = error_msg

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""

        if result.get(DONE):
            return True

        if callable(self.stopping_criterion):
            return self.stopping_criterion(self.trial_id, result)

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            elif isinstance(criteria, dict):
                raise ValueError(
                    "Stopping criteria is now flattened by default. "
                    "Use forward slashes to nest values `key1/key2/key3`.")
            elif result[criteria] >= stop_value:
                return True
        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}

        if result.get(DONE) and self.checkpoint_at_end:
            return True

        if self.checkpoint_freq:
            return result.get(TRAINING_ITERATION,
                              0) % self.checkpoint_freq == 0
        else:
            return False

    def has_checkpoint(self):
        return self._checkpoint.value is not None

    def clear_checkpoint(self):
        self._checkpoint.value = None

    def should_recover(self):
        """Returns whether the trial qualifies for retrying.

        This is if the trial has not failed more than max_failures. Note this
        may return true even when there is no checkpoint, either because
        `self.checkpoint_freq` is `0` or because the trial failed before
        a checkpoint has been made.
        """
        return self.num_failures < self.max_failures or self.max_failures < 0

    def update_last_result(self, result, terminate=False):
        result.update(trial_id=self.trial_id, done=terminate)
        if self.experiment_tag:
            result.update(experiment_tag=self.experiment_tag)
        if self.verbose and (terminate or time.time() - self.last_debug >
                             DEBUG_PRINT_INTERVAL):
            print("Result for {}:".format(self))
            print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
            self.last_debug = time.time()
        self.last_result = result
        self.last_update_time = time.time()
        self.result_logger.on_result(self.last_result)

    def compare_checkpoints(self, attr_mean):
        """Compares two checkpoints based on the attribute attr_mean param.
        Greater than is used by default. If  command-line parameter
        checkpoint_score_attr starts with "min-" less than is used.

        Arguments:
            attr_mean: mean of attribute value for the current checkpoint

        Returns:
            True: when attr_mean is greater than previous checkpoint attr_mean
                  and greater than function is selected
                  when attr_mean is less than previous checkpoint attr_mean and
                  less than function is selected
            False: when attr_mean is not in alignment with selected cmp fn
        """
        if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
            return True
        elif (not self._cmp_greater
              and attr_mean < self.best_checkpoint_attr_value):
            return True
        return False

    def get_trainable_cls(self):
        return get_trainable_cls(self.trainable_name)

    def set_verbose(self, verbose):
        self.verbose = verbose

    def is_finished(self):
        return self.status in [Trial.TERMINATED, Trial.ERROR]

    @property
    def node_ip(self):
        return self.last_result.get("node_ip")

    def __repr__(self):
        return str(self)

    def __str__(self):
        """Combines ``env`` with ``trainable_name`` and ``trial_id``.

        Can be overriden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        identifier += "_" + self.trial_id
        return identifier.replace("/", "_")

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a DISK checkpoint.
        """
        assert self._checkpoint.storage == Checkpoint.DISK, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["result_logger"] = None
        if self.result_logger:
            self.result_logger.flush()
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)

    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])
        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        validate_trainable(self.trainable_name)
        if logger_started:
            self.init_logger()
Example #2
0
class Trial(object):
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.
    """

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 export_formats=None,
                 restore_path=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_function=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """

        Trial._registration_check(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.config = config or {}
        self.local_dir = os.path.expanduser(local_dir)
        self.experiment_tag = experiment_tag
        self.resources = (
            resources
            or self._get_trainable_cls().default_resource_request(self.config))
        self.stopping_criterion = stopping_criterion or {}
        self.upload_dir = upload_dir
        self.loggers = loggers
        self.sync_function = sync_function
        validate_sync_function(sync_function)
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end
        self._checkpoint = Checkpoint(
            storage=Checkpoint.DISK, value=restore_path)
        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.error_file = None
        self.num_failures = 0

        self.custom_trial_name = None

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "_checkpoint",
            "config",
            "loggers",
            "sync_function",
            "last_result",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

    @classmethod
    def _registration_check(cls, trainable_name):
        if not has_trainable(trainable_name):
            # Make sure rllib agents are registered
            from ray import rllib  # noqa: F401
            if not has_trainable(trainable_name):
                raise TuneError("Unknown trainable: " + trainable_name)

    @classmethod
    def generate_id(cls):
        return binary_to_hex(_random_string())[:8]

    def init_logger(self):
        """Init logger."""

        if not self.result_logger:
            if not os.path.exists(self.local_dir):
                os.makedirs(self.local_dir)
            if not self.logdir:
                self.logdir = tempfile.mkdtemp(
                    prefix="{}_{}".format(
                        str(self)[:MAX_LEN_IDENTIFIER], date_str()),
                    dir=self.local_dir)
            elif not os.path.exists(self.logdir):
                os.makedirs(self.logdir)

            self.result_logger = UnifiedLogger(
                self.config,
                self.logdir,
                upload_uri=self.upload_dir,
                loggers=self.loggers,
                sync_function=self.sync_function)

    def update_resources(self, cpu, gpu, **kwargs):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")
        self.resources = Resources(cpu, gpu, **kwargs)

    def sync_logger_to_new_location(self, worker_ip):
        """Updates the logger location.

        Also pushes logdir to worker_ip, allowing for cross-node recovery.
        """
        if self.result_logger:
            self.result_logger.sync_results_to_new_location(worker_ip)

    def close_logger(self):
        """Close logger."""

        if self.result_logger:
            self.result_logger.close()
            self.result_logger = None

    def write_error_log(self, error_msg):
        if error_msg and self.logdir:
            self.num_failures += 1  # may be moved to outer scope?
            error_file = os.path.join(self.logdir,
                                      "error_{}.txt".format(date_str()))
            with open(error_file, "w") as f:
                f.write(error_msg)
            self.error_file = error_file

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""

        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            if result[criteria] >= stop_value:
                return True

        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}

        if result.get(DONE) and self.checkpoint_at_end:
            return True

        if self.checkpoint_freq:
            return result.get(TRAINING_ITERATION,
                              0) % self.checkpoint_freq == 0
        else:
            return False

    def progress_string(self):
        """Returns a progress message for printing out to the console."""

        if not self.last_result:
            return self._status_string()

        def location_string(hostname, pid):
            if hostname == os.uname()[1]:
                return 'pid={}'.format(pid)
            else:
                return '{} pid={}'.format(hostname, pid)

        pieces = [
            '{}'.format(self._status_string()), '[{}]'.format(
                self.resources.summary_string()), '[{}]'.format(
                    location_string(
                        self.last_result.get(HOSTNAME),
                        self.last_result.get(PID))), '{} s'.format(
                            int(self.last_result.get(TIME_TOTAL_S)))
        ]

        if self.last_result.get(TRAINING_ITERATION) is not None:
            pieces.append('{} iter'.format(
                self.last_result[TRAINING_ITERATION]))

        if self.last_result.get(TIMESTEPS_TOTAL) is not None:
            pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))

        if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
            pieces.append('{} rew'.format(
                format(self.last_result[EPISODE_REWARD_MEAN], '.3g')))

        if self.last_result.get(MEAN_LOSS) is not None:
            pieces.append('{} loss'.format(
                format(self.last_result[MEAN_LOSS], '.3g')))

        if self.last_result.get(MEAN_ACCURACY) is not None:
            pieces.append('{} acc'.format(
                format(self.last_result[MEAN_ACCURACY], '.3g')))

        return ', '.join(pieces)

    def _status_string(self):
        return "{}{}".format(
            self.status, ", {} failures: {}".format(self.num_failures,
                                                    self.error_file)
            if self.error_file else "")

    def has_checkpoint(self):
        return self._checkpoint.value is not None

    def clear_checkpoint(self):
        self._checkpoint.value = None

    def should_recover(self):
        """Returns whether the trial qualifies for restoring.

        This is if a checkpoint frequency is set and has not failed more than
        max_failures. This may return true even when there may not yet
        be a checkpoint.
        """
        return (self.checkpoint_freq > 0
                and (self.num_failures < self.max_failures
                     or self.max_failures < 0))

    def update_last_result(self, result, terminate=False):
        if terminate:
            result.update(done=True)
        if self.verbose and (terminate or time.time() - self.last_debug >
                             DEBUG_PRINT_INTERVAL):
            print("Result for {}:".format(self))
            print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
            self.last_debug = time.time()
        self.last_result = result
        self.last_update_time = time.time()
        self.result_logger.on_result(self.last_result)

    def _get_trainable_cls(self):
        return ray.tune.registry._global_registry.get(
            ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)

    def set_verbose(self, verbose):
        self.verbose = verbose

    def is_finished(self):
        return self.status in [Trial.TERMINATED, Trial.ERROR]

    def __repr__(self):
        return str(self)

    def __str__(self):
        """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.

        Can be overriden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        if self.experiment_tag:
            identifier += "_" + self.experiment_tag
        return identifier.replace("/", "_")

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a DISK checkpoint.
        """
        assert self._checkpoint.storage == Checkpoint.DISK, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["result_logger"] = None
        if self.result_logger:
            self.result_logger.flush()
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)

    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])
        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        Trial._registration_check(self.trainable_name)
        if logger_started:
            self.init_logger()
Example #3
0
class Trial(object):
    """A trial object holds the state for one model training run.

    Trials are themselves managed by the TrialRunner class, which implements
    the event loop for submitting trial runs to a Ray cluster.

    Trials start in the PENDING state, and transition to RUNNING once started.
    On error it transitions to ERROR, otherwise TERMINATED on success.
    """

    PENDING = "PENDING"
    RUNNING = "RUNNING"
    PAUSED = "PAUSED"
    TERMINATED = "TERMINATED"
    ERROR = "ERROR"

    def __init__(self,
                 trainable_name,
                 config=None,
                 trial_id=None,
                 local_dir=DEFAULT_RESULTS_DIR,
                 experiment_tag="",
                 resources=None,
                 stopping_criterion=None,
                 checkpoint_freq=0,
                 checkpoint_at_end=False,
                 keep_checkpoints_num=None,
                 checkpoint_score_attr="",
                 export_formats=None,
                 restore_path=None,
                 upload_dir=None,
                 trial_name_creator=None,
                 loggers=None,
                 sync_function=None,
                 max_failures=0):
        """Initialize a new trial.

        The args here take the same meaning as the command line flags defined
        in ray.tune.config_parser.
        """

        Trial._registration_check(trainable_name)
        # Trial config
        self.trainable_name = trainable_name
        self.config = config or {}
        self.local_dir = os.path.expanduser(local_dir)
        self.experiment_tag = experiment_tag
        self.resources = (
            resources
            or self._get_trainable_cls().default_resource_request(self.config))
        self.stopping_criterion = stopping_criterion or {}
        self.upload_dir = upload_dir
        self.loggers = loggers
        self.sync_function = sync_function
        validate_sync_function(sync_function)
        self.verbose = True
        self.max_failures = max_failures

        # Local trial state that is updated during the run
        self.last_result = {}
        self.last_update_time = -float("inf")
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_at_end = checkpoint_at_end

        self.history = []
        self.keep_checkpoints_num = keep_checkpoints_num
        self._cmp_greater = not checkpoint_score_attr.startswith("min-")
        self.best_checkpoint_attr_value = -float("inf") \
            if self._cmp_greater else float("inf")
        # Strip off "min-" from checkpoint attribute
        self.checkpoint_score_attr = checkpoint_score_attr \
            if self._cmp_greater else checkpoint_score_attr[4:]

        self._checkpoint = Checkpoint(storage=Checkpoint.DISK,
                                      value=restore_path)
        self.export_formats = export_formats
        self.status = Trial.PENDING
        self.logdir = None
        self.runner = None
        self.result_logger = None
        self.last_debug = 0
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.error_file = None
        self.num_failures = 0
        self.custom_trial_name = None

        # AutoML fields
        self.results = None
        self.best_result = None
        self.param_config = None
        self.extra_arg = None

        self._nonjson_fields = [
            "_checkpoint",
            "loggers",
            "sync_function",
            "results",
            "best_result",
            "param_config",
            "extra_arg",
        ]
        if trial_name_creator:
            self.custom_trial_name = trial_name_creator(self)

    @classmethod
    def _registration_check(cls, trainable_name):
        if not has_trainable(trainable_name):
            # Make sure rllib agents are registered
            from ray import rllib  # noqa: F401
            if not has_trainable(trainable_name):
                raise TuneError("Unknown trainable: " + trainable_name)

    @classmethod
    def generate_id(cls):
        return binary_to_hex(_random_string())[:8]

    def init_logger(self):
        """Init logger."""

        if not self.result_logger:
            if not os.path.exists(self.local_dir):
                os.makedirs(self.local_dir)
            if not self.logdir:
                self.logdir = tempfile.mkdtemp(prefix="{}_{}".format(
                    str(self)[:MAX_LEN_IDENTIFIER], date_str()),
                                               dir=self.local_dir)
            elif not os.path.exists(self.logdir):
                os.makedirs(self.logdir)

            self.result_logger = UnifiedLogger(
                self.config,
                self.logdir,
                upload_uri=self.upload_dir,
                loggers=self.loggers,
                sync_function=self.sync_function)

    def update_resources(self, cpu, gpu, **kwargs):
        """EXPERIMENTAL: Updates the resource requirements.

        Should only be called when the trial is not running.

        Raises:
            ValueError if trial status is running.
        """
        if self.status is Trial.RUNNING:
            raise ValueError("Cannot update resources while Trial is running.")
        self.resources = Resources(cpu, gpu, **kwargs)

    def sync_logger_to_new_location(self, worker_ip):
        """Updates the logger location.

        Also pushes logdir to worker_ip, allowing for cross-node recovery.
        """
        if self.result_logger:
            self.result_logger.sync_results_to_new_location(worker_ip)

    def close_logger(self):
        """Close logger."""

        if self.result_logger:
            self.result_logger.close()
            self.result_logger = None

    def write_error_log(self, error_msg):
        if error_msg and self.logdir:
            self.num_failures += 1  # may be moved to outer scope?
            error_file = os.path.join(self.logdir,
                                      "error_{}.txt".format(date_str()))
            with open(error_file, "w") as f:
                f.write(error_msg)
            self.error_file = error_file

    def should_stop(self, result):
        """Whether the given result meets this trial's stopping criteria."""

        if result.get(DONE):
            return True

        for criteria, stop_value in self.stopping_criterion.items():
            if criteria not in result:
                raise TuneError(
                    "Stopping criteria {} not provided in result {}.".format(
                        criteria, result))
            if result[criteria] >= stop_value:
                return True

        return False

    def should_checkpoint(self):
        """Whether this trial is due for checkpointing."""
        result = self.last_result or {}

        if result.get(DONE) and self.checkpoint_at_end:
            return True

        if self.checkpoint_freq:
            return result.get(TRAINING_ITERATION,
                              0) % self.checkpoint_freq == 0
        else:
            return False

    def progress_string(self):
        """Returns a progress message for printing out to the console."""

        if not self.last_result:
            return self._status_string()

        def location_string(hostname, pid):
            if hostname == os.uname()[1]:
                return 'pid={}'.format(pid)
            else:
                return '{} pid={}'.format(hostname, pid)

        pieces = [
            '{}'.format(self._status_string()),
            '[{}]'.format(self.resources.summary_string()), '[{}]'.format(
                location_string(self.last_result.get(HOSTNAME),
                                self.last_result.get(PID))),
            '{} s'.format(int(self.last_result.get(TIME_TOTAL_S)))
        ]

        if self.last_result.get(TRAINING_ITERATION) is not None:
            pieces.append('{} iter'.format(
                self.last_result[TRAINING_ITERATION]))

        if self.last_result.get(TIMESTEPS_TOTAL) is not None:
            pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))

        if self.last_result.get(EPISODE_REWARD_MEAN) is not None:
            pieces.append('{} rew'.format(
                format(self.last_result[EPISODE_REWARD_MEAN], '.3g')))

        if self.last_result.get(MEAN_LOSS) is not None:
            pieces.append('{} loss'.format(
                format(self.last_result[MEAN_LOSS], '.3g')))

        if self.last_result.get(MEAN_ACCURACY) is not None:
            pieces.append('{} acc'.format(
                format(self.last_result[MEAN_ACCURACY], '.3g')))

        return ', '.join(pieces)

    def _status_string(self):
        return "{}{}".format(
            self.status, ", {} failures: {}".format(
                self.num_failures, self.error_file) if self.error_file else "")

    def has_checkpoint(self):
        return self._checkpoint.value is not None

    def clear_checkpoint(self):
        self._checkpoint.value = None

    def should_recover(self):
        """Returns whether the trial qualifies for restoring.

        This is if a checkpoint frequency is set and has not failed more than
        max_failures. This may return true even when there may not yet
        be a checkpoint.
        """
        return (self.checkpoint_freq > 0
                and (self.num_failures < self.max_failures
                     or self.max_failures < 0))

    def update_last_result(self, result, terminate=False):
        if terminate:
            result.update(done=True)
        if self.verbose and (terminate or time.time() - self.last_debug >
                             DEBUG_PRINT_INTERVAL):
            print("Result for {}:".format(self))
            print("  {}".format(pretty_print(result).replace("\n", "\n  ")))
            self.last_debug = time.time()
        self.last_result = result
        self.last_update_time = time.time()
        self.result_logger.on_result(self.last_result)

    def compare_checkpoints(self, attr_mean):
        """Compares two checkpoints based on the attribute attr_mean param.
        Greater than is used by default. If  command-line parameter
        checkpoint_score_attr starts with "min-" less than is used.

        Arguments:
            attr_mean: mean of attribute value for the current checkpoint

        Returns:
            True: when attr_mean is greater than previous checkpoint attr_mean
                  and greater than function is selected
                  when attr_mean is less than previous checkpoint attr_mean and
                  less than function is selected
            False: when attr_mean is not in alignment with selected cmp fn
        """
        if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
            return True
        elif (not self._cmp_greater
              and attr_mean < self.best_checkpoint_attr_value):
            return True
        return False

    def _get_trainable_cls(self):
        return ray.tune.registry._global_registry.get(
            ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)

    def set_verbose(self, verbose):
        self.verbose = verbose

    def is_finished(self):
        return self.status in [Trial.TERMINATED, Trial.ERROR]

    def __repr__(self):
        return str(self)

    def __str__(self):
        """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.

        Can be overriden with a custom string creator.
        """
        if self.custom_trial_name:
            return self.custom_trial_name

        if "env" in self.config:
            env = self.config["env"]
            if isinstance(env, type):
                env = env.__name__
            identifier = "{}_{}".format(self.trainable_name, env)
        else:
            identifier = self.trainable_name
        if self.experiment_tag:
            identifier += "_" + self.experiment_tag
        return identifier.replace("/", "_")

    def __getstate__(self):
        """Memento generator for Trial.

        Sets RUNNING trials to PENDING, and flushes the result logger.
        Note this can only occur if the trial holds a DISK checkpoint.
        """
        assert self._checkpoint.storage == Checkpoint.DISK, (
            "Checkpoint must not be in-memory.")
        state = self.__dict__.copy()
        state["resources"] = resources_to_json(self.resources)

        for key in self._nonjson_fields:
            state[key] = binary_to_hex(cloudpickle.dumps(state.get(key)))

        state["runner"] = None
        state["result_logger"] = None
        if self.result_logger:
            self.result_logger.flush()
            state["__logger_started__"] = True
        else:
            state["__logger_started__"] = False
        return copy.deepcopy(state)

    def __setstate__(self, state):
        logger_started = state.pop("__logger_started__")
        state["resources"] = json_to_resources(state["resources"])
        if state["status"] == Trial.RUNNING:
            state["status"] = Trial.PENDING
        for key in self._nonjson_fields:
            state[key] = cloudpickle.loads(hex_to_binary(state[key]))

        self.__dict__.update(state)
        Trial._registration_check(self.trainable_name)
        if logger_started:
            self.init_logger()